Arxiv 2112 | SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS

Vision Transformer | Arxiv 2112 - SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS_transformer

现有方法的问题

  • 主要针对于图像或窗口内部的token计算全局注意力,而破坏了patch之间在2D结构上的空间和局部相关性。
  • 此外,由于位置编码的独特性,目前视觉Transformer缺少平移不变性(translation invariant)。

本文的工作

Vision Transformer | Arxiv 2112 - SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS_深度学习_02

核心计算逻辑

划窗逻辑

Vision Transformer | Arxiv 2112 - SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS_深度学习_03

Vision Transformer | Arxiv 2112 - SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS_人工智能_04

通过使用CNN中基于滑动窗口的层次结构带来的局部结构保留,以及Transformer中的自我注意力的信息聚集,这份工作弥合了CNN和变压器对于视觉数据建模的认知差距。

  • 提出了multi-head central self-attention来替换标准的msa。
  • 使用重叠划窗的形式集成空间信息和跨窗口的连接,保留局部结构。
  • 每个划窗中,仅计算中间的patch与周围patch之间的相关性。
  • 仅使用在前三个阶段中。第四个阶段中使用MSA被用来建立全局依赖。
  • 因为交互关系本身限制在了局部范围,这可以引入平移不变性,所不再使用位置编码。

核心代码

可见这里是基于​​unfold​​​操作来将k和v对应位于划窗中的token聚集到一个独立的维度上得到​​ks*ks, hc​​​大小的tensor,而q则是​​1, hc​​​大小,qk计算则消去​​hc​​​得到​​1, ks*ks​​​,qkv计算得到​​1, hc​​。

class CenterAttention(nn.Module):
def __init__(self,
dim,
num_heads=1,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
stride=1,
padding=True,
kernel_size=3):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.k_size = kernel_size # kernel size
self.stride = stride # stride
self.in_channels = dim
self.num_heads = num_heads
self.head_channel = dim // num_heads
# it seems that padding must be true to make unfolded dim matchs query dim h*w*ks*ks
self.pad_size = kernel_size // 2 if padding is True else 0 # padding size
self.pad = nn.ZeroPad2d(self.pad_size) # padding around the input
self.scale = qk_scale or (dim // num_heads)**-0.5
self.unfold = nn.Unfold(kernel_size=self.k_size, stride=self.stride, padding=0, dilation=1)

self.qkv_bias = qkv_bias
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.kv_proj = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.softmax = nn.Softmax(dim=-1)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x, H, W):
B, N, C = x.shape
x = x.reshape(B, H, W, C)
assert C == self.in_channels

self.pat_size_h = (H+2 * self.pad_size-self.k_size) // self.stride+1
self.pat_size_w = (W+2 * self.pad_size-self.k_size) // self.stride+1
self.num_patch = self.pat_size_h * self.pat_size_w

# (B, NumHeads, H, W, HeadC)
q = self.q_proj(x).reshape(B, H, W, self.num_heads, self.head_channel).permute(0, 3, 1, 2, 4)
# query need to be copied by (self.k_size*self.k_size) times
q = q.unsqueeze(dim=4)
q = q * self.scale
# if stride is not 1, q should be masked to match ks*ks*patch

# (2, B, NumHeads, HeadsC, H, W)
kv = self.kv_proj(x).reshape(B, H, W, 2, self.num_heads, self.head_channel).permute(3, 0, 4, 5, 1, 2)
kv = self.pad(kv) # (2, B, NumH, HeadC, H, W)
kv = kv.permute(0, 1, 2, 4, 5, 3) # (2, B, NumH, H, W, HeadC)

H, W = H + self.pad_size * 2, W + self.pad_size * 2
# unfold plays role of conv2d to get patch data
kv = kv.permute(0, 1, 2, 5, 3, 4).reshape(2 * B, -1, H, W) # (2*B, NumH*HeadC, H, W)
kv = self.unfold(kv)
kv = kv.reshape(2, B, self.num_heads, self.head_channel, self.k_size**2,
self.num_patch) # (2, B, NumH, HC, ks*ks, NumPatch)
kv = kv.permute(0, 1, 2, 5, 4, 3) # (2, B, NumH, NumPatch, ks*ks, HC)
k, v = kv[0], kv[1]

# (B, NumH, NumPatch, 1, HeadC)
q = q.reshape(B, self.num_heads, self.num_patch, 1, self.head_channel)
attn = (q @ k.transpose(-2, -1)) # (B, NumH, NumPatch, 1, ks*ks)
attn = self.softmax(attn) # softmax last dim
attn = self.attn_drop(attn)

out = (attn @ v).squeeze(3) # (B, NumH, NumPatch, HeadC)
out = out.permute(0, 2, 1, 3).reshape(B, self.pat_size_h, self.pat_size_w, C) # (B, Ph, Pw, C)
out = self.proj(out)
out = self.proj_drop(out)
out = out.reshape(B, -1, C)
return out

实验结果

Vision Transformer | Arxiv 2112 - SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS_注意力机制_05

对比实验

分类

检测

分割

Vision Transformer | Arxiv 2112 - SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS_注意力机制_06

Vision Transformer | Arxiv 2112 - SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS_深度学习_07

Vision Transformer | Arxiv 2112 - SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS_深度学习_08

消融实验

Vision Transformer | Arxiv 2112 - SIMVIT: EXPLORING A SIMPLE VISION TRANSFORMER WITH SLIDING WINDOWS_transformer_09