1、引言

尽管Transformer方案在语义分割领域取得了非常惊人的性能,但在实时性方面,纯CNN方案仍占据主流地位。本文提出了一种用于实时语义分割的高效对偶分辨率Transformer方案RTFormer,它具有比CNN方案更佳的性能-效率均衡。

为达成GPU设备上的高推理效率,所提RTFormer采用了线性复杂度的GPU友好注意力模块,同时消除了多头机制。此外,作者发现:跨注意力机制对于全局上下文信息聚合非常有效。多个主流基准数据集(Cityscapes, CamVid, COCOStuff, ADE20K)上的实验结果验证了所提RTFormer的有效性。下图给出了CAMVid数据集上不同方案的性能与推理速度对比,很明显:RTFormer具有最佳的性能-速度均衡。

语义分割出左右腿 语义分割transformer_深度学习

2、出发点

ViT技术在CV领域证实其有效性后,相关技术迅速在各个领域取得了一系列的成果。比如语义分割领域的DPT、SegFormer、HRFormer、Segmentor等均取得了非常优异的成绩。但是,相比CNN方案,Transformer方案因自注意力机制问题存在高计算量、高显存占用问题,导致其推理效率明显不如CNN方案。

作者认为:注意力机制在推理效率方面的瓶颈主要源自以下两个维度:

1. 现有注意力机制的计算属性对于GPU设备不够友好,如二次复杂度、多头机制

2. 仅在高分辨率特征图实施注意力可能并非最有效捕获长距离上下文关系的方案,这是因为高分辨率特征的单个特征向量的感受野非常有限

基于上述所提到的两个局限性,本文提出了一种GPU友好的注意力模块与跨分辨率注意力模块,并由此构建了RTFormer。

3、本文方案

接下来,我们首先对本文所提GPU友好注意力RTFormer模块进行介绍,然后结合如何基于RTFormer模块构建RTFormer分割架构。

语义分割出左右腿 语义分割transformer_transformer_02

上图给出了本文所提RTFormer模块示意图,它是一种对偶分辨率模块,它包含两种类型注意力模块。在低分辨率分支,作者采用了GPU友好的注意力模块以捕获高层全局上下文信息,而在高分辨率分支,作者则引入了跨分辨率注意力机制对高层全局上下文信息进行传播扩散,也就是将两个分辨率的特征通过注意力模块进行聚合。

语义分割出左右腿 语义分割transformer_人工智能_03

语义分割出左右腿 语义分割transformer_人工智能_04

主要贡献: (1)它使得矩阵乘操作成为一体且非常适合于GPU设备;(2)它在某种程度上了保持了多头机制的优势。

语义分割出左右腿 语义分割transformer_语义分割出左右腿_05

整体结构

语义分割出左右腿 语义分割transformer_transformer_06

4、实验结果

语义分割出左右腿 语义分割transformer_人工智能_07

核心代码展示
#rtformer中的注意力机制
class ExternalAttention(nn.Layer):
   def __init__(self, ...)
       super().__init__()
       
   def _act_sn(self, x):
       x = x.reshape([-1, self.inter_channels, 0, 0]) * (self.inter_channels ** -0.5)
       x = F.softmax(x, axis=1)
       x = x.reshape([1, -1, 0, 0])
   def _act_dn(self, x):
       x_shape = paddle.shape(x)
       h, w = x_shape[2], x_shape[3]
       x = x.reshape([0, self.num_heads, self.inter_channels //self.num_heads, -1])
       x = F.softmax(x, axis=3)
       x = x / (paddle.sum(x, axis=2, keepdim=True) + 1e-06)
       x = x.reshape([0, self.inter_channels, h, w])
       
   def forward(self, x, cross_k=None, cross_v=None):
       x = self.norm(x)
       if not self.use_cross_kv:
           x = F.conv2d(x, self.k, bias=None, stride=2 if not self.same_in_out_chs else 1, padding=0) 
           x = self._act_dn(x)  # n,c_inter,h,w
           x = F.conv2d(x, self.v, bias=None, stride=1, padding=0)
        else:
           B = x.shape[0]
           x = x.reshape([1, -1, 0, 0])  # n,c_in,h,w -> 1,n*c_in,h,w
           x = F.conv2d(x, cross_k, bias=None, stride=1, padding=0, groups=B)  
           x = self._act_sn(x)
           x = F.conv2d(x, cross_v, bias=None, stride=1, padding=0, groups=B)  
           x = x.reshape([-1, self.in_channels, 0, 0]) 
        return x
#整体的RTFormer结构
class RTFormer(nn.Layer):
    def __init__(self, ...):
        super().__init__()
        ...
    def forward(self, x):
        x1 = self.layer1(self.conv1(x))
        x2 = self.layer2(self.relu(x1))
        
        # stage-3
        x3 = self.layer3(self.relu(x2))
        x3_ = x2 + F.interpolate(self.compression3(x3), size=paddle.shape(x2)[2:], mode='bilienar')
        x3_ = self.layer3_(self.relu(x3_))
        
        # stage-4与stage-5的计算流程类似stage-3
        x4_ = ...
        x5_ = ...
        
        # SegHead, DAPPM
        x6 = self.spp(x5)
        x6 = x6 + F.interpolate(x6, size=paddle.shape(x5_), mode='bilinear')
        x_out = self.seghead(paddle.concat([x5_, x6], axis=1))
        return F.interpolate(x_out, paddle.shape(x)[2:], mode='bilinear')

5、ADE20K数据验证(建议使用32GBV100,16GB的环境不稳定)

%cd data/
/home/aistudio/data
!unzip -d ./ data21637/ADE20K.zip
%cd /home/aistudio/PaddleSeg/
/home/aistudio/PaddleSeg
#安装paddleseg
!pip install -r requirements.txt
!python setup.py install
#开启验证
!python val.py --config configs/rtformer/rtformer_base_ade20k_512x512_160k.yml --model_path https://paddleseg.bj.bcebos.com/dygraph/ade20k/rtformer_base_ade20k_512x512_160k/model.pdparams
2022-10-24 23:54:30 [INFO]	
---------------Config Information---------------
batch_size: 4
export:
  transforms:
  - keep_ratio: true
    size_divisor: 32
    target_size:
    - 2048
    - 512
    type: Resize
  - mean:
    - 0.485
    - 0.456
    - 0.406
    std:
    - 0.229
    - 0.224
    - 0.225
    type: Normalize
iters: 160000
loss:
  coef:
  - 1
  - 0.4
  types:
  - type: CrossEntropyLoss
lr_scheduler:
  end_lr: 1.0e-07
  learning_rate: 0.0001
  power: 1.0
  type: PolynomialDecay
  warmup_iters: 1500
  warmup_start_lr: 1.0e-06
model:
  base_channels: 64
  drop_path_rate: 0.1
  head_channels: 128
  in_channels: 3
  num_classes: 150
  pretrained: https://paddleseg.bj.bcebos.com/dygraph/backbone/rtformer_base_backbone_imagenet_pretrained.zip
  type: RTFormer
  use_injection:
  - true
  - false
optimizer:
  beta1: 0.9
  beta2: 0.999
  type: AdamW
  weight_decay: 0.05
train_dataset:
  dataset_root: /home/aistudio/data/ADEChallengeData2016/
  mode: train
  transforms:
  - max_scale_factor: 2.0
    min_scale_factor: 0.5
    scale_step_size: 0.25
    type: ResizeStepScaling
  - crop_size:
    - 512
    - 512
    type: RandomPaddingCrop
  - type: RandomHorizontalFlip
  - brightness_range: 0.4
    contrast_range: 0.4
    saturation_range: 0.4
    type: RandomDistort
  - mean:
    - 0.485
    - 0.456
    - 0.406
    std:
    - 0.229
    - 0.224
    - 0.225
    type: Normalize
  type: ADE20K
val_dataset:
  dataset_root: /home/aistudio/data/ADEChallengeData2016/
  mode: val
  transforms:
  - keep_ratio: true
    size_divisor: 32
    target_size:
    - 2048
    - 512
    type: Resize
  - mean:
    - 0.485
    - 0.456
    - 0.406
    std:
    - 0.229
    - 0.224
    - 0.225
    type: Normalize
  type: ADE20K
------------------------------------------------
W1024 23:54:30.952924  2381 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1024 23:54:30.952975  2381 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
2022-10-24 23:54:32 [INFO]	Loading pretrained model from https://paddleseg.bj.bcebos.com/dygraph/backbone/rtformer_base_backbone_imagenet_pretrained.zip
Connecting to https://paddleseg.bj.bcebos.com/dygraph/backbone/rtformer_base_backbone_imagenet_pretrained.zip
Downloading rtformer_base_backbone_imagenet_pretrained.zip
[==================================================] 100.00%
Uncompress rtformer_base_backbone_imagenet_pretrained.zip
[==================================================] 100.00%
2022-10-24 23:54:34 [WARNING]	spp.scale1.1.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale1.1.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale1.1._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale1.1._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale1.3.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale2.1.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale2.1.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale2.1._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale2.1._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale2.3.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale3.1.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale3.1.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale3.1._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale3.1._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale3.3.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale4.1.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale4.1.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale4.1._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale4.1._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale4.3.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale0.0.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale0.0.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale0.0._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale0.0._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.scale0.2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process1.0.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process1.0.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process1.0._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process1.0._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process1.2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process2.0.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process2.0.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process2.0._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process2.0._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process2.2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process3.0.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process3.0.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process3.0._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process3.0._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process3.2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process4.0.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process4.0.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process4.0._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process4.0._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.process4.2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.compression.0.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.compression.0.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.compression.0._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.compression.0._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.compression.2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.shortcut.0.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.shortcut.0.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.shortcut.0._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.shortcut.0._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	spp.shortcut.2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.bn1.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.bn1.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.bn1._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.bn1._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.conv1.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.bn2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.bn2.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.bn2._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.bn2._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.conv2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead.conv2.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.bn1.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.bn1.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.bn1._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.bn1._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.conv1.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.bn2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.bn2.bias is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.bn2._mean is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.bn2._variance is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.conv2.weight is not in pretrained model
2022-10-24 23:54:34 [WARNING]	seghead_extra.conv2.bias is not in pretrained model
2022-10-24 23:54:35 [INFO]	There are 184/261 variables loaded into RTFormer.
2022-10-24 23:54:35 [INFO]	Loading pretrained model from https://paddleseg.bj.bcebos.com/dygraph/ade20k/rtformer_base_ade20k_512x512_160k/model.pdparams
Connecting to https://paddleseg.bj.bcebos.com/dygraph/ade20k/rtformer_base_ade20k_512x512_160k/model.pdparams
Downloading model.pdparams
[==================================================] 100.00%
2022-10-24 23:54:36 [INFO]	There are 261/261 variables loaded into RTFormer.
2022-10-24 23:54:36 [INFO]	Loaded trained params of model successfully
2022-10-24 23:54:36 [INFO]	Start evaluating (total_samples: 2000, total_iters: 2000)...
2000/2000 [==============================] - 233s 117ms/step - batch_cost: 0.1165 - reader cost: 2.6616e-04
2022-10-24 23:58:30 [INFO]	[EVAL] #Images: 2000 mIoU: 0.4202 Acc: 0.7998 Kappa: 0.7848 Dice: 0.5593
2022-10-24 23:58:30 [INFO]	[EVAL] Class IoU: 
[0.7296 0.7938 0.9383 0.7793 0.7149 0.8047 0.8167 0.8554 0.5648 0.654
 0.5531 0.617  0.757  0.3068 0.404  0.5337 0.5422 0.4496 0.6809 0.5076
 0.8116 0.5159 0.6482 0.5795 0.3588 0.3309 0.5502 0.5389 0.5428 0.2478
 0.3292 0.5039 0.3498 0.3988 0.3631 0.4947 0.562  0.6764 0.3366 0.4716
 0.2512 0.1639 0.4084 0.3022 0.3719 0.3542 0.2735 0.6217 0.59   0.6838
 0.6032 0.3909 0.1998 0.2861 0.6989 0.3906 0.9169 0.4607 0.4942 0.2799
 0.1224 0.483  0.3572 0.3354 0.5314 0.7955 0.2614 0.4505 0.0416 0.3651
 0.5572 0.651  0.4797 0.2642 0.5435 0.4192 0.5963 0.276  0.5266 0.2991
 0.7662 0.5612 0.424  0.1681 0.2262 0.6142 0.179  0.1722 0.3547 0.6268
 0.4608 0.0781 0.2804 0.1483 0.0057 0.0845 0.2348 0.3025 0.1298 0.3377
 0.2877 0.1106 0.3147 0.4934 0.1574 0.6972 0.1958 0.4509 0.0765 0.3081
 0.2897 0.3657 0.198  0.6256 0.8161 0.0888 0.6474 0.827  0.2447 0.4364
 0.5342 0.0459 0.2872 0.191  0.3214 0.3619 0.4993 0.5169 0.5379 0.5373
 0.6439 0.0945 0.3429 0.4754 0.3667 0.2653 0.2008 0.0411 0.29   0.4778
 0.3486 0.0043 0.3989 0.3567 0.3718 0.     0.4519 0.0771 0.1946 0.2957]
2022-10-24 23:58:30 [INFO]	[EVAL] Class Precision: 
[0.8293 0.8534 0.9677 0.8615 0.7968 0.8865 0.9119 0.9015 0.703  0.7697
 0.7205 0.7495 0.8201 0.521  0.6394 0.6801 0.6855 0.6969 0.8259 0.6811
 0.8798 0.6907 0.7501 0.6968 0.5208 0.5468 0.6318 0.7628 0.7908 0.385
 0.4991 0.6248 0.5658 0.5604 0.5218 0.6497 0.741  0.8571 0.534  0.6645
 0.4917 0.4177 0.6872 0.525  0.5066 0.6017 0.4195 0.7764 0.768  0.7569
 0.7383 0.5253 0.3847 0.593  0.7453 0.5458 0.9436 0.7504 0.6457 0.576
 0.2813 0.7471 0.5851 0.6653 0.6217 0.8773 0.3792 0.5884 0.1026 0.5766
 0.6891 0.8524 0.622  0.3301 0.7507 0.6057 0.8412 0.6392 0.8833 0.5756
 0.8503 0.7686 0.7812 0.3535 0.4904 0.7559 0.4981 0.4423 0.8261 0.8019
 0.6184 0.1149 0.4884 0.4074 0.0196 0.247  0.6647 0.6068 0.2687 0.7113
 0.5268 0.2914 0.6267 0.8315 0.9184 0.7498 0.7371 0.6113 0.2787 0.5241
 0.5363 0.4909 0.5245 0.8022 0.8225 0.3417 0.8096 0.8575 0.3379 0.6366
 0.7054 0.2658 0.7805 0.6391 0.7348 0.6842 0.7164 0.6935 0.7441 0.7517
 0.7496 0.4958 0.5997 0.8295 0.7244 0.4825 0.3627 0.0768 0.4863 0.708
 0.5538 0.0084 0.6492 0.6546 0.564  0.     0.7816 0.4388 0.4913 0.8068]
2022-10-24 23:58:30 [INFO]	[EVAL] Class Recall: 
[0.8586 0.9191 0.9686 0.8909 0.8744 0.8971 0.8866 0.9436 0.7419 0.8131
 0.7042 0.7772 0.9078 0.4275 0.5233 0.7126 0.7217 0.5589 0.795  0.6659
 0.9129 0.6708 0.8268 0.7749 0.5357 0.456  0.8099 0.6474 0.6338 0.4102
 0.4916 0.7226 0.4782 0.5803 0.5443 0.6746 0.6993 0.7623 0.4767 0.619
 0.3393 0.2124 0.5017 0.416  0.5831 0.4627 0.4399 0.7573 0.718  0.8762
 0.7673 0.6043 0.2936 0.3561 0.9182 0.5789 0.9701 0.544  0.678  0.3525
 0.1782 0.5774 0.4784 0.4035 0.7853 0.8951 0.4569 0.6578 0.0654 0.4989
 0.7443 0.7337 0.677  0.5696 0.6632 0.5764 0.6719 0.327  0.5659 0.3837
 0.8856 0.6753 0.4811 0.2427 0.2957 0.7662 0.2183 0.2199 0.3833 0.7417
 0.6439 0.1958 0.397  0.1891 0.0079 0.1139 0.2663 0.3763 0.2006 0.3913
 0.388  0.1513 0.3873 0.5482 0.1597 0.9086 0.2105 0.632  0.0953 0.4277
 0.3864 0.589  0.2412 0.7397 0.9906 0.1072 0.7637 0.9588 0.4701 0.5812
 0.6877 0.0525 0.3124 0.2142 0.3636 0.4345 0.6223 0.67   0.6599 0.6533
 0.8203 0.1045 0.4447 0.5269 0.4262 0.3709 0.3103 0.0811 0.418  0.5951
 0.4847 0.0087 0.5085 0.4394 0.5217 0.     0.5173 0.0855 0.2437 0.3182]
2022-10-24 23:58:30 [WARNING]	This `val.py`  will be removed in version 2.8, please use `tools/val.py`.

MIOU为 42.02%与论文中一致。

#训练代码
!python train.py --config configs/rtformer/rtformer_base_ade20k_512x512_160k.yml --do_eval --use_vdl --save_interval 500 --save_dir output
模型预测图片效果
!python tools/predict.py --config configs/rtformer/rtformer_base_ade20k_512x512_160k.yml \
                         --image_path /home/aistudio/ADE_val_00000010.jpg \
                         --model_path https://paddleseg.bj.bcebos.com/dygraph/ade20k/rtformer_base_ade20k_512x512_160k/model.pdparams
语义分割结果可视化
from matplotlib import pyplot as plt  
from PIL import Image 
import numpy as np 
import cv2

ori_image = np.asarray(Image.open("/home/aistudio/ADE_val_00000010.jpg"))
pred_result = np.asarray(Image.open("/home/aistudio/PaddleSeg/output/result/pseudo_color_prediction/ADE_val_00000010.png").convert("RGB"))

#原图
plt.imshow(ori_image)
plt.show()

语义分割出左右腿 语义分割transformer_transformer_08

#预测结果
plt.imshow(pred_result)
(pred_result)
plt.show()

语义分割出左右腿 语义分割transformer_语义分割出左右腿_09

6、总结

1、将多头注意力替换为External attention提升计算速度同时保持性能;

2、设计了不同分辨率融合结构,类似HRNet。

3、可以看到分割结果上仍有不足,未来还有很大改进空间

4、总的来说是一篇可探索实际使用的工作,Nice!