文章目录

  • ​​Pytorch​​
  • ​​torch.nn​​
  • ​​torch.nn.modules​​

Pytorch

torch.nn

​torch.nn​​​包含两个比较重要的模块​​torch.nn.modules​​​和​​torch.nn.functional​​​,如果要扩展layer,建议使用​​modules​​​,因为modules保存着参数和buffer,如果不需要参数,建议使用​​funtional​​(如激活函数、pooling)。

torch.nn.modules

所有网络的基类,模型也应该继承这个类。

  • 可以将子模块赋予模块属性并调用如下​​self.conv1 = nn.Conv2d(1, 20, 5)​​。
  • 可以通过​​add_module(name, module)​​​增加新的​​child module​
  • ​children()​​返回当前模型 子模块的迭代器
  • ​.modules()​​返回一个包含 当前模型 所有模块的迭代器。
  • ​parameters(memo=None)​​返回一个 包含模型所有参数 的迭代器。
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)# submodule: Conv2d
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

未完待续…