pytorch中的BN层简介

  • 简介
  • pytorch里BN层的具体实现过程
  • momentum的定义
  • 冻结BN及其统计数据


简介

BN层在训练过程中,会将一个Batch的中的数据转变成正太分布,在推理过程中使用训练过程中的参数对数据进行处理,然而网络并不知道你是在训练还是测试阶段,因此,需要手动的加上,需要在测试和训练阶段使用如下函数。

model.train() or model.eval()

在Pytorch中,BN层的类的参数有:

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

其中affine定义了BN层的参数γ和β是否是可学习的(不可学习默认是常数1和0). BN层的统计数据更新是在每一次训练阶段model.train()后的forward()方法中自动实现的,而不是在梯度计算与反向传播中更新optim.step()中完成。

pytorch里BN层的具体实现过程

BN层的输出Y与输入X之间的关系是:
pytorch bn层 pytorch bn层 eps_测试阶段

  • 其中gamma、beta为可学习参数(在pytorch中分别改叫weight和bias),训练时通过反向传播更新;
  • 而running_mean、running_var则是在前向时先由X计算出mean和var,再由mean和var以动量momentum来更新running_mean和running_var。
  • 所以在训练阶段,running_mean和running_var在每次前向时更新一次;在测试阶段,则通过net.eval()固定该BN层的running_mean和running_var,此时这两个值即为训练阶段最后一次前向时确定的值,并在整个测试阶段保持不变。

momentum的定义

Pytorch中的BN层的动量平滑和常见的动量法计算方式是相反的,默认的momentum=0.1

冻结BN及其统计数据

冻结BN的方式是在模型训练时,把BN单独挑出来,重新设置其状态为eval(在model.train()之后覆盖training状态)

  • 方法一
def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
      m.eval()

model.apply(set_bn_eval)
  • 方法二,重写module中的train()方法
def train(self, mode=True):
        """  Override the default train() to freeze the BN parameters  """
        super(MyNet, self).train(mode)
        if self.freeze_bn:
            print("Freezing Mean/Var of BatchNorm2D.")
            if self.freeze_bn_affine:
                print("Freezing Weight/Bias of BatchNorm2D.")
        if self.freeze_bn:
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    if self.freeze_bn_affine:
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False