论文:CANet: Class-Agnostic Segmentation Networks with Iterative Refinement and Attentive Few-Shot Learning(2019年CVPR)
论文要解决的问题:
- 利用小样本训练一个比较好的分割模型
语义分割任务需要对每个像素点进行分类,因此训练语义分割模型的数据集需要人为的对每个像素点进行标记,这是一项繁琐和成本高的工作。(data labeling for pixelwise segmentation is tedious and costly) - 一般的语义分割模型不能对未见过的新类别信息进行较好的预测
先前的语义分割模型一般只能预测训练数据集中包含(预先定义)的类别信息,泛化能力差(a trained model can only make predictions within a set of pre-defined classes.) - 所以论文提出一种类别无关的语义分割模型CANet,只需要几张/少量标注的新类别样本,模型就能够准确分割。模型包括三部分:密集比较,迭代优化和ASPP模块。整个模型可以单样本和多样本(k=5)学习,区别在于多样本学习时加了注意力机制,得到各个样本的权重。
一、Dense comparison module (DCM)
结构/过程:backbone提取特征网络为Res-50,只用了block2和block3中间层的特征,block2后经过dilatinotallow=2的卷积得到block3,然后拼接后再经过一层卷积重新表示图像特征信息。接着mask和图像特征点乘操作,得到的结果只包含了类别的信息,去除了图像中无关的背景信息。最后使用全局平均池化得到特定类别的特征向量(类别信息都包含在这个特征向量中)。
注意:只用了block2和block3中间层的特征,由于中间层特征更加class-level,保留了低层的颜色,边界等信息,也保留了类别等抽象语义信息。
dilation rate作用:在不降低空间分辨率下,也能够增大感受野;如果设置不同dilation rate,获取了多尺度信息。在这一层中,由于block2和block3的特征要拼接,分辨率要一样。
细节: mask和图像特征点乘,需要相同的分辨率,需对mask进行插值处理;由于要和query图像特征密集比较,需对全局平均池化得到特定类别的特征向量进行上采样扩充。对应代码如下:
mask = F.interpolate(mask, support.shape[-2:], mode=‘bilinear’, align_corners=True)
z = z.expand(-1, -1, feature_size[0], feature_size[1])
密集比较:support和query生成相同维度的特征表示,两者拼接一起,目的是对空间每一个位置都进行比较,实现时用一个33的卷积核进行卷积操作,最后得到是匹配后的结果。
二、迭代优化(iterative optimization module (IOM)
同一类别的不同图像,存在明显地外观差异,上一步密集比较可能只匹配了目标地一部分,所以需进行迭代优化来得到准确的分割结果。
实现:论文提到不能直接把预测的mask拼接到密集比较得到的feature map上,因为第一次并没有预测的mask,特征通道要保持一致。所以,采用残差的形式来融合预测地mask信息。
融合预测地mask信息,相当于自监督的方式,残差学习具有这种性质(恒等映射),至于后面又接了两个残差块(估计是做实验觉得好,或是和ResNet保持一致,没说这样做的目的)。这里还有一个实现细节,we alternatively use predicted masks in the last epoch and empty masks as the input to IOM. The predicted masks yt−1 is reset to empty masks with a probability of pr. 代码如下:
out_plus_history=torch.cat([out,history_mask],dim=1) #和历史的mask拼接[1,258,41,41]
#分别经过3个残差块的迭代
out = out + self.residule1(out_plus_history)
out = out + self.residule2(out)
out = out + self.residule3(out)
三、ASPP(Atrous Spatial Pyramid Pooling module)模块
主要是为了捕获多尺度信息,实现:包含3个atrous rates分别为6,12,18的33卷积以及一个11的卷积,值得注意的是论文在这一部分多了一个11的卷积操作(提取全局信息),实际拼接有5个部分。代码如下:
#ASPP模块
global_feature=F.avg_pool2d(out,kernel_size=feature_size) #[1,256,1,1]全局特征
global_feature=self.layer6_0(global_feature) #全卷积
global_feature=global_feature.expand(-1,-1,feature_size[0],feature_size[1]) #扩充[1,256,41,41]
#不同dialation卷积,然后拼接[1,1280,41,41]
out=torch.cat([global_feature,self.layer6_1(out),self.layer6_2(out),self.layer6_3(out),self.layer6_4(out)],dim=1)
最后通过两个1*1的卷积模块得到最终的输出。
四、Attention Mechanism for k-shot Segmentation
对于多个样本(k=5)的情况下,只是先计算了下各个样本权重,最后对个样本进行加权求和。计算权重的方式是2个卷积+全局平均池化。
总结下:全文的迭代优化这种思想可以借鉴,尤其是把分割预测的mask添加进来,实现自监督。