BN

BN中有一些比较值得注意的地方:

  1. train/test不一致的好处与坏处
  2. 推理中的坑:移动平均。
  3. 训练中的坑:batch的大小与分布。
  4. 微调中的坑:参数化,数据分布等。
  5. 实现中的坑:一个多功能的BN的实现。
  6. GN,precise-BN等等改进。

BN在训练和测试的时候,行为是不一致的。

在训练的时候,BN是使用了EMA来进行更新的。在测试的时候,并不是采用了EMA,而是采用了训练时候的统计量。

  1. EMA在\(\lambda\)过于小的时候,EMA并不是合理的近似。
  2. \(\lambda\)过于大的时候,需要很多次迭代。
  3. 模型不稳定的时候,或者是数据不稳定的时候。可能造成一些问题。

使用Precise-BatchNorm

继续使用EMA,但是使用比较大的\(\lambda\),把模型固定住。forward很多次迭代。

Rethinking 'Batch' in batchnormalization这篇paper没怎么读。但是我读了一下precise BN的code:

为了防止大家对里面的一些函数并不是很熟悉,所以。

itertools.islice()表示对迭代器进行切片,并且会消耗迭代器。

running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)

running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)

这个其实很好理解。这个等价于先求和再取平均。

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import itertools

import torch

BN_MODULE_TYPES = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
)


@torch.no_grad()
def update_bn_stats(model, data_loader, num_iters: int = 200):
"""
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration, so
the running average can not precisely reflect the actual stats of the
current model.
In this function, the BN stats are recomputed with fixed weights, to make
the running average more precise. Specifically, it computes the true average
of per-batch mean/variance instead of the running average.

Args:
model (nn.Module): the model whose bn stats will be recomputed.

Note that:

1. This function will not alter the training mode of the given model.
Users are responsible for setting the layers that needs
precise-BN to training mode, prior to calling this function.

2. Be careful if your models contain other stateful layers in
addition to BN, i.e. layers whose state can change in forward
iterations. This function will alter their state. If you wish
them unchanged, you need to either pass in a submodule without
those layers, or backup the states.
data_loader (iterator): an iterator. Produce data as inputs to the model.
num_iters (int): number of iterations to compute the stats.
"""
bn_layers = get_bn_modules(model)

if len(bn_layers) == 0:
return

# In order to make the running stats only reflect the current batch, the
# momentum is disabled.
# bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
# Setting the momentum to 1.0 to compute the stats without momentum.
momentum_actual = [bn.momentum for bn in bn_layers]
for bn in bn_layers:
bn.momentum = 1.0

# Note that running_var actually means "running average of variance"
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]

for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
model(inputs)

for i, bn in enumerate(bn_layers):
# Accumulates the bn stats.
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
# We compute the "average of variance" across iterations.
assert ind == num_iters - 1, (
"update_bn_stats is meant to run for {} iterations, "
"but the dataloader stops at {} iterations.".format(num_iters, ind)
)

for i, bn in enumerate(bn_layers):
# Sets the precise bn stats.
bn.running_mean = running_mean[i]
bn.running_var = running_var[i]
bn.momentum = momentum_actual[i]


def get_bn_modules(model):
"""
Find all BatchNorm (BN) modules that are in training mode. See
cvpack2.modeling.nn_utils.precise_bn.BN_MODULE_TYPES for a list of all modules that are
included in this search.

Args:
model (nn.Module): a model possibly containing BN modules.

Returns:
list[nn.Module]: all BN modules in the model.
"""
# Finds all the bn layers.
bn_layers = [
m
for m in model.modules()
if m.training and isinstance(m, BN_MODULE_TYPES)
]
return bn_layers