Faster R-CNN原理介绍
看本文之前需要先了解Fast RCNN。
Faster R-CNN提出了一种加快计算region proposals的方法,就是通过建立RPN(Region Proposal Network)网络。RPN是一个全连接的卷积网络,通过 end-to-end的方式训练出来高质量的region proposal。然后将Faster R-CNN训练好的卷积特征和Fast R-CNN共享(通过attention model)。
文章目录 [隐藏]
- 1 介绍
- 2 相关工作
- 2.1 Object Proposals
- 2.2 检测方面相关工作
- 3 Faster R-CNN
- 3.1 Region Proposal Networks
- 3.2 RPN训练数据的生成
- 3.3 ROI Pooling的训练数据
- 3.4 共享特征
- 3.5 三种训练方式
- 3.6 四步交替训练法
- 3.7 训练细节
- 4 实现代码
- 5 论文
介绍
比较经典的region proposal方法是Selective Search,CPU耗时为2秒,EdgeBox是一种在准确率和效率之间做tradeoff的方式,耗时0.2秒,依然用了很长时间。
除了改进计算方式之外,其实还可以将Selective Search的计算放到GPU上面运行,但是放到GPU上就无法共享特征了。
作者提出,通过「Region Proposal Network」计算region proposals,在测试阶段,耗时只有10ms。
相关工作
Object Proposals
Object Proposals的提取方式有很多种,按照工作原理可以分为两大类:
- grouping super-pixels, Selective Search属于此类
- sliding windows,EdgeBox属于此类
文中还提到了三篇对比提取Object Proposals方法的论文,将来值得一看。
检测方面相关工作
- R-CNN通过对候选region直接进行end-to-end的分类
- OverFeat通过全连接层对localization坐标做预估,随后再用全连接层做检测
- MultiBox通过一个神经网络生成候选proposal,然后将proposal作为detection网络输入。
(此方式生成候选proposal的网络和detection网络没有共享feature) - DeepMask本文提出,下面介绍
Faster R-CNN
Region Proposal Networks
整体结构如下图:
关于anchor的理解:
利用anchor是从第二列这个位置开始进行处理,这个时候,原始图片已经经过一系列卷积层和池化层以及relu,得到了这里的 feature:51x39x256(256是层数)
在这个特征参数的基础上,通过一个3x3的滑动窗口,在这个51x39的区域上进行滑动,stride=1,padding=2,这样一来,滑动得到的就是51x39个3x3的窗口。
对于每个3x3的窗口,作者就计算这个滑动窗口的中心点所对应的原始图片的中心点。然后作者假定,这个3x3窗口,是从原始图片上通过SPP池化得到的,而这个池化的区域的面积以及比例,就是一个个的anchor。换句话说,对于每个3x3窗口,作者假定它来自9种不同原始区域的池化,但是这些池化在原始图片中的中心点,都完全一样。这个中心点,就是刚才提到的,3x3窗口中心点所对应的原始图片中的中心点。如此一来,在每个窗口位置,我们都可以根据9个不同长宽比例、不同面积的anchor,逆向推导出它所对应的原始图片中的一个区域,这个区域的尺寸以及坐标,都是已知的。而这个区域,就是我们想要的 proposal。所以我们通过滑动窗口和anchor,成功得到了 51x39x9 个原始图片的proposal。接下来,每个proposal我们只输出6个参数:每个 proposal 和 ground truth 进行比较得到的前景概率和背景概率(2个参数)(对应图上的 cls_score);由于每个 proposal 和 ground truth 位置及尺寸上的差异,从 proposal 通过平移放缩得到 ground truth 需要的4个平移放缩参数(对应图上的 bbox_pred)。
所以根据我们刚才的计算,我们一共得到了多少个anchor box呢?
51 x 39 x 9 = 17900
约等于 20 k
bingo!
RPN的输入是任意大小的图像,输出是一组打过分的候选框(object proposals)。在卷积的最后一层feature map上使用固定大小的窗口滑动,每个窗口会输出固定大小维度的特征(图中256),每一个窗口对候选的9个box进行回归坐标和分类(这里的分类表示box中是不是一个object,而不是具体的类别)。
为了将一个物体可以在不同的尺寸下识别出来,有两种主要的方式,1)对输入进行剪裁/缩放,2)对feature map进行不同大小的划窗,RPN则是采用了不同的方式。
生成训练数据的过程为先看anchor覆盖ground truth是否超过75%,超过就将当前anchor的object分类标记「存在」;如果都没有超过75%,就选择一个覆盖比例最大的标记为「存在」.
RPN的目标函数是分类和回归损失的和,分类采用交叉熵,回归采用稳定的Smooth L1,
SmoothL1公式为:
SmoothL1(x)={0.5x2|x|−0.5|x|<1otherwiseSmoothL1(x)={0.5x2|x|<1|x|−0.5otherwise
整体损失函数具体为:
到这里还有一个地方没有说清楚,就是9类anchor都是固定位置和大小的(当前window),回归的时候怎么回归?其实9个anchor是会变化的,要学回归的就是每一个anchor平移比例和缩放比例的信息。
如果网络足够聪明,并且当前window只有一个object,所有anchor,通过平移和缩放后,结果应该是一样的。下面为具体四个维度的计算方式:
训练的过程对正负样本进行抽样,保证比例为1:1.
RPN训练数据的生成
输入是一张图片,这个不用多说了,对VGG16来说,conv5卷积输出大小为W*H/16,一共有anchor数量为A=W*H/16*3*3,类别个数为2*A,回归的bounding box为 4*A。前向传播会将2A+6A个数据都计算出来,反向传播的时候会根据gt_data只会挑出来出来部分anchor进行反向传播。具体代码:
|
ROI Pooling的训练数据
在测试(inference)的时候,ROI Pooling从RPN网络得到候选的roi列表,通过conv5拿到所有的特征,进行后面的分类和回归。
在训练的时候,如果还是只使用RPN预测的roi可能会训练速度很慢或者loss就不收敛。这里需要对拿到的roi列表数据进行一些改造。
- 将ground-truth box加入到rpn_roi
- 通过rpn_roi和ground-truth box对比,调整覆盖覆盖比例
比较巧妙的一点是R-CNN的分类label生成和bbox的label生成,方法就是训练的时候根据每次生成的rpn_roi和gt对比,生成相应的标注信息,反向传播的时候只将目标类别的loss权重不为0,达到反向传播的时候只考虑正确的分类。
共享特征
这里说的特征共享,是指RPN和Fast R-CNN的特征共享,也就是生成候选和检测之间的共享。算法的整个流程如下图:
可以看到,卷积层可以让RPN和后面R-CNN公用。
三种训练方式
- 交替训练,先训练RPN->Fast R-CNN -> RPN -> Fast R-CNN这样一直迭代
- 近似联和训练,将每一次反向传播的RPN的梯度和Fast R-CNN的梯度合并
- 联和训练,从Fast R-CNN一直往后传递,RPN看做是Fast R-CNN的输入
四步交替训练法
- 第一步,训练RPN网络,通过使用ImageNet数据预训练好的网络初始化网络,然后进行end-to-end的fine-tune
- 第二步,使用Fast R-CNN训练模型,将上一步RPN的输出的region proposals作为输入,同时也用ImageNet的预训练好的网络初始化,此时RPN和Fast R-CNN没有共享网络。
- 第三步,使用Fast R-CNN训练的网络初始化RPN的网络,接着训练fine-tune RPN网络中非公共部分
- 第四步,对Fast R-CNN进行fine-tune,只修改非公共部分
后两步骤可以迭代进行,论文作者发现迭代带来很少的收益。
训练细节
由于图片是固定大小的,feature map中的anchor box可能并非完全在照片内部,需要将所有有外漏的anchor box都删掉。作者发现,如果不删掉就可能导致最终的训练不收敛。测试的时候,由于proposal boxes是RPN网络预测出来的,所以可能会导致「不完全在图像内」,处理方式是将超出图片的部分使用就近的图片边界替代。
为了减少RPN生成的region重叠的问题,通过NMS(non-maximum suppression)对重叠的region进行删减。
实现代码
python https://github.com/rbgirshick/py-faster-rcnn
Tensorflow 实现:https://github.com/smallcorgi/Faster-RCNN_TF
论文
《Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks》