在本文中,我们通过金字塔池化模块提出了用于场景解析的PSPNet,该网络可以聚合不同区域的上下文信息来挖掘全局的上下文信息,我们的全局信息可以有效地在场景解析任务中产生高质量的结果。

1. Introduction

基于语义分割的场景解析是计算机视觉的一个基础课题,其目的是为图像中的每一个像素指定一个类别标签。

最先进的场景解析分析框架主要是基于全卷积网络(FCN),基于深度卷积神经网络的方法提高了对动态对象的理解,但由于场景的多样性和词汇的不受限制性,其仍然面临较大的挑战。比如下图中第一行展示的例子,FCN将船误认为汽车,这些错误是由于物体外观相似造成的,但是基于上下文信息,我们可以发现靠近河流应该为船。

英文nlp ner 最好的数据集 nlp net_人工智能

2. Pyramid Scene Parsing Network

2.1 Important Observations

英文nlp ner 最好的数据集 nlp net_池化_02

Mismatched Relationship 上下文关系具有普遍性和重要性,尤其对于复杂场景的理解。存在着共现的视觉模式,即有些物体常常是一起出现的,例如对于上图第一行,FCN根据外观将船预测为汽车,但众所周知,汽车很少在河上行驶,所以,缺乏收集上下文信息的能力会增大错误分类的机会

Confusion Categories ADE20k数据集中有许多类别标签,在分类时容易混淆。由于它们的外观十分相似,即使专家注释员标记,仍然产生了17.6%的误差。例如对于上图第二行,FCN预测其部分为摩天大楼,部分为建筑物,这是错误的,框中的物体要么为摩天大楼,要么为建筑物,而不是两者都有。这个问题可以通过类别间的关系来解决。

Inconspicuous Classes 场景包含任意大小的物体。一些小的物体很难找到但它们又很重要,相反,大的物体可能会超过FCN的感受野,从而导致预测不连续。例对于上图第三行,枕头与床单外观相似,忽略全局场景类别,可能无法解析枕头。为了提高标记非常小或非常大的物体的性能,应该特别注意包含不同类别物体的不同子区域。

2.2 Pyramid Pooling Module

英文nlp ner 最好的数据集 nlp net_深度学习_03

金字塔模块融合了四种不同尺度下的特征。上图红色突出显示的为最粗略的层级,是通过全局池化生成的单个bin输出。剩下的三个层级将输入特征图划分成若干个不同的子区域,并对每个子区域进行池化,金字塔池化模块中不同层级输出不同尺度的特征图,为了保持全局特征的权重,我们在每个金字塔层级后使用1x1的卷积核,当某个层级维数为n时,即可将语境特征的维数降到原始特征的1/n。然后,通过双线性插值直接对低维特征图进行上采样,使其与原始特征图尺度相同。最后,将不同层级的特征图拼接为最终的金字塔池化全局特征。

我们的金字塔池化模块是一个四层级的模块,分别有1x12x23x36x6bin大小。针对于每个层级的池化操作是选择最大池化还是平均池化,我们进行了大量的实验,实验表明平均池化效果更好

2.3 Network Architecture

英文nlp ner 最好的数据集 nlp net_池化_04

给定应该输入图像( a ),我们首先利用ResNet模型来提取特征,最终特征图的尺寸为输入图像的1/8,如上图(b)所示,然后我们利用上图( c )中所示的金字塔池化模块来提取上下文信息,其中金字塔池化模块分为4个层级,最终将4个层级提取的特征图融合为全局特征,在( c )模块的最后部分,我们将融合得到的全局特征和原始输入特征图拼接,这样提取的特征图就同时携带局部和全局上下文信息。最后,在 ( d )通过一层卷积生成最终的预测特征图

英文nlp ner 最好的数据集 nlp net_深度学习_05

在上图中,给出了我们增加辅助损失函数后的ResNet101模型的一个示例。除了使用softmax loss训练最终分类器的主要分支外,在第四阶段(即res4b22残差块)后应用另一个分类器。辅助损失函数有助于优化学习过程,我们通过增加不同的权重来使得主分支损失函数占主导地位,后续的实验证明这样做有利于快速收敛。

但是在测试时,我们放弃了这个辅助分支,只使用主分支进行最终预测

3. Pytorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.backbone.resnet import resnet50


class PSPModule(nn.Module):
    def __init__(self, num_channels, bin_size_list):
        super(PSPModule, self).__init__()
        num_filters = num_channels // len(bin_size_list)
        self.features = nn.ModuleList()
        for i in range(len(bin_size_list)):
            self.features.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(bin_size_list[i]),

                    nn.Conv2d(num_channels, num_filters, 1),
                    nn.BatchNorm2d(num_filters),
                    nn.ReLU()
                )
            )


    def forward(self, inputs):
        outs = [inputs]
        for idx, op in enumerate(self.features):
            x = op(inputs)
            x = F.interpolate(x, inputs.shape[2:], mode='bilinear', align_corners=True)
            outs.append(x)

        return torch.cat(outs, 1)


class PSPNet(nn.Module):
    def __init__(self, pretrained=False, num_classes=2):
        super(PSPNet, self).__init__()
        backbone = resnet50()

        if pretrained:
            backbone.load_state_dict(torch.load('./backbone/weights/resnet50.pth'))

        self.layer0 = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool
        )

        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4

        num_channels = 2048

        self.pspmodule = PSPModule(num_channels, (1, 2, 3, 6))

        num_channels *= 2

        self.classifier = nn.Sequential(
            nn.Conv2d(num_channels, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )

    def forward(self, inputs):
        x = self.layer0(inputs)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.pspmodule(x)
        x = self.classifier(x)
        x = F.interpolate(x, inputs.shape[2:], mode='bilinear', align_corners=True)
        return x


if __name__ == '__main__':
    inputs = torch.randn(2, 3, 512, 512)
    model = PSPNet()
    out = model(inputs)
    print(out.shape)