LR_scheduler
LR_scheduler是用于调节学习率lr的,在代码中,我们经常看到这样的一行代码
scheduler.step()
通过这行代码来实现lr的更新的,那么其中的底层原理是什么呢?我们就进去看看
在pytorch代码中,各种类型scheduler大多基于_LRScheduler类
我们就看看这个类的step()函数到底干了什么
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
self._step_count += 1
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
def __enter__(self):
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1 # 表示上一个epoch
values = self.get_lr() # 计算学习率lr
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch # 直接跳转到参数epoch
if hasattr(self, "_get_closed_form_lr"):
values = self._get_closed_form_lr()
else:
values = self.get_lr()
# 对所有参数权重对应的lr进行修改
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr # 修改学习率
self.print_lr(self.verbose, i, lr, epoch)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
由上代码可知,step()的目的是计算计算新的学习率并对旧学习率进行修改,其中最重要的函数是get_lr(),我们接下来对这个函数进行分析
def get_lr(self):
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError
由于_LRScheduler类是一个基类,不表示任何学习率策略,我们选择最简单的StepLR学习策略(学习率阶梯式下降)来分析
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): # 表示在一个阶梯上,不改变学习率
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma # 对所有学习率乘以一个小于1的小数,减小学习率
for group in self.optimizer.param_groups]
如果step()函数中有epoch参数,需要直接跳转到指定epoch,那么直接乘以固定的小数就不对了,这时候就需要函数_get_closed_form_lr()
def _get_closed_form_lr(self):
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
for base_lr in self.base_lrs]
其中self.last_epoch之前在基类_LRScheduler中已经被赋值了self.last_epoch = epoch ,所以直接根据学习率变化公式计算处理
由上可知,get_lr()和_get_closed_form_lr()就是具体的学习率计算方法
这样,我们就可以根据不同的学习率计算方式设计自己的scheduler类了。
warmup
初始训练阶段,直接使用较大学习率会导致权重变化较大,出现振荡现象,使得模型不稳定,加大训练难度。而使用Warmup预热学习率,在开始的几个epoch,逐步增大学习率,如下图所示,使得模型逐渐趋于稳定,等模型相对稳定后再选择预先设置的基础学习率进行训练,使得模型收敛速度变得更快,模型效果更佳
上图中的0-10epoch阶段就是一个warmup操作,学习率缓慢增加,10之后就是常规的学习率递减算法
原理上很简单,接下来从代码上进行分析,warmup可以有两种构成方式:
- 对已有的scheduler类进行包装重构
- 直接编写新的类
对于第一种情况,我们以CosineAnnealingLR类为例
scheduler = CosineAnnealingLR( # pytorch自带的类
optimizer=optimizer,
eta_min=0.000001,
T_max=(epochs - warmup_epoch) * n_iter_per_epoch)
scheduler = GradualWarmupScheduler( # 重构的类
optimizer,
multiplier=args.warmup_multiplier,
after_scheduler=scheduler,
warmup_epoch=warmup_epoch * n_iter_per_epoch)
其中,GradualWarmupScheduler就是基于CosineAnnealingLR重构的类,我们首先查看类中step()函数
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
if epoch > self.warmup_epoch: # 超过warmup范围,使用自带的类,也就是CosineAnnealingLR
self.after_scheduler.step(epoch - self.warmup_epoch) # 注意CosineAnnealingLR要从0epoch开始,所以需要减去
else:
super(GradualWarmupScheduler, self).step(epoch) # warmup范围,使用当前重构类的()
对于超过warmup范围,直接使用CosineAnnealingLR类,比较简单
对于warmup范围类,使用当前重构类的step()函数,因为也是继承于_LRScheduler类,所以step()同样是运用到get_lr()
def get_lr(self):
if self.last_epoch > self.warmup_epoch: # 超过warmup范围,使用CosineAnnealingLR类的get_lr()
return self.after_scheduler.get_lr()
else: # warmup范围,编写线性变化,也就是上图中0-10区间内的直线
return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.)
for base_lr in self.base_lrs]
对于第二种情况,step()无需构造,直接继承_LRScheduler,需要构造的是get_lr()函数,其中warmup范围外的代码与自带的CosineAnnealingLR类中get_lr()代码一样。