深度学习论文: MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer及其PyTorch实现
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
PDF: ​​​https://arxiv.org/pdf/2110.02178.pdf​​​ PyTorch代码: ​​https://github.com/shanglianlm0525/CvPytorch​​ PyTorch代码: ​​https://github.com/shanglianlm0525/PyTorch-Networks​

1 概述

MobileViT是一种结合 CNNs 和 ViTs 的优势,用于移动设备的轻量级通用视觉Transformer。
MobileNetv2:

深度学习论文: MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer及其PyTorch实现_2d

Transformer:

深度学习论文: MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer及其PyTorch实现_python_02

2 MobileViT

2-1 MobileViT 框架

深度学习论文: MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer及其PyTorch实现_ide_03

2-2 MobileViT 设计原理

  • 使用CNN和transformer混合的方式,使用CNN提取局部特征,使用transformer提取全局特征。理论上网络的感受野为H*W,即全感受野。
  • 对transformer的输入保留了维度patch,每个patch包含所有像素的位置顺序。这使得encoder的输入包含了每个patch的顺序以及每个patch中所有像素的顺序。
  • 深度学习论文: MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer及其PyTorch实现_python_04

2-3 multi-scale samplers

深度学习论文: MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer及其PyTorch实现_ide_05


深度学习论文: MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer及其PyTorch实现_pytorch_06

3 网络结构

  • 将输入图像后接3*3标准卷积,并做2倍下采样。
  • 之后添加4个MobileNetV2 block,并做两次2倍下采样。
  • 间隔添加上述介绍的 MobileVit block和MV2
  • 使用1*1卷积进行通道压缩。
  • 进行全局平均池化
  • 深度学习论文: MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer及其PyTorch实现_ide_07

  • PyTorch代码:
# !/usr/bin/env python
# -- coding: utf-8 --
# @Time : 2021/10/15 16:59
# @Author : liumin
# @File : MobileViT.py

import torch
import torch.nn as nn
from einops import rearrange

def Conv3x3BN(in_channels,out_channels,stride=1,groups=1):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, groups=groups),
nn.BatchNorm2d(out_channels)
)


def Conv3x3BNActivation(in_channels,out_channels,stride=1,groups=1):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, groups=groups),
nn.BatchNorm2d(out_channels),
nn.SiLU()
)


def Conv1x1BN(in_channels,out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(out_channels)
)

def Conv1x1BNActivation(in_channels,out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.SiLU()
)


class MV2Block(nn.Module):
def __init__(self, in_channels, out_channels, stride, expansion_factor=6):
super(MV2Block, self).__init__()
self.stride = stride
mid_channels = (in_channels * expansion_factor)

self.bottleneck = nn.Sequential(
Conv1x1BNActivation(in_channels, mid_channels),
Conv3x3BNActivation(mid_channels, mid_channels, stride, groups=mid_channels),
Conv1x1BN(mid_channels, out_channels)
)

if self.stride == 1:
self.shortcut = Conv1x1BN(in_channels, out_channels)

def forward(self, x):
out = self.bottleneck(x)
out = (out+self.shortcut(x)) if self.stride==1 else out
return out


class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)

self.heads = heads
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()

def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
]))

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x


class MobileViTBlock(nn.Module):
def __init__(self, dim, depth, channel, patch_size, mlp_dim, dropout=0.):
super().__init__()
self.ph, self.pw = patch_size

self.conv1 = Conv3x3BNActivation(channel, channel)
self.conv2 = Conv1x1BNActivation(channel, dim)

self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)

self.conv3 = Conv1x1BNActivation(dim, channel)
self.conv4 = Conv3x3BNActivation(2 * channel, channel)

def forward(self, x):
y = x.clone()

# Local representations
x = self.conv1(x)
x = self.conv2(x)

# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
pw=self.pw)

# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x


class MobileViT(nn.Module):
def __init__(self, dims, channels, expansion=4, patch_size=(2, 2), num_classes=1000):
super(MobileViT, self).__init__()
depth = [2, 4, 3]

self.conv1 = Conv3x3BNActivation(3, channels[0], 2)
self.layer1 = MV2Block(in_channels=channels[0], out_channels=channels[1], stride=1, expansion_factor=expansion)

self.layer2 = nn.Sequential(
MV2Block(in_channels=channels[1], out_channels=channels[2], stride=2, expansion_factor=expansion),
MV2Block(in_channels=channels[2], out_channels=channels[3], stride=1, expansion_factor=expansion),
MV2Block(in_channels=channels[3], out_channels=channels[3], stride=1, expansion_factor=expansion)
)

self.layer3 = nn.Sequential(
MV2Block(in_channels=channels[3], out_channels=channels[4], stride=2, expansion_factor=expansion),
MobileViTBlock(dim=dims[0], depth=depth[0], channel=channels[5], patch_size=patch_size, mlp_dim=int(dims[0]*2))
)

self.layer4 = nn.Sequential(
MV2Block(in_channels=channels[5], out_channels=channels[6], stride=2, expansion_factor=expansion),
MobileViTBlock(dim=dims[1], depth=depth[1], channel=channels[7], patch_size=patch_size, mlp_dim=int(dims[1]*4))
)

self.layer5 = nn.Sequential(
MV2Block(in_channels=channels[7], out_channels=channels[8], stride=2, expansion_factor=expansion),
MobileViTBlock(dim=dims[2], depth=depth[2], channel=channels[9], patch_size=patch_size, mlp_dim=int(dims[2]*4))
)

self.last_conv = Conv1x1BNActivation(channels[9], channels[10])
self.avgpool = nn.AvgPool2d(kernel_size=8, stride=1)
self.dropout = nn.Dropout(p=0.2)
self.linear = nn.Linear(in_features=channels[10], out_features=num_classes)

self.init_params()


def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if hasattr(m, 'bias'):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)


def forward(self, x):
x = self.conv1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.last_conv(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
out = self.linear(x)
return out


def mobilevit_xxs():
dims = [64, 80, 96]
channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
return MobileViT(dims, channels, expansion=2, num_classes=1000)


def mobilevit_xs():
dims = [96, 120, 144]
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
return MobileViT(dims, channels, num_classes=1000)


def mobilevit_s():
dims = [144, 192, 240]
channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
return MobileViT(dims, channels, num_classes=1000)


if __name__=='__main__':
model = mobilevit_s()
print(model)

input = torch.randn(1, 3, 256, 256)
out = model(input)
print(out.shape)