算是动态图的一个坑吧。记录loss信息的时候直接使用了输出的Variable。 应该不止我经历过这个吧... 久久不用又会不小心掉到这个坑里去...
for data, label in trainloader:
......
out = model(data)
loss = criterion(out, label)
loss_sum += loss # <--- 这里
......
运行着就发现显存炸了
观察了一下发现随着每个batch显存消耗在不断增大.. 参考了别人的代码发现那句loss一般是这样写 /(ㄒoㄒ)/~~
loss_sum += loss.data[0]
这是因为输出的loss的数据类型是Variable。
而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。
如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大~那么消耗的显存也就越来越大~~
总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算...
包括数据的输入时候,如果“过早”把数据丢到Variable里面去,那么可能也会被系统视为网络的一部分。所以,要投入的时候再把数据丢到Variable里面去吧~ 题外话
想更多感受动态图的话,可以通过Variable的grad_fun来观察到该Variable是通过什么运算得到的(前提是前面的Variable的required_grad置为True)。 大概是这样
>> >> z = x + y
>> z.grad_fn
out:
<AddBackward1 at 0x107286240>
推荐阅读:
【深度学习实战】pytorch中如何处理RNN输入变长序列padding 【机器学习基本理论】详解最大后验概率估计(MAP)的理解 【区块链】区块链最通俗入门教程
欢迎关注公众号学习交流~