原文:Sohn K, Berthelot D, Carlini N, et al. Fixmatch: Simplifying semi-supervised learning with consistency and confidence[J]. Advances in neural information processing systems, 2020, 33: 596-608.
源码:https://github.com/google-research/fixmatch
半监督学习(SSL)提供了一种利用无标记数据提高模型性能的有效方法。最近,这一领域取得了快速的进展,但代价是需要更复杂的方法。在本文中,我们提出了FixMatch方法,对现有的SSL方法进行了大幅简化。首先,FixMatch对无标记图像的弱增强视图进行预测,生成伪标签,并且只有当模型产生高置信度预测时,才会保留伪标签。然后,FixMatch对同一无标记图像的强增强视图进行预测,并且将预测结果与上述伪标签进行匹配,计算损失,以此来训练模型。FixMatch简单易行,并且在各种半监督学习基准上都达到了最先进的性能。我们进行了广泛的消融研究,以梳理出对FixMatch成功最重要的因素。
图1:FixMatch示意图。首先,将弱增强图像(上)输入模型以获得预测结果,当某一类的预测概率高于阈值(虚线)时,将预测结果转换为one-hot伪标签。然后,利用模型对同一图像的强增强视图进行预测(下),并且将预测结果与上述伪标签进行匹配,计算交叉熵损失,以此来训练模型。
表1:不同SSL算法的比较。
表2:不同SSL算法在CIFAR-10、CIFAR-100、SVHN和STL-10数据集上的错误率。FixMatch(RA)使用RandAugment数据增强,FixMatch(CTA)使用CTAugment数据增强。
图2:在CIFAR-10数据集上,FixMatch仅用上面10张标记图像就能达到78%的精度。
图3:FixMatch的消融研究结果。(a)改变伪标签的置信度阈值对错误率的影响。(b)当置信度阈值τ分别为0、0.8、0.95时,观测“sharpening”对错误率的影响。具有默认超参数的FixMatch的错误率用红色虚线表示。
表3:在CIFAR-10数据集上,FixMatch不同数据增强的消融研究结果。
算法1:FixMatch的伪代码。将有标记数据和无标记数据同时输入模型;计算有标记数据的交叉熵损失ls;计算无标记数据的交叉熵损失lu;计算总的损失:ls+λu×lu,并以此来训练模型,其中λu是lu的权重系数。
表4:FixMatch在CIFAR-10、CIFAR-100、SVHN和STL-10数据集上的超参数。
表5:在不同的置信度阈值下,模型训练结束时的mask rate和impurity,以及FixMatch的错误率。
表6:学习率衰减的消融研究结果。
表7:优化器的消融研究结果。
图4:优化器的消融研究结果。(a)改变β值对错误率的影响。(b)当β=0时,改变η值对错误率的影响。
图5:FixMatch的消融研究结果。(a)改变无标记数据的比例μ对错误率的影响。(b)改变权重衰减系数对错误率的影响。具有默认超参数的FixMatch的错误率用红色虚线表示。
图6:每类只有一个标记数据的实验。
表12:RandAugment数据增强策略。
表13:CTAugment数据增强策略。
半监督学习(SSL)最近取得了快速的进展。不幸的是,这种进步的代价是越来越复杂的学习算法,复杂的损失项和许多难以调整的超参数。我们提出了FixMatch,这是一种更简单的SSL算法,可以在许多数据集上获得最先进的结果。在每类只有一个标记数据的情况下,FixMatch依然能够达到惊人的高精度,这在一定程度上弥合了少量标签(low-label)半监督学习和小样本学习之间的差距。FixMatch在标记数据和无标记数据上计算标准交叉熵损失,它的训练目标只需几行代码就能编写出来。由于这种简单性,我们能够彻底研究FixMatch的工作原理。我们发现某些设计选项很重要(而且往往被低估)——最重要的是权重衰减和优化器的选择。总的来说,我们相信这种简单有效的半监督学习算法的存在,将有助于机器学习部署在越来越多的实际应用领域,而这些领域往往存在标注成本高、标签获取难等问题。
多模态人工智能