训练过程中保存模型参数,就不怕断电了——沃资基·索德

防止断电重跑,另一方面可以观察不同迭代次数模型的表现;在训练完成以后,我们需要保存模型参数值用于后续的测试过程。所以,保存的对象包含网络参数值、优化器参数值、epoch值等等。

一、定义一个容易识别的网络

在正式介绍模型的保存和加载之前,我们首先定义一个基本的网络Net,它只包含一个全连接层:


class


将全连接的权重w和偏差b分别设置为10和1,全连接的计算方式如下:


pytorch tensor保存为RGB图像 pytorch保存参数_全连接


假设输入x=1,可以知道y值为11:


pytorch tensor保存为RGB图像 pytorch保存参数_8145v5 参数_02


测试一下输出是不是11,代码如下:


x


输出:tensor([[11.]], grad_fn=<AddmmBackward>),说明上述计算是正确的。不采用参数随机初始化,而是用特殊的数值初始化,是因为我们希望重载模型的时候,能够从特殊数值一眼判断出保存和重载过程是否正确,也可以把权重设置为一张图片数值,然后判断加载的参数值能不能恢复原图。

二、保存Net的参数值

Net的参数值存储在其state_dict(状态字典)属性中,我们查看一下net的state_dict包含哪些参数:


print


我们将会得到net包含的所有参数名称与参数值


pytorch tensor保存为RGB图像 pytorch保存参数_加载_03


包含一个weight和一个bias,对应的值分别是10和1,和我们之前定义的全连接层一致。我们需要保存的就是这个state_dict,保存的函数为“torch.save()”,参数是我们需要保存的dict和存储路径


torch.save(obj=net.state_dict(), f="models/net.pth")


现在,同级目录models下将会出现net.pth文件,pth文件中的内容就是net的参数名称和值对应的state_dict,如下:


pytorch tensor保存为RGB图像 pytorch保存参数_加载_04


三、加载Net参数值并用于新的模型

区别仅仅是Model参数初始值和Net不同,代码如下:


class


这里将Model的初始值权重w和偏差都设置为0,查看其state_dict:


model


得到的w和b值与预期相同,均为0,如下:


pytorch tensor保存为RGB图像 pytorch保存参数_全连接_05


现在,我们将model对象的参数值设置为net.pth中的值,需要使用“model.load_state_dict()”函数重置model的参数值为"torch.load(models/ net.pth)"中的参数值,如下:


model


至此,model的w和b值就不再是0了,而是net中w和b对应的10和1,如下:


pytorch tensor保存为RGB图像 pytorch保存参数_加载_03


其中参数值重载的核心函数为“model.load_state_dict()”,每个继承自nn.Module的网络都能通过这个函数设定参数值。

四、优化器与epoch的保存

保存优化器参数值和epoch值的主要目的是用于继续训练,保存的流程依旧是先“torch.save()”再“torch.load_state_dict()”,我们首先定义一个Adam优化器、一个任意的epoch值与net如下:


net = Net()
Adam = optim.Adam(params=net.parameters(), lr=0.001, betas=(0.5, 0.999))
epoch = 96


现在,创建一个字典来保存所有的对象,并用save函数保存这个字典


all_states


所有的对象都被保存到models文件夹下了:


pytorch tensor保存为RGB图像 pytorch保存参数_加载_07


可以使用load()函数把所有的对象再次提取出来:


reload_states


得到的所有参数如下:


pytorch tensor保存为RGB图像 pytorch保存参数_权重_08


五、总结

pytorch中state_dict()和load_state_dict()函数配合使用可以实现状态的获取与重载,load()和save()函数配合使用可以实现参数的存储与读取。其中最重要的部分是“字典”的概念,因为参数对象的存储是需要“名称”——“值”对应(即键值对),读取时也是通过键值对读取的。

参考:

https://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/