Vision Transformers for Dense Prediction

论文链接:https://arxiv.org/abs/2103.13413v1 论文代码:https://github.com/isl-org/DPT

Abstract

本文引入dense vision transformers,它用vision transformers 代替卷积网络作为密集预测(dense prediction)任务的主干。将来自 Vision Transformer 各个阶段的token组装成各种分辨率的类似图像的表示,并使用卷积解码器将它们逐步组合成全分辨率预测。Transformer 主干以恒定且相对较高的分辨率处理表示,并且在每个阶段都具有全局感受野。与全卷积网络相比,这些属性允许密集视觉 Transformer 提供更细粒度和更全局连贯的预测。在深度估计和语义分割的任务中,模型都有更好的表现。

Introduction

卷积网络:

  • 将网络分为encoderdecoder,encoder通常基于图像分类网络,也称为主干,它是在一个大型语料库(如ImageNet)上进行预训练的。decoder聚合来自encoder的特征,并将其转换为最终的密集预测 。
  • 卷积骨干逐步下采样输入图像,提取多尺度特征。下采样逐步增加感受野,将低级特征分组为抽象的高级特征,同时确保网络的内存和计算需求保持易于处理。
  • 缺点:在encoder中进行下采样时,特征分辨率和粒度在模型的更深阶段丢失,因此很难在解码器中恢复。但理想情况下,应该以接近输入图像的分辨率来解析特征。

DPT(dense prediction

  • 基于编码器-解码器设计,利用Transformer作为编码器的基本计算构建块
  • 使用ViT作为主干架构
  • 将由ViT提供的bag-of-words representation重组为各种分辨率下的类图像特征表示,并使用卷积decoder逐步将这些特征表示组合到最终的密集预测中
  • ViT主干在计算初始图像嵌入后放弃显式的降采样操作,并在所有处理阶段中保持不变的维数表示
  • 在每个阶段都有一个全局感受野

Architecture

vision transformer pytorch源码 vision transformer for dense prediction_深度学习

Transformer encoder

在高层次上,ViT对图像的bag-of-words representation进行操作:

  • image patch被单独嵌入特征空间,或者从图像中提取深度特征,充当文字word的角色。把嵌入的word称为token(orange)。
  • Transformer使用串联的多头自注意(MHSA)转换token集合,token彼此关联以转换表示。

Embed

  • 方式1:ViT通过处理图像中所有大小为vision transformer pytorch源码 vision transformer for dense prediction_计算机视觉_02像素的非重叠正方形patch,从图像中提取一个patch嵌入。这些patch被压平成矢量,并使用线性投影分别嵌入。
  • 方式2:通过对图像应用ResNet50来提取嵌入,并使用生成的特征映射的像素特征作为token。
  • Transformer不保留单个token的空间位置信息。因此,图像嵌入与可学习的位置嵌入相连接,以将该信息添加到表示中。
  • ViT添加一个不是基于输入图像的特殊的token(readout token),作为用于分类的最终全局图像表示(red)。
  • 大小为H*W像素的图像embed后的输出是一个token的集合,记为vision transformer pytorch源码 vision transformer for dense prediction_深度学习_03vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_04,其中vision transformer pytorch源码 vision transformer for dense prediction_人工智能_05vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_06:readout token,D:每个token的特征维数。

Transformer

  • 输入token通过L个transformer层转换为新的表示vision transformer pytorch源码 vision transformer for dense prediction_计算机视觉_07,其中l表示第l个transformer层的输出。
  • 变体1:ViT -base,它使用了基于patch的嵌入程序,具有12个transformer层,D=768。
  • 变体2:ViT-Large,采用相同的嵌入程序,有24个transformer层,特征尺寸D更宽,为1024。
  • 变体3:Vit - hybrid,它使用ResNet50来计算图像嵌入,然后使用12个转换器层,以输入分辨率的1/16提取特征
  • 在所有实验中使用patch大小p = 16
  • ViT -base和ViT-Large中的特征维度D都大于输入patch中像素的数量,这意味着如果对任务有益,embed过程可以学会保留信息。从输入patch中得到的特征原则上可以以像素级的精度处理。

ViT编码器在所有transformer阶段都保持初始嵌入的空间分辨率:

  • 原因1:Transformer在所有计算中保持token的数量不变
  • 原因2:token与图像patch具有一对一的对应关系

在初始嵌入后的每一阶段,转换器都有一个全局的感受野:

  • 原因:每一个token都可以关注并影响每一个其他token

Convolutional decoder

decoder将一组token组合成各种分辨率下的类似图像的特征表示,特征表示逐渐fuse到最终的密集预测中。本文提出了一个简单的三阶段重组操作,从transformer编码器的任意层的输出token恢复类似图像的表示:

vision transformer pytorch源码 vision transformer for dense prediction_计算机视觉_08
s:输出表示的大小与输入图像大小的比例
vision transformer pytorch源码 vision transformer for dense prediction_编码器_09:输出特征的维度

  • vision transformer pytorch源码 vision transformer for dense prediction_人工智能_10个token映射到vision transformer pytorch源码 vision transformer for dense prediction_人工智能_11个token,这些token可以被连接成类图像表示Read:vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_12vision transformer pytorch源码 vision transformer for dense prediction_人工智能_13vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_14
    Read用于处理readout token,使用三种映射方式:
    ①忽略readout token:vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_15
    ②将representation相加:vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_16
    ③通过将readout连接到所有其他token,将信息传递给其他token,然后使用线性层和GELU非线性将表示投射到原始特征维度D。
    vision transformer pytorch源码 vision transformer for dense prediction_人工智能_17
  • 在read block之后,通过根据初始patch在图像中的位置放置每个token,将产生的vision transformer pytorch源码 vision transformer for dense prediction_人工智能_11个token重塑为类似图像的表示。使用一个空间连接操作,得到大小为vision transformer pytorch源码 vision transformer for dense prediction_计算机视觉_19,通道数为D的特征映射。
    Concatenate:vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_14vision transformer pytorch源码 vision transformer for dense prediction_人工智能_13vision transformer pytorch源码 vision transformer for dense prediction_人工智能_22
  • 将此表示传递到一个空间重采样层,缩放大小至vision transformer pytorch源码 vision transformer for dense prediction_编码器_23,每像素具有vision transformer pytorch源码 vision transformer for dense prediction_深度学习_24个特征
    ①首先vision transformer pytorch源码 vision transformer for dense prediction_编码器_25卷积将输入表示映射到vision transformer pytorch源码 vision transformer for dense prediction_计算机视觉_26,之后如果svision transformer pytorch源码 vision transformer for dense prediction_人工智能_27p,使用(strided) vision transformer pytorch源码 vision transformer for dense prediction_深度学习_28卷积下采样,如果 s < p, 使用(strided)vision transformer pytorch源码 vision transformer for dense prediction_深度学习_28 转置卷积上采样。
    ②在四个不同的阶段和四种不同的分辨率重新组合特征。我们以较低的分辨率聚合来自transformer较深层的特征,而以较高的分辨率聚合来自较浅层的特征。当使用Vit-large时,从层l= {5 12 18 24}重新集合token;而对于Vit-base,使用层l ={3 6 9 12}。使用ViT-Hybrid时,使用来自嵌入网络的第一个和第二个ResNet块的特征和l ={9 12}的token 。默认使用投影作为readout操作,并取vision transformer pytorch源码 vision transformer for dense prediction_计算机视觉_26=256。分别将这些架构称为DPT-Base、DPT-Large和DPT-Hybrid。
  • 利用基于refine-net的特征融合,将连续阶段提取的特征映射结合起来,并在每个融合阶段逐步以2倍向上取样。最终的表示尺寸只有输入图像分辨率的一半。我们附加一个特定于任务的输出头来产生最终的预测。
Handling varying image sizes

类似于完全卷积网络,DPT可以处理不同大小的图像。只要图像大小能被p整除,就可以应用嵌入程序,产生不同的图像token的数量vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_31。作为一种set-to-set的体系结构,transformer encoder可以简单地处理不同数量的token。但是,位置嵌入依赖于图像的大小,因为它对输入图像中的patch的位置进行编码。我们遵循ViT中提出的方法,将位置嵌入线性插值到适当的大小。上述操作可以在每个图像上动态完成。在嵌入过程和转换阶段之后,只要输入图像与卷积解码器(32像素)的步幅对齐,重组和融合模块都可以简单地处理不同数量的token。

Experiments

Monocular Depth Estimation

Semantic Segmentation

  • Loss: employ a cross-entropy loss and add an auxiliary output head together with an auxiliary loss to the output of the penultimate fusion layer

vision transformer pytorch源码 vision transformer for dense prediction_计算机视觉_32


vision transformer pytorch源码 vision transformer for dense prediction_人工智能_33


vision transformer pytorch源码 vision transformer for dense prediction_深度学习_34


vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_35


vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_36


vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_37

用开源模型跑得到的结果,左:输入原图,中:深度估计,右:语义分割

vision transformer pytorch源码 vision transformer for dense prediction_人工智能_38


vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_39


vision transformer pytorch源码 vision transformer for dense prediction_卷积网络_40