论文链接:https://arxiv.org/pdf/1608.06993.pdf
PyTorch代码:https://github.com/bamos/densenet.pytorch/blob/master/densenet.py
本来想研究下轻量级网络PeleeNet,但由于PeleeNet是由DenseNet变化而来,因此就先研究了下DenseNet。
参考链接:https://zhuanlan.zhihu.com/p/37189203
我们知道,ResNet模型的核心是通过建立前面层与后面层之间的“短路连接”(shortcuts,skip connection),从而使得有助于训练过程中梯度的反向传播,以便训练出更深的CNN网络。
而DenseNet其实相对于ResNet而言,也是采用了“短路连接”的方法,同时增加了特征复用的通道。如图1所示,图1描述的是组成DenseNet的基本单元Dense Block,一共由5个dense layer组成,在每个dense layer中都是只生成4个通道的feature maps,也就是growth rate = 4。通过图1 可以看出,使用特征复用的方式,可以减小模型的参数量,因为每一层只需要生成4个新的特征层,参数量大小和growth rate的大小有关。这就是DenseNet网络中最大的创新之处。
图1 Dense Block
在Dense Block中,第l层接受的输入是前面所有层的输出。
一、网络结构
表1 DenseNet网络结构
其中,DenseNet用于ImageNet分类的网络结构如表1所示,其中“conv”代表的是“BN+ReLU+Conv”操作,网络的growth rate=32。网络构建代码如下:
通过表1和源代码可以看出,DenseNet由Convolution卷积层、Pooling池化层、Dense Block、Transition层以及Classification分类层组成。接下来以DenseNet-121为例进行讲解,对各部分进行讲解。
二、各模块组成
1. Convolution卷积层和Pooling池化层
步长均为2,主要是为了快速缩小输入图片的大小,对应代码中的self.conv1。
2. Dense Block
这一部分是论文的核心,也就是图1所示的结构,只不过图1中的DenseBlock只有4层(第一个是Input,第二层才开始算是Dense Block)。
在表1所示结构中,共有4个Dense Block,每个Dense Block分别包括6、12、24、16个Dense layer。 每个dense layer采用的是Bottleneck的形式,由BN+ReLU+1x1Conv+BN+ReLU+3x3Conv组成,这里大家要注意一下,使用的是BN+ReLU+Conv的方式,而不是Conv+ReLU+BN。
但在第一个1x1 Conv中,输入通道数量为那么第l层输入的channel数为k0 + growth rate * (l-1),输出的通道数量为4*growth rate,k0表示在当前dense block中输入通道的数量,此处将输出通道数扩展为4倍,主要是为了减小模型的大小而设置,因为在较深的卷积层中,特征层数较多中间层输出特征数量小,可以减小模型大小,但是对于浅层特征来说,采用固定的中间层输出数量,可能会增加模型的大小,这一缺陷在PeleeNet中作者也提到了;
而在3x3 Conv的卷积中,输入通道数量为4*growth rate,输出通道数量为growth rate;
然后再降当前dense layer中输入通道的数量与输出通道数量相加,即进行concat操作,即可得出下一dense layer的输入,即k0 + growth rate * l。
对应代码的位置是self._make_dense函数,每执行一次,生成一个denseblock,共执行4次。
3. Transition层
主要是
连接两个相邻的DenseBlock,并且降低特征图尺寸大小。Transition层包括一个1x1的卷积和2x2的AvgPooling。另外,Transition层可以起到压缩模型的作用,因为它包括1个1x1的卷积层,因此可以更改特征的通道数,
作者在论文中加入了一个超参数θ,用于压缩模型的通道数,减小模型权重的大小。
4.Classification分类层
用于对输入图片进行分类,将7*7*n_channel的张量转化为1*1*n_channel后,后接维度为1000的全连接层,用于输出ImageNet的分类结果。对应代码中的self.fc。
三、总结
DenseNet优势 : 由于将特征进行了重复利用,因此
减小了模型的大小,同时,采用密集连接的方式,使得
模型更易训练。另外,也正是由于这种特征链接方式,使得低层的特征在模型分类或者检测时,也更加容易将特征进行充分利用,实现了
低层纹理信息和高层语义信息的结合。但是DenseNet存在一个较为严重的问题,就是耗费显存比较大,具体原因可参考链接:
为了解决DenseNet占用显存过大的问题,论文作者也给出了解决方法,详见论文:
https://arxiv.org/pdf/1707.06990.pdf,代码链接如下:https://github.com/gpleiss/efficient_densenet_pytorch