一、什么是Batch Normalization(BN)层
BN层是数据归一化的方法,一般都是在深度神经网络中,激活函数之前,我们在训练神经网络之前,都会对数据进行预处理,即减去均值和方差的归一化操作。但是随着网络深度的加深,函数变的越来越复杂,每一层的输出的数据分布变化越来越大。BN的作用就是把数据强行拉回我们想要的比较好的正态分布下。这样可以在一定程度上避免梯度爆炸或者梯度消失的问题,加快收敛的速度。
二、BN是如何操作的
BN工作流程:
1、计算当前batch_size数据的均值和方差;
2、将当前batch内的数据,normalize到均值为0,方差为1的分布上;
3、然后对normalized后的数据进行缩放和平移,缩放和平移的是可学习的。
BN层的状态包含4个参数:
- weight,即缩放操作的\gamma
- bias,缩放操作的\beta
- running_mean,训练阶段在全训练数据上统计的均值,测试阶段会用到
- running_var,训练阶段在全训练数据上统计的方差,测试阶段会用到
weight和bias这两个参数需要训练,而running_mean、running_val不需要训练,它们只是训练阶段的统计值。
训练时,均值、方差分别是该批次内数据相应维度的均值与方差;
推理时,均值、方差是基于所有批次的期望计算所得,
BN层的使用:
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
momentum:估计running_mean和 ruuning_var时使用
affine:如果为true,就学习参数,否则不学习。
track_running_stats:如果为true,持续跟踪running_mean,running_var
三、BN最大的作用
加快收敛。
四、为什么要freeze BN层
BN层在CNN网络中大量使用,可以看上面bn层的操作,第一步是计算当前batch的均值和方差,也就是bn依赖于均值和方差,如果batch_size太小,计算一个小batch_size的均值和方差,肯定没有计算大的batch_size的均值和方差稳定和有意义,这个时候,还不如不使用bn层,因此可以将bn层冻结。另外,我们使用的网络,几乎都是在imagenet上pre-trained,完全可以使用在imagenet上学习到的参数。
五、如何freeze BN层
有两种,一种是在训练阶段,将bn层变为eval(),即不更新统计running_mean和runn_val;另一种是需要将bn层的requires grad = False,BN层的参数weight和bias不优化,更新。
frozen: stop gradient update in norm layers
norm_eval: stop moving average statistics update in norm layers
def train(self, model=True):
freeze_bn = False
freeze_bn_affine = False
supper(myNet, self).train(mode)
if freeze_bn:
print ("Freezing Mean/Var of BatchNorm2D.")
for m in self.model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if freeze_bn_affine:
print ("Freezeing Weight/Bias of BatchNorm2D.")
if freeze_bn_affine:
m.weight.requires_grad = False
m.bias.requires_grad = False
两种freeze BN的方式,如何使用,我们来看一下《MMDetection: Open MMLab Detection Toolbox and Benchmark》里面的相关实验,在mmdetection中eval = True, requires grad = True是默认设置,不更新BN层的统计信息,也就是running_var和running_mean,但是优化更新其weight和bias的学习参数。
当GPU内存限制时,batch_size只能设置很小,例如1或者2,因此会对BN层进行freeze。上面的table6 时eval和requires_grad不同组合时的效果,该实验使用的网络是Mask R-CNN。Table 6显示,lr schedulex1时,更新统计信息,即eval = False,会损害网络性能,当eval = True,对权重weight 和 bias是否更新,即requires_grad = False or True,影响不大;但是lr_schedulex2中,eval=True, requires_grad = True 效果最好。