前言:

接触pytorch这几个月来,一开始就对计算图的奥妙模糊不清,不知道其内部如何传播。这几天有点时间,就去翻阅了Githubpytorch Forum,还有很多个人博客(后面会给出链接),再加上自己的原本一些见解,现在对它的计算图有了更深层次的理解。

pytorch是非常好用和容易上手的深度学习框架,因为它所构建的是动态图,极大的方便了coding and debug。可是对于初学者而言,计算图是一个需要深刻理解的概念,在后期的搭建的神经网络都是基于计算图而设计的。

一、构建计算图

pytorch是动态图机制,所以在训练模型时候,每迭代一次都会构建一个新的计算图。而计算图其实就是代表程序中变量之间的关系。举个列子:

pytorch 计算图 查找变量 pytorch的计算图理解_标量

在这个运算过程就会建立一个如下的计算图:

pytorch 计算图 查找变量 pytorch的计算图理解_pytorch 计算图 查找变量_02

pytorch 计算图 查找变量 pytorch的计算图理解_动态图_03

在这个计算图中,节点就是参与运算的变量,在pytorch中是用Variable()变量来包装的,而图中的边就是变量之间的运算关系,比如:torch.mul()torch.mm()torch.div() 等等。

注意图中的 leaf_node,叶子结点就是由用户自己创建的Variable变量,在这个图中仅有a,b,c leaf_node。为什么要关注leaf_node?因为在网络backward时候,需要用链式求导法则求出网络最后输出的梯度,然后再对网络进行优化,如下就是网络的求导过程。


pytorch 计算图 查找变量 pytorch的计算图理解_迭代_04

pytorch 计算图 查找变量 pytorch的计算图理解_动态图_05

二、图的细节。

pytoch构建的计算图是动态图,为了节约内存,所以每次一轮迭代完之后计算图就被在内存释放,所以当你想要多次backward时候就会报如下错:


net = nn.Linear(3, 4)  # 一层的网络,也可以算是一个计算图就构建好了
input = Variable(torch.randn(2, 3), requires_grad=True)  # 定义一个图的输入变量
output = net(input)  # 最后的输出
loss = torch.sum(output)  # 这边加了一个sum() ,因为被backward只能是标量
loss.backward() # 到这计算图已经结束,计算图被释放了


上面这个程序是能够正常运行的,但是下面就会报错


net = nn.Linear(3, 4)
input = Variable(torch.randn(2, 3), requires_grad=True)
output = net(input)
loss = torch.sum(output)


loss.backward()
loss.backward()

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

之所以会报这个错,因为计算图在内存中已经被释放。但是,如果你需要多次backward只需要在第一次反向传播时候添加一个标识,如下:


net = nn.Linear(3, 4)
input = Variable(torch.randn(2, 3), requires_grad=True)
output = net(input)
loss = torch.sum(output)
loss.backward(retain_graph=True) # 添加retain_graph=True标识,让计算图不被立即释放
loss.backward()


这样在第一次backward之后,计算图并不会被立即释放。

读到这里,可能你对计算图中的backward还是一知半解。例如上面提过backward只能是标量。那么在实际运用中,如果我们只需要求图中某一节点的梯度,而不是整个图的,又该如何做呢?下面举个例子,列子下面会给出解释。


x = Variable(torch.FloatTensor([[1, 2]]), requires_grad=True)  # 定义一个输入变量
y = Variable(torch.FloatTensor([[3, 4],
[5, 6]]))
loss = torch.mm(x, y)    # 变量之间的运算
loss.backward(torch.FloatTensor([[1, 0]]), retain_graph=True)  # 求梯度,保留图                                    
print(x.grad.data)   # 求出 x_1 的梯度
x.grad.data.zero_()  # 最后的梯度会累加到叶节点,所以叶节点清零
loss.backward(torch.FloatTensor([[0, 1]]))   # 求出 x_2的梯度
print(x.grad.data)        # 求出 x_2的梯度


结果如下:


3  5

[torch.FloatTensor of size 1x2]