基本概念

小样本学习(Few-Shot Learning, FSL)任务,顾名思义,就是能够仅通过一个或几个示例就快速建立对新概念的认知能力。实现小样本学习的方式也有很多,比如:度量学习、数据增强、预训练模型、元学习等等,其中元学习是目前广泛使用的处理小样本学习问题的方法。

元学习(meta learning或learning to learn),也称学会学习,元学习算法能够在学习不同任务的过程中积累经验,从而使得模型能够快速适应新任务。

元学习与一般的监督学习的区别:

一般的监督学习是在训练集上训练出一个函数映射

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python

,这个函数可以识别出哪张图片是狗,哪张是猫。输入是某张图片,输出是标签。元学习算法则是让模型学会学习,即在训练集上学习出一个函数 

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_02


小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_02

 可以自动学习出一个函数 

小样本分类 数据增强 python 适合小样本的分类算法_pytorch_04

,他可以分辨出哪张图片是狗,哪张是猫。

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_02

 的输入是一个个的图片集合,输出是函数 

小样本分类 数据增强 python 适合小样本的分类算法_pytorch_04

, 

小样本分类 数据增强 python 适合小样本的分类算法_pytorch_04

 的输入是某张图片,输出是标签。 

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_08

可以这么理解使用元学习算法的小样本学习任务:我有一个数据集{大象,老虎,狮子},小样本学习并非是让模型识别出哪个是老虎、大象或者狮子,而是学习出每个类别之间的差异,以便在新的数据集(比如:{汽车、电视、沙发、鼠标})中更好的分类。

小样本学习图片分类的基本思想

为了更形式化评估元学习算法,在分类问题上,元学习的数据形式和一般监督学习的数据形式也有所不同,最小的数据点不再是一张图片,而是一个一个的小任务。每个小任务中有

小样本分类 数据增强 python 适合小样本的分类算法_数据集_09

 个类别,每个类别有 

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_10

 张图片,我们称这些任务为N-way K-shot图像分类任务,一共有

小样本分类 数据增强 python 适合小样本的分类算法_pytorch_11

 个小任务。当K值很小时(一般K<10),该任务就是小样本图像分类任务了。当K=1时,该任务即为单样本图像分类任务。

除此之外我们还需要知道两个重要概念:

  • 支持集(Support Set):相当于每个小任务中的训练集,包含N个分类标签,每个标签有K张图片。
  • 查询集(Query Set):相当于每个小任务中的测试集,包含Q张未分类的图片。

如下图,为一个3-way 2-shot图像分类任务,蓝色板块是支持集,绿色的是查询集:

小样本分类 数据增强 python 适合小样本的分类算法_数据集_12

 注意,对于元学习而言,上图的3-way 2-shot图像分类任务只是一个数据点,完整的数据集及其训练集-测试集划分如下图所示:

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_13

  

元学习流派

Black Box / Model-based
为小数据集场景专门制定一个能够快速变化参数的模型。代表作有:MANN,MetaNet等。

Optimization Based
通过让模型快速优化自己的参数来实现小样本学习。代表作有:MAML,NAIL,Reptile等。

Metric Based(基于度量的方法)
也是目前主流的方法,通过学习一个Encoder,将数据映射到一个表征空间,然后使用无参的Decoder来进行分类。代表作有Matching Network,Prototypical Network等。

Prototypical Network基本原理

以一个episode为例,其中包含 {狗,虫子,鸟} 三个类别的图片。

1、首先,对支持集中的每张图片使用编码器  

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_14

 进行信息提取,学习到每张图片的Embedding编码表示。(编码器可以选择常规的卷积操作、resnet系列、vit等等)

小样本分类 数据增强 python 适合小样本的分类算法_数据集_15

2、 然后对支持集的每个类别下的Embeddings做均值处理,得到每个类别的原型表示(class Prototype)。

小样本分类 数据增强 python 适合小样本的分类算法_分类_16

  3、对查询集中的图片进行分类。首先使用编码器 

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_14

 将查询集图片进行编码,得到该图片的Embedding向量表示。

小样本分类 数据增强 python 适合小样本的分类算法_pytorch_18

 4、然后拿着这个Embedding表示和类别原型进行相似度计算,也就是无参的解码过程。(相似度计算的方式很多,可以是欧氏距离或者余弦相似度等)

 

小样本分类 数据增强 python 适合小样本的分类算法_数据集_19

 5、计算完相似度后,往往还需要使用softmax将相似度激活成概率分布。

最终得到查询集图片的分类标签,然后和真实值标签做交叉熵loss,然后梯度反向传播即可完成一个episode的训练。

Prototypical Network算法描述

1、假设原始数据集为D,对于每一个episode,包含一个支持集和一个查询集,即

小样本分类 数据增强 python 适合小样本的分类算法_数据集_20

  。

实现方法就是在原始数据集 D 中随机选取N个类别,每个类别选取K张图片,构成支持集,选取Q张图片,构成查询集,这样就组成了一个episode的小数据集。以此类推,构造

小样本分类 数据增强 python 适合小样本的分类算法_pytorch_11

个小数据集。

2、 对每张图片,利用Encoder进行特征提取,即

 

小样本分类 数据增强 python 适合小样本的分类算法_数据集_22

  

小样本分类 数据增强 python 适合小样本的分类算法_pytorch_23

 

3、计算出支持集中的每个类别的原型(prototype),即

小样本分类 数据增强 python 适合小样本的分类算法_深度学习_24

其中,

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_25

 表示图片 

小样本分类 数据增强 python 适合小样本的分类算法_分类_26

 的类别标签。

 4、接下来计算每个查询集图片Embedding与每个类别的相似度,即

小样本分类 数据增强 python 适合小样本的分类算法_小样本分类 数据增强 python_27

5、训练用的损失函数,公式如下:

小样本分类 数据增强 python 适合小样本的分类算法_pytorch_28

Prototypical Network 的 Pytorch实现

prototypical network是有官方的论文实现的【prototypical-network源码】,而且很多框架里自带原型网络的包,直接调用即可。

但是官方的论文源码比较难看,而且某些场景需要拆解组合时也不方便,因此这里我自己实现了一个精简版的原型网络【我自己的代码复现】

总结

一般来说way与shot准确率的关系如下所示:

小样本分类 数据增强 python 适合小样本的分类算法_pytorch_29

 这个很好理解,一个episode中类别(way)越少,就越容易找出图片之间的异同,比如二分类就比十分类容易一些;一个episode中同一个类别的样本(shot)越多,就越容易找出图片之间的异同。

本质上来说,原型网络就是集成学习中的stacking思想。