PyTorch从1.6.0版本以后开始支持Stochastic Weight Averaging。

That is, after the conventional training of an object detector with the initial learning rate Stochastic Weight Averaging_泛化 and the ending learning rate Stochastic Weight Averaging_自然语言处理_02, train it for an extra 12 epochs using the cyclical learning rates (Stochastic Weight Averaging_泛化, Stochastic Weight Averaging_自然语言处理_02) for each epoch, and then average these 12 checkpoints as the final detection model.

SWA理论认为平均多个SGD优化轨迹上的多个模型,最终模型泛化性能更好。如下图

Stochastic Weight Averaging_深度学习_05


SGD倾向于收敛到loss的平稳区域,由于权重空间的维度比较高,平稳区域的大部分都处于边界,SGD通常只会走到这些平稳区域的边界。SWA通过平均多个SGD的权重参数,使其能够达到平稳区域的中心。

Stochastic Weight Averaging_泛化_06

Object Detection

SWA Object Detection在目标检测任务上尝试了不同的epoch和固定学习率或者循环余弦退火学习率,最后发现使用12个epoch和循环余弦退火学习率效果最好。

Stochastic Weight Averaging_权重_07

PyTorch示例代码:

loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(300):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
if i > swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data
preds = swa_model(test_input)

参考资料
1 ​​​Stochastic Weight Averaging blog​​​

2 ​​Stochastic Weight Averaging in PyTorch​

3 ​​Stochastic Weight Averaging docs​

​​ 4 ​​SWA Object Detection​​​

5 ​​Averaging Weights Leads to Wider Optima and Better Generalization​