EMA的定义

在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。
指数移动平均EMA以及Pytorch实现_深度学习

深度学习中的EMA

上面讲的是广义的ema定义和计算方法,特别的,在深度学习的优化过程中,θt 是t时刻的模型权重weights,vt是t时刻的影子权重(shadow weights)。在梯度下降的过程中,会一直维护着这个影子权重,但是这个影子权重并不会参与训练。基本的假设是,模型权重在最后的n步内,会在实际的最优点处抖动,所以我们取最后n步的平均,能使得模型更加的鲁棒。

代码

class ModelEMA(object):
def __init__(self, args, model, decay, device='', resume=''):
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device
self.wd = args.lr * args.wdecay
if device:
self.ema.to(device=device)
self.ema_has_module = hasattr(self.ema, 'module')
if resume:
self._load_checkpoint(resume)
for p in self.ema.parameters():

p.requires_grad_(False)

def _load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
assert isinstance(checkpoint, dict)
if 'ema_state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['ema_state_dict'].items():
if self.ema_has_module:
name = 'module.' + k if not k.startswith('module') else k
else:
name = k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)

def update(self, model):
needs_module = hasattr(model, 'module') and not self.ema_has_module
with torch.no_grad():
msd = model.state_dict()
for k, ema_v in self.ema.state_dict().items():
if needs_module:
k = 'module.' + k
model_v = msd[k].detach()
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
# weight decay
if 'bn' not in k:
msd[k] = msd[k] * (1. - self.wd)

Reference

​https://fyubang.com/2019/06/01/ema/​