算是动态图的一个坑吧。记录loss信息的时候直接使用了输出的Variable。 应该不止我经历过这个吧... 久久不用又会不小心掉到这个坑里去...

for data, label in trainloader:
    ......
    out = model(data)
    loss = criterion(out, label)
    loss_sum += loss     # <--- 这里
    ......

运行着就发现显存炸了

观察了一下发现随着每个batch显存消耗在不断增大.. 参考了别人的代码发现那句loss一般是这样写 /(ㄒoㄒ)/~~

loss_sum += loss.data[0]

这是因为输出的loss的数据类型是Variable。

而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。

如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大~那么消耗的显存也就越来越大~~

总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算...

包括数据的输入时候,如果“过早”把数据丢到Variable里面去,那么可能也会被系统视为网络的一部分。所以,要投入的时候再把数据丢到Variable里面去吧~ 题外话

想更多感受动态图的话,可以通过Variable的grad_fun来观察到该Variable是通过什么运算得到的(前提是前面的Variable的required_grad置为True)。 大概是这样

>> >> z = x + y
>> z.grad_fn
out:
    <AddBackward1 at 0x107286240>

推荐阅读:

【深度学习实战】pytorch中如何处理RNN输入变长序列padding 【机器学习基本理论】详解最大后验概率估计(MAP)的理解 【区块链】区块链最通俗入门教程

      欢迎关注公众号学习交流~