1. 解决的问题


  • resnet在推理时的分支不友好
  • repvgg模块因为会在模块外面添加relu层,导致模型深度有影响
    虽然残差连接可以训练深度非常深的神经网络,但由于其多分支拓扑结构,对在线推理并不友好。这鼓励了许多研究人员去设计没有残差连接的DNN。例如,RepVGG在部署时将多分支拓扑重新参数化为类VGG(单分支)结构,在网络相对较浅的情况下表现出良好的性能。然而,RepVGG不能将ResNet等效地转换为VGG,因为重新参数化方法只能应用于线性块,而非线性层(ReLU)必须放在残差连接之外,这导致了表示能力有限,特别是对于更深层次的网络。

RM操作作为一种plugin方法,基本上有3个优点:


  • 其实现使其对高比率网络剪枝比较友好
  • 突破了RepVGG的深度限制
  • 与ResNet和RepVGG相比,RMNet具有更好的精度-速度权衡网络

2. 模型与代码

  1. repvgg的问题
    【CNN】——RMNET推理时去掉残差模块(代码解析)_2d

从图2可以看出,随着深度的增加,ResNet可以得到更好的精度,这与前面的分析一致。相比之下,RepVGG-133在CIFAR-100上的准确率为79.57%,而RepVGG-133的准确率仅为41.38%。

  1. RM操作
    图3显示了RM操作等效去除残差连接的过程。为简单起见,在图中没有显示BN层,输入通道、中间通道和输出通道的数量相同,并赋值为C。
    【CNN】——RMNET推理时去掉残差模块(代码解析)_深度学习_02
    resblock代码实现:
class ResBlock(nn.Module):
def __init__(self, in_planes, mid_planes, out_planes, stride=1):
super(ResBlock, self).__init__()

assert mid_planes > in_planes

self.in_planes = in_planes
self.mid_planes = mid_planes - out_planes +in_planes
self.out_planes = out_planes
self.stride = stride

self.conv1 = nn.Conv2d(in_planes, self.mid_planes - in_planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.mid_planes - in_planes)

self.conv2 = nn.Conv2d(self.mid_planes - in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_planes)

self.relu = nn.ReLU(inplace=True)

self.downsample=nn.Sequential()
if self.in_planes != self.out_planes or self.stride != 1:
self.downsample=nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_planes))
self.running1 = nn.BatchNorm2d(in_planes,affine=False)
self.running2 = nn.BatchNorm2d(out_planes,affine=False)

def forward(self, x):
if self.in_planes == self.out_planes and self.stride == 1:
self.running1(x)
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.downsample(x)
self.running2(out)
return self.relu(out)

def deploy(self, merge_bn=False):
idconv1 = nn.Conv2d(self.in_planes, self.mid_planes, kernel_size=3, stride=self.stride, padding=1, bias=False).eval()
idbn1=nn.BatchNorm2d(self.mid_planes).eval()

nn.init.dirac_(idconv1.weight.data[:self.in_planes])
bn_var_sqrt=torch.sqrt(self.running1.running_var + self.running1.eps)
idbn1.weight.data[:self.in_planes]=bn_var_sqrt
idbn1.bias.data[:self.in_planes]=self.running1.running_mean
idbn1.running_mean.data[:self.in_planes]=self.running1.running_mean
idbn1.running_var.data[:self.in_planes]=self.running1.running_var

idconv1.weight.data[self.in_planes:]=self.conv1.weight.data
idbn1.weight.data[self.in_planes:]=self.bn1.weight.data
idbn1.bias.data[self.in_planes:]=self.bn1.bias.data
idbn1.running_mean.data[self.in_planes:]=self.bn1.running_mean
idbn1.running_var.data[self.in_planes:]=self.bn1.running_var

idconv2 = nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3, stride=1, padding=1, bias=False).eval()
idbn2=nn.BatchNorm2d(self.out_planes).eval()
downsample_bias=0
if self.in_planes==self.out_planes:
nn.init.dirac_(idconv2.weight.data[:,:self.in_planes])
else:
idconv2.weight.data[:,:self.in_planes],downsample_bias=self.fuse(F.pad(self.downsample[0].weight.data, [1, 1, 1, 1]),self.downsample[1].running_mean,self.downsample[1].running_var,self.downsample[1].weight,self.downsample[1].bias,self.downsample[1].eps)

idconv2.weight.data[:,self.in_planes:],bias=self.fuse(self.conv2.weight,self.bn2.running_mean,self.bn2.running_var,self.bn2.weight,self.bn2.bias,self.bn2.eps)

bn_var_sqrt=torch.sqrt(self.running2.running_var + self.running2.eps)
idbn2.weight.data=bn_var_sqrt
idbn2.bias.data=self.running2.running_mean
idbn2.running_mean.data=self.running2.running_mean+bias+downsample_bias
idbn2.running_var.data=self.running2.running_var

if merge_bn:
return [torch.nn.utils.fuse_conv_bn_eval(idconv1,idbn1),self.relu,torch.nn.utils.fuse_conv_bn_eval(idconv2,idbn2),self.relu]
else:
return [idconv1,idbn1,self.relu,idconv2,idbn2,self.relu]


def fuse(self,conv_w, bn_rm, bn_rv,bn_w,bn_b, eps):
bn_var_rsqrt = torch.rsqrt(bn_rv + eps)
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
conv_b = bn_rm * bn_var_rsqrt * bn_w-bn_b
return conv_w,conv_b
  1. 残差模块的转换
    这里我们直接print原始的残差模块和转换后的残差模块。

原始残差模块:

(0): ResBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential()
(running1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
(running2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
)

换后的RM模块

(3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)

解释的图像

【CNN】——RMNET推理时去掉残差模块(代码解析)_深度学习_03


  • 将原来的conv2d(64, 64) ->conv2d(64, 128), 增加输出的channel数,增加的数量和输入的feature数一样。
  • 增加的kernel的权重满足dirac分布(只保留一个通道,其余为0)
  • 得到的结果等价于concat(x, conv(x)),相当于将输入特征和第一个conv的计算结果进行了通道拼接。
  • 第二个卷积的变换,是将输入通道增加输入的feature数,同时增加的kernel权重满足dirac分布。这一步的操作等价于残差模块的+

3. 优缺点

优点


  • 残差模块确实可以全部转换成卷积,bn,relu。感觉后续这些也可以合并
  • 等价之后方便裁剪,因为很多kernel权重为0

缺点

  • 转换为RM模块之后,kernel数增加了,虽然很多为0,但是做了很多无用计算。要在设备上实测推理速度。