CBAM: Convolutional Block Attention Module
文章目录
- CBAM: Convolutional Block Attention Module
- 参考
- 个人理解
- Channel Attention
- Spatial Attention
- 如何融入ResBlock中?
- 效果如何?
- 实现
参考
- 【注意力机制】CBAM详解
- CBAM–卷积层中的注意力模块
- attention-module
个人理解
- 由于懒得系统介绍所以就长话短说,个人理解CBAM就是给Feature Map在不同维度加权重(channel和HxW),通过改变不同channel和区域的值,从而使得网络更加关注某些区域,从而实现性能上的提升,文中用的Res50为例,且经过大佬们的验证,确实有效。
- 大概的样子就长这样:
Channel Attention
- 它长这样
- 别的都好理解,就是把Feature Map分两路走,分别作MaxPool和AvgPool(注意是全局的MaxPool和AvgPool),然后进入MLP(这里下面会说明),之后相加再激活就成了
- 然后一开始我没太理解MLP,参考了这篇【注意力机制】CBAM详解,现在我是这么理解的它的结构的,MaxPool/AvgPool的输入进入MLP之后,首先做一次1x1xC的全连接,输出C,也就是第一个长条进过一层FC,出来就是中间的较短的条,接下来就是的第二层全连接(也就是第二根长条),和第一层基本一样
- 对于这个结构的作用我的理解是,给Pooling后的结果一个非线性变换,使它们在channel纬度上的特征被可以被提取出来,从而实现表示channel重要程度的功能
- 公式的话就是下面这个,就直接截过来了,应该和图片比较对于,需要提的是$\delta $表示的是sigmoid函数
Spatial Attention
- 它长这样
- 这个就跟公式一起介绍好了,属于是比较直白,看下面的公式,我们就channel维度对Feature Map作全局Pooling后得到HxWx2的输出,此时再用7x7的卷积撸一遍,得到的Feature Map做一次sigmoid就是输出的结果了,这里需要注意的是只有一个7x7卷积核,因为HxW方向上的注意力只需要一层就够了
如何融入ResBlock中?
- ResBlock的内部直接嵌入就行了,其实我的猜想是既然是个通用结构,理论上放在任何位置都OK,即你放俩ResBlock之间也不会有啥问题。
效果如何?
- 我的评价是少量参数增加+卓越的性能提升,且个人猜测参数量的增加应该主要是在channel 那里的FC处
实现
- 这个是抄这篇的实现:CBAM–卷积层中的注意力模块
- 官方代码是这份是BAM+CBAM都有:attention-module
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type=='avg':
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( avg_pool )
elif pool_type=='max':
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( max_pool )
elif pool_type=='lp':
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( lp_pool )
elif pool_type=='lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp( lse_pool )
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
return x * scale
class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out