目录

一. 语义分割概述

二.  PSPNet语义分割原理和Pytorch实现

1. PSPNet算法原理

2. 环境配置

3.  训练数据集处理

4.数据预处理和加载

5. 模型构建

5. 训练

三.  KNN抠图

四. 总结

参考文献


一. 语义分割概述

图像语义分割是一种将图像分割成一系列具有特定语义类别属性区域的方法,目前已成为当前图像理解分析和计算机视觉 等领 域的热点研究内容。简单举个例子,下图为例:

PIL 图像语义分割 图像语义分割算法_pytorch

PIL 图像语义分割 图像语义分割算法_语义分割_02

                                                                                             图1.1 语义分割示例

左边是一张自然街景拍摄的图片,右边是对应的语义分割图。可以看到,分割的结果就是将同类的物体全都用一种颜色标注出来,每一类物体就是一种“语义”,例如图中人是一类语义,马路是一类语义,树是一类语义,电线杆是一类语义,等等。语义分割就是要按照语义类别进行分割。如果两个像素靠的很近,例如图中的人和树,但是它们不属于同一个语义类别,那么语义分割算法就需要将他们分割开来。可以看到,语义分割相比传统的分割添加了语义的概念,它需要算法具备一定的先验知识,大体上“知道”人是什么样子、树是什么样子、马路是什么样子,有了这种先验知识,才能准确的对图像每个像素进行分割(标注)。

如果从分类的角度来看这个问题,那么语义分割可以理解为为图像中的每个像素进行分类,类别就是图像中所有的语义种类(个数)。相比计算机视觉中一般的分类问题,语义分割的难度更大,因为其精度需要精确至像素级别。

语义分割方法按照时间可以大致分为两类:传统方法和深度学习方法。

  • 传统方法:主要采用马尔科夫随机场(MRF)和条件随机场(CRF)等方法进行数学建模,方法相对简单,运行速度快。缺点就是缺乏有效的先验知识,分割精度低;
  • 深度学习方法:目前主流的语义分割算法都是采用深度学习来实现,较传统算法来说,深度学习方法可以充分利用大样本数据的先验知识得到更佳的分割性能;

语义分割技术可以对整幅图像进行像素级的分析,目前,语义分割已经被广泛应用于自动驾驶、无人机落点判定、地质检测、面部分析、精准农业等场景中。

全连接神经网络FCN(Fully Convolutional Networks)是第一个基于深度学习被应用到语义分割场景中所提的方法,于2014年提出。也正是基于这篇论文的工作,后续正式开启了深度学习语义分割的研究热潮。直至今日,FCN论文中提出的很多语义分割处理思路依然被沿用。

对于一般的分类CNN网络,如VGG和Resnet,都会在网络的最后加入一些全连接层,经过softmax后就可以获得类别概率信息。但是这个概率信息是1维的,即只能标识整个图片的类别,不能标识每个像素点的类别,所以这种全连接方法不适用于图像分割。而FCN提出可以把后面几个全连接都换成卷积,这样就可以获得一张2维的特征图,后接softmax获得每个像素点的分类信息,从而解决了每个像素的分割问题。整个模型结构如下:

PIL 图像语义分割 图像语义分割算法_pytorch_03

                                                                                                    图1.2 FCN模型结构

可以看到整个模型结构全部采用了卷积,去掉了全连接层。直至今日,该模型的实现比较简单,但是在当时的背景下其性能超越了一众传统算法。

本文将从语义分割角度切入,以证件照分割任务为例,详细讲解如何利用语义分割技术实现智能证件照制作。本文核心在于利用高级的语义分割技术实现复杂背景下的人像分割。

二.  PSPNet语义分割原理和Pytorch实现

1. PSPNet算法原理

目前,经常用来作人像分割的模型主要有UNet和PSPNet,使用UNet的优势在于它的模型简洁和高效。最早UNet是用来处理医学影像分割任务的,它可以从相对较少的样本中学习到精确的分割边界。UNet另一个重要优势就是它的网络本身结构较小,可以方便的和MobileNet V2相结合,执行速度非常快,因此该网络被广泛应用于对速度要求较高的实时语义分割任务。

尽管使用UNet模型速度较快、分割精度高,但是该模型本身缺乏充分的上下文语义信息,尤其是对于人像分割任务,人像中的背景往往是异常复杂的、难以预测的,在先验信息不充分的情况下,UNet模型容易产生大面积误判。因此,为了保证语义分割的有效性,需要在模型构建的过程中考虑更多的全局语义信息。PSPNet模型就是一个性能突出的全局语义分割模型。

PSPNet算法认为,FCN在处理图像时没有有效的考虑到图像上下文语义信息,因此它在进行语义分割的时候容易“混淆”语义。比如下面这个图:

PIL 图像语义分割 图像语义分割算法_语义分割_04

           图 2.1 PSPNet上下文语义分割示意图,从左往右依次为:原图、真值、FCN算法结果、PSPNet算法结果

上图可以看到FCN算法错误的将船认成了其它语义信息(认成了车),两者外观非常相似,但是如果从上下文去推断整个场景,那么很明显,靠在湖边的更有可能的是船,因此,PSPNet算法给出了正确结果。这就是利用上下文语义信息的优势,可以隐式的捕捉到各语义之间的关系,然后利用这种上下文语义关系提高语义预测精度。

PSPNet模型结构如下图所示:

PIL 图像语义分割 图像语义分割算法_pytorch_05

                                                                                              图 2.2 PSPNet模型结构

整个模型包含两个子模块:

  • 特征提取子模块Feature Map:该模块采用深度残差网络ResNet架构对图像进行特征提取。很多研究表明越深的网络,对于图像语义特征挖掘的越充分,越利于图像的分类。采用ResNet网络不仅可以有效的拓展网络深度,同时相比VGG等网络模型ResNet更轻量。因此,PSPNet的首个子模型就采用了深度ResNet来提取特征。由于ResNet可以方便的应用于其它分类任务,为了提高分割性能并且加速收敛,可以进行迁移学习,利用预训练结束的ResNet模型权重来进行初始化(Pretrained Resnet)。本文人像分割任务也采用了这种方式。
  • 金字体池化子模块Paramid Pooling Module:如何挖掘上下文语义之间的关系信息呢?PSPNet采用了多尺度池化技术。也就对前面特征提取模块得到的特征层按照不同尺度进行池化, 池化的好处就是可以让各个部分特征进行堆叠融合,通过多个尺度的池化融合就可以实现上下文语义信息挖掘。

简单来说,PSPNet就是先用一个强大的特征提取网络对图像进行特征提取,然后使用池化技术,根据不同的池化核进行池化操作,最后,对于各个池化的结果再经过双线性上采样统一成相同的尺寸,各尺度特征级联融合后再进行语义级的像素预测分类。

2. 环境配置

本文采用Pytorch进行算法建模,Pytorch版本为1.4,cuda版本为10.1,Python版本为3.6.1。详细的环境安装教程请参考另一篇博客:。本文使用Windows10操作系统,两块GTX1080 TI显卡进行运算。

另外,为了实时观看训练的中间结果,本文使用TenosorboardX这个查看工具。相关介绍和使用请参考博客。

3.  训练数据集处理

为了能够有效的训练人像语义分割模型,这里采用爱分割提供的人像数据集AiFenGe,该数据集总共包含34425张图像,每张图像尺寸均已调整为600x800,并且同时提供对应的alpha图。部分样例如下图所示:

PIL 图像语义分割 图像语义分割算法_深度学习_06

从实际观测效果来看,该数据集的alpha图标注并不精确,尽管如此,我们还是可以用它来训练一个较好的PSPNet人像语义分割模型。需要注意的是该数据集的标注形式并不是以常见的alpha通道图给出,而是直接给出了抠图前景,因此在处理该数据集时需要先将alpha通道提取出来,然后进行二值分割即可,分割阈值选择为50。

最终我们需要的数据集包含3个文件夹img、trimap和alpha。其中img用来存储原始RGB图像,trimap用来存储trimap图(需要注意trimap是一个三值单通道图,即每个像素的取值只能是0、128或255),alpha用来存储alpha通道图(单通道图)。

转换脚本代码如下:

def genAiFenGe():
    """
    生成标准化的AiFenGe数据集,同时生成JSON文件列表
    """
    # 设置拷贝路径
    src_img_folder='E:\deeplearn\Matting_Human_Half\clip_img' 
    src_alpha_folder='E:\deeplearn\Matting_Human_Half\matting'
    des_img_folder='./data/AiFenGe/img' 
    des_alpha_folder='./data/AiFenGe/alpha' 
    des_trimap_folder='./data/AiFenGe/trimap' 

    # 检索文件
    imglist = getFileList(src_img_folder, [], 'jpg')
    alphalist = getFileList(src_alpha_folder, [], 'png')

    print('检索到 '+str(len(imglist))+' 个原始图像')
    print('检索到 '+str(len(alphalist))+ '个alpha通道图')

    # 逐张检查
    index=0
    save_img_list=list()
    save_alpha_list=list()
    save_trimap_list=list()
    for imgpath in imglist:
        imgname= os.path.splitext(os.path.basename(imgpath))[0]
        alphaname=imgname+'.png'

        for j in range(len(alphalist)):
            if alphaname in alphalist[j]:
                alphapath = alphalist[j]
                try:
                    img = cv2.imread(imgpath, cv2.IMREAD_COLOR)

                    alpha = cv2.imread(alphapath, cv2.IMREAD_UNCHANGED)
                    alpha = alpha[:,:,3] # 分离alpha通道
                    ret,alpha = cv2.threshold(alpha,50,255,cv2.THRESH_BINARY)

                    # 生成trimap
                    trimap = erode_dilate(alpha)

                    # 保存   
                    cv2.imwrite(des_img_folder+('/%d.png' % (index)),img)
                    cv2.imwrite(des_alpha_folder+('/%d.png' % (index)),alpha)
                    cv2.imwrite(des_trimap_folder+('/%d.png' % (index)),trimap)

                    # 记录
                    save_img_list.append(des_img_folder+('/%d.png' % (index)))
                    save_alpha_list.append(des_alpha_folder+('/%d.png' % (index)))
                    save_trimap_list.append(des_trimap_folder+('/%d.png' % (index)))

                    index += 1
                    print('当前写入第 %d 张图片' % (index))

                except Exception as err:
                    print(err)

    # 写入json文件
    with open('./data/aifenge_img.json', 'w') as jsonfile1:
        json.dump(save_img_list, jsonfile1)

    with open('./data/aifenge_alpha.json', 'w') as jsonfile2:
        json.dump(save_alpha_list, jsonfile2)

    with open('./data/aifenge_trimap.json', 'w') as jsonfile3:
        json.dump(save_trimap_list, jsonfile3)

    print('共写入 %d 张图片' % (index))

其中在生成trimap图时用了Opencv的腐蚀和膨胀操作,代码如下:

def erode_dilate(mask, size=(10, 10), smooth=True):
    """
    腐蚀膨胀生成trimap
    输入 mask:单通道二值掩码图
    """
    # 构造核
    if smooth:
        size = (size[0]-4, size[1]-4)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, size)

    # 膨胀
    dilated = cv2.dilate(mask, kernel, iterations=1)
    if smooth:  
        dilated[(dilated>5)] = 255
        dilated[(dilated <= 5)] = 0
    else:
        dilated[(dilated>0)] = 255

    # 腐蚀
    eroded = cv2.erode(mask, kernel, iterations=1)
    if smooth:
        eroded[(eroded<250)] = 0
        eroded[(eroded >= 250)] = 255
    else:
        eroded[(eroded < 255)] = 0

    res = dilated.copy()
    res[((dilated == 255) & (eroded == 0))] = 128

    # 保证trimap图中只有三种值
    # cnt0 = len(np.where(res >= 0)[0])
    # cnt1 = len(np.where(res == 0)[0])
    # cnt2 = len(np.where(res == 128)[0])
    # cnt3 = len(np.where(res == 255)[0])
    # assert cnt0 == cnt1 + cnt2 + cnt3
    
    return res

4.数据预处理和加载

为了加强语义分割性能,我们在提取图像时对图像作一定的变换,包括:随机裁剪、随机缩放、随机左右镜像、随机上下镜像。详细代码如下

class HumanDataset(Dataset):
    """
    人像数据集
    """
    def __init__(self, dataname, transforms=None):

        items = []
        img_path = './data/'+ dataname + '_img.json'
        trimap_path = './data/'+ dataname + '_trimap.json'
        alpha_path = './data/'+ dataname + '_alpha.json'

        with open(img_path, 'r') as j:
            imglist = json.load(j)
        with open(trimap_path, 'r') as j:
            trimaplist = json.load(j)
        with open(alpha_path, 'r') as j:
            alphalist = json.load(j)

        for i in range(len(imglist)):
            items.append((imglist[i], trimaplist[i], alphalist[i]))

        self.items = items
        self.transforms = transforms

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):
        image_name, trimap_name, alpha_name = self.items[index]
        image = cv2.imread(image_name, cv2.IMREAD_COLOR)
        trimap = cv2.imread(trimap_name, cv2.IMREAD_GRAYSCALE)
        alpha = cv2.imread(alpha_name, cv2.IMREAD_GRAYSCALE)

        if self.transforms is not None:
            for transform in self.transforms:
                image, trimap, alpha = transform(image, trimap, alpha)

        return image, trimap, alpha


class RandomPatch(object):
    """
    自定义随机块裁剪变换
    """
    def __init__(self, patch_size):
        self.patch_size = patch_size

    def __call__(self, image, trimap, alpha):
        # 随机尺度变化
        if random.random() < 0.5:
            h, w, c = image.shape
            scale = 0.75 + 0.5 * random.random() # 尺度变化范围(0.75到1.25)
            image = cv2.resize(image, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_CUBIC)
            trimap = cv2.resize(trimap, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_NEAREST) # 采用最近邻插值保证每个像素值为0、128或255
            alpha = cv2.resize(alpha, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_CUBIC)
        # 产生随机块
        if random.random() < 0.5:  # 产生裁剪块
            h, w, c = image.shape
            if h > self.patch_size and w > self.patch_size:
                x = random.randrange(0, w - self.patch_size)
                y = random.randrange(0, h - self.patch_size)
                image = image[y:y + self.patch_size, x:x + self.patch_size, :]
                trimap = trimap[y:y + self.patch_size, x:x + self.patch_size]
                alpha = alpha[y:y + self.patch_size, x:x + self.patch_size]
            else:
                image = cv2.resize(image, (self.patch_size, self.patch_size), interpolation=cv2.INTER_CUBIC)
                trimap = cv2.resize(trimap, (self.patch_size, self.patch_size), interpolation=cv2.INTER_NEAREST)
                alpha = cv2.resize(alpha, (self.patch_size, self.patch_size), interpolation=cv2.INTER_CUBIC)
        else: # 产生压缩块
            image = cv2.resize(image, (self.patch_size, self.patch_size), interpolation=cv2.INTER_CUBIC)
            trimap = cv2.resize(trimap, (self.patch_size, self.patch_size), interpolation=cv2.INTER_NEAREST)
            alpha = cv2.resize(alpha, (self.patch_size, self.patch_size), interpolation=cv2.INTER_CUBIC)

        return image, trimap, alpha


class RandomFlip(object):
    """
    自定义随机翻转
    """
    def __call__(self, image, trimap, alpha):
        if random.random() < 0.5: # 垂直翻转
            image = cv2.flip(image, 0)
            trimap = cv2.flip(trimap, 0)
            alpha = cv2.flip(alpha, 0)

        if random.random() < 0.5: # 水平翻转
            image = cv2.flip(image, 1)
            trimap = cv2.flip(trimap, 1)
            alpha = cv2.flip(alpha, 1)
        return image, trimap, alpha


class Normalize(object):
    """
    自定义归一化操作
    """
    def __call__(self, image, trimap, alpha):
        image = (image.astype(np.float32) - (114., 121., 134.,)) / 255.0
        trimap[trimap == 0] = 0
        trimap[trimap == 128] = 1
        trimap[trimap == 255] = 2
        alpha = alpha.astype(np.float32) / 255.0
        return image, trimap, alpha


class NumpyToTensor(object):
    """
    numpy数组转张量tensor
    """
    def __call__(self, image, trimap, alpha):
        h, w, c = image.shape
        image = torch.from_numpy(image.transpose((2, 0, 1))).view(c, h, w).float()
        trimap = torch.from_numpy(trimap).view(-1, h, w).long()  
        alpha = torch.from_numpy(alpha).view(1, h, w).float()
        return image, trimap, alpha

5. 模型构建

本文使用PSPNet算法进行人像语义分割,根据阿里巴巴提出的抠图算法SHM的思路,将其重新封装为TNet网络,其本质是一样的。其中PSPNet的特征提取部分采用了Resnet50作为backone,详细代码如下:

import torch
from torch import nn
import torchvision
import torch.nn.functional as F
import math
from torch.utils import model_zoo


def load_weights_sequential(target, source_state):
    """
    字典形式导入官方预训练模型
    """
    model_to_load = {k: v for k, v in source_state.items() if k in target.state_dict().keys()}
    target.load_state_dict(model_to_load)


class Bottleneck(nn.Module):
    """
    残差网络基本子模块
    """
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
                               padding=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x   # 64 x 100 x100

        out = self.conv1(x)  # 64 x 100 x100
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, block, layers=(3, 4, 23, 3)):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)

        for m in self.modules():  # 对卷积和池化参数进行初始化
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride, downsample)]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x_3 = self.layer3(x)
        x = self.layer4(x_3)

        return x, x_3


def resnet50(pretrained=True):
    """
    特征提取模块
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3])
    if pretrained:
        load_weights_sequential(model, model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth'))
    return model


class PSPModule(nn.Module):
    """
    多尺度卷积池化模块
    """
    def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)):
        super().__init__()
        self.stages = []
        self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes])
        self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1)
        self.relu = nn.ReLU()

    def _make_stage(self, features, size):
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
        conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
        return nn.Sequential(prior, conv)

    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)
        priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=False) for stage in self.stages] + [feats]
        bottle = self.bottleneck(torch.cat(priors, 1)) # 按照通道进行级联
        return self.relu(bottle)


class PSPUpsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.PReLU()
        )

    def forward(self, x):
        h, w = 2 * x.size(2), 2 * x.size(3)
        p = F.interpolate(input=x, size=(h, w), mode='bilinear', align_corners=False)
        return self.conv(p)


class PSPNet(nn.Module):
    """
    PSPNet模型
    """
    def __init__(self, n_classes=3, sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024,
                 pretrained=True):
        super().__init__()
        self.feats = resnet50(pretrained)              # 特征提取模块
        self.psp = PSPModule(psp_size, 1024, sizes)    # 金字塔池化模块
        self.drop_1 = nn.Dropout2d(p=0.3)

        self.up_1 = PSPUpsample(1024, 256)             #上采样
        self.up_2 = PSPUpsample(256, 64)
        self.up_3 = PSPUpsample(64, 64)

        self.drop_2 = nn.Dropout2d(p=0.15)
        self.final = nn.Sequential(
            nn.Conv2d(64, n_classes, kernel_size=1),
        )

    def forward(self, x):
        f, class_f = self.feats(x) # 2048,50,50
        p = self.psp(f) #  1024,50,50
        p = self.drop_1(p)

        p = self.up_1(p) # 256,100,100
        p = self.drop_2(p) 

        p = self.up_2(p) # 64,200,200
        p = self.drop_2(p)

        p = self.up_3(p) # 64,400,400
        p = self.drop_2(p)

        return self.final(p) # 3,400,400


class TNet(nn.Module):
    """
    人像语义分割模型
    """
    def __init__(self):
        super(TNet, self).__init__()
        self.backbone = PSPNet()

    def forward(self, x):
        trimap = self.backbone(x)
        return trimap

5. 训练

详细的训练脚本如下:

import torch.backends.cudnn as cudnn
import torch
from torch import nn
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from models import TNet
from datasets import HumanDataset,RandomFlip,RandomPatch,Normalize,NumpyToTensor
from utils import *
from loss import ClassificationLoss


# 数据集参数
data_folder = './data/'   # 数据存放路径
dataname = 'aifenge'      # 数据集名称

# 学习参数
checkpoint = None     # 预训练模型路径,如果不存在则为None
batch_size = 8        # 批大小
start_epoch = 1       # 轮数起始位置
epochs = 100           # 迭代轮数
workers = 4           # 工作线程数
lr = 0.0001           # 学习率             
weight_decay = 0.0005 # 权重延迟

# 设备参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ngpu = 2           # 用来运行的gpu数量

cudnn.benchmark = True # 对卷积进行加速

writer = SummaryWriter() # 实时监控     使用命令 tensorboard --logdir runs  进行查看

def main():
    """
    训练.
    """
    global checkpoint,start_epoch,writer

    # 初始化
    model = TNet()
    # 初始化优化器
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                    lr=lr, betas=(0.9, 0.999),
                                    weight_decay=weight_decay)

    # 迁移至默认设备进行训练
    model = model.to(device)
    criterion = ClassificationLoss()
    criterion.to(device)

    # 加载预训练模型
    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    if torch.cuda.is_available() and ngpu > 1:
        model = nn.DataParallel(model, device_ids=list(range(ngpu)))

    # 定制化的dataloaders
    transforms = [
                RandomPatch(400),
                RandomFlip(),
                Normalize(),
                NumpyToTensor()
            ]
    train_dataset = HumanDataset('aifenge',transforms)
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 


    # 开始逐轮训练
    for epoch in range(start_epoch, epochs+1):

        model.train()  # 训练模式:允许使用批样本归一化

        loss_epoch = AverageMeter()  # 统计损失函数

        n_iter = len(train_loader)

        # 按批处理
        for i, (imgs, trimaps_gt, alphas) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            imgs = imgs.to(device)  
            trimaps_gt = trimaps_gt.to(device)  

            # 前向传播
            trimaps_pre = model(imgs)

            # 计算损失
            loss = criterion(trimaps_pre, trimaps_gt)  

            # 后向传播
            optimizer.zero_grad()
            loss.backward()

            # 更新模型
            optimizer.step()

            # 记录损失值
            loss_epoch.update(loss.item(), imgs.size(0))

            # 监控图像变化
            if i == n_iter-2:
                trimaps_pre_temp = trimap_to_image(trimaps_pre[:4,:3:,:,:])
                writer.add_image('TNet/epoch_'+str(epoch)+'_1', make_grid(imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('TNet/epoch_'+str(epoch)+'_2', make_grid(trimaps_pre_temp, nrow=4, normalize=True),epoch)
                writer.add_image('TNet/epoch_'+str(epoch)+'_3', make_grid(trimaps_gt[:4,:,:,:].float().cpu(), nrow=4, normalize=True),epoch)

            # 打印结果
            print("第 "+str(i)+ " 个batch训练结束")
 
        # 手动释放内存              
        del imgs, trimaps_pre, trimaps_gt, alphas, trimaps_pre_temp

        # 监控损失值变化
        writer.add_scalar('PreTrainTNet/Loss', loss_epoch.val, epoch)    

        # 保存预训练模型
        torch.save({
            'epoch': epoch,
            'model': model.module.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'results/checkpoint_tnet.pth')
    
    # 训练结束关闭监控
    writer.close()


if __name__ == '__main__':
    main()

下图是训练时损失函数变化

PIL 图像语义分割 图像语义分割算法_语义分割_07

整个学习过程在epoch=18时接近收敛位置。共耗时16小时左右(双GTX1080 TI显卡)。下图是通过tensorboard查看的中间训练结果,在epoch=2时的语义分割效果图:

PIL 图像语义分割 图像语义分割算法_PIL 图像语义分割_08

PIL 图像语义分割 图像语义分割算法_深度学习_09

PIL 图像语义分割 图像语义分割算法_PIL 图像语义分割_10

第一行为原始输入图,第二行为预测结果,第三行为标定的真值。可以看到,本文训练结果较好,预测和真值吻合度高。

三.  KNN抠图

前面语义分割模型能够从复杂的背景中得到一个较好的人像边界,一种有用的应用场景就是证件照制作。证件照制作核心步骤在于人像抠图,抠出前景人像后再与纯色背景进行合成,但是如果直接将人像分割的结果进行抠取,那么合成后边界会产生明显的视觉噪点。为了解决上述问题,可以采用抠图算法进行实现,例如KNN抠图。当然也可以采用SHM算法提出的思路,再额外训练一个抠图神经网络模型,但是这样做复杂度比较高。下面我们简单的使用传统抠图算法来完成这个任务,后面有兴趣的读者并且希望进步一提升整体性能的可以尝试SHM算法中关于抠图模块的实现(MNet),并且可以尝试语义分割和抠图整个端到端模型的训练。

首先安装抠图包pymatting:

pip3 install pymatting

这是一个比较综合的抠图工具包,官方网址:https://github.com/pymatting/pymatting,集成了很多抠图算法,包括

Closed Form Alpha Matting、Large Kernel Matting、KNN Matting、Learning Based Digital Matting、Random Walk Matting等。

下面是详细的单张测试代码:

import torch.backends.cudnn as cudnn
import torch
from torch import nn
from models import TNet
from utils import *
import time
import cv2
from pymatting import *


# 测试图像
imgPath = './results/1.jpg'

# 模型参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':

    # 预训练模型
    checkpoint = "./results/checkpoint_tnet.pth"

    # 加载模型
    checkpoint = torch.load(checkpoint)
    model = TNet()

    model = model.to(device)
    model.load_state_dict(checkpoint['model'])

    model.eval()

    # 加载图像
    img_org = cv2.imread(imgPath, cv2.IMREAD_COLOR)
    width = img_org.shape[1]
    height = img_org.shape[0]
    # resize image
    img = cv2.resize(img_org, (400,400), interpolation = cv2.INTER_CUBIC)

    img = (img.astype(np.float32) - (114., 121., 134.,)) / 255.0
    h, w, c = img.shape
    img = torch.from_numpy(img.transpose((2, 0, 1))).view(c, h, w).float()
    img= img.view(1, 3, h, w)

    # 记录时间
    start = time.time()

    # 转移数据至设备
    img = img.to(device)

    # 模型推理
    with torch.no_grad():
        trimap = model(img)
        n, c, h, w = trimap.size()
        if c == 3:
            trimap = torch.argmax(trimap, dim=1, keepdim=False)
        trimap.float().div_(2.0)
        trimap = trimap.float().mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()

        # 缩放
        trimap = erode_dilate(trimap)
        trimap = trimap.reshape(h, w)
        trimap = cv2.resize(trimap, (width,height), interpolation = cv2.INTER_NEAREST)
        cv2.imwrite('./results/trimap.png',trimap)

        # 抠图
        scale = 1.0
        image = load_image(imgPath, "RGB", scale, "box")
        trimap = load_image('./results/trimap.png', "GRAY", scale, "nearest")

        alpha = estimate_alpha_cf(image, trimap)
        bg = np.zeros([height,width,3],np.uint8)
        bg[:,:,:] = np.ones([height,width,3])*255

        foreground, background = estimate_foreground_ml(image, alpha, return_background=True)
        new_image = blend(image, bg, alpha)

        # 保存抠图结果
        grid = make_grid([new_image,])
        save_image("./results/merge.png", grid)


    print('用时  {:.3f} 秒'.format(time.time()-start))

下图展示了部分证件照抠图效果(测试样本来自网络):

PIL 图像语义分割 图像语义分割算法_深度学习_11

PIL 图像语义分割 图像语义分割算法_语义分割_12

PIL 图像语义分割 图像语义分割算法_人像_13

PIL 图像语义分割 图像语义分割算法_深度学习_14

PIL 图像语义分割 图像语义分割算法_pytorch_15

PIL 图像语义分割 图像语义分割算法_PIL 图像语义分割_16

PIL 图像语义分割 图像语义分割算法_语义分割_17

PIL 图像语义分割 图像语义分割算法_PIL 图像语义分割_18

 

在CPU下面单张测试速度(包含完整的语义分割和抠图)5.5秒。

从整体效果来看,依托PSPNet对于场景的的准确识别,整体的证件照抠图效果还是不错的,即使是复杂的背景该算法依然能够准确的对其进行分割。原则上来说,只要语义分割精度高,抠图效果就不会差。另外,可以看到,采用KNN抠图在边界处具有一定的毛刺,这个问题可以留在后面的文章中进行解决,后面我们会详细讲解如何构建一个专门的抠图网络MNet,使得抠图边界更加的自然。另外,本文使用的aifenge数据集本身分割精度并不是很高,如果需要进行产品化,那么首先要做的就是搜集高精度人像分割数据集,最好能够精细到发丝级别,然后在本文基础上进行迁移学习,相信可以达到不错的产品质量。

四. 总结

本文实现了一个有效的人像分割算法并运用于证件照制作,整体效果尚可。如果需要继续提高模型精度,可以扩大训练集规模和多样性,提升算法鲁棒性和精度。如果对算法实时性要求较高,那么可以采用UNet模型。当然,读者也可以阅读最新的文献,采用其它的语义分割模型或者更换更高精度的人像分割数据集。

由于水平有限,文中肯定存在模型理解和代码上的错误和不足,请读者多多指正,共同探讨、共同进步!后续将会尝试复现实例分割等相关内容,如果读者感兴趣可以继续关注。

参考文献

【1】Long J, Shelhamer E, Darrell T, et al. Fully convolutional networks for semantic segmentation[C]. computer vision and pattern recognition, 2014: 3431-3440.

【2】Olaf Ronneberger, Philipp Fischer, Thomas Brox. U-Net: Convolutional Networks for Biomedical Image Segmentation[C]// International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer International Publishing, 2015.

【3】Iglovikov, Vladimir, Shvets, Alexey. TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation[C]// International Conference on Compute Vision and Pattern Recognition. 2018.

【4】Zhao H, Shi J, Qi X, et al. Pyramid Scene Parsing Network[C]. computer vision and pattern recognition, 2017: 6230-6239.

【5】Chen Q, Ge T, Xu Y, et al. Semantic Human Matting[C]. acm multimedia, 2018: 618-626.

【6】Xu N, Price B, Cohen S, et al. Deep Image Matting[C]. computer vision and pattern recognition, 2017: 311-320.

【7】Shen X, Tao X, Gao H, et al. Deep Automatic Portrait Matting[C]. european conference on computer vision, 2016: 92-107.