深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解


目录

  • 深度学习之图像分类(十九)Bottleneck Transformer(BoTNet)网络详解
  • 1. 前言
  • 2. Multi-Head Self-Attention
  • 3. Bottleneck Transformer
  • 4. BoTNet 网络结构
  • 5. 代码


上节有讲 ViT 结构。本节学习 CNN 与 Attention 的融合网络 BoTNet,即 Bottleneck Transformer。

etnet神经网络 神经网络 bottleneck_pytorch

1. 前言

BoTNet 为 UC Berkeley 和 谷歌团队合作发布在 2021 CVPR 的文章,Transformer 的一作 Ashish Vaswani 也在本文作者当中。BoTNet 原始论文为 Bottleneck Transformers for Visual Recognition。我们首先看一下 CV 领域 Self-Attention Model 分类,可分为纯 Attention Model 包括 ViT,此外则是 CNN + Attention。为什么要把 CNN 与 Attention 进行融合呢?其中一大原因就是因为纯 ViT 类型的结构对于输入尺寸特别敏感,不能变动,是 etnet神经网络 神经网络 bottleneck_神经网络_02 就是 etnet神经网络 神经网络 bottleneck_神经网络_02 。然而我们的目标检测,实例分割等任务的输入可能是 etnet神经网络 神经网络 bottleneck_pytorch_04 的大图。如果硬Train一发纯 Transformer,计算量怕你吃不消。与 CNN 融合的 Attention 结构也可分为 Transform 形式的,也可分为 Non-Local 形式的。 关于 NL 的知识可以参考知乎解答 Non-local neural networks,其实 NL 提到说 Self-Attention 是 NL 的一个特例,在 Non-Local 中没有位置编码 Position Embedding。CNN 融合的 Attention 可以放在 Backbone 外,将 CNN 的输出特征图进行映射后接入 Transformer,也可以直接将 Transformer block 嵌入 CNN Backbone 的单个 block 内。本次讲解的 BoTNet 则是将 Attention 模块融入到 CNN 原有的 block 中。

etnet神经网络 神经网络 bottleneck_pytorch_05

说难也难,说简单也简单,BoTNet 其实就是把 ResNet50 等使用的 Bottleneck 中间的 etnet神经网络 神经网络 bottleneck_etnet神经网络_06

etnet神经网络 神经网络 bottleneck_深度学习_07

ResNet50 变成 BoTNet 之后能做什么?很自然,ResNet50 能做的 BoTNet 都能做。直接作为分类模型进行 ImageNet 分类,作为 Backbone 嵌入 Faster R-CNN 做目标检测,作为 Backbone 嵌入 Mask R-CNN 做实例分割。

2. Multi-Head Self-Attention

首先我们来看一下作者使用的 Multi-Head Self-Attention 结构。Bottleneck Transformer 中使用的 MHSA 结构如下所示。在实验中,作者使用的是 head = 4,在下图中并没有进行展示。MHSA 的输入尺寸为 etnet神经网络 神经网络 bottleneck_深度学习_08,分别表示输入特征矩阵的高宽以及单个 token 的维度(特征图的通道转换到最后一维来就行了), token 的数量即为 etnet神经网络 神经网络 bottleneck_etnet神经网络_09此处的 relative position encodings 是我一直没有搞清楚的点。首先是初始化两个可学习的参数向量 etnet神经网络 神经网络 bottleneck_etnet神经网络_10etnet神经网络 神经网络 bottleneck_transformer_11,分别表示高度和宽度不同位置的位置编码,然后将它们通过广播机制加起来,那就是 etnet神经网络 神经网络 bottleneck_pytorch_12 位置的编码为 etnet神经网络 神经网络 bottleneck_深度学习_13 的两个 etnet神经网络 神经网络 bottleneck_深度学习_14 维向量相加。这样做将 etnet神经网络 神经网络 bottleneck_深度学习_08 个编码简化到了 etnet神经网络 神经网络 bottleneck_etnet神经网络_16。但是为什么这么来做呢,参考文献 [53, 4, 49] 是这么做的。此外,位置编码也不是直接加到输入上,而是与 query 矩阵进行矩阵乘法得到 Attention 的一部分,将其与 query 和 key 算出来的加和后经过 softmax 得到最终的 Attention。

etnet神经网络 神经网络 bottleneck_神经网络_17

对于位置编码作者进行了实验,R50 指 ResNet50 作为 backbone。作者在 COCO 数据集进行目标检测上发现,如果不加位置编码直接使用 etnet神经网络 神经网络 bottleneck_神经网络_18 计算 Attention 有提升,但是直接使用位置编码计算的 Attention 得到的 gain 还要大一些 etnet神经网络 神经网络 bottleneck_神经网络_19。如果结合而这提升更大。做这也尝试了绝对位置编码(不知道是不是不可学习的意思哈),发现效果其实并不如相对位置编码。

etnet神经网络 神经网络 bottleneck_深度学习_20

注意到,上图说没有高亮蓝色框之后就是 Non-Local Layer(highlighted blue boxes,position encodings and the value projection),然而我去看他的引文[63] Non-local Neural Networks 原始论文,别人有 value projection 啊,所以我不知道是他理解错了还是我理解错了。

etnet神经网络 神经网络 bottleneck_深度学习_21

[53] Peter Shaw, Jakob Uszkoreit, and Ashish Vaswani. Self-attention with relative position representations. arXiv preprint arXiv:1803.02155, 2018.

[4] wan Bello, Barret Zoph, Ashish Vaswani, Jonathon Shlens, and Quoc V Le. Attention augmented convolutional networks. In Proceedings of the IEEE International Conference on Computer Vision, pages 3286–3295, 2019.

[49] Prajit Ramachandran, Niki Parmar, Ashish Vaswani, Irwan Bello, Anselm Levskaya, and Jonathon Shlens. Stand-alone self-attention in vision models. arXiv preprint arXiv:1906.05909, 2019.

[63] Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. Non-local neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 7794–7803, 2018.

3. Bottleneck Transformer

在 Multi-Head Self-Attention 结构前后加上 etnet神经网络 神经网络 bottleneck_etnet神经网络_22 卷积就得到了 Bottleneck Transformer。Bottleneck Transformer 和 ViT 中的 Transformer block 其实是有亲属关系的,他们不是差异很大的结构。作者在论文中首先提到说,具有 MHSA 的 ResNet botteneck 块可以被视作具有 bottleneck 结构,其他方面具有微小差异(例如残差结构,正则化层等)的 Transformer 块。We point out that it is not the case. Rather, ResNet botteneck blocks with the MHSA layer can be viewed as Transformer blocks with a bottleneck structure, modulo minor differences such as the residual connections, choice of normalization layers, etc

etnet神经网络 神经网络 bottleneck_pytorch_23

这怎么理解呢?我觉得是这样的,在 ViT 中的 Transformer 块中 MHSA 前后有残差连接,MHSA 之后经过了 MLP block,这里的 MLP block 其实就是两层全连接,第一层全连接对特征进行扩充,然后第二层全连接对特征进行压缩。MLP block 的输出也和输入进行了残差相加。这里忽略掉 Dropout 层。整个结构就如上图最左边一样。如果我们把 MLP block 拆开,然后 MHSA 放到 Expansion 之后会这么样呢?那就成了上图中间的模样。这个不就和 ResNet 中的 Bottleneck block 一样了嘛,不同在于 ResNet 在中间使用 etnet神经网络 神经网络 bottleneck_etnet神经网络_06

etnet神经网络 神经网络 bottleneck_pytorch_25

作者也相当坦诚啊:We note that the architectural design of the BoT block is not our contribution. Rather, we point out the relationship between MHSA ResNet bottleneck blocks and the Transformer with the hope that it improves our understanding of architecture design spaces [47, 48] for self-attention in computer vision.作者的贡献不在于提出 BoTNet 的结构,而在于揭示它和 Transformer 原始结构的内在联系。

值得注意的是,既然把 MHSA 放到了 Resnet 中,原来的 LN 就被替换为了 BN。此外,在原始的 Transformer Block 中仅仅在 MLP block 中有一个 GELU 激活函数,而在 Bottleneck Transformer 中,有三个激活函数分别在一头一尾两个 etnet神经网络 神经网络 bottleneck_etnet神经网络_22

4. BoTNet 网络结构

BoTNet50 的网络结构如表所示,我们能够发现它其实就是把 ResNet50 最后一个阶段的 3 个 block 给替换上了 MHSA 结构。我们知道 Transformer 计算量是很大的,仅仅放在最后我们看到推理速度和计算操作个数都大幅增加了,尽管参数量少了点。这里有一个细节,c5 阶段第一个 block 是要进行下采样的,以前是 etnet神经网络 神经网络 bottleneck_etnet神经网络_06 那里的 stride = 2。由于 MHSA 不支持下采样,现在呢在 MHSA 之后增加一个 etnet神经网络 神经网络 bottleneck_etnet神经网络_28

etnet神经网络 神经网络 bottleneck_深度学习_29

BoTNet50 是用来做 Backbone 的,用在 R-CNN 框架中。如果是做 ImageNet 分类,则 c5 其实不需要再进行下采样,否则特征图太小了。所以 c5 第一个 block 不进行 etnet神经网络 神经网络 bottleneck_etnet神经网络_28

etnet神经网络 神经网络 bottleneck_神经网络_31

除去 BoT50 外,做这还能做到 BoT152 和 BoT200。再让 TPU 发展几年,这还不得到 BoT1k 啊。

etnet神经网络 神经网络 bottleneck_etnet神经网络_32

一般我不分析实验的,但今天还是站出来提出一个一直困扰我的疑惑!欢迎大家在评论区发表高见。

作者提到,BoTNet 对于小物体检测有提升,对于大物体没什么优势。但是 MHSA 是加在最后一个 stage 的,也就是最小分辨率的特征图上的。我们知道,层数越深尺度越小的特征图其实是被我们用来检测大物体的,而层数越浅尺度越大的特征图其实是被我们用来检测小物体的。BoTNet 在 c5 阶段扩大感受野到整个特征图,但是为什么提升了小物体的检测性能呢?DETR 就是因为将 Transformer 加到了 backbone 输出后面,对大特征图没有用,所以小物体检测也没有提升。而为什么 BoTNet 就能做到呢?

Significant boost from BoTNet on small objects (+2.4 Mask AP and +2.6 Box AP) (Appendix);

With visibly good gains on small objects in BoTNet, we believe there maybe an opportunity to address the lack of gain on small objects found in DETR, in future

5. 代码

代码出处见 此处。

import torch
import torch.nn as nn
import torch.nn.functional as F


def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp


class MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14, heads=4):
        super(MHSA, self).__init__()
        self.heads = heads

        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
        k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
        v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)

        content_content = torch.matmul(q.permute(0, 1, 3, 2), k)

        content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
        content_position = torch.matmul(content_position, q)

        energy = content_content + content_position
        attention = self.softmax(energy)

        out = torch.matmul(v, attention.permute(0, 1, 3, 2))
        out = out.view(n_batch, C, width, height)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, heads=4, mhsa=False, resolution=None):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        if not mhsa:
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
        else:
            self.conv2 = nn.ModuleList()
            self.conv2.append(MHSA(planes, width=int(resolution[0]), height=int(resolution[1]), heads=heads))
            if stride == 2:
                self.conv2.append(nn.AvgPool2d(2, 2))
            self.conv2 = nn.Sequential(*self.conv2)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


# reference
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000, resolution=(224, 224), heads=4):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.resolution = list(resolution)

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        if self.conv1.stride[0] == 2:
            self.resolution[0] /= 2
        if self.conv1.stride[1] == 2:
            self.resolution[1] /= 2
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # for ImageNet
        if self.maxpool.stride == 2:
            self.resolution[0] /= 2
            self.resolution[1] /= 2

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, heads=heads, mhsa=True)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Dropout(0.3), # All architecture deeper than ResNet-200 dropout_rate: 0.2
            nn.Linear(512 * block.expansion, num_classes)
        )

    def _make_layer(self, block, planes, num_blocks, stride=1, heads=4, mhsa=False):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for idx, stride in enumerate(strides):
            layers.append(block(self.in_planes, planes, stride, heads, mhsa, self.resolution))
            if stride == 2:
                self.resolution[0] /= 2
                self.resolution[1] /= 2
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out) # for ImageNet

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out


def ResNet50(num_classes=1000, resolution=(224, 224), heads=4):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, resolution=resolution, heads=heads)


def main():
    x = torch.randn([2, 3, 224, 224])
    model = ResNet50(resolution=tuple(x.shape[2:]), heads=8)
    print(model(x).size())
    print(get_n_params(model))