0 引言

随着深度学习领域中各类算法的迅速发展,卷积神经网络(CNN)被广泛应用在了分类任务上,输出的结果是整个图像的类标签。在生物医学领域,医生需要对病人的病灶区域进行病理分析,这时需要一种更先进的网络模型,即能通过少量的图片训练集,就能实现对像素点类别的预测,并且可以对像素点进行着色绘图,形成更复杂、严谨的判断。于是U-Net网络被设计了出来。

1 U-Net概念及原理

U-Net网络结构最早由Ronneberger等人于2015年提出。该图像的核心思想是引入了跳跃连接,使得图像分割的精度大大提升。
U-Net网络的主要结构包括了解码器编码器瓶颈层三个部分。

  • 编码器:包括了四个程序块。每个程序块都包括 POINTNET是图神经网络 神经网络finetune_卷积 的卷积(使用Relu激活函数),步长为 POINTNET是图神经网络 神经网络finetune_POINTNET是图神经网络_02POINTNET是图神经网络 神经网络finetune_卷积_03
  • 解码器: 与编码器部分对称,也包括四个程序块,每个程序块包括步长为 POINTNET是图神经网络 神经网络finetune_POINTNET是图神经网络_02POINTNET是图神经网络 神经网络finetune_卷积_03 的上采样操作,然后与编码部分进行特征映射级联(Concatenate),即拼接,最后通过两个 POINTNET是图神经网络 神经网络finetune_卷积
  • 瓶颈层:包含两个 POINTNET是图神经网络 神经网络finetune_卷积

最后经过一个 POINTNET是图神经网络 神经网络finetune_POINTNET是图神经网络_08的卷积层得到最后的输出。

POINTNET是图神经网络 神经网络finetune_深度学习_09


如图所示,该网络模型形似字母“U”,故称为U-Net。

整体过程:
先对图片进行卷积和池化。比如说一开始输入的图片大小是 POINTNET是图神经网络 神经网络finetune_深度学习_10,进过四次池化后,分别得到 POINTNET是图神经网络 神经网络finetune_POINTNET是图神经网络_11 , POINTNET是图神经网络 神经网络finetune_深度学习_12 , POINTNET是图神经网络 神经网络finetune_神经网络_13, POINTNET是图神经网络 神经网络finetune_2d_14 四个不同尺寸的特征图。然后对 POINTNET是图神经网络 神经网络finetune_神经网络_15 的特征图做上采样,得到 POINTNET是图神经网络 神经网络finetune_深度学习_16 的特征图。将这个 POINTNET是图神经网络 神经网络finetune_深度学习_16的特征图与之前池化得到的 POINTNET是图神经网络 神经网络finetune_深度学习_16 特征图进行通道上的拼接(concat),然后再对拼接之后的特征图做卷积和上采样,得到 POINTNET是图神经网络 神经网络finetune_卷积_19 的特征图,然后再与之前的 POINTNET是图神经网络 神经网络finetune_深度学习_12

在本图片上的U-Net中,它输入大小为 POINTNET是图神经网络 神经网络finetune_2d_21, 而输出大小为 POINTNET是图神经网络 神经网络finetune_深度学习_22, 那是因为它在卷积过程中没有加padding层所造成的。

2 代码

import torch
import torch.nn as nn
import torch.nn.functional as F

# Double Convolution
class DoubleConv2d(nn.Module):
    def __init__(self, inputChannel, outputChannel):
        super(DoubleConv2d, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inputChannel, outputChannel, kernel_size=3, padding=1),
            nn.BatchNorm2d(outputChannel),
            nn.ReLU(True),
            nn.Conv2d(outputChannel, outputChannel, kernel_size=3, padding=1),
            nn.BatchNorm2d(outputChannel),
            nn.ReLU(True)
        )

    def forward(self, x):
        out = self.conv(x)
        return out


# Down Sampling
class DownSampling(nn.Module):
    def __init__(self):
        super(DownSampling, self).__init__()
        self.down = nn.MaxPool2d(kernel_size=2)

    def forward(self, x):
        out = self.down(x)
        return out


# Up Sampling
class UpSampling(nn.Module):

    # Use the deconvolution
    def __init__(self, inputChannel, outputChannel):
        super(UpSampling, self).__init__()
        self.up = nn.Sequential(
            nn.ConvTranspose2d(inputChannel, outputChannel, kernel_size=2, stride=2),
            nn.BatchNorm2d(outputChannel)
        )

    def forward(self, x, y):
        x =self.up(x)
        diffY = y.size()[2] - x.size()[2]
        diffX = y.size()[3] - x.size()[3]
        x = F.pad(x, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        out = torch.cat([y, x], dim=1)
        return out


class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.layer1 = DoubleConv2d(1, 64)
        self.layer2 = DoubleConv2d(64, 128)
        self.layer3 = DoubleConv2d(128, 256)
        self.layer4 = DoubleConv2d(256, 512)
        self.layer5 = DoubleConv2d(512, 1024)
        self.layer6 = DoubleConv2d(1024, 512)
        self.layer7 = DoubleConv2d(512, 256)
        self.layer8 = DoubleConv2d(256, 128)
        self.layer9 = DoubleConv2d(128, 64)

        self.layer10 = nn.Conv2d(64, 2, kernel_size=3, padding=1)  # The last output layer

        self.down = DownSampling()
        self.up1 = UpSampling(1024, 512)
        self.up2 = UpSampling(512, 256)
        self.up3 = UpSampling(256, 128)
        self.up4 = UpSampling(128, 64)

    def forward(self, x):
        conv1 = self.layer1(x)
        down1 = self.down(conv1)
        conv2 = self.layer2(down1)
        down2 = self.down(conv2)
        conv3 = self.layer3(down2)
        down3 = self.down(conv3)
        conv4 = self.layer4(down3)
        down4 = self.down(conv4)
        conv5 = self.layer5(down4)
        up1 = self.up1(conv5, conv4)
        conv6 = self.layer6(up1)
        up2 = self.up2(conv6, conv3)
        conv7 = self.layer7(up2)
        up3 = self.up3(conv7, conv2)
        conv8 = self.layer8(up3)
        up4 = self.up4(conv8, conv1)
        conv9 = self.layer9(up4)
        out = self.layer10(conv9)
        return out


# Test part

mynet = Unet()
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# mynet.to(device)
input = torch.rand(3, 1, 572, 572)
# output = mynet(input.to(device))
output = mynet(input)
print(output.shape)  # (3,2,572,572)

https://www.jianshu.com/p/a73f74992b1a https://arxiv.org/pdf/1505.04597v1.pdf