STDC(实时语义分割网络-2021)


文章目录

  • STDC(实时语义分割网络-2021)
  • 一、前言
  • 1、Motivation——如何提升实时性?
  • 二、主要工作
  • 1、特征提取模块——STDC(short-term dense concatenate module)
  • 2、STDC网络结构
  • 3、整体网络结构
  • 4、Detail Guidance of Low-level Features
  • 三、训练
  • 总结



源代码

一、前言

1、Motivation——如何提升实时性?

主要采取的方式有两种:
   一是:选用轻量级的骨干网络(如DFANet和BiSeNet V1),同时研究特征融合/聚合模块来提升检测精度。但目前常用的轻量级骨干网络是在图像分类任务中提出的,对语义分割任务的效果也许不够好。

   二是:输入分辨率较小的原始图像。但分辨率变小,容易忽略掉边界和小目标周围的细节外观。
   1)BiSeNet解决方式:添加额外的辅助路径,但是增加了耗时,并且辅助路径总是缺乏底层信息的引导。
   2)STDC解决方式:使用detail guidance方式对低层特征中的空间信息进行编码,而不需要额外的耗时路径。首先利用detail aggregation模块生成detail ground-truth;然后利用二元交叉熵损失和dice loss来优化细节信息的学习任务,将其视为一种侧信息学习(只用于训练,推理不需要)。最后将其与主干网络得到的语义信息进行融合。


二、主要工作

   针对上面的两种动机,作者提出了一种为分割任务而设计的轻量级特征提取网络,并改善了耗时的双边结构,采用一种detail guidance方式。

实时语义分割对计算资源_git


图一


1、特征提取模块——STDC(short-term dense concatenate module)

(1)通道数的设计
  (a)在图像分类任务中,通常的做法是在较高层使用更多的通道。但在语义分割中,更值得关注的是可扩展的感受野和多尺度信息
  (b)低层阶段需要足够的通道来编码感受野较小的细粒度信息;而具有较大感受野的高层特征更侧重于高级语义信息的诱导,设置成跟低层阶段相同的通道数可能导致信息冗余

(2)特征信息的提取

  为了丰富特征信息,采用了跳跃连接将各阶段特征图进行级联连接,在连接之前,要通过3*3的平均池化将不同阶段的特征图下采样到相同的空间大小。

实时语义分割对计算资源_计算机视觉_02


图二


   从图得知,*对于通道数,前三个block会不断减少一半,block4保持不变;在卷积核上,除了第一个Block1采用1×1卷积,其余都采用3×3卷积*。


总结就是,STDC模块先是几何倍数减小了通道数,并采用的了较小的卷积核来有效降低了参数量,并让所有block进行连接,以此来增加感受野和多尺度特征。

2、STDC网络结构

实时语义分割对计算资源_git_03


图3


编码网络除了输入层和预测层外,还包含6个阶段:

  • 第1阶段ConvX1和第2阶段ConvX2,都只采用了一个3×3卷积层,因S=2,故分辨率会依次下降一半,第2阶段结束后,特征图分辨率=(H/4,W/4),表中显示通道数变成64。(浅层阶段提取的是外观特征信息,为了追求效率,论文中写到根据经验一个卷积层是足够的)
  • Stage3、Stage4、Stage5中使用之前提到的STDC模块,每个Stage都用到了两个STDC模块,第一个STDC模块选用的是图二中(c)模式,即S=2进行下采样一次,后面分辨率不变。所以经过每个Stage都会让特征图分辨率下降一半,最终分辨率=(H/32, W/32)。
  • 第6个阶段通过一个ConvX、一个全局平均池化层和两个全连接层输出预测logits(应付分类网络,对于分割网络直接舍弃)。

3、整体网络结构

实时语义分割对计算资源_git_04


图4


  将STDC网络(除去第6个阶段)作为编码器的backbone,采用和BiSeNet一样的context path对上下文信息进行编码。   具体来说,对1/32特征图的Stage5,使用全局平均池化来提供具有大感受野的全局上下文信息,对应图中的灰色矩形(*经过全局池化后会上采样到Stage5同样大小,可以把其当成一个特征增强的阶段6*),之后就是按照FPN结构(论文中称U型结构,其实是一样的方式)进行对各个阶段依次融合的方式(*代码中采用特征相加,而不是特征拼接,所以应该是FPN结构*)。遵循BiSeNet,在特征相加之前,先使用注意力细化模块ARM对Stage4和Stage5的特征图进行细化。ARM结构如下图所示:

实时语义分割对计算资源_git_05


图5


然后将Stage3的特征图和经过FPN结构后的特征图送入FFM中(论文中写道:认为Stage3来自骨干网络的较低层特征,保留了丰富的细节信息,而经过FPN解码后的特征具有更强的全局上下文信息),FFM结构图如下:

实时语义分割对计算资源_实时语义分割对计算资源_06


图6


4、Detail Guidance of Low-level Features

  通常的网络模型就跟上面所说的结构一般,而该论文中展示了BiseNet中所用的空间路径与STDC骨干网络中的浅层阶段的特征图,发现空间路径比骨干网络自身的浅层阶段拥有更好的空间细节。为了解决STDC网络自身的细节特征不足的问题,提出了detail guidance模块,来引导低层学习到更多的空间信息。

Detail Ground-truth Generation:图4中的(c)所示,首先通过2d-拉普拉斯算子(图3的e)生成不同尺寸的软细节特征图,再上采用到原始大小,把三个软细节图stack再一起,再压缩一个维度(因为使用stack拼接后会多一个维度),然后经过一个1×1卷积(论文中写的使用可训练的1x1卷积进行融合,但在官方代码中用了一个不可训练参数提前设定不随训练更新的fusion_kernel进行融合)。最后,根据阈值将其转变成值[0,1],为最终的带有边界和角点信息的二值GT。

Detail Loss:生成细节GT肯定是为了进行监督训练,作者考虑到:细节像素的数量是远远小于非细节像素的,所以这里还存在一个类别不平衡的问题。选用的解决方案是采用二值交叉熵(binary cross-entropy)和Dice loss来共同优化细节学习。(Dice loss衡量的是预测图和GT之间的重叠,且对前景/背景像素的数量不敏感,故可以缓解类别不平衡的问题
   如图4(b)所示,可选用Stage1-3来进行监督训练,并先通过Detail head(其实就是一层卷积和一个单卷积)来增强特征表示。该方法在推理阶段舍弃,故不会增加推理时间。

三、训练

1、在model_stages.py中,应该可以选择对第1,2,3阶段都做细节损失训练,对第2,3阶段做细节训练,只对第3阶段做监督训练(但在原代码中,只对一个阶段做细节损失训练

# 前三个做分割损失,后三个做细节损失
  if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8:
       return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8
        
  if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8:
       return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8

  if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8:
       return feat_out, feat_out16, feat_out32, feat_out_sp8
        
  if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8):
       return feat_out, feat_out16, feat_out32

2、再来看train.py中,分割损失和细节损失如下:

criteria_p = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    criteria_16 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    criteria_32 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
    boundary_loss_func = DetailAggregateLoss()

(1)对于分割损失,STDC中采用的是OHEM Loss(Online Hard Example Mining)。
  和Focal loss一样,最初被提出都是用于目标检测的类别不平衡的问题。OHEM Loss在训练过程中,不是使用一个batch中所有样本来计算损失,而是筛选出损失值较大的那一部分(即Hard Example)来参与后续计算。这个给过程是在训练中产生的,故称为online。

class OhemCELoss(nn.Module):
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')

    def forward(self, logits, labels):
        N, C, H, W = logits.size()
        loss = self.criteria(logits, labels).view(-1)
        loss, _ = torch.sort(loss, descending=True)
        if loss[self.n_min] > self.thresh:  
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]
        return torch.mean(loss)

(2)对于边界损失,采用的是BCE loss+Dice loss

class DetailAggregateLoss(nn.Module):
    def __init__(self, *args, **kwargs):
        super(DetailAggregateLoss, self).__init__()
        
        self.laplacian_kernel = torch.tensor(
            [-1, -1, -1, -1, 8, -1, -1, -1, -1],
            dtype=torch.float32).reshape(1, 1, 3, 3).requires_grad_(False).type(torch.cuda.FloatTensor)
        
        self.fuse_kernel = torch.nn.Parameter(torch.tensor([[6./10], [3./10], [1./10]],
            dtype=torch.float32).reshape(1, 3, 1, 1).type(torch.cuda.FloatTensor))

    def forward(self, boundary_logits, gtmasks):
        # boundary_logits = boundary_logits.unsqueeze(1)
        boundary_targets = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, padding=1)
        boundary_targets = boundary_targets.clamp(min=0)
        boundary_targets[boundary_targets > 0.1] = 1
        boundary_targets[boundary_targets <= 0.1] = 0

        # 三种拉普拉斯(步距S=2,4,8)
        boundary_targets_x2 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=2, padding=1)
        boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
        
        boundary_targets_x4 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=4, padding=1)
        boundary_targets_x4 = boundary_targets_x4.clamp(min=0)

        boundary_targets_x8 = F.conv2d(gtmasks.unsqueeze(1).type(torch.cuda.FloatTensor), self.laplacian_kernel, stride=8, padding=1)
        boundary_targets_x8 = boundary_targets_x8.clamp(min=0)

        boundary_targets_x8_up = F.interpolate(boundary_targets_x8, boundary_targets.shape[2:], mode='nearest')
        boundary_targets_x4_up = F.interpolate(boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
        boundary_targets_x2_up = F.interpolate(boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
        
        boundary_targets_x2_up[boundary_targets_x2_up > 0.1] = 1
        boundary_targets_x2_up[boundary_targets_x2_up <= 0.1] = 0
        
        
        boundary_targets_x4_up[boundary_targets_x4_up > 0.1] = 1
        boundary_targets_x4_up[boundary_targets_x4_up <= 0.1] = 0
       
        
        boundary_targets_x8_up[boundary_targets_x8_up > 0.1] = 1
        boundary_targets_x8_up[boundary_targets_x8_up <= 0.1] = 0

       
        boudary_targets_pyramids = torch.stack((boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), dim=1)
        
        boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
        boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids, self.fuse_kernel)

        boudary_targets_pyramid[boudary_targets_pyramid > 0.1] = 1
        boudary_targets_pyramid[boudary_targets_pyramid <= 0.1] = 0
        
        if boundary_logits.shape[-1] != boundary_targets.shape[-1]:
            boundary_logits = F.interpolate(
                boundary_logits, boundary_targets.shape[2:], mode='bilinear', align_corners=True)
        
        bce_loss = F.binary_cross_entropy_with_logits(boundary_logits, boudary_targets_pyramid)
        dice_loss = dice_loss_func(torch.sigmoid(boundary_logits), boudary_targets_pyramid)
        return bce_loss,  dice_loss

总结

未完,日后待更新。

参考文章1