参考视频:ResNet网络结构,BN以及迁移学习详解

一、ResNet网络结构

plotneuralnet画resnet网络 resnet网络结构详解_人工智能


plotneuralnet画resnet网络 resnet网络结构详解_迁移学习_02


梯度消失:每一层的误差梯度都小于1,反向传播过程中,每向前传播一层都要乘以一个小于1的数,当网络越来越深,每次都乘以一个小于1的数,梯度会趋向于0

梯度爆炸:每一层的梯度都大于1,反向传播过程中,每向前传播一层都要乘以一个大于1的系数,当网络越来越深,每次都乘以一个大于1的数,梯度会越来愈大趋向爆炸

解决方案:全局初始化,BN标准化处理

使用残差结构解决退化问题

plotneuralnet画resnet网络 resnet网络结构详解_机器学习_03


256-d:输入深度为256

使用残差结构越多,节省参数越多

左边结构:3×3×256×256+3×3×256×256 = 1179648

右边结构:1×1×256×64+3×3×64×64+1×1×64×256=69632各种不同层数的残差结构:

plotneuralnet画resnet网络 resnet网络结构详解_机器学习_04

以34层残差结构为例:

plotneuralnet画resnet网络 resnet网络结构详解_迁移学习_05


plotneuralnet画resnet网络 resnet网络结构详解_人工智能_06


plotneuralnet画resnet网络 resnet网络结构详解_方差_07


plotneuralnet画resnet网络 resnet网络结构详解_机器学习_08


为什么残差分支有的是实线有的是虚线呢?

plotneuralnet画resnet网络 resnet网络结构详解_方差_09


虚线的残差结构的作用是:可以将输入的特征矩阵的高度宽度深度进行变化

实线的残差结构输入特征矩阵和输出特征矩阵宽度高度深度一模一样

所以每个卷积的第一层都需要虚线残差结构使输出的特征矩阵调整为当前层所需要的矩阵结构

二、Batch Normalization

plotneuralnet画resnet网络 resnet网络结构详解_方差_10


我们在图像预处理过程中通常会对图像进行标准化处理,这样能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的feature map就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律)。而我们Batch Normalization的目的就是使我们的feature map满足均值为0,方差为1的分布规律。

plotneuralnet画resnet网络 resnet网络结构详解_迁移学习_11


“对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。” 假设我们输入的x是RGB三通道的彩色图像,那么这里的d就是输入图像的channels即d=3,,其中就代表我们的R通道所对应的特征矩阵,依此类推。标准化处理也就是分别对我们的R通道,G通道,B通道进行处理。上面的公式不用看,原文提供了更加详细的计算公式:

u表示均值,另一个表示方差

plotneuralnet画resnet网络 resnet网络结构详解_方差_12


举例说明:

plotneuralnet画resnet网络 resnet网络结构详解_迁移学习_13


使用BN时需要注意的问题

(1)训练时要将traning参数设置为True,在验证时将trainning参数设置为False。在pytorch中可通过创建模型的model.train()和model.eval()方法控制。

(2)batch size尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。

(3)建议将bn层放在卷积层(Conv)和激活层(例如Relu)之间,且卷积层不要使用偏置bias,因为没有用,参考下图推理,即使使用了偏置bias求出的结果也是一样的

参考博文:Batch Normalization详解以及pytorch实验

三、迁移学习

plotneuralnet画resnet网络 resnet网络结构详解_机器学习_14


plotneuralnet画resnet网络 resnet网络结构详解_人工智能_15


对于浅层的卷积层学到了一些通用信息(角点,纹理)在其他的网络中也适用,可以将浅层网络的一些参数迁移到新的网络中去,使新的网络也拥有识别底层通用特征的能力了,新的网络拥有了这些底层通用的检测识别能力之后,就能够更加快速得去学习新的数据集的高维特征

plotneuralnet画resnet网络 resnet网络结构详解_人工智能_16