有没有更高效的方式来学习不同空间位置的关系的权重?本文提出了一种更有效的自适应权重混合方法来实现与 Self-attention 相似的功能,在多种视觉任务中取得了不错的性能。对ViT与MLP反复探究后出来了无注意力自适应权重混合视觉模型

AMixer:无注意力自适应权重混合视觉模型

论文名称:AMixer: Adaptive Weight Mixing for Self-Attention Free Vision Transformers

论文地址:

https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136810051.pdf

背景和动机

视觉 Transformer (ViT) 及其变体利用自注意力机制,通过直接从原始数据中,捕捉不同空间位置之间的相互作用进行学习,与传统的 CNN 在很大程度上依赖于人类先验知识不同。ViT 相比 CNN,在架构设计中更少地融入归纳偏置,即卷积操作更多地在局部建模空间信息。纯多层感知器 (MLP) 模型可以进一步简化视觉 Transformer 的架构设计:通过将 Self-attention 模块替换为空间 MLP 模块,MLP 模型可以以一种更加简单和高效的方式,获取到视觉任务中需要的融合不同空间位置信息的能力。但是,一些先驱工作 (比如 MLP-Mixer[1],ResMLP[2]) 的实验结果发现这种无注意力的 MLP 视觉模型的识别效果略低于视觉 Transformer 模型 (比如 76.6% 的 ResMLP-12 vs. 79.8% 的 DeiT-S)。这自然而然存在一个问题:究竟是什么使得视觉 Transformer 比 MLP 更有效?

作为视觉 Transformer 的核心组件 Self-attention,其在前向传播的时候充分考虑了所有空间位置之间的相互关系 (通过计算 Query 和 Key 的 dot-product 结果,attention map,也就是不同空间位置的关系的权重) 。受益于所有 token 之间的关系都有建模,所以与卷积相比,ViT 可以用更少的 Blocks 来更好地模拟长距离的依赖关系。而作为视觉 MLP 的核心组件 Spatial MLP,其在前向传播的时候直接通过一个可学习的,固定的权重来建模所有 tokens 的关系。但是这种简单的修改可能会带来两个缺点:1) 这种类似 Memory 的 Spatial MLP 通常表现出比 Transformer 更弱的表达能力;2) 与 Transformer 不同的是,基于 MLP 的模型由于 Spatial MLP 的权重大小是固定的,因此很难扩展到新的分辨率。

因此,这自然而然存在另一个问题:有没有更高效的方式来学习不同空间位置的关系的权重?

本文作者受这个问题的启发,从一个 MLP 模型开始,逐渐增加一些视觉 Transformer 的设计,观察实验效果。受到 MLP 模型中 Spatial MLP (学习一个类似 Memory 的权重) 的启发,作者希望设计出一种不通过 dot-product 方式得到的动态权重的做法。作者的解决方案通过预测一个小的权重混合矩阵,线性混合来自权重库的一组静态权重,使静态的 Memory 权重可以自适应输入。

用统一的视角看待视觉 Transformer 和 MLP 模型

AMixer_Self

AMixer_自适应_02

重新思考注意力机制

具体来说,视觉 Transformer 包含了所有 MLP 模型所不具备的四个组成部分:

AMixer_Self_03

作者研究了这些组件是否可以为 MLP 模型带来改进,实验结果如下图1所示,是从 ResMLP 开始到超过 DeiT 的过程。 

AMixer_Self_04

图1:从 ResMLP 开始到超过 DeiT 的过程

多头机制的影响

多头机制最初是为了帮助 Self-attention 模型关注多个位置而提出的。多头方案的一个很好的特性是它不会引起额外的计算量。然而目前主流的 MLP 工作很少引入多头机制。作者从 ResMLP 开始,加入多头机制,发现在 ImageNet 上的精度可以从 76.6% 提高到 77.4%。

额外投影的影响

AMixer_Self_05

Softmax 层的影响

自我注意和 MLP 的另一个区别是 Softmax 操作,作者发现它可以将精度提高到78.2%。

AMixer_人工智能_06

自适应权重混合 

AMixer_人工智能_07

AMixer_权重_08

图2:自适应权重混合方法得到类似注意力的权重

相对注意力权重

AMixer_Self_09

自适应权重混合的 PyTorch 代码如下:

class AdaSpatialMLP(nn.Module):
    def __init__(self, dim, n=196, k=16, r=4, num_heads=1, mode='softmax', post_proj=False, pre_proj=False, relative=False):
        super().__init__()

        self.relative = relative
        if not relative:
            self.weight_bank = nn.Parameter(torch.randn(k, n, n, dtype=torch.float32) * 0.02)
        else:
            h = w = int(math.sqrt(n))
            assert h * w == n
            # define a parameter table of relative position bias
            self.weight_bank = nn.Parameter(torch.randn(k, (2 * h - 1) * (2 * w - 1), dtype=torch.float32) * 0.02)  # 2*Wh-1 * 2*Ww-1, nH

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(h)
            coords_w = torch.arange(w)
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += h - 1  # shift to start from 0
            relative_coords[:, :, 1] += w - 1
            relative_coords[:, :, 0] *= 2 * w - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)

        self.adapter = nn.Sequential(
            nn.Linear(dim, dim//r),
            nn.GELU(),
            nn.Linear(dim//r, k * num_heads)
        )

        self.k = k
        self.dim = dim
        self.num_heads = num_heads
        self.n = n
        self.mode = mode


        if pre_proj:
            self.pre_proj = nn.Linear(dim, dim)
        else:
            self.pre_proj = None

        if post_proj:
            self.post_proj = nn.Linear(dim, dim)
        else:
            self.post_proj = None

        print('[AdaSpatialMLP layer] k=%d, num_heads=%d, mode=%s, pos/pre-proj=%s/%s, relative=%s' % (k, self.num_heads, mode, pre_proj, post_proj, relative))
        
        
    def forward(self, x, mask=None):
        B, n, C = x.shape
        mix_policy = self.adapter(x).reshape(B, n, self.k, self.num_heads)

        if not self.relative:
            weight_bank = self.weight_bank
        else:
            weight_bank = self.weight_bank[:, self.relative_position_index.view(-1)].view(self.k, n, n)  # k,Wh*Ww,Wh*Ww
        
        if self.mode == 'softmax':
            mix_policy = torch.softmax(mix_policy, dim=2)
            weight = torch.einsum('bnkh,knm->bnmh', mix_policy, weight_bank)
        elif self.mode == 'linear':
            weight = torch.einsum('bnkh,knm->bnmh', mix_policy, weight_bank)
        elif self.mode == 'softmax-softmax':
            mix_policy = torch.softmax(mix_policy, dim=2)
            weight = torch.einsum('bnkh,knm->bnmh', mix_policy, weight_bank)
            weight = torch.softmax(weight, dim=1)
        elif self.mode == 'linear-softmax':
            weight = torch.einsum('bnkh,knm->bnmh', mix_policy, weight_bank)
            weight = torch.softmax(weight, dim=1)
        elif self.mode == 'linear-sigmoid':
            weight = torch.einsum('bnkh,knm->bnmh', mix_policy, weight_bank)
            weight = torch.sigmoid(weight)
        elif self.mode == 'linear-normalize':
            weight = torch.einsum('bnkh,knm->bnmh', mix_policy, weight_bank)
            weight = torch.nn.functional.normalize(weight, dim=1, p=2)
        elif self.mode == 'sigmoid':
            mix_policy = torch.sigmoid(mix_policy)
            weight = torch.einsum('bnkh,knm->bnmh', mix_policy, weight_bank)
        else:
            raise NotImplementedError

        if self.pre_proj is not None:
            x = self.pre_proj(x)
        
        x = x.reshape(B, n, self.num_heads, -1)
        x = torch.einsum('bnhc,bnmh->bmhc', x, weight).reshape(B,n,C)

        if self.post_proj is not None:
            x = self.post_proj(x)
        return x

仔细看其实和 Self-attention 基本逻辑是一致的,只是 weight (attention map) 在获得的过程中是采取了1.1.4小节式5-式7的方法。

基于自适应权重混合构造的视觉主干模型 AMixer

作者接下来将自适应权值混合应用于 Swin Transformer 架构,得到与之对应的 AMixer-T,AMixer-S,AMixer-B 模型。作者遵循 Swin Transformer 中的分层设计,采用4个 Stage 的架构,每个阶段有 [2,2,n,2] 个 Block。作者调整了 Stage3 的 Block 数量、Head 数量和 MLP ratio,以缩放模型,使其具有与 Swin 系列相似的计算成本。与 Swin 不同的是,作者发现将 MLP ratio 设置为3可以更好地在模型的准确性和复杂性之间进行权衡。

AMixer_自适应_10

图3:视觉主干模型 AMixer 架构配置

实验结果

ImageNet-1K 图像分类任务与 Baseline 模型的对比结果

如下图4所示是 ImageNet-1K 图像分类任务与 Baseline 模型的对比结果实验结果。通过相似的网络架构和相同的训练配置,可以看到,AMixer 比 DeiT 模型提高了 1% 的性能,也可以将 Swin-T 模型提高约 0.7%。此外,作者发现多头机制可以显著改善 MLP,其中改进的 MLP 模型可以使 ResMLP-12 和 Swin-Mixer-T/D6 模型分别提高 2.6% 和 1.0%。AMixer 与MH-MLP 之间的性能差距、也清楚地表明了自适应权重混合带来的有效改善。这些结果有力地证明了自适应权重混合是一种比 Self-attention 更有效的产生注意力权重的方法。

AMixer_自适应_11

图4:ImageNet-1K 图像分类任务与 Baseline 模型的对比结果

ImageNet-1K 图像分类任务与 SOTA 模型的对比结果

如下图5所示是 ImageNet-1K 图像分类任务与 SOTA 模型的对比结果实验结果。通过进一步放大 AMixer 模型,作者建立了一系列基于 Swin transformer 的模型,以与最先进的视觉transformer和mlp类模型进行比较,如下图5所示。可以看到 AMixer 获得了非常有竞争力的结果。AMixer 的基本 Block 仅由 MLP 和自适应权重混合操作组成,而之前的许多工作增加了卷积以更好地捕获局部信息或使用更复杂的架构。AMixer 具有应用于大多数视觉 Transformer 变体并提高其效率的潜力。

AMixer_权重_12

图5:ImageNet-1K 图像分类任务与 SOTA 模型的对比结果

迁移学习实验结果

作者在 CIFAR-10、CIFAR-100、Stanford Cars 和 Flowers-102 上测试了 AMixer-T 模型和 AMixer-B 模型。在之前工作的设置之后,作者使用 ImageNet 预训练的权重初始化模型,并在新的数据集上对其进行微调。结果如下图6所示。AMixer 模型通常对各种下游数据集具有很强的迁移性。与最先进的 CNN 和复杂度相对较低的大型视觉 Transformer 相比,AMixer 也显示出具有竞争力的性能。

AMixer_自适应_13

图6:迁移学习实验结果

语义分割实验结果

语义分割是一种高输入分辨率密集预测任务中的通用性的下游任务。作者在 ADE20K 数据集上评估了 AMixer 模型,结果如下图7所示。可以看到,AMixer 优于具有相似复杂度水平的强 Swin 模型,这表明 AMixer 可以很好地推广到密集的预测任务。

AMixer_自适应_14

图7:语义分割实验结果

与已有的权重生成方法的对比

如下图8(a)所示,作者将 AMixer 与 Self-attention,Synthesizer,位置共享权生成与 DyConv 进行比较,在精心控制的设置下 (同样的遵循 DeiT 的训练方法和相同的网络配置),AMixer 获得了最好的性能,并且效率很高,这表明本文方法更适合和高效。

AMixer_权重_15

图8:与已有的权重生成方法的对比

与其他高效注意力机制的对比

如下图8(b)所示,作者将 AMixer 与与其他高效注意力机制进行了对比,对象包括 LinFormer,PerFormer,Nystroformer,PVT V2 等。为了比较公平,作者使用 DeiT-S 作为基本架构,直接用不同的高效注意力机制代替标准的注意力机制。作者还通过堆叠足够的层数来确保所有模型的 FLOPs 为 ~ 4.6G。作者发现这些高效的注意力机制都不能带来对 DeiT-S 的改善,而 AMixer-Deit-S 可以显著优于基线。

自适应权重可视化

为了研究矩阵如何随着输入图像而变化,作者将自适应权重 \textbf{M}\textbf{M} 可视化在图9中。作者首先找到在 GT 类别中分类分数最高的标记 (由红色框表示),并可视化与该标记对应的权重。首先作者发现权重在不同的图像中表现出显著的多样性。其次,可以发现 AMixer 倾向于关注图像中最具鉴别性的部分 (例如,在第一行中,靠近狗头的权重值更高)。这些可视化显示,AMixer 在没有令牌-令牌交互的情况下生成的自适应权重通过 Self-attention 获得的权重具有类似的行为。

AMixer_自适应_16

图9:自适应权重可视化

总结

作为视觉 Transformer 的核心组件 Self-attention,其在前向传播的时候充分考虑了所有空间位置之间的相互关系 (通过计算 Query 和 Key 的 dot-product 结果,attention map,也就是不同空间位置的关系的权重)。而作为视觉 MLP 的核心组件 Spatial MLP,其在前向传播的时候直接通过一个可学习的,固定的权重来建模所有 tokens 的关系。本文作者提出了一种更有效的自适应权重混合方法来实现与 Self-attention 相似的功能,在多种视觉任务中取得了不错的性能。