例1
假设我们函数是,我们要求对的导数,应该如何用pytorch来求解。
上面的计算图表示,先计算括号内部的加法,再计算乘法。计算顺序是:,,。
用代码来表示:
import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.mul(w, x) # a = w * x
b = torch.add(w, 1) # b = w + 1
y = torch.mul(a, b) # y = a * b
# a.retain_grad() 用于保持非叶子节点的梯度
y.backward() #反向传播求导
print(a.grad_fn)
print(x.grad)
print(w.grad)
print(a.is_leaf, b.is_leaf, y.is_leaf, w.is_leaf, x.is_leaf)
'''
<MulBackward0 object at 0x7fe3280a1160>
tensor([2.])
tensor([6.])
False False False True True
'''
pytorch创建的计算图分为叶子节点和非叶子节点。每一个节点都是一个tensor, tensor具有属性requires_grad
(记录该tensor是否要求梯度),is_leaf
(记录是否是叶子节点), grad_fn
(记录创建该tensor的方法)。
requires_grad
: 如果创建这个tensor的输入中,至少有一个tensor的requires_grad=True
,那么新创建的这个tensor的requires_grad=True
。在上面这个例子中,是由相加得到的,的requires_grad=True
,所以的requires_grad=True
。is_leaf
:叶子节点是你所有手动创建的tensor,在这个例子中,叶子节点是。注意,叶子节点的requires_grad
并不一定是True。在本例中,也是叶子节点,但是我们也可以将其requires_grad=False
。再比如你创建的神经网络的参数是叶子节点,其requires_grad=True
,比如你创建的模型的输入,虽然requires_grad=False
,但是也是叶子节点。grad_fn
:记录创建这个tensor的方法,比如本例中,的grad_fn
就是MulBackward0
,表示由乘法得到。
还有一点需要注意,只有叶子节点的梯度在backward()之后是不被销毁的,非叶子节点的梯度在backward()之后是被销毁的,可以在y.backward() 之后打印的梯度试试。如果想保持飞叶子节点的梯度,在backward()之前,使用a.retain_grad()
。
例2
上面这个例子我们看到,也是叶子节点,这似乎有点难以理解,如果我们再创建一个tensor,,那么是否是叶子节点?实践可知,c.is_leaf=False
。
例3
我们再举一个例子。观察下面的代码:
mport torch
w = torch.tensor([1.], requires_grad=False)
x = torch.tensor([2.], requires_grad=True)
print(x.is_leaf)
x = x + 1
# x.data.copy_(x.data+1)
print(x.is_leaf)
print(x.requires_grad)
# w = w+w
a = torch.add(w, x) # a = w + x
b = torch.add(w, 1) # b = w + 1
y = torch.mul(a, b) # y = a * b
y.backward() #对y进行反向传播
print(x.grad)
'''
True
False
True
None
'''
上面代码表达的意思是,先对进行自加1再算。但是打印结果发现,在自加1之后,就不是叶子节点了。
通过例2和例3是否可得到结论,对于requires_grad=True
的叶子节点来说,对其做任何改动,得到的新的tensor都不是叶子节点了。这个结论为时过早,将x=x+1替换为x.data.copy_(x.data+1),可发现,x还是叶子节点。所以如果想对叶子节点的值进行改变,应该用copy_
函数,而不是直接用等号改变。
tensor.data.copy_()
:只会改变tensor的data值,而不会改变is_leaf, requires_grad
等其他属性值。
参考:pytorch——计算图与动态图机制