小杰深度學(xué)習(five)——正則化、神經(jīng)網(wǎng)絡(luò )的過(guò)擬合解決專(zhuān)業(yè)的方案
1. 正則化
正則化是通過(guò)在損失函數中添加正則化項來(lái)控制模型復雜度、防止過(guò)擬合的技術(shù)。機器學(xué)習中,復雜模型易過(guò)擬合(訓練表現好、新數據泛化差),正則化通過(guò)約束參數抑制模型復雜度,常見(jiàn)正則化類(lèi)型有 L1 和 L2 正則化。
損失函數的公式為:

1.當加上L1正則化后,損失函數變成:

2.當加上L2正則化后,損失函數變成:

1.1 為什么加入正則化可以解決過(guò)擬合?
加入正則化后,損失函數需同時(shí)最小化原損失(如 MSE)和正則化項。以 L1/L2 為例,正則化項迫使參數 w 盡可能小,參數 w 小和解決過(guò)擬合的關(guān)系:
- 過(guò)擬合本質(zhì):模型因參數過(guò)多或數值過(guò)大而復雜,過(guò)度捕捉訓練數據噪聲。
- 參數越小 = 模型越簡(jiǎn)單:小參數限制模型對細節的擬合能力,降低復雜度,抑制過(guò)擬合(如削弱噪聲特征的權重影響)。
總結:,正則化通過(guò) “懲罰大參數” 壓縮模型表達能力,使其從 “記憶訓練數據” 轉向 “學(xué)習通用規律”,從而提升泛化性。
舉個(gè)例子如下圖曲線(xiàn)擬合散點(diǎn):


1.2 正則化的基本思想
正則化的核心思想是在損失函數中引入與模型復雜度相關(guān)的額外項(如參數的 L1/L2 范數),通過(guò)調整正則化參數λ控制其權重,以懲罰模型復雜度,進(jìn)而避免過(guò)擬合。
1.3 為什么只在W添加懲罰?
參數 b 是偏置項,僅控制擬合曲線(xiàn)沿 y 軸平移,不改變曲線(xiàn)形狀。而正則化的目標是通過(guò)懲罰參數復雜度使曲線(xiàn)平滑,故對 b 施加正則化無(wú)實(shí)際意義。
1.4 L1正則化和L2正則化
以包含兩個(gè)參數 w1、w2 的模型為例,左圖為 L1 正則化(約束區域為菱形),右圖為 L2 正則化(約束區域為圓形),二者通過(guò)不同幾何形狀限制參數空間,實(shí)現復雜度懲罰。

其中,L1正則化的約束條件為:
![]()
L2正則化的約束條件為:

以雙參數 w1,w2 為例(彩色圈為損失等高線(xiàn),中心損失最小):
- L1 正則化(黑色菱形約束):最優(yōu)解易落在坐標軸上(如 w1=0),使參數稀疏(部分為 0),同時(shí)滿(mǎn)足損失最小化與正則約束。
- L2 正則化(黑色圓形約束):最優(yōu)解落在圓周與等高線(xiàn)切點(diǎn),參數平滑縮?。ǚ窍∈瑁?,損失函數常寫(xiě)為 Loss總=Loss原損失+1/22λw**2(含系數 1/2便于反向傳播求導計算)。
2. 基本原理
2.1 散點(diǎn)輸入
本實(shí)驗中提供了一些散點(diǎn),其分布如下圖所示:

現在我們需要根據這些散點(diǎn)來(lái)擬合一條線(xiàn),使我們可以根據這條線(xiàn)來(lái)預測新的散點(diǎn)的坐標。
2.2 定義前向模型
定義前向模型,定義一個(gè)具有三個(gè)隱藏層的網(wǎng)絡(luò ),來(lái)擬合這一條線(xiàn)。

2.3 定義損失函數和優(yōu)化器
由于是擬合線(xiàn)問(wèn)題,所以本實(shí)驗選擇的是MSE(均方誤差損失函數)。
定義好損失函數后,就需要定義反向傳播所需要的超參數了,需要對學(xué)習率和優(yōu)化器以及正則化率進(jìn)行選擇,優(yōu)化器這里選擇的是Adam。正則化率默認選擇的是0.001,不同的正則化率會(huì )對模型有較大的影響,如下圖所示。

2.4 開(kāi)始迭代
通過(guò)“開(kāi)始迭代”組件,設置模型的訓練次數。

2.5 顯示頻率設置
為了能夠更好的觀(guān)察迭代過(guò)程中的現象,可以通過(guò)“顯示頻率設置”組件來(lái)設置每隔多少次顯示一次當前的擬合狀態(tài)。

2.6 擬合線(xiàn)顯示與輸出
通過(guò)“擬合線(xiàn)顯示與輸出”組件,就可以觀(guān)察到迭代過(guò)程中曲線(xiàn)的擬合狀態(tài)了。

代碼:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
#確保初始化的值都一樣
seed=42
torch.manual_seed(seed)
# 1.創(chuàng )造數據,數據集
points = np.array([[-0.5, 7.7], [1.2, 65.8], [0.4, 39.2], [-1.4, -15.7],[1.5, 75.6], [0.4, 34.0], [0.8, 62.3]])
# 分離特征和標簽
x_train = points[:, 0]
y_train = points[:, 1]
# 2.定義前向模型
class Model(nn.Module):#定義初始化def __init__(self):super(Model,self).__init__()self.layer1=nn.Linear(1,16)self.layer2=nn.Linear(16,32)self.layer3=nn.Linear(32,16)self.layer4=nn.Linear(16,1)#前向過(guò)程def forward(self,x):#線(xiàn)性層后都跟著(zhù)激活函數,實(shí)現非線(xiàn)性化x=torch.relu(self.layer1(x))x=torch.relu(self.layer2(x))x=torch.relu(self.layer3(x))# 最后一層是擬合回歸不用激活x=self.layer4(x)return x
model=Model()
# 3.定義損失函數和優(yōu)化器
#定義學(xué)習率
lr=0.05
#定義損失函數,這里是回歸問(wèn)題用mse
cri=torch.nn.MSELoss()
#定義優(yōu)化器
#在梯度中添加正則化系數 weight_decay
optimizer=torch.optim.Adam(model.parameters(),lr=lr,weight_decay=0.2)
#7.畫(huà)圖
fig,(ax1,ax2) =plt.subplots(1,2,figsize=(12,6))
epoch_list=[]
loss_list=[]
# 4.開(kāi)始迭代
epoches=1000
for epoch in range(1,epoches+1):#數據轉化為tensorx_train_tensor=torch.tensor(x_train,dtype=torch.float32).unsqueeze(1)y_train_tensor=torch.tensor(y_train,dtype=torch.float32)#數據輸入模型前向傳播pre_result=model(x_train_tensor)#計算損失loss=cri(pre_result.squeeze(1),y_train_tensor)loss_list.append(loss.detach().numpy())epoch_list.append(epoch)#優(yōu)化更新#梯度清零optimizer.zero_grad()#反向傳播loss.backward()#參數更新optimizer.step()# 5.顯示頻率設置if epoch==1 or epoch%20==0:print(f"epoch:{epoch},loss:{loss}")# 6.繪圖ax1.cla()ax1.scatter(x_train,y_train)x_range=torch.tensor(np.linspace(-2,2,100),dtype=torch.float32)y_range=model(x_range.unsqueeze(1))ax1.plot(x_range.detach().numpy(),y_range.detach().numpy().squeeze(1))ax2.cla()ax2.plot(epoch_list,loss_list)plt.pause(1)
plt.show()
2.神經(jīng)網(wǎng)絡(luò )的過(guò)擬合解決方案
1. 過(guò)擬合解決方案
1.1 神經(jīng)網(wǎng)絡(luò )的欠擬合解決方案
欠擬合出現的原因通常是數據量不足、模型過(guò)于簡(jiǎn)單等因素導致的,那么可以通過(guò)適當的增加樣本數據集或增加神經(jīng)網(wǎng)絡(luò )隱藏層的層數來(lái)使神經(jīng)網(wǎng)絡(luò )復雜一點(diǎn)來(lái)解決欠擬合的問(wèn)題。
1.2 神經(jīng)網(wǎng)絡(luò )的過(guò)擬合解決方案
過(guò)擬合出現的原因通常是模型過(guò)于復雜或者數據量太少,導致過(guò)度學(xué)習訓練模型中的細節和噪聲。
在這里介紹兩種過(guò)擬合的解決方案。
1.2.1 正則化
L1 和 L2 正則化,其核心是通過(guò)向損失函數添加正則化項懲罰模型參數大小,抑制過(guò)擬合。
1.2.2 Dropout
Dropout 是神經(jīng)網(wǎng)絡(luò )訓練中通過(guò)隨機 “丟棄” 部分神經(jīng)元輸出(置為 0)的正則化技術(shù),可降低模型對特定神經(jīng)元的依賴(lài),減少復雜度,增強泛化能力以抑制過(guò)擬合。
以某層含 4 個(gè)神經(jīng)元的神經(jīng)網(wǎng)絡(luò )為例,來(lái)說(shuō)明 dropout,如下圖所示

若網(wǎng)絡(luò )過(guò)擬合,可對各隱藏層應用 Dropout(參數 0.5):該層每個(gè)節點(diǎn)以 50% 概率被隨機置為 0(即 類(lèi)似執行“刪除操作”),被置 0 的節點(diǎn)暫不參與前后層連接計算,以此降低模型復雜度,抑制過(guò)擬合。
如圖所示)。

假設某隱藏層輸出為 [0.7, 0.2, 0.9, 0.5],應用 dropout(參數 0.5)后,部分神經(jīng)元以 50% 概率被隨機置 0,如變?yōu)?[0.7, 0, 0.9, 0]。此過(guò)程通過(guò)減少神經(jīng)元間復雜依賴(lài),降低模型復雜度,進(jìn)而抑制過(guò)擬合。
1.2.2.1 Inverted Dropout(反向丟棄法)
訓練時(shí)按 Dropout 概率隨機舍棄神經(jīng)元,對保留神經(jīng)元的輸出按比例縮放;測試時(shí)保留所有神經(jīng)元。將權重按訓練概率縮放,此方法稱(chēng)為反向丟棄法(Inverted Dropout)
import torch
X=np.array([0.7,0.8,0.9,0.5])
# 在訓練時(shí)
drop_prob=0.5
keep_prob = 1 - drop_prob # keep_prob = 1-p即為保留率
mask = (torch.rand(X.shape) < keep_prob).float()
Y = mask * X / keep_prob
print(Y)
# 在測試時(shí)
# x
1.2.2.2 Dropout為什么能夠解決過(guò)擬合:
(1)抑制過(guò)擬合:標準神經(jīng)網(wǎng)絡(luò )易依賴(lài)特定神經(jīng)元導致過(guò)擬合,Dropout 通過(guò)隨機丟棄神經(jīng)元,迫使網(wǎng)絡(luò )學(xué)習對神經(jīng)元變化魯棒的特征,降低對訓練數據的過(guò)擬合。
(2)等效模型平均:訓練時(shí)隨機丟棄神經(jīng)元相當于訓練多個(gè)子網(wǎng)絡(luò ),測試時(shí)保留全連接結構,預測結果等效于對各子網(wǎng)絡(luò )輸出取平均,通過(guò) “綜合抵消” 減輕過(guò)擬合。
代碼展示:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
# 1.散點(diǎn)輸入
class1_points = np.array([[-0.7, 0.7], [3.9, 1.5], [1.7, 2.2], [1.9, -2.4], [0.9, 1.4], [4.2, 0.9], [1.7, 0.7], [0.2, -0.2], [3.1, -0.4],[-0.2, -0.9], [1.7, 0.2], [-0.6, -3.9], [-1.8, -4.0], [0.7, 3.8], [-0.7, -3.3], [0.8, 1.8], [-0.5, 1.5],[-0.6, -3.6], [-3.1, -3.0], [2.1, -2.5], [-2.5, -3.4], [-2.6, -0.8], [-0.2, 0.9], [-3.0, 3.3], [-0.7, 0.2],[0.3, 3.0], [0.6, 1.9], [-4.0, 2.4], [1.9, -2.2], [1.0, 0.3], [-0.9, -0.7], [-3.7, 0.6], [-2.7, -1.5], [0.9, -0.3],[0.8, -0.2], [-0.4, -4.4], [-0.3, 0.8], [4.1, 1.0], [-2.5, -3.5], [-0.8, 0.3], [0.6, 0.6], [2.6, -1.0], [1.8, 0.4],[1.5, -1.0], [3.2, 1.1], [3.3, -2.5], [-3.8, 2.5], [3.1, -0.9], [3.4, -1.1], [0.3, 0.8], [-0.1, 2.9], [-2.8, 1.9],[2.8, -3.3], [-1.0, 3.1], [-0.8, -0.6], [-2.5, -1.5], [0.3, 0.2], [-1.0, -2.9], [0.7, 0.2], [-0.5, 0.9],[-0.8, 0.7], [4.1, 0.5], [2.8, 2.3], [-3.9, 0.1], [2.2, -1.4], [-0.7, -3.5], [1.0, 1.2], [-0.7, -4.0], [1.3, 0.6],[-0.1, 3.3], [0.0, -0.3], [1.8, -3.0], [0.6, 0.0], [3.6, -2.8], [-3.9, -0.9], [-4.3, -0.9], [0.1, -0.8],[-1.6, -2.7], [-1.8, -3.3], [1.7, -3.5], [3.6, -3.1], [-2.4, 2.5], [-1.0, 1.8], [3.9, 2.5], [-3.9, -1.3],[3.4, 1.6], [-0.1, -0.6], [-3.7, -1.3], [-0.3, 3.4], [-3.7, -1.7], [4.0, 1.1], [3.4, 0.2], [0.1, -1.6],[-1.2, -0.5], [2.4, 1.7], [-4.4, -0.5], [-0.2, -3.6], [-0.8, 0.4], [-1.5, -2.2], [3.9, 2.5], [4.4, 1.4],[-3.5, -1.1], [-0.7, 1.5], [-3.0, -2.6], [0.2, -3.5], [0.0, 1.2], [-4.3, 0.1], [-1.8, 2.8], [1.1, -2.5],[0.2, 4.3], [-3.9, 2.2], [1.0, 1.6], [4.5, 0.2], [3.9, -1.6], [-0.4, -0.5], [0.3, -0.4], [-3.2, 1.7], [2.0, 4.1],[2.5, 2.2], [-1.1, -0.3], [-3.7, -1.9], [1.5, -1.1], [-2.1, -1.9], [-0.1, 4.5], [3.8, -0.3], [-0.9, -3.8],[-2.9, -1.6], [1.0, -1.2], [0.7, 0.0], [-0.8, 3.3], [-2.8, 3.1], [0.4, -3.2], [4.6, 1.0], [2.5, 3.1], [4.2, 0.8],[3.6, 1.8], [1.4, -3.0], [-0.4, -1.4], [-4.1, 1.1], [1.1, -0.2], [-2.9, -0.0], [-3.5, 1.3], [-1.4, 0.0],[-3.7, 2.2], [-2.9, 2.8], [1.7, 0.4], [-0.8, -0.6], [2.9, 1.1], [-2.3, 3.1], [-2.9, -2.0], [-2.7, -0.4],[2.6, -2.4], [-1.7, -2.8], [1.2, 3.1], [3.8, 1.3], [0.1, 1.9], [-0.5, -1.0], [0.0, -0.5], [3.9, -0.7],[-3.7, -2.5], [-3.1, 2.7], [-0.9, -1.0], [-0.7, -0.8], [-0.4, -0.1], [1.5, 1.0], [-2.6, 1.9], [-0.8, 1.7],[0.8, 1.8], [2.0, 3.6], [3.2, 1.4], [2.3, 1.4], [4.9, 0.5], [2.2, 1.8], [-1.4, -2.7], [3.1, 1.1], [-1.0, 3.8],[-0.4, -1.1], [3.3, 1.1], [2.2, -3.9], [1.0, 1.2], [2.6, 3.2], [-0.6, -3.0], [-1.9, -2.8], [1.2, -1.2],[-0.4, -2.7], [1.1, -4.3], [0.3, -0.8], [-1.0, -0.4], [-1.1, -0.2], [0.1, 1.2], [0.9, 0.6], [-2.7, 1.6],[1.0, -0.7], [0.3, -4.2], [-2.1, 3.2], [3.4, -1.2], [2.5, -4.0], [1.0, -0.8], [1.0, -0.9], [0.1, -0.6]])
class2_points = np.array([[-3.0, -3.8], [4.4, 2.5], [2.6, 4.1], [3.7, -2.7], [-3.7, -2.9], [5.3, 0.3], [3.9, 2.9], [-2.7, -4.5], [5.4, 0.2],[3.0, 4.8], [-4.2, -1.3], [-2.1, -5.4], [-3.2, -4.6], [0.7, 4.5], [-1.4, -5.7], [0.5, 5.9], [-2.1, 4.0],[-0.1, -5.1], [-3.4, -4.7], [3.3, -4.7], [-2.7, -4.1], [-4.5, -2.0], [4.3, 2.9], [-3.6, 4.0], [-0.5, 5.5],[0.2, 5.2], [5.3, -0.9], [-4.5, 3.6], [3.4, -2.8], [-3.4, -3.7], [1.6, -5.5], [-5.9, -0.1], [-4.8, -2.5],[-5.5, 0.3], [1.6, 4.4], [-0.9, -5.3], [-1.0, 5.4], [4.9, 0.8], [-3.1, -4.0], [2.3, 4.7], [4.0, -1.6], [4.9, -1.5],[4.2, -2.5], [-3.5, 3.7], [4.7, 0.5], [5.3, -2.6], [-5.0, 2.4], [5.5, -1.2], [5.6, -1.3], [3.3, -4.3], [-1.3, 4.4],[-4.1, 3.6], [3.3, -4.5], [-2.3, 5.2], [2.6, 4.6], [-4.4, -1.6], [4.7, -2.0], [-1.7, -4.9], [-5.1, -2.4],[4.5, 3.2], [-3.9, -3.4], [6.0, -0.4], [3.5, 4.3], [-4.9, -0.6], [3.3, -3.2], [-0.3, -4.8], [-1.6, -4.7],[-1.4, -4.6], [-3.1, 3.8], [-1.4, 4.9], [1.8, -4.5], [2.2, -5.5], [3.1, -3.4], [4.7, -2.8], [-5.3, -0.4],[-6.0, -0.1], [1.4, -4.5], [-3.1, -4.3], [-1.8, -5.7], [1.7, -5.6], [4.5, -3.7], [-2.6, 4.3], [-3.4, 3.4],[4.7, 3.1], [-5.2, -2.8], [5.4, 1.2], [-5.4, 1.2], [-4.9, -1.3], [-1.3, 5.6], [-4.1, -2.6], [5.0, 1.0], [5.2, 1.2],[2.4, -4.9], [-3.2, 3.8], [3.3, 3.4], [-5.5, -0.8], [0.6, -5.0], [1.2, 5.4], [-3.4, -3.3], [4.6, 2.8], [5.2, 1.7],[-4.4, -0.9], [-5.0, -1.3], [-3.1, -3.6], [-0.7, -4.5], [5.9, -0.9], [-5.1, -0.5], [-2.6, 5.2], [1.4, -4.8],[-0.7, 5.6], [-5.3, 2.1], [4.9, 2.6], [5.3, 0.9], [5.1, -1.2], [2.7, -4.4], [-2.0, -5.6], [-4.9, 3.2], [2.8, 5.3],[2.6, 3.9], [-0.0, 5.7], [-5.7, -1.8], [-1.1, -4.7], [-2.4, -3.8], [-1.1, 5.6], [5.3, -1.5], [-0.4, -5.8],[-4.5, -1.6], [-4.4, -3.7], [-4.3, 2.4], [0.1, 4.8], [-3.0, 3.8], [0.3, -5.8], [5.6, 0.5], [4.1, 3.6], [5.0, 1.5],[5.7, 1.5], [3.2, -4.1], [-1.7, -5.6], [-5.3, 0.9], [4.3, 3.0], [-5.4, 0.3], [-5.0, 0.8], [2.7, 5.1], [-5.0, 2.2],[-4.0, 3.0], [-4.4, -3.9], [-3.5, -3.9], [5.3, 1.5], [-4.2, 4.2], [-3.9, -4.0], [-4.7, -0.1], [3.7, -4.7],[-3.0, -4.7], [2.7, 4.4], [4.3, 2.0], [-3.6, -4.5], [5.5, 0.9], [-4.7, -2.8], [5.5, -2.2], [-5.1, -2.6],[-3.6, 3.1], [-3.2, -4.0], [-4.8, 1.3], [-5.5, -1.6], [4.1, -1.6], [-4.2, 3.6], [5.6, -1.4], [4.9, -3.3],[1.7, 4.9], [5.3, 2.5], [3.8, 2.8], [5.8, 0.7], [3.9, 2.6], [-2.1, -4.8], [5.2, 2.5], [-2.0, 4.3], [2.8, -4.1],[5.6, 0.8], [2.2, -5.2], [-1.1, 5.5], [4.2, 3.8], [-1.8, -5.2], [-3.4, -3.6], [3.7, -3.6], [-0.5, -4.8],[1.9, -5.6], [-1.1, 5.4], [2.3, 4.7], [0.0, -5.4], [2.1, -5.6], [4.8, -0.3], [-4.7, 2.9], [-3.8, 3.9], [0.9, -5.5],[-2.3, 3.6], [5.3, -2.5], [3.7, -4.6], [-5.0, 2.4], [0.0, -5.7], [0.2, -5.9]])
# 合并兩類(lèi)點(diǎn)
points = np.concatenate((class1_points, class2_points), axis=0)
print(points)
# 標簽0 表示類(lèi)別1 ,標簽1 表示類(lèi)別2
labels1 = np.zeros(len(class1_points))
labels2 = np.ones(len(class2_points))
labels = np.concatenate((labels1, labels2))
# 2.定義前向模型
class ModelClass(nn.Module):def __init__(self):super(ModelClass, self).__init__()self.layer1 = nn.Linear(2, 8)self.layer2 = nn.Linear(8, 32)self.layer3 = nn.Linear(32, 32)self.layer4 = nn.Linear(32, 2)self.dropout1 = nn.Dropout(p=0.1)self.dropout2 = nn.Dropout(p=0.1)self.dropout3 = nn.Dropout(p=0.1)def forward(self, x):x = torch.relu(self.layer1(x))x = self.dropout1(x)x = torch.relu(self.layer2(x))x = self.dropout2(x)x = torch.relu(self.layer3(x))x = self.dropout3(x)# 二分類(lèi)這里使用softmax加交叉熵x = torch.softmax(self.layer4(x), dim=1)return x
model = ModelClass()
# 3.定義損失函數和優(yōu)化器
lr = 0.001
# 定義交叉熵損失函數
cri = nn.CrossEntropyLoss()
# 定義優(yōu)化器
optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=0.01)
# 7.畫(huà)圖使用數據
x_min, x_max = points[:, 0].min() - 1, points[:, 0].max() + 1
y_min, y_max = points[:, 1].min() - 1, points[:, 1].max() + 1
step_size = 0.1
# 創(chuàng )建網(wǎng)格
xx, yy = np.meshgrid(np.arange(x_min, x_max, step_size),np.arange(y_min, y_max, step_size))
grid_points = np.c_[xx.ravel(), yy.ravel()]
print(grid_points)
# 7.創(chuàng )建三維圖形和右側的二維子圖
fig = plt.figure(figsize=(12, 8))
ax1_3d = fig.add_subplot(121, projection='3d')
ax2_2d = fig.add_subplot(122)
# 4.開(kāi)始迭代
epoches = 500
batch_size = 8
for epoch in range(1, epoches + 1):# 進(jìn)入訓練模式model.train()# 按照batch_size進(jìn)行迭代for batch_start in range(0, len(points), batch_size):batch_inputs = torch.tensor(points[batch_start:batch_start + batch_size, :], dtype=torch.float32)batch_labels = torch.tensor(labels[batch_start:batch_start + batch_size], dtype=torch.long)# 前向傳播outputs = model(batch_inputs)# 計算lossloss = cri(outputs, batch_labels)#添加正則化項# 反向傳播和優(yōu)化optimizer.zero_grad()loss.backward()optimizer.step()# 進(jìn)入驗證模式model.eval()# 5.顯示頻率設置fre_display = 20# 顯示與輸出if epoch % fre_display == 0 or epoch == 1:# 使用訓練好的模型預測網(wǎng)格點(diǎn)的標簽# 轉化為tensorgrid_points_tensor = torch.tensor(grid_points, dtype=torch.float32)# 模型預測pre_result = model(grid_points_tensor)# 取出第一類(lèi)的概率值pre_prob_one_class = pre_result[:, 0].reshape(xx.shape).detach().numpy()# 畫(huà)ax1_3d圖ax1_3d.cla()ax1_3d.scatter(class1_points[:, 0], class1_points[:, 1], np.ones_like(class1_points[:, 0]), c='blue',label='class 1')ax1_3d.scatter(class2_points[:, 0], class2_points[:, 1], np.zeros_like(class2_points[:, 0]), c='red',label='class 2')ax1_3d.legend()# 繪制三維表面圖ax1_3d.plot_surface(xx, yy, pre_prob_one_class, alpha=0.5)ax1_3d.contour(xx, yy, pre_prob_one_class, levels=[0.5], cmap='jet')ax1_3d.set_xlabel('feature 1')ax1_3d.set_ylabel('feature 2')ax1_3d.set_zlabel('label')ax1_3d.set_title('hyperplane')# 繪制2d圖ax2_2d.cla()ax2_2d.scatter(class1_points[:, 0], class1_points[:, 1], c='blue', label='Class 1')ax2_2d.scatter(class2_points[:, 0], class2_points[:, 1], c='red', label='Class 2')ax2_2d.contour(xx, yy, pre_prob_one_class, levels=[0.5], colors='black')plt.pause(1)
plt.show()
