在训练深度学习模型时,有时候由于各种原因(如断电、代码错误、计算机故障等),训练过程可能会被中断。为了能够从中断处继续训练,我们需要保存模型的当前状态(包括模型参数和优化器状态),并在恢复训练时加载这些状态。
在PyTorch中,我们可以使用torch.save()
和torch.load()
函数来保存和加载模型。而为了保存和加载优化器状态,我们可以使用state_dict()
函数和load_state_dict()
函数。
下面是一个示例,展示了如何在中断后继续训练一个简单的神经网络模型:
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 1) # 输入层和输出层都是一维的线性变换
def forward(self, x):
x = self.fc(x)
return x
# 创建模型实例
model = Net()
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 检查是否存在之前保存的模型状态
if torch.cuda.is_available():
checkpoint = torch.load('model_checkpoint.pth')
else:
checkpoint = torch.load('model_checkpoint.pth', map_location=torch.device('cpu'))
# 加载模型参数和优化器状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
losses = checkpoint['losses']
# 模型继续训练
for epoch in range(start_epoch, num_epochs):
# 训练代码...
# 保存模型状态
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'losses': losses
}
torch.save(checkpoint, 'model_checkpoint.pth')
在上面的代码中,我们首先定义了一个简单的神经网络模型,并创建了损失函数和优化器。然后,我们检查是否存在之前保存的模型状态,如果存在,则加载模型参数和优化器状态。接下来,我们使用一个循环来继续训练模型,训练过程中我们可以保存一些状态,比如损失值。最后,我们在每个训练周期结束后保存模型的当前状态,以便下次训练时可以从中断处继续。
需要注意的是,在保存和加载模型时,我们需要指定模型所在的设备(如GPU或CPU)。如果模型在GPU上训练,但在加载模型时没有可用的GPU,则需要使用map_location
参数将模型加载到CPU上。
通过这种方式,我们可以在中断后恢复训练,而无需重新开始。这对于大规模的深度学习任务尤其重要,因为这样可以节省时间和计算资源,同时保留了模型在中断前的训练进度。