在训练深度学习模型时,有时候由于各种原因(如断电、代码错误、计算机故障等),训练过程可能会被中断。为了能够从中断处继续训练,我们需要保存模型的当前状态(包括模型参数和优化器状态),并在恢复训练时加载这些状态。

在PyTorch中,我们可以使用torch.save()torch.load()函数来保存和加载模型。而为了保存和加载优化器状态,我们可以使用state_dict()函数和load_state_dict()函数。

下面是一个示例,展示了如何在中断后继续训练一个简单的神经网络模型:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(10, 1)  # 输入层和输出层都是一维的线性变换

    def forward(self, x):
        x = self.fc(x)
        return x

# 创建模型实例
model = Net()

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 检查是否存在之前保存的模型状态
if torch.cuda.is_available():
    checkpoint = torch.load('model_checkpoint.pth')
else:
    checkpoint = torch.load('model_checkpoint.pth', map_location=torch.device('cpu'))

# 加载模型参数和优化器状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
losses = checkpoint['losses']

# 模型继续训练
for epoch in range(start_epoch, num_epochs):
    # 训练代码...

    # 保存模型状态
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': losses
    }
    torch.save(checkpoint, 'model_checkpoint.pth')

在上面的代码中,我们首先定义了一个简单的神经网络模型,并创建了损失函数和优化器。然后,我们检查是否存在之前保存的模型状态,如果存在,则加载模型参数和优化器状态。接下来,我们使用一个循环来继续训练模型,训练过程中我们可以保存一些状态,比如损失值。最后,我们在每个训练周期结束后保存模型的当前状态,以便下次训练时可以从中断处继续。

需要注意的是,在保存和加载模型时,我们需要指定模型所在的设备(如GPU或CPU)。如果模型在GPU上训练,但在加载模型时没有可用的GPU,则需要使用map_location参数将模型加载到CPU上。

通过这种方式,我们可以在中断后恢复训练,而无需重新开始。这对于大规模的深度学习任务尤其重要,因为这样可以节省时间和计算资源,同时保留了模型在中断前的训练进度。