PyTorch从1.6.0版本以后开始支持Stochastic Weight Averaging。
That is, after the conventional training of an object detector with the initial learning rate and the ending learning rate , train it for an extra 12 epochs using the cyclical learning rates (, ) for each epoch, and then average these 12 checkpoints as the final detection model.
SWA理论认为平均多个SGD优化轨迹上的多个模型,最终模型泛化性能更好。如下图
SGD倾向于收敛到loss的平稳区域,由于权重空间的维度比较高,平稳区域的大部分都处于边界,SGD通常只会走到这些平稳区域的边界。SWA通过平均多个SGD的权重参数,使其能够达到平稳区域的中心。
Object Detection
SWA Object Detection在目标检测任务上尝试了不同的epoch和固定学习率或者循环余弦退火学习率,最后发现使用12个epoch和循环余弦退火学习率效果最好。
PyTorch示例代码:
参考资料
1 Stochastic Weight Averaging blog
2 Stochastic Weight Averaging in PyTorch
3 Stochastic Weight Averaging docs
4 SWA Object Detection
5 Averaging Weights Leads to Wider Optima and Better Generalization