Contents
- Introduction
- Methods
- Re-balanced weighting after Re-sampling
- Negative-Tolerant Regularization
- Distribution-Balanced Loss (DB loss)
- Experiments
- Dataset Construction
- Experiments
- Benchmarking Results
- References
Introduction
- 针对具有长尾分布的多标签分类问题,作者提出了一种新的损失函数 Distribution-Balanced Loss。作者认为,多标签分类问题有两个难点:label co-occurrence 和 dominance of negative labels,Distribution-Balanced Loss 则通过两个对标准二分类交叉熵的改进来解决上述问题:(1) 考虑了 label co-occurrence 之后的权值重平衡策略 (re-balanced weighting);(2) 用于缓解负样本过度抑制 (over-suppression of negative labels) 的 negative tolerant regularization
- 在 Pascal VOC 和 MS COCO 上的实验证明,Distribution-Balanced Loss 与已有方法相比性能有较大提升
Related work: 其他的重加权工作:
Methods
为数据集的样本数, 为类别数, 为样本 对应的标签向量, 为类别 的正样本数 ()
Re-balanced weighting after Re-sampling
- 在单标签分类中,常用的重采样方法有 class-aware sampling,也就是先随机采样一个类别,再从该类别中随机采样一个样本。然而,在多标签分类中,label co-occurrence 是十分常见的。例如,一张图片中包含的 “老虎”、“猎豹” 等非常见标签经常会和 “树”、“河” 等常见标签一起出现。因此,(1) 重采样会引入类内不平衡,也就是同一类中的不同样本不再是以相同概率被采样,这是因为按照 class-aware sampling,多标签分类中样本的采样概率为 而非 ;(2) 更致命的是,对多标签分类数据集进行上述重采样未必会使得类别分布更加均衡,甚至可能加剧类别分布不均衡的情况,这是因为在增加稀有标签采样概率的同时也会增加常见标签的采样概率
- 基于上述观察,作者提出在重采样后增加一个重平衡加权策略。首先,对于样本 和标签 (),我们期望的 Class-level 重采样概率 为
然而,实际的 Instance-level 重采样概率 却为
为了缩小期望重采样概率 和实际重采样概率 之间的差距,作者在损失函数中引入了重平衡权重,用于缓解各个类别实际采样概率大于期望采样概率的问题:
但是, 有时会趋近于 0,这不利于模型的优化。为了使优化过程更加稳定,作者进一步使用了一个平滑函数来将 映射到 的值域内:
其中, 为偏置, 控制映射函数的形状, 为 sigmoid 函数。最终,可以得到如下的 Re-balanced-BCE:
其中, 为样本 输出的类别 的 logit。值得一提的是,尽管 是从正样本的重采样过程中推出的,但这里 不仅被用在了正样本上,还被用在了负样本上,作者表示这是为了维持类级别的一致性 (对于某一类别而言,经过重加权后正负样本的采样概率是一样的,均为 )
Negative-Tolerant Regularization
- 在多标签分类数据集中,每个样本通常只属于少数几个类别,因此给定一个样本,就会产生少数正样本和大量负样本,带来正负样本的不平衡。如果使用 BCE 这种对称损失进行训练,过多的负样本就会导致负样本的过度抑制 (我的理解是这里 “过度抑制” 就是指模型倾向于输出更小的 logit),进而使得分类边界带有显著的偏向性。具体而言,与单标签分类中的 CE+Softmax 相比,BCE+sigmoid 的优化过程更为剧烈。当遇到负样本时,CE 和 BCE 对 logit 的导数如下:
下图对导数进行了可视化:
可见,CE+Softmax 在优化时,如果样本的正类 logit 远高于负类 logit,损失对负类 logit 的梯度就会很小,但 BCE+sigmoid 在优化时,由于将多标签分类拆分为了多个二分类问题,因此不管样本的正类 logit 是多少,损失对负类 logit 的梯度总会使得负类 logit 远离 0,向比较小的负值靠拢 (which results in continuous suppression) (同理,也会使得正类 logit 远离 0,向比较大的正值靠拢)。上述负样本的过度抑制现象带来的最直接的后果就是模型容易对尾部类别过拟合 (因为对于一个特定类别 (尤其是尾部),数据集中绝大多数都是它的负样本,当分类器被海量负样本包围,且被要求对每一个负样本都输出一个足够低的预测值时,分类器向量在训练过程中将被迫远远偏离大量自然样本的分布,而仅仅过拟合在它的个别正样本上。可以想像分类器预测值在特征向量空间中的分布具有一个尖锐的波峰,泛化性能很差) - 为了解决上述问题,作者提出了一种正则化方式,使得模型在训练时不再对负样本持续施加过重的惩罚,而是点到为止。我们只需要对分类器的负类输出进行一个简单的线性变换就能够实现上述功能 (i.e. 新的负类输出 ) (不要忘记加上正则化系数 来约束梯度值的范围在 0 到 1 之间),变换后的损失函数对负类 logit 的梯度如下图所示,可以看到,当负类 logit 低于阈值 时,相比原来的梯度 (lambda=1),加入正则化后的梯度急剧降低,进而缓解了负类 logit 的过度优化:
最后的损失函数为:
其中, 为 class-specific bias,它与模型的内在偏置 有关。在模型训练过程中,模型的内在偏置会接近类概率先验 ,也就是最小化下式:
求解上式可得估计的 ,该偏置随样本频率递增:
即为一个用于模拟该内在偏置的量,这种手动初始化 bias 的方法可以把这种本征概率分布嵌入学习过程中,便于网络更多地学习频率分布之外的类别特征 (类似的思想在《Long-Tail Learning via Logit Adjustment》这篇文章里有更清楚的解释):
Distribution-Balanced Loss (DB loss)
- Distribution-Balanced Loss 就是将之前的 R-BCE 和 NT-BCE 结合起来,有助于平滑模型的输出分布:
Experiments
Dataset Construction
- 作者基于 Pascal VOC 和 MS COCO,按照 pareto distribution,以抽取的方式人工构造了两个长尾分布的多标签数据集用以训练,称为 VOC-MLT 和 COCO-MLT,其中 可以控制数据规模衰减的速度。具体而言,选定 后可以得到 pareto distribution,可以在 CDF 达到 0.99 时截断 pdf,然后将 pdf 的最大值 rescale 到 ,最后将 轴按照原始数据集中的样本数均匀分割就能得到 reference distribution
- 在多标签数据集中,假设我们随机选取一个类别 ,那么从该类别中采样得到的样本也属于类别 的概率为
可以认为 较大的类别即为 head class. 在构造长尾数据集 (subset) 时,作者首先将所有类别按照 降序排列,此时 subset 为空。然后从 head 到 tail,对每个类别 ,作者都比较现在 subset 内已有的类别 的样本数和 reference distribution 中期望的样本数并随机进行样本的增加或删减,这样就能保证 tail classes 具有比较少的数据量。如下图所示,长尾数据集的构造过程是递进的,并且各个类别包含的样本个数的排列顺序与测试集较为接近 - VOC-MLT:VOC-MLT 来自 VOC-2012 的 train-val set,. 它包括了来自 20 个类别的 1142 张图像,最多的类别有 775 张图像,最少的类别只有 4 张图片。测试集来自 VOC2007 test set,包含 4952 张图像
- COCO-MLT:COCO-MLT 来自 MS COCO-2017,. 它包括了来自 80 个类别的 1909 张图像,最多的类别有 1128 张图像,最少的类别只有 6 张图片。测试集来自 COCO2017 test set,包含 5000 张图像
Experiments
Benchmarking Results
- Evaluation Metrics: 作者以 mAP 为主要评价指标在原始测试集上进行验证。作者根据每个类别含有的训练样本数量 将其划分为头部 (head, ),中部 (medium, ) 和尾部 (tail, ) 三个子集,并在整体和各子集上都进行了结果对比
- Comparing Methods: (1) Empirical risk minimization (ERM): baseline; (2) Re-weighting (RW): a smooth version of re-weighting to be inversely proportional to the square root of class frequency, and we normalize the weights to be between 0 and 1 in a mini-batch; (3) Re-sampling (RS): class-aware re-sampling; …