Arxiv 2205 - EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers

Vision Transformer | Arxiv 2205 - EdgeViTs: Competing Light-weight CNNs on Mobile Devices_深度学习

核心内容

Vision Transformer | Arxiv 2205 - EdgeViTs: Competing Light-weight CNNs on Mobile Devices_计算机视觉_02


Vision Transformer | Arxiv 2205 - EdgeViTs: Competing Light-weight CNNs on Mobile Devices_深度学习_03

  • 仍然遵循金字塔结构形式的Transformer范式。
  • 修改Transformer Block为提出的Local-Global-Local Bottleneck,只对输入Token的子集计算Self-attention,但支持完整的空间交互。
  • Local aggregation:仅集成来自局部近似Token信号的局部聚合。对于每个Token,利用Depth-wise和Point-wise卷积在大小为k×k的局部窗口中聚合信息(图3(a))。
  • Global sparse attention:建模一组代表性Token之间的长期关系,其中每个Token都被视为一个局部窗口的代表。对均匀分布在空间中的稀疏代表性Token集进行采样,每个r×r窗口有一个代表性Token。这里,r表示子样本率。然后,只对这些被选择的Token应用Self-attention(图3(b))。这与所有现有的ViTs不同,在那里,所有的空间Token都作为Self-attention计算中的query被涉及到。
  • Local propagation:将委托学习到的全局上下文信息扩散到具有相同窗口的非代表Token。通过转置卷积将代表性Token中编码的全局上下文信息传播到它们的相邻的Token中(图3©)。

LGL bottleneck与最近的PVTs和Twins-SVTs模型有一个相似的目标,这些模型试图减少Self-attention开销。然而,它们在核心设计上有所不同。

  • PVTs执行Self-attention,其中Key和Value的数量通过strided-convolutions减少,而Query的数量保持不变。换句话说,PVTs仍然在每个网格位置上执行Self-attention。在这项工作中,作者质疑位置级Self-attention的必要性,并探索由LGL bottleneck所支持的信息交换在多大程度上可以近似于标准的MHSA。
  • Twins-SVTs结合了Local-Window Self-attention和PVTs的Global Pooled Attention。这不同于LGL bottleneck的混合设计,LGL bottleneck同时使用分布在一系列局部-全局-局部操作中的Self-attention操作和卷积操作。

性能

Vision Transformer | Arxiv 2205 - EdgeViTs: Competing Light-weight CNNs on Mobile Devices_深度学习_04

伪代码

class LocalAgg():
def __init__(self, dim):
self.conv1 = Conv2d(dim, dim, 1)
self.conv2 = Conv2d(im, dim, 3, padding=1, groups=dim)
self.conv3 = Conv2d(dim, dim, 1)
self.norm1 = BatchNorm2d(dim)
self.norm2 = BatchNorm2d(dim)

def forward(self, x):
"""
[B, C, H, W] = x.shape
"""
x = self.conv1(self.norm1(x))
x = self.conv2(x)
x = self.conv3(self.norm2(x))
return x

class GlobalSparseAttn():
def __init__(self, dim, sample_rate, scale):
self.scale = scale
self.qkv = Linear(dim, dim * 3)
self.sampler = AvgPool2d(1, stride=sample_rate)
kernel_size=sr_ratio
self.LocalProp = ConvTranspose2d(dim, dim, kernel_size, stride=sample_rate, groups=dim)
self.norm = LayerNorm(dim)
self.proj = Linear(dim, dim)

def forward(self, x):
"""
[B, C, H, W] = x.shape
"""
x = self.sampler(x)
q, k, v = self.qkv(x)

attn = q @ k * self.scale
attn = attn.softmax(dim=-1)
x = attn @ v

x = self.LocalProp(x)
x = self.proj(self.norm(x))
return x

class DownSampleLayer():
def __init__(self, dim_in, dim_out, downsample_rate):
self.downsample = Conv2d(dim_in, dim_out, kernel_size=downsample_rate, stride=downsample_rate)
self.norm = LayerNorm(dim_out)

def forward(self, x):
x = self.downsample(x)
x = self.norm(x)
return x

class PatchEmbed():
def __init__(self, dim):
self.embed = Conv2d(dim, dim, 3, padding=1, groups=dim)

def forward(self, x):
return x + self.embed(x)

class FFN():
def __init__(self, dim):
self.fc1 = nn.Linear(dim, dim*4)
self.fc2 = nn.Linear(dim*4, dim)

def forward(self, x):
x = self.fc1(x)
x = GELU(x)
x = self.fc2(x)
return x