BatchNorm已经作为常用的手段应用在深度学习中,效果显著,加快了训练速度,保证了梯度的流动,防止过拟合,降低网络对初始化权重敏感程度,减少对调参的要求。今天自己就做个总结,记录一下BatchNorm,并从Pytorch源码来看BatchNorm。

BN的灵感来源

讲解BN之前,我们需要了解BN是怎么被提出的。在机器学习领域,数据分布是很重要的概念。如果训练集和测试集的分布很不相同,那么在训练集上训练好的模型,在测试集上应该不奏效(比如用ImageNet训练的分类网络去在灰度医学图像上finetue在测试,效果应该不好)。对于神经网络来说,如果每一层的数据分布都不一样,后一层的网络则需要去学习适应前一层的数据分布,这相当于去做了domian的adaptation,无疑增加了训练难度,尤其是网络越来越深的情况。

实际上,确实如此,不同层的输出的分布是有差异的。BN的那篇论文中指出,不同层的数据分布会往激活函数的上限或者下限偏移。论文称这种偏移为internal Covariate Shift,internal指的是网络内部。
BN就是为了解决偏移的,解决的方式也很简单,就是让每一层的分布都normalize到标准高斯分布。(这里的每一层并不准确,BN是根据划分数据的集合去做Normalization,不同的划分方式也就出现了不同的Normalization,如GN,LN,IN)

BN是如何做的

BN的行为根据训练和测试不同行为而不同。

在训练中使用BN

BN中的B是batchsize,就是说BN基于mini-batch SGD,首先训练数据必须是一个批次,含有多个样本。
假设特征pytorch conv2d源码 pytorch bn源码_Pytorch,在通道维度上求均值和方差
pytorch conv2d源码 pytorch bn源码_方差_02
pytorch conv2d源码 pytorch bn源码_pytorch conv2d源码_03

举个栗子。

pytorch conv2d源码 pytorch bn源码_深度学习_04


pytorch conv2d源码 pytorch bn源码_pytorch conv2d源码_05


以此类推

pytorch conv2d源码 pytorch bn源码_方差_06


计算得到一个四维的向量,作为这个层的pytorch conv2d源码 pytorch bn源码_方差_07

然后,

pytorch conv2d源码 pytorch bn源码_深度学习_08

pytorch conv2d源码 pytorch bn源码_方差_09

这里第二个公式绝对是有用的,pytorch conv2d源码 pytorch bn源码_深度学习_10是要学习的参数,参与训练。为什么需要这个公式呢。

因为我们在第一个公式中减去了均值和除以方差,降低了非线性能力。第二个公式就是去补偿非线性能力的。甚至通过学习均值和方差,BN是可以还原回原来的特征。

我们在一些源码中,可以看到带有BN的卷积层,bias设置为False,就是因为即便卷积之后加上了Bias,在BN中也是要减去的,所以加Bias带来的非线性就被BN一定程度上抵消了。需要补偿。

然后再接激活函数即可。这就是完成了BN训练过程

在测试中使用BN

在训练中使用BN是要计算均值和方差的,而这两个统计量是随着样本不同而变化的。如果在测试中依然遵循这样的方式,那么无疑同一个样本在不同的batch中预测会得到不一样的概率值,这显然是不对的。
在测试中,BN根据训练过程中计算的均值和方差,使用滑动平均去记录这些值。在测试的时候统一使用记录下来的滑动平均值,这一点可以从源码中看出来。所以在TensorFlow或者Pytorch中,BN的代码分别有is_training 和 self.training字段,就是为了区别使用行为的。

举个例子。
在训练过程的第t次迭代中,我们得到了均值u和方差sigma。那么u和sigma将使用如下方式记录下来。
pytorch conv2d源码 pytorch bn源码_BatchNorm_11

pytorch conv2d源码 pytorch bn源码_Pytorch_12
最后得到的pytorch conv2d源码 pytorch bn源码_pytorch conv2d源码_13pytorch conv2d源码 pytorch bn源码_方差_14作为最终的值保存下来。供测试环节使用。

BN的好处以及原因

加速训练

输出分布向着激活函数的上下限偏移,带来的问题就是梯度的降低,(比如说激活函数是sigmoid),通过normalization,数据在一一个合适的分布空间,经过激活函数,仍然得到不错的梯度。梯度好了自然加速训练。

降低参数初始化敏感

以往模型需要设置一个不错的初始化才适合训练,加了BN就不用管这些了,现在初始化方法中随便选择一个用,训练得到的模型就能收敛。

PyTorch中BN源码解析

nn.BatchNorm2d继承_BatchNorm,BatchNorm2d仅仅负责查看tensor的尺寸是否符合要求。直接跳到_BatchNorm中。

self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)

上面是构造函数的一部分,其中running_mean,running_var就是用来记录均值和方差的滑动平均值的。都是用buffer来申请储存空间不是用parameter,是因为这不参与训练。weight和bias就是pytorch conv2d源码 pytorch bn源码_方差_15,pytorch conv2d源码 pytorch bn源码_BatchNorm_16,是训练参数。

def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is self.momentum set to
        # (when it is available) only so that if gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)

在前向过程中,可以看到self.training,如果是训练中使用BN,需要设置exponential_average_factor ,这个值就是我们上面讲解测试中使用bN用到的0.9。