CCNet: Criss-Cross Attention for Semantic Segmentation

Abstract

上下文信息对于语义分割和目标检测任务都很重要,这里提出CCNet。对于每个像素,criss-cross attention模块能获得其交叉路径上所有像素的上下文信息,通过进一步的递归操作,每个像素最终可以捕获全图像的依赖关系。此外,提出类别一致损失使得criss-cross attention模块生成更具判别性的特征。CCNet有以下优点:(1)GPU显存友好,比non-local block少11倍显存消耗 (2)高计算效率,比non-local block少85% (3)最先进性能,Cityscapes可达81.9%
开源地址:https://github.com/speedinghzl/CCNet 这里推荐一篇很有意思的工作,通过注意力机制处理预测head关联多目标跟踪领域的检测和ReID特征:Rethinking the competition between detection and ReID in Multi-Object Tracking,该论文不但构思了一种简单有效的注意力机制,还巧妙地利用注意力机制来交叉关联两个任务,避免了检测和ReID的竞争,并联合提升了彼此分支的性能。

Introduction

FCN是固定的几何结构(卷积的网格效应),局部感受野只能提供短距离上下文信息。为了弥补FCN的缺点,deeplab系列提出多尺度空洞卷积结构的ASPP模块聚合上下文信息,PSPNet引入金字塔池化模块捕获上下文信息。然而,基于空洞卷积的方法从一些周围像素收集信息,不能准确产生稠密的上下文信息;基于池化的方法以非自适应的相同上下文提取策略处理所有像素,不能满足不同像素需要不同上下文依赖的要求。Non-local的空间和时间复杂度高,需要改进。

Related work

UNet,Deeplabv3+,MSCI,SPGNet,RefineNet,DFN采取encoder-decoder结构,融合低层次和高层次信息做出稠密预测。Scale-adaptive Convolutions(SAC)和Deformable Convolutional Network(DCN)改善标准卷积处理目标形变和各种尺寸目标…

Approach

1.Network Architecture

crnn 如何训练_池化


上图是Non-local block和Criss-Cross Attention block结构的简化示意图,Non-local block通过计算任意两个位置之间的交互直接捕获远距离依赖,而不用局限于相邻点,但是由于每个位置对应的向量(共crnn 如何训练_Cross_02个)都要和crnn 如何训练_Cross_02个向量相乘,带来的计算量偏大;而Criss-Cross Attention block一次只用考虑"十字交叉"的同行同列的向量,即每个位置对应的向量(crnn 如何训练_Cross_02个)都要和crnn 如何训练_计算机视觉_05个向量相乘,这样捕获的是一个位置和”十字交叉“路径上其他位置的依赖,但是将输出结构果再次送入Criss-Cross Attention block即可获得一个位置与全局位置的依赖,进一步的理解可以看下文的Criss-cross Attention模块细节。

crnn 如何训练_计算机视觉_06

Backbone是全卷积网络,移去最后两个下采样操作,并且随后的卷积层都采取空洞卷积(带空洞卷积的FCN),输出特征图crnn 如何训练_卷积_07为输入图像的1/8。crnn 如何训练_卷积_07通过卷积层降低通道维度输出crnn 如何训练_crnn 如何训练_09crnn 如何训练_crnn 如何训练_09送入criss-cross attention模块聚合每个像素交叉路径上的上下文信息得到crnn 如何训练_crnn 如何训练_11crnn 如何训练_crnn 如何训练_11再次送入criss-cross attention模块输出crnn 如何训练_crnn 如何训练_13,则crnn 如何训练_crnn 如何训练_13的每个像素聚合了所有像素的信息。两个criss-cross attention模块共享参数,取名为recurrent Criss-Cross Attention(RCCA)模块。然后,crnn 如何训练_crnn 如何训练_13和特征crnn 如何训练_卷积_07进行concat,接着是一个或几个带BN的卷积层和激活层用于特征融合,最后融合的特征送入分割层预测最终的分割结果。

2.Criss-cross Attention

考虑一个局部特征图crnn 如何训练_Cross_17,首先通过两个crnn 如何训练_池化_18卷积生成两个特征图crnn 如何训练_池化_19crnn 如何训练_crnn 如何训练_20crnn 如何训练_Cross_21crnn 如何训练_池化_22是比crnn 如何训练_卷积_23小的通道数,形状为crnn 如何训练_crnn 如何训练_24的三维特征图可以很容易reshape成二维的crnn 如何训练_Cross_25的矩阵。通过Affinity操作生成注意力图crnn 如何训练_池化_26,对于特征图crnn 如何训练_池化_19的每一个位置crnn 如何训练_池化_28,拉出一条维度为crnn 如何训练_池化_22的向量crnn 如何训练_Cross_30。同时,从特征图crnn 如何训练_crnn 如何训练_20拉出同属于crnn 如何训练_池化_28位置的同行或同列的crnn 如何训练_计算机视觉_05条(crnn 如何训练_卷积_34会包括crnn 如何训练_池化_28位置两次)维度均为crnn 如何训练_Cross_36的向量集crnn 如何训练_计算机视觉_37crnn 如何训练_卷积_38crnn 如何训练_crnn 如何训练_39的第crnn 如何训练_卷积_40个元素(向量),则Affinity操作可用公式表达如下:crnn 如何训练_池化_41 crnn 如何训练_计算机视觉_42是特征crnn 如何训练_计算机视觉_43crnn 如何训练_crnn 如何训练_44之间的关联度,crnn 如何训练_池化_45,且crnn 如何训练_Cross_46,然后对crnn 如何训练_计算机视觉_47在通道维度上添加softmax层,输出注意力图A

输入特征图H通过另外一个crnn 如何训练_池化_18卷积生成特征图crnn 如何训练_卷积_49,对于特征图crnn 如何训练_池化_50的每一个位置crnn 如何训练_池化_28,同理拉出维度为crnn 如何训练_卷积_23的向量crnn 如何训练_卷积_53和向量集 crnn 如何训练_计算机视觉_54,然后给出聚合(Aggregation)操作的公式如下:crnn 如何训练_卷积_55
其中,crnn 如何训练_卷积_56crnn 如何训练_卷积_57中位置crnn 如何训练_池化_28的特征向量,crnn 如何训练_Cross_59是注意力图A中位置crnn 如何训练_池化_28对应的第crnn 如何训练_卷积_40个数值。最后是以残差的形式输出crnn 如何训练_crnn 如何训练_11,增强了像素级的表达能力,并聚合了全局上下文信息,提升了语义分割的性能。

crnn 如何训练_crnn 如何训练_63


Recurrent Criss-Cross Attention(RCCA)模块包含两个Criss-Cross Attention模块,且是共享参数的,RCCA可获得一个位置与全局位置的依赖,能够获得稠密丰富的下文信息。3.Learning Category Consistent Features

对于语义分割任务,同一类像素应该有相似的特征,不同类像素应该有差别大的特征,这被称作类别一致性。论文认为,RCCA模块聚合的特征可能会存在过度平滑的问题,这是图神经网络的常见问题,因此除了使用交叉熵损失crnn 如何训练_池化_64监督外,还提出了类别一致损失。RCCA模块输出后接crnn 如何训练_池化_18卷积降低特征图通道数(这里设置16),在这个低通道数特征图crnn 如何训练_卷积_66上添加类别一致性损失。假定crnn 如何训练_卷积_23是mini-batch images里存在的类别数,crnn 如何训练_卷积_68是属于类别crnn 如何训练_计算机视觉_69的有效元素数目,crnn 如何训练_crnn 如何训练_70是特征图M空间位置crnn 如何训练_卷积_40对应的特征向量(crnn 如何训练_卷积_40是属于类别crnn 如何训练_Cross_73的,是crnn 如何训练_卷积_68中的一个元素),crnn 如何训练_Cross_75是类别crnn 如何训练_Cross_73的平均特征向量(聚类中心),crnn 如何训练_计算机视觉_77计算两者之间的距离进行惩罚,希望同类别像素对应的特征向量具有相似性,靠近该类聚类中心最好;crnn 如何训练_池化_78crnn 如何训练_crnn 如何训练_79是两个不同类别的聚类中心(特征向量),crnn 如何训练_crnn 如何训练_80计算两个类别中心之间的距离进行惩罚,两两类别计算,希望不同类别像素的聚类中心越远越好。crnn 如何训练_计算机视觉_81是聚类中心向量的正则项损失,最终损失是所有损失的加权和:crnn 如何训练_卷积_82其中,设置crnn 如何训练_池化_83。总之,类别一致损失是从特征上,希望同类别像素特征具有相似性,不同类相似特征具有差异性。损失具体公式如下:

crnn 如何训练_计算机视觉_84