Pytorch的特色之一是提供构建动态计算图的框架,这样网络结构就不是一成不变的了,甚至可以在运行时修正它们。
【Tensor】
Tensor是Pytorch中的基本对象,意思为张量,表示多维的矩阵,是Pytorch中的基本操作对象之一,Tensor的声明和获取size如下:

import torch
x = torch.Tensor(5,3)
x.size()

Tensor与Numpy的array可以进行相互啊转换,转换函数为:

x = torch.rand(5,3)
# torch -> numpy:
y = x.numpy()

# numpy -> torch:
x = y.torch.from_numpy()

【Variable】
Variable是Pytorch的一个基本对象,可以把它理解为是对Tensor的一个封装,Variable用于放入计算图中以进行前向传播、反向传播和自动求导。在一个Variable中有三个重要属性:data、grad和creator。其中,data表示包含的Tensor数据部分;grad表示传播方向的梯度,这个属性是延迟分配的,而且仅允许进行一次;creator表示创建这个Variable的Function的引用,该引用用于回溯整个创建链路。

from torch.autograd import Variable
x = torch.rand(4)
x = variable(x, requires_grad = true)
y = x * 3
grad_variables = torch.floattensor([1,2,3,4])
y.backward(grad_variables)
print(x.grad)

-------tensor([3., 6., 9., 12.])--------

对于y.backward(grad_variables),grad_variables就是y求导时的梯度参数,由于autograd仅用于标量,因此当y不是标量且在声明时使用了requires_grad=true时,必须指定grad_variables参数,在完成原始的反向传播后得到的梯度会用这个grad_variables进行修正,然后将结果保存至Variable的grad中,grad_variables的长度要与y一致。在深度学习中求导与梯度有关,因此grad_variables一般会定义为类似[1, 0.1, 0.01, 0.001]表示梯度的方向,取较小的值不会对求导效率有影响。

【CUDA】
如果安装了支持CUDA版本的Pytorch,就可以启用显卡运算了。torch.cuda用于设置和运行CUDA操作,它会记录当前选择的GPU,并且分配所有CUDA张量将默认在上面创建,可以使用torch.cuda.device上下文管理器更改所选设备。不过,一旦张量被分配,可以直接对其进行操作,而不考虑所选择的设备,结果将始终放在与张量相关的设备上。默认情况下, 不支持跨GPU操作,唯一的例外时copy_()。除非启用对等存储器的访问,否则对于分布不同设备上的张量,任何启动操作的尝试豆浆引发错误。

torch.cuda.is_available()

【模型的保存与加载】
Python中对于模型数据的保存和加载操作都是引用Python内置的pickle包,使用pickle.dump()和pickle.load()方法。

torch.save(model, 'model.pkl')  #保存整个模型
model = torch.load('model.pkl')  #加载整个模型
torch.save(alexnet.state_dict(), 'params.pkl')  #保存网络中的参数
alexnet.load_state_dict(torch.load('params.pkl'))  #加载网络中的参数

在torchvision.models模块里,Pytorch提供了一些常用的模型:AlexNet、VGG、ResNet等,可以使用torch.util.model_zoo来预加载它们,具体设置通过参数pretrained=True来实现。

import torchvision.models as models
alexnet = models.alexnet(pretrained=True)

加载这类预训练模型的过程中,还可以进行微处理。

pretrained_dict = model_zoo.load_url(model_url['resnet134'])
model_dict = model.state_dict()
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict}  #将pretrained_dict里不属于model_dict的键剔除掉
model_dict.update(pretrained_dict) #更新现有的model_dict
model.load_state_dict(model_dict)