真实的国产乱ⅩXXX66竹夫人,五月香六月婷婷激情综合,亚洲日本VA一区二区三区,亚洲精品一区二区三区麻豆

成都創(chuàng)新互聯(lián)網(wǎng)站制作重慶分公司

pytorch如何在網(wǎng)絡(luò)中添加可訓(xùn)練參數(shù)和修改預(yù)訓(xùn)練權(quán)重文件-創(chuàng)新互聯(lián)

這篇文章主要介紹了pytorch如何在網(wǎng)絡(luò)中添加可訓(xùn)練參數(shù)和修改預(yù)訓(xùn)練權(quán)重文件,具有一定借鑒價(jià)值,感興趣的朋友可以參考下,希望大家閱讀完這篇文章之后大有收獲,下面讓小編帶著大家一起了解一下。

網(wǎng)站的建設(shè)成都創(chuàng)新互聯(lián)專(zhuān)注網(wǎng)站定制,經(jīng)驗(yàn)豐富,不做模板,主營(yíng)網(wǎng)站定制開(kāi)發(fā).小程序定制開(kāi)發(fā),H5頁(yè)面制作!給你煥然一新的設(shè)計(jì)體驗(yàn)!已為工商代辦等企業(yè)提供專(zhuān)業(yè)服務(wù)。

實(shí)踐中,針對(duì)不同的任務(wù)需求,我們經(jīng)常會(huì)在現(xiàn)成的網(wǎng)絡(luò)結(jié)構(gòu)上做一定的修改來(lái)實(shí)現(xiàn)特定的目的。

假如我們現(xiàn)在有一個(gè)簡(jiǎn)單的兩層感知機(jī)網(wǎng)絡(luò):

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

現(xiàn)在想在前向傳播時(shí),在relu之后給x乘以一個(gè)可訓(xùn)練的系數(shù),只需要在__init__函數(shù)中添加一個(gè)nn.Parameter類(lèi)型變量,并在forward函數(shù)中乘以該變量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter變量和Variable變量的操作大致相同,但是不能手動(dòng)調(diào)用.cuda()方法將其加載在GPU上,事實(shí)上它會(huì)自動(dòng)在GPU上加載,可以通過(guò)model.state_dict()或者model.named_parameters()函數(shù)查看現(xiàn)在的全部可訓(xùn)練參數(shù)(包括通過(guò)繼承得到的父類(lèi)中的參數(shù)):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

輸出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

這個(gè)參數(shù)會(huì)在反向傳播時(shí)與原有變量同時(shí)參與更新,這就達(dá)到了添加可訓(xùn)練參數(shù)的目的。

如果我們有原先網(wǎng)絡(luò)的預(yù)訓(xùn)練權(quán)重,現(xiàn)在添加了一個(gè)新的參數(shù),原有的權(quán)重文件自然就不能加載了,我們需要修改原權(quán)重文件,在其中添加我們的新變量的初始值。

調(diào)用model.state_dict查看我們添加的參數(shù)在參數(shù)字典中的完整名稱(chēng),然后打開(kāi)原先的權(quán)重文件:

a = torch.load("OldWeights.pth") a是一個(gè)collecitons.OrderedDict類(lèi)型變量,也就是一個(gè)有序字典,直接將新參數(shù)名稱(chēng)和初始值作為鍵值對(duì)插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

現(xiàn)在權(quán)重就可以加載在修改后的模型上了。

感謝你能夠認(rèn)真閱讀完這篇文章,希望小編分享的“pytorch如何在網(wǎng)絡(luò)中添加可訓(xùn)練參數(shù)和修改預(yù)訓(xùn)練權(quán)重文件”這篇文章對(duì)大家有幫助,同時(shí)也希望大家多多支持創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司,關(guān)注創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司行業(yè)資訊頻道,更多相關(guān)知識(shí)等著你來(lái)學(xué)習(xí)!

另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無(wú)理由+7*72小時(shí)售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、網(wǎng)站設(shè)計(jì)器、香港服務(wù)器、美國(guó)服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡(jiǎn)單易用、服務(wù)可用性高、性價(jià)比高”等特點(diǎn)與優(yōu)勢(shì),專(zhuān)為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場(chǎng)景需求。


分享題目:pytorch如何在網(wǎng)絡(luò)中添加可訓(xùn)練參數(shù)和修改預(yù)訓(xùn)練權(quán)重文件-創(chuàng)新互聯(lián)
網(wǎng)頁(yè)路徑:http://weahome.cn/article/dgjigc.html

其他資訊

在線咨詢

微信咨詢

電話咨詢

028-86922220(工作日)

18980820575(7×24)

提交需求

返回頂部