Prototypical Networks for Few-shot Learning

摘要

我们提出了原型网络,用于解决少样本分类问题,在这种情况下,分类器必须对训练集中未见的新类进行归纳,而每个新类只有少量的例子。原型网络学习一个度量空间,在这个空间中,可以通过计算与每个类别的原型表示的距离来进行分类。与最近的少样本学习方法相比,它们反映了一种更简单的归纳偏见,在这种有限的数据制度中是有益的,并取得了很好的效果。我们提供了一个分析,表明一些简单的设计决定可以比最近涉及复杂的结构选择和元学习的方法产生实质性的改进。我们进一步将原型网络扩展到零样本学习,并在CU-Birds数据集上取得了最先进的结果。

1 引言

少样本分类[20, 16, 13]是一项任务,其中分类器必须适应训练中未见的新类别,而每个类别只有几个例子。原始的方法会严重地过拟合,例如在新的数据上重新训练模型。虽然这个问题相当困难,但已经证明人类有能力进行一次性分类,即只给每个新类的一个例子,并有很高的准确性[16]。

最近有两种方法在少样本学习方面取得了重大进展。Vinyals等人[29]提出了匹配网络,它在已标记的例子集(支持集)的学习嵌入上使用注意机制来预测未标记的点(查询集)的类别。匹配网络可以被解释为在嵌入空间内应用的加权近邻分类器。值得注意的是,这个模型在训练过程中利用了被称为 "迭代"的小批次抽样,每个迭代都是通过对类和数据点进行子抽样来模仿少数样本的任务。迭代的使用使训练问题更忠实于测试环境,从而提高了泛化能力。Ravi和Larochelle[22]进一步提出了情节训练的想法,并提出了一个元学习的方法来进行少样本学习。他们的方法包括训练一个LSTM[9]来产生对分类器的更新,给定一个情节,这样它就能很好地泛化到测试集上。在这里,LSTM元学习器不是在多个事件中训练一个单一的模型,而是为每个事件训练一个自定义模型

我们通过解决过度拟合的关键问题来解决少样本学习的问题。由于数据严重受限,我们的工作假设是分类器应该有一个非常简单的归纳偏见。我们的方法,即原型网络,是基于这样的想法:存在一个嵌入,其中的点围绕着每个类别的单一原型代表而聚集。为了做到这一点,我们使用神经网络学习输入到嵌入空间的非线性映射,并将一个类别的原型作为其在嵌入空间的支持集的平均值。然后,通过简单地寻找最近的类原型,对嵌入的查询点进行分类。我们采用同样的方法来处理零样本学习;在这里,每个类都有元数据,对该类进行高层次的描述,而不是少量的标记过的例子。因此,我们学习将元数据嵌入一个共享空间,作为每个类别的原型。分类是通过为嵌入的查询点寻找最接近的类原型来进行的,就像在少样本的情况下一样。

在本文中,我们为少样本和零样本的设置制定了原型网络。我们将其与少样本设置中的匹配网络联系起来,并分析了模型中使用的基本距离函数。特别是,我们将原型网络与聚类[4]联系起来,以证明当距离用Bregman发散计算时,使用类的手段作为原型,例如平方欧氏距离。我们从经验上发现,距离的选择至关重要,因为欧氏距离大大超过了更常用的余弦相似度。在几个基准任务上,我们取得了最先进的性能。原型网络比最近的元学习算法更简单、更有效,使它们成为一种有吸引力的少样本和零样本学习方法。

2 原型网络

2.1 符号

在少样本分类中,我们得到一个由N个已标记的例子组成的小型支持集S =

论文阅读:Prototypical Networks for Few-shot Learning_支持集

其中每个

论文阅读:Prototypical Networks for Few-shot Learning_数据集_02

是一个例子的D维特征向量,

论文阅读:Prototypical Networks for Few-shot Learning_欧氏距离_03

是相应的标签。

2.2 模型

原型网络通过嵌入函数,以可学习的参数φ,计算出每个类的M维表示,或称原型。每个原型是属于其类别的嵌入支持点的平均矢量

论文阅读:Prototypical Networks for Few-shot Learning_数据集_04

给定一个距离函数,原型网络根据嵌入空间中与原型的距离的softmax,为查询点x产生一个类别分布:

论文阅读:Prototypical Networks for Few-shot Learning_欧氏距离_05

学习过程中,通过SGD最小化真实类别k的负对数概率。训练情节是通过从训练集中随机选择几个类的一个子集,然后在每个类的样本中选择一个子集作为支持集,其余的子集作为查询点来形成。算法1中提供了计算训练集的损失J(φ)的伪代码。

论文阅读:Prototypical Networks for Few-shot Learning_支持集_06

2.3 原型网络作为混合密度估计

对于一类特殊的距离函数,即常规Bregman发散[4],原型网络算法等同于在支持集上用指数族密度进行混合密度估计。正则Bregman发散定义为:

论文阅读:Prototypical Networks for Few-shot Learning_数据集_07

其中φ是一个Legendre型的可微、严格凸函数。Bregman发散的例子包括平方欧氏距离和Mahalanobis距离。原型计算可以被看作是支持集上的硬聚类,每个类有一个聚类,每个支持点被分配到其相应的类聚类中。对于Bregman分歧,已经证明[4],实现与其分配点最小距离的聚类代表是聚类平均数。因此,当使用Bregman发散时,方程(1)中的原型计算产生了给定支持集标签的最佳集群代表。此外,任何具有参数θ和累积函数ψ的正则指数族分布pψ(z|θ)都可以写成唯一确定的正则Bregman发散[4]:

论文阅读:Prototypical Networks for Few-shot Learning_支持集_08

现在考虑一个有参数的常规指数族混合模型

论文阅读:Prototypical Networks for Few-shot Learning_支持集_09

考虑到Γ,对一个未标记的点z的聚类分配y的推断变成:

论文阅读:Prototypical Networks for Few-shot Learning_数据集_10

对于每类有一个聚类的等权混合模型,聚类分配推断(6)等同于查询类预测(2)

论文阅读:Prototypical Networks for Few-shot Learning_支持集_11

。在这种情况下,原型网络实际上是在进行混合物密度估计,其指数族分布由决定。因此,距离的选择指定了关于嵌入空间中的类别条件数据分布的建模假设。

2.4 作为线性模型的重新解释

一个简单的分析对于深入了解所学分类器的性质很有用。当我们使用欧氏距离时,那么方程(2)中的模型就等同于一个具有特定参数化的线性模型[19]。为了看到这一点,展开指数中的项:

论文阅读:Prototypical Networks for Few-shot Learning_支持集_12

方程(7)中的第一个项相对于类k来说是常数,所以它不影响softmax概率。我们可以把其余的项写成如下的线性模型。

论文阅读:Prototypical Networks for Few-shot Learning_欧氏距离_13

我们在这项工作中主要关注平方欧氏距离(对应于球形高斯密度)。我们的结果表明,尽管等同于线性模型,但欧氏距离是一个有效的选择。我们假设这是因为所有需要的非线性都可以在嵌入函数中学习。事实上,这是现代神经网络分类系统目前使用的方法,例如,[14, 28]

2.5 与匹配网络的比较

原型网络与匹配网络在少样本的情况下不同,在单样本的情况下是等同的。匹配网络[29]产生一个给定支持集的加权近邻分类器,而原型网络在使用平方欧氏距离时产生一个线性分类器。在单一样本学习的情况下,因为每个类别只有一个支持点,匹配网络和原型网络变得等价。

一个自然的问题是,每类使用多个原型而不是只有一个原型是否有意义。如果每个类的原型数量是固定的,并且大于1,那么这就需要一个分区方案来进一步对类中的支持点进行分组。这在Mensink等人[19]和Rippel等人[25]中已经提出;但是这两种方法都需要一个与权重更新解耦的单独的分区阶段,而我们的方法很容易用普通的梯度下降方法学习。

Vinyals等人[29]提出了一些扩展,包括解耦支持点和查询点的嵌入函数,以及使用第二级全条件嵌入(FCE),考虑到每个情节中的特定点。这些同样可以被纳入到原型网络中,然而它们增加了可学习参数的数量,而且FCE使用双向LSTM对支持集进行了任意的排序。相反,我们表明,使用简单的设计选择可以达到相同的性能水平,我们接下来将概述这一点。

2.6 设计选择

距离指标 Vinyals等人[29]和Ravi和Larochelle[22]使用余弦距离来应用匹配网络。然而对于原型网络和匹配网络来说,任何距离都是允许的,我们发现使用平方的欧氏距离可以大大改善两者的结果。我们猜测这主要是由于余弦距离不是布雷格曼发散,因此第2.3节中讨论的与混合密度估计的等价关系不成立。

情节的构成 在Vinyals等人[29]和Ravi和Larochelle[22]中使用的一种构建情节的直接方法是选择论文阅读:Prototypical Networks for Few-shot Learning_支持集_14个类和每个类的论文阅读:Prototypical Networks for Few-shot Learning_数据集_15个支持点,以匹配测试时的预期情况。也就是说,如果我们期望在测试时进行5-way和1-shot,那么训练情节可以由

论文阅读:Prototypical Networks for Few-shot Learning_支持集_16

组成。然而,我们发现,用比测试时更高的论文阅读:Prototypical Networks for Few-shot Learning_支持集_14或 "方式 "来训练是非常有益的。在我们的实验中,我们在保留的验证集上调整训练论文阅读:Prototypical Networks for Few-shot Learning_支持集_14。另一个考虑因素是在训练和测试时是否要匹配论文阅读:Prototypical Networks for Few-shot Learning_数据集_15,或者说 "shot"。对于原型网络,我们发现通常最好是用相同的 "shot "数进行训练和测试。

2.7 零样本学习

零样本学习与少样本学习的不同之处在于,我们不是得到一个训练点的支持集,而是为每个类别得到一个类别元数据向量vk。这些数据可以事先确定,也可以从例如原始文本中学习[7]。对原型网络进行修改以处理 "零样本"情况是很简单的:我们只需定义

论文阅读:Prototypical Networks for Few-shot Learning_数据集_20

为元数据向量的单独嵌入。图1显示了原型网络的零点程序与少量程序的关系。由于元数据向量和查询点来自不同的输入域,我们发现根据经验将原型嵌入g固定为单位长度是有帮助的,然而我们并不限制查询嵌入f。

论文阅读:Prototypical Networks for Few-shot Learning_欧氏距离_21

图1:少样本和零样本情况下的原型网络。左图:少样本的原型ck被计算为每个类别的嵌入式支持实例的平均值。右图。零样本的原型ck是通过嵌入类的元数据vk产生的。在任何一种情况下,嵌入的查询点都是通过与类原型的距离的softmax来分类的:

论文阅读:Prototypical Networks for Few-shot Learning_支持集_22

3 实验

对于少样本学习,我们在Omniglot[16]和ILSVRC-2012[26]的miniImageNet版本上进行了实验,并采用Ravi和Larochelle[22]提出的分割法。我们在2011年版本的加州理工学院UCSD鸟类数据集(CUB-200 2011)[31]上进行了零散的实验。

3.1 Omniglot Few-shot Classification

Omniglot[16]是一个由50个字母组成的1623个手写字符的数据集。每个字符有20个例子,每个例子都是由不同的人绘制的。我们遵循Vinyals等人[29]的程序,将灰度图像的大小调整为28×28,并用90度的倍数的旋转来增加字符类别。我们使用1200个字符加上旋转来进行训练(总共4,800个类),剩下的类,包括旋转,用于测试。我们的嵌入结构反映了Vinyals等人[29]所使用的结构,由四个卷积块组成。每个区块包括一个64个过滤器的3×3卷积,批量归一化层[10],一个ReLU非线性和一个2×2的最大集合层。当应用于28×28的Omniglot图像时,这种结构会产生一个64维的输出空间。我们使用相同的编码器来嵌入支持点和查询点。我们所有的模型都是通过SGD与Adam[11]训练的。我们使用了10-3的初始学习率,并在每2000个事件中把学习率减半。除了批量归一化之外,没有使用正则化。

我们在1次拍摄和5次拍摄的情况下使用欧氏距离训练原型网络,训练情节包含60个类和每个类5个查询点。我们发现,将训练镜头与测试镜头相匹配是有利的,并且每个训练情节使用更多的类(更高的 "方式")而不是更少的类。我们与各种基线进行比较,包括神经统计学家[6]和匹配网络的微调和非微调版本[29]。我们计算了我们的模型在测试集的1000个随机生成的事件中的平均分类精度。结果显示在表1中,据我们所知,它们代表了这个数据集的最先进水平。

论文阅读:Prototypical Networks for Few-shot Learning_支持集_23

3.2 miniImageNet 少样本分类

miniImageNet数据集最初由Vinyals等人[29]提出,来自更大的ILSVRC-12数据集[26]。Vinyals等人[29]使用的分片包括60,000张大小为84×84的彩色图像,分为100类,每类有600个例子。在我们的实验中,我们使用了Ravi和Larochelle[22]介绍的分片,以便直接与最先进的几率学习算法进行比较。他们的拆分使用了一个不同的100个类的集合,分为64个训练类、16个验证类和20个测试类。我们遵循他们的程序,对64个训练类进行训练,并使用16个验证类来监测泛化性能。

我们使用与Omniglot实验中相同的四块嵌入结构,尽管在这里由于图像大小的增加而导致了1600维的输出空间。我们还使用与Omniglot实验中相同的学习率计划,并训练到验证损失停止改善。我们使用30次的情节进行1次分类训练,使用20次的情节进行5次分类训练。我们将训练镜头与测试镜头相匹配,每个类别每集包含15个查询点。我们与Ravi和Larochelle[22]报告的基线进行比较,其中包括一个简单的近邻方法,在64个训练类的分类网络上学习的特征之上。其他基线是两个非微调的匹配网络变体(包括普通和FCE)和Meta-Learner LSTM。从表2中可以看出,原型网络在这里以很大的优势达到了最先进的水平。我们进行了进一步的分析,以确定距离指标和每集训练类的数量对原型网络和匹配网络性能的影响。

论文阅读:Prototypical Networks for Few-shot Learning_数据集_24

为了使这些方法具有可比性,我们使用了我们自己的匹配网络的实现,它利用了与我们的原型网络相同的嵌入结构。在图2中,我们比较了余弦与欧几里得距离,以及5路与20路在1次拍摄和5次拍摄情况下的训练情节,每集每类15个查询点。我们注意到,20路比5路取得了更高的准确率,并推测20路分类的难度增加有助于网络更好地泛化,因为它迫使模型在嵌入空间做出更精细的决定。另外,使用欧氏距离比余弦距离大大改善了性能。这种效果对于原型网络更加明显,在这种网络中,计算类原型作为嵌入支持点的平均值更自然地适合欧氏距离,因为余弦距离不是布雷格曼发散。

论文阅读:Prototypical Networks for Few-shot Learning_数据集_25

图2:比较显示了距离指标和每个训练情节的类数对miniImageNet上的匹配网络和原型网络的5路分类准确性的影响。X轴表示训练情节的配置(方式、距离和镜头),Y轴表示相应镜头的5向测试精度。误差条表示在600个测试情节中计算出的95%的置信区间。请注意,在1次拍摄的情况下,匹配网络和原型网络是相同的。

3.3 CUB零点分类

为了评估我们的方法对零点学习的适用性,我们还在Caltech-UCSD鸟类(CUB)200-2011数据集上进行了实验[31]。CUB数据集包含200种鸟类的11,788张图像。我们严格按照Reed等人[23]的程序来准备数据。我们使用他们的分割法,将类分为100个训练,50个验证,和50个测试。对于图像,我们使用通过应用GoogLeNet[28]对原始和水平翻转的图像的中间、左上、右上、左下和右下裁剪提取的1024维特征2。在测试时,我们只使用原始图像的中间部分。对于类元数据,我们使用CUB数据集提供的312维连续属性向量。这些属性编码了鸟类的各种特征,如它们的颜色、形状和羽毛图案。

我们在1024维图像特征和312维属性向量的基础上学习了一个简单的线性映射,以产生一个1024维的输出空间。对于这个数据集,我们发现将类别原型(嵌入的属性向量)归一化为单位长度是有帮助的,因为属性向量来自与图像不同的领域。我们用50个类和每个类的10张查询图像来构建训练集。嵌入是通过SGD优化的,Adam的固定学习率为10-4,权重衰减为10-5。验证损失的早期停止被用来确定在训练和验证集上重新训练的最佳历时数。

表3显示,与利用属性作为类元数据的方法相比,我们取得了最先进的结果,差距很大。我们将我们的方法与其他嵌入方法进行比较,如ALE[1]、SJE[2]、DS-SJE/DA-SJE[23]。我们还与最近的一个聚类方法[17]进行了比较,该方法在通过微调AlexNet[14]获得的学习特征空间上训练SVM。这些零点分类结果表明,即使数据点(图像)来自相对于类(属性)的不同领域,我们的方法也足够普遍,可以应用。

论文阅读:Prototypical Networks for Few-shot Learning_数据集_26

4 相关工作

关于度量学习的文献非常多[15, 5];我们在这里总结一下与我们提出的方法最相关的工作。邻近成分分析(NCA)[8]学习Mahalanobis距离,以最大化K-近邻(KNN)在转换空间中的留空精度。Salakhutdinov和Hinton[27]通过使用神经网络进行转换来扩展NCA。大边际近邻(LMNN)分类法[30]也试图优化KNN的准确性,但使用了铰链损失,鼓励一个点的本地邻居包含具有相同标签的其他点。DNet-KNN[21]是另一种基于边际的方法,它通过利用神经网络来执行嵌入,而不是简单的线性变换来改进LMNN。在这些方法中,我们的方法与NCA[27]的非线性扩展最为相似,因为我们使用神经网络来进行嵌入,并且我们根据转换空间中的欧几里得距离来优化一个softmax,而不是余量损失。我们的方法和非线性NCA之间的一个关键区别是,我们直接在类上形成一个softmax,而不是单个点,根据与每个类的原型代表的距离计算。这使得每个类别都有一个独立于数据点数量的简明表示,并避免了存储整个支持集来进行预测的需要。

我们的方法也类似于最近的类平均值方法[19],其中每个类由其例子的平均值来表示。这种方法是为了在不重新训练的情况下迅速将新的类纳入分类器,然而它依赖于线性嵌入,并被设计用来处理新的类带有大量例子的情况。相比之下,我们的方法利用神经网络来非线性地嵌入点,并将其与偶发训练结合起来,以处理少数几个例子的情况。Mensink等人试图扩展他们的方法来进行非线性分类,但他们是通过允许类有多个原型来实现的。他们在预处理步骤中通过在输入空间上使用k-means找到这些原型,然后执行其线性嵌入的多模式变体。另一方面,原型网络以端到端的方式学习非线性嵌入,没有这样的预处理,产生的非线性分类器仍然只需要每类一个原型。此外,我们的方法自然可以推广到其他距离函数,特别是Bregman发散。

另一个相关的少样本学习方法是Ravi和Larochelle[22]提出的元学习方法。这里的关键见解是,LSTM动力学和梯度下降可以以有效的相同方式编写。然后,一个LSTM可以被训练成自己从一个给定的情节中训练出一个模型,其性能目标是对查询点进行良好的泛化。匹配网络和原型网络也可以被看作是元学习的形式,因为它们从新的训练情节中动态地产生简单的分类器;但是它们所依赖的核心嵌入在训练后是固定的。FCE对匹配网的扩展涉及到一个依赖于支持集的二级嵌入。然而,在少样本的情况下,数据量非常小,简单的归纳偏见似乎很有效,不需要为每个情节学习一个自定义的嵌入。

原型网络也与生成模型文献中的神经统计学家[6]有关,它扩展了变异自动编码器[12, 24],以学习数据集而非单个点的生成模型。神经统计学家的一个组成部分是 "统计网络",它将一组数据点总结为一个统计向量。它通过对数据集中的每个点进行编码,取一个样本平均值,并应用一个后处理网络来获得统计向量的近似后验。Edwards和Storkey在Omniglot数据集上测试了他们的一次性分类模型,他们将每个字符视为一个单独的数据集,并根据统计向量上的近似后验与测试点推断的后验有最小的KL-分歧的类别进行预测。像神经统计学家一样,我们也为每个类别产生一个汇总统计。然而,我们的模型是一个鉴别性的模型,这与我们的少样本分类的鉴别性任务是相称的。

关于零样本学习,在原型网络中使用嵌入式元数据与[3]的方法相似,都是预测线性分类器的权重。[23]的DS-SJE和DA-SJE方法也学习了图像和类元数据的深度多模态嵌入函数。与我们不同的是,他们使用经验性的风险损失进行学习。无论是[3]还是[23]都没有使用偶发训练,这使得我们可以帮助加快训练速度并使模型正规化。

5 结论

我们提出了一种简单的方法,称为原型网络,用于少数次学习,其基础是我们可以通过神经网络学习的表示空间中其例子的平均值来表示每个类别。我们通过使用偶发训练来训练这些网络,使其在几率设置中表现良好。这种方法比最近的元学习方法简单得多,也高效得多,即使没有为匹配网络开发的复杂扩展,也能产生最先进的结果(尽管这些扩展也可以应用于原型网络)。我们展示了如何通过仔细考虑所选择的距离度量,以及通过修改偶发学习程序来大大改善性能。我们进一步证明了如何将原型网络推广到零点设置,并在CUB-200数据集上取得最先进的结果。未来工作的一个自然方向是利用欧氏距离平方以外的布雷格曼发散,对应于球面高斯以外的类条件分布。我们对此进行了初步探索,包括为每个类别学习每个维度的方差。这并没有带来任何经验上的收获,这表明嵌入网络本身就有足够的灵活性,而不需要为每个类增加拟合参数。总的来说,原型网络的简单性和有效性使其成为一种有前途的少数次学习方法。

A 额外的Omniglot结果

在表4中,我们显示了使用欧氏距离训练的原型网络的测试分类精度,每集有5、20和60个类。

表4:Omniglot上原型网络的其他分类精度结果。训练集的配置由每集的类数("方式")、每类的支持点数("镜头")和每类的查询点数("查询")表示。分类准确率是测试集中随机产生的1000个情节的平均数。

论文阅读:Prototypical Networks for Few-shot Learning_支持集_27

图3显示了原型网络学到的嵌入的t-SNE可视化样本[18]。我们将来自同一字母的测试字符子集可视化,以获得更好的洞察力,尽管实际测试情节中的类可能来自不同的字母。即使可视化的字符是彼此之间的微小变化,网络也能够将手绘的字符紧密地聚集在类的原型周围。

B 额外的miniImageNet结果

在表5中,我们展示了主文中图2中训练情节配置比较的全部结果。

我们还比较了用每集不同数量的类来训练的欧氏距离原型网络。在这里,我们将每个训练集的类数从5到30不等,而每个类的查询点数量固定为15。结果显示在图4中。我们的研究结果表明,为了获得良好的几率分类结果,训练集的构建是一个重要的考虑因素。表6包含了这组实验的全部结果。

论文阅读:Prototypical Networks for Few-shot Learning_欧氏距离_28

图3:原型网络在Omniglot数据集上学习的嵌入的t-SNE可视化。图中显示了Tengwar文字的一个子集(测试集中的一个字母)。类原型用黑色表示。几个被错误分类的字符用红色突出显示,并有箭头指向正确的原型。

论文阅读:Prototypical Networks for Few-shot Learning_数据集_29

图4:在miniImageNet上训练的原型网络的训练 "方式"(每集的类的数量)的效果比较。每个训练情节包含每个类别的15个查询点。误差条表示在600个测试情节中计算出的95%的置信区间。

论文阅读:Prototypical Networks for Few-shot Learning_欧氏距离_30

表5:miniImageNet上的匹配网络和原型网络在余弦与欧几里得距离、5路与20路、1次与5次的比较。所有的实验对支持点和查询点都使用了一个共享的编码器,嵌入维度为1,600(结构和训练细节在主论文的第3.2节中提供)。分类精度是测试集中随机产生的600个情节的平均值,并显示95%的置信区间。

论文阅读:Prototypical Networks for Few-shot Learning_欧氏距离_31

表6:在miniImageNet上使用欧氏距离的原型网络的训练 "方式"(每个训练集的类的数量)的影响。训练情节中每一类的查询点数量固定为15个。分类精度是测试集中600个随机生成的情节的平均值,并显示了95%的置信区间。