Pytorch是如果为函数中的每个变量求导的呢?这就引入了一个特殊的概念:计算图

 

假如我们有一个函数: z = 2*a + 2*b

我们不妨设:               f1=2*a,   f2=2*b

那么:                          z = f1 + f2

其函数的内部运算过程可以用下图表示出来:

PyTorch 计算量统计 pytorch查看计算图_PyTorch 计算量统计

 

 

当我们进行前向传播时(即通过自变量a,b的值计算函数值z时),pytorch会在内部存储结构中自动的构建出如上图这样的计算图,第一层的第一个圈为2,由于其是常数,没有梯度,故我们将其画成白色,a和b都是变量,我们反向传播时都要对其求导(故在创建a,b时我们都会设置requires_grad=True),故将其涂成绿色,第二层的两个圈涂成蓝色,这里涂成蓝色的意思是这两个圈表示的是操作,由图中我们可以看出

1. 常数a会和变量a相乘,相乘的结果会存在f1中

2. 变量a会和变量b相乘,相乘的结果会存在f2中

3.f1会和f2相加,相加的结果就是函数z

 

我们从z开始往回走,每一条线都可表示该线上的输出(该条线段右边的圈)对于该线上的输入(该条线段左边的圈)的求导,如图

PyTorch 计算量统计 pytorch查看计算图_子节点_02

 

 

 如果我们想要函数关于某个变量的导数,我们只需观察终点到这个变量的路径,同一条路径上的导数相乘,如果存在多条路径可到达,那么就将这些路径相加。

比如z到a共有两条路径:z先是到f1,f1再到a,那么我们就得到  (z对f1的导数)* (f1对a的导数)

                  还有一条是:z先是到f2,f2再到a,那么我们就得到  (z对f2的导数)* (f2对a的导数)

 共有两条路径,那么z整个对a的导数就是:    (z对f1的导数)* (f1对a的导数)   +    (z对f2的导数)* (f2对a的导数),

按照上述规则,可得:

    

PyTorch 计算量统计 pytorch查看计算图_反向传播_03

 

 

注意,在python代码中当我们创建变量a,b并将其requires_grad设置为True后,由a和b复合运算产生的新变量(如f1和f2)乃至往后所有基于a,b运算产生的变量,其requires_grad都会自动的默认为True,我们不用显式的调用函数去设置这些变量的requires_grad为True。

 

如果我们希望某些计算不在计算图中(因为计算图主要是用来做反向传播的,而有些变量我们并不要其导数,我们只要正向的计算过程,那么我们就可以把这些变量从计算图中分离出来,以加快反向传播的速度)那么分离某些计算就有两种方法:

方法1:使用with torch.no_grad(),在with torch.no_grad()代码块下,所有创建的变量requires_grad都会默认设置为False

 

方法2:调用tensor变量的detach()方法,调用后,该变量的requires_grad会设置为False

 

具体代码如下

import torch

a = torch.tensor(3. , requires_grad=True)
b = torch.tensor(4. , requires_grad=True)

f1 = 2 * a
f2 = a * b
z = f1 + f2

print(type(f1))           #f1由常量和张量a运算构成,故f1也是张量,输出为tensor
print(f1.requires_grad)   #输出为True,因为默认设置了requires_grad为True
print(f2.requires_grad)   #输出为True,原因与f1同理
z.backward()
print(f1.grad)            #输出为None,这里我怀疑反向传播只会计算叶子节点变量的梯度,而对中间变量的梯度会计算(因为叶子节点的梯度需要中间节点的梯度)但中间节点计算后的梯度并不会保留

print(a.grad)      #输出为6.0
print(b.grad)      #输出为3.0


#分离计算的方法:

#法1
with torch.no_grad():
    f3 = a * b
    print(f3.requires_grad)    #输出为False,因为f3变量是在with torch.no_grad()代码块中创建,因此其requires_grad默认为True
    print(f2.requires_grad)    #输出为True,因为f2变量是之间已经创建好的,因此不受with torch.no_grad()代码块影响

#法2
a1 = a.detach()
print(a.requires_grad)        #输出为True,a并不受影响
print(a1.requires_grad)       #输出为False,调用detach方法相当于a1把这个变量,从计算图中分离出来,因此其requires_grad默认为False

 

 

问题:既然分离计算后变量的特点都是requires_grad为False,那么我们直接用requires_grad_(False)函数将变量的 requires_grad设置为False是不是等价于将该变量分离计算呢?

 

 

TRANSLATE with x

English

Arabic

Hebrew

Polish

Bulgarian

Hindi

Portuguese

Catalan

Hmong Daw

Romanian

Chinese Simplified

Hungarian

Russian

Chinese Traditional

Indonesian

Slovak

Czech

Italian

Slovenian

Danish

Japanese

Spanish

Dutch

Klingon

Swedish

English

Korean

Thai

Estonian

Latvian

Turkish

Finnish

Lithuanian

Ukrainian

French

Malay

Urdu

German

Maltese

Vietnamese

Greek

Norwegian

Welsh

Haitian Creole

Persian

 

 

TRANSLATE with

COPY THE URL BELOW

Back

EMBED THE SNIPPET BELOW IN YOUR SITE


Enable collaborative features and customize widget: Bing Webmaster Portal

Back