Arxiv 2106 - CAT: Cross Attention in Vision Transformer

Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_自然语言处理

本文仅做核心模块的粗略说明,力求对本文工作核心差异的完整展示,具体细节可见参考上面的解读文章。

主要内容

Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_深度学习_02


Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_transformer_03


Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_自然语言处理_04

  • Cross Attention Block (CAB) = Inner-Patch Self-Attention Block (IPSA) + Cross-Patch Self-Attention Block (CPSA):
  • IPSA:就是标准的基于patch的attention,即attention的输入为​​B*nph*npw,ph*pw,C​​​大小的tensor,得到的是空间大小为​​ph*pw,ph*pw​​的attention矩阵。该模块建模了patch内部的全局关系。
  • CPSA:这里处理的方式和以往的改进不太一样。这里attention计算的输入为​​B*C,nph*npw,ph*pw​​​。对应的attention矩阵大小为​​nph*npw,nph*npw​​,这里计算过程中是吧每个patch内部单一通道上的空间维度作为了每个patch信息的表示,从而通过相似性计算将这一维度给吸收了。这一模块基于通道独立的操作设计,构建了全局patch之间轻量的信息交互形式。

核心代码

x = x.view(B, H, W, C)

# partition
patches = partition(x, self.patch_size) # nP*B, patch_size, patch_size, C
patches = patches.view(-1, self.patch_size * self.patch_size, C) # nP*B, patch_size*patch_size, C

# IPSA or CPSA
if self.attn_type == "ipsa":
attn = self.attn(patches) # nP*B, patch_size*patch_size, C
elif self.attn_type == "cpsa":
patches = patches.view(B, (H // self.patch_size) * (W // self.patch_size), self.patch_size ** 2, C).permute(0, 3, 1, 2).contiguous()
patches = patches.view(-1, (H // self.patch_size) * (W // self.patch_size), self.patch_size ** 2) # nP*B*C, nP*nP, patch_size*patch_size
attn = self.attn(patches).view(B, C, (H // self.patch_size) * (W // self.patch_size), self.patch_size ** 2)
attn = attn.permute(0, 2, 3, 1).contiguous().view(-1, self.patch_size ** 2, C) # nP*B, patch_size*patch_size, C
else :
raise NotImplementedError(f"Unkown Attention type: {self.attn_type}")

# reverse opration of partition
attn = attn.view(-1, self.patch_size, self.patch_size, C)
x = reverse(attn, self.patch_size, H, W) # B H' W' C

x = x.view(B, H * W, C)

实验结果

Vision Transformer | Arxiv 2106 - CAT: Cross Attention in Vision Transformer_人工智能_05