BBN: Bilateral-Branch Network with Cumulative Learning for Long-Tailed Visual Recognition

一、背景介绍

1.长尾效应

长尾数据目标检测 大数据长尾效应_特征提取

长尾分布比较常见,指的是数据集中少量类别占总数据集比重较大。如果使用带有长尾分布的数据集去直接训练分类网络的话,就会导致对于占比较大的类别能够较好的预测,占比较小的类别不能够较好的预测。这样做,模型整体性能就会下降。

2.问题

常见的解决长尾效应的方法有resampling(采样)和cost-sensitive re-weighting(给损失函数添加权重)——这两个统称class re-balancing方法。

  • resampling:over-sampling——重复采集少类数据,可能会导致对少数类的过拟合。under-sampling——放弃主要类别数据,削弱深度网络泛化能力。
  • re-weighting:通常在损失函数中为尾类的训练样本分配较大的权重。

通过采用class rebalancing方法可以更改训练集逼近测试集的分布并且使得训练更加关注尾部类别,这也就是为什么class rebalancing方法可以改善长尾效应。

然而,尽管重新平衡方法具有良好的最终预测,但我们认为这些方法仍然有不利影响,即它们也会在一定程度上意外地损害学习到的深层特征(即表示学习)的表示能力。具体来说,重采样存在尾部数据过拟合(通过过采样)的风险,并且在数据不平衡极度时也存在欠拟合整个数据分布(通过欠采样)的风险。对于重新加权,它会通过直接改变甚至反转数据呈现频率来扭曲原始分布。

长尾数据目标检测 大数据长尾效应_深度学习_02

作者通过实验论证了上述观点。具体做法为首先任意选取一个分类网络,将其分为两部分:特征提取部分+分类层,接着使用三种方法训练该网络,三种方法分别为正常训练、resampling、re-weighting。这样我们便得到了分别用三种方法训练得到的分类网络,选取其特征提取部分并固定住,然后从头开始训练整个分类网络(此时特征提取部分参数不再变化,就像迁移学习中的冻结网络)。然后继续使用三种方法训练得到上图的实验结果。从实验结果不难看出当时用resampling和re-weighting使得最终模型的分类性能提高但是特征提取能力变低。(感觉这里度量特征提取部分性能的方法有点怪怪的,作者采取的方式是如果特征提取层提取到的特征性能较好,那么使用相同的训练分类层方法训练得到的整个模型性能就越好,类似于特征≈数据。故横向对比时,使用CE方法训练得到的模型性能最好)。

从这个实验结果图来看,可以得出一个结论支撑作者的模型,也就是当我们使用CE方法训练整个模型,然后使用RS/RW方法从头开始训练分类层,此时我们可以得到最好的结果。从实验结果来看,这个训练方法是明显好于直接使用RW/RS来训练整个网络的性能。

具体来说,对于每个类,我们首先通过平均此类的表示来计算质心向量。然后,计算这些表示与其质心之间的 2 距离,然后将其平均作为类内表示的紧凑性的度量。如果一个类的平均距离很小,则意味着该类的表示在特征空间中聚集得很近。我们在训练阶段将表示的 2-范数归一化为 1,以避免特征尺度的影响。我们根据分别使用交叉熵 (CE)、重新加权 (RW) 和重新采样 (RS) 学习的表征来报告结果。

基于此,我们便可以知道,想要得到一个更好的分类模型去处理长尾问题,我们便需要充分利用上述这个实验结果。

除了这个,作者还在补充材料里面利用另一种方法论证了使用re-balancing方法会导致特征提取能力变差。这次使用的评价标准是类内特征之间的距离,如果类内特征之间的距离越近,即平均距离越近,得到的特征越好。

长尾数据目标检测 大数据长尾效应_机器学习_03

基于上图,我们便可以得出下图这个结论。

长尾数据目标检测 大数据长尾效应_机器学习_04

综上,在本文中,作者揭示了re-balancing的机制是显着促进分类器学习,但会在一定程度上意外地损害所学深层特征的表示能力。

二、How class re-balancing strategies work?

这个部分上面两个实验已经介绍过了。即re-balancing可以提高分类器的性能但是同时会退化特征提取器性能。

三、方法介绍

1.网络结构图

长尾数据目标检测 大数据长尾效应_深度学习_05

作者已经发现使用re-balancing方法可以提高模型性能,但是使用该方法会导致特征提取层模型性能下降。故作者想要结合这两个方法的优势,来进一步提高模型性能。作者的办法是使用一种累计学习策略,先学习通用模式,然后逐渐关注尾部数据。

这里简单介绍一下这个网络的流程。首先通过两个部分共享的双分支网络,输入一个是具有长尾分布的数据集长尾数据目标检测 大数据长尾效应_特征提取_06,另一个是通过reverse操作后的数据集长尾数据目标检测 大数据长尾效应_机器学习_07。特征提取网络采用的是残差网络,最后一个残差网络不共享权重。GAP指的是全局平均池化。最终两个层的输出特征分别为长尾数据目标检测 大数据长尾效应_特征提取_08长尾数据目标检测 大数据长尾效应_长尾数据目标检测_09。中间的长尾数据目标检测 大数据长尾效应_机器学习_10是一个参数,其随着epoch的增加而改变。最终的输出结果为长尾数据目标检测 大数据长尾效应_深度学习_11,其表达式如下
长尾数据目标检测 大数据长尾效应_机器学习_12
由于是分类,最终的输出要经过softmax操作得到概率值长尾数据目标检测 大数据长尾效应_论文阅读_13

长尾数据目标检测 大数据长尾效应_机器学习_10的大小变化来看,随着训练的增加模型参数更新越来越依赖于红色框的分支。

损失函数的表达式如下:
长尾数据目标检测 大数据长尾效应_长尾数据目标检测_15
这里简单说一下,第二个分支输入数据、权重共享策略以及如何获得以及长尾数据目标检测 大数据长尾效应_机器学习_10的更新策略。

  • 第二个分支主要通过每个类别的概率长尾数据目标检测 大数据长尾效应_特征提取_17来进行采集,长尾数据目标检测 大数据长尾效应_特征提取_17表达式如下。

长尾数据目标检测 大数据长尾效应_深度学习_19

这里长尾数据目标检测 大数据长尾效应_长尾数据目标检测_20表示所有类别中类别数量最大的样本数,长尾数据目标检测 大数据长尾效应_长尾数据目标检测_21表示第长尾数据目标检测 大数据长尾效应_深度学习_22类样本对应的样本数。

数据生成步骤如下:1.计算出长尾数据目标检测 大数据长尾效应_深度学习_23;2.根据长尾数据目标检测 大数据长尾效应_深度学习_23随机抽取一个类长尾数据目标检测 大数据长尾效应_深度学习_22;3.均匀地从第长尾数据目标检测 大数据长尾效应_深度学习_22类中抽取一个样本进行替换。重复这个过程,直到获得一个batch的样本。

  • 在 BBN 中,两个分支在经济上共享相同的残差网络结构,如图 3 所示。我们使用 ResNets [12] 作为我们的骨干网络,例如 ResNet-32 和 ResNet-50。详细地说,除了最后一个残差块之外,两个分支网络共享相同的权重。共享权重有两个好处:一方面,传统学习分支(蓝色框)的良好学习表示可以有利于重新平衡分支(红色框)的学习。另一方面,共享权重将大大降低推理阶段的计算复杂度。
  • 长尾数据目标检测 大数据长尾效应_长尾数据目标检测_27
    这里的长尾数据目标检测 大数据长尾效应_特征提取_28表示当前时刻的epoch数,长尾数据目标检测 大数据长尾效应_论文阅读_29表示最大的epoch数。

2.推理阶段

这里面涉及到一个超参长尾数据目标检测 大数据长尾效应_机器学习_10,这里作者直接令长尾数据目标检测 大数据长尾效应_深度学习_31

四、实验结果

1.数据集介绍

CIFAR数据集为常用实验数据集,作者用长尾数据目标检测 大数据长尾效应_机器学习_32来表示不平衡比例。iNaturalist数据集是真实世界的大尺度数据集,其数据类别分布极度不平衡且存在细粒度问题。

2.对比实验

长尾数据目标检测 大数据长尾效应_机器学习_33

长尾数据目标检测 大数据长尾效应_论文阅读_34

这里简单说一下2X scheduler的意思是允许两倍个epoch数。

3.消融实验

长尾数据目标检测 大数据长尾效应_长尾数据目标检测_35

4.观点的验证实验

长尾数据目标检测 大数据长尾效应_深度学习_36


长尾数据目标检测 大数据长尾效应_论文阅读_37

注意,这个图是越平坦+方差越小越好。纵坐标表示对于分类器第长尾数据目标检测 大数据长尾效应_深度学习_22类的倾向性。

疑惑点

这个长尾数据目标检测 大数据长尾效应_机器学习_10起作用原因是啥——个人觉得这东西很像混合学习,只不过取的混合系数分布不同。