pytorch中如何只讓指定變量向后傳播梯度?
創(chuàng)新互聯(lián)長(zhǎng)期為上千家客戶提供的網(wǎng)站建設(shè)服務(wù),團(tuán)隊(duì)從業(yè)經(jīng)驗(yàn)10年,關(guān)注不同地域、不同群體,并針對(duì)不同對(duì)象提供差異化的產(chǎn)品和服務(wù);打造開放共贏平臺(tái),與合作伙伴共同營造健康的互聯(lián)網(wǎng)生態(tài)環(huán)境。為響水企業(yè)提供專業(yè)的成都做網(wǎng)站、網(wǎng)站建設(shè),響水網(wǎng)站改版等技術(shù)服務(wù)。擁有10多年豐富建站經(jīng)驗(yàn)和眾多成功案例,為您定制開發(fā)。(或者說如何讓指定變量不參與后向傳播?)
有以下公式,假如要讓L對(duì)xvar求導(dǎo):
(1)中,L對(duì)xvar的求導(dǎo)將同時(shí)計(jì)算out1部分和out2部分;
(2)中,L對(duì)xvar的求導(dǎo)只計(jì)算out2部分,因?yàn)閛ut1的requires_grad=False;
(3)中,L對(duì)xvar的求導(dǎo)只計(jì)算out1部分,因?yàn)閛ut2的requires_grad=False;
驗(yàn)證如下:
#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Wed May 23 10:02:04 2018 @author: hy """ import torch from torch.autograd import Variable print("Pytorch version: {}".format(torch.__version__)) x=torch.Tensor([1]) xvar=Variable(x,requires_grad=True) y1=torch.Tensor([2]) y2=torch.Tensor([7]) y1var=Variable(y1) y2var=Variable(y2) #(1) print("For (1)") print("xvar requres_grad: {}".format(xvar.requires_grad)) print("y1var requres_grad: {}".format(y1var.requires_grad)) print("y2var requres_grad: {}".format(y2var.requires_grad)) out1 = xvar*y1var print("out1 requres_grad: {}".format(out1.requires_grad)) out2 = xvar*y2var print("out2 requres_grad: {}".format(out2.requires_grad)) L=torch.pow(out1-out2,2) L.backward() print("xvar.grad: {}".format(xvar.grad)) xvar.grad.data.zero_() #(2) print("For (2)") print("xvar requres_grad: {}".format(xvar.requires_grad)) print("y1var requres_grad: {}".format(y1var.requires_grad)) print("y2var requres_grad: {}".format(y2var.requires_grad)) out1 = xvar*y1var print("out1 requres_grad: {}".format(out1.requires_grad)) out2 = xvar*y2var print("out2 requres_grad: {}".format(out2.requires_grad)) out1 = out1.detach() print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad)) L=torch.pow(out1-out2,2) L.backward() print("xvar.grad: {}".format(xvar.grad)) xvar.grad.data.zero_() #(3) print("For (3)") print("xvar requres_grad: {}".format(xvar.requires_grad)) print("y1var requres_grad: {}".format(y1var.requires_grad)) print("y2var requres_grad: {}".format(y2var.requires_grad)) out1 = xvar*y1var print("out1 requres_grad: {}".format(out1.requires_grad)) out2 = xvar*y2var print("out2 requres_grad: {}".format(out2.requires_grad)) #out1 = out1.detach() out2 = out2.detach() print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad)) L=torch.pow(out1-out2,2) L.backward() print("xvar.grad: {}".format(xvar.grad)) xvar.grad.data.zero_()