deeplabv3 pytorch 训练 deeplabv3代码pytorch_卷积

import torch  # 导入 PyTorch 库
import torch.nn as nn  # 导入 PyTorch 神经网络模块
import torch.nn.functional as F  # 导入 PyTorch 函数模块
from torchvision.models.segmentation import deeplabv3_resnet50  # 从 torchvision 导入预训练的 DeepLabv3 ResNet50 模型


class DeepLabV3Plus(nn.Module):  # 定义 DeepLabV3Plus 类,继承自 nn.Module
    def __init__(self, num_classes=21, pretrained_backbone=True):  # 初始化方法,参数包括类别数,默认值为 21,和是否使用预训练的骨干网络,默认为 True
        super(DeepLabV3Plus, self).__init__()  # 调用父类的初始化方法
        self.deeplabv3 = deeplabv3_resnet50(pretrained_backbone=pretrained_backbone)  # 使用预训练的 DeepLabv3 ResNet50 模型
        self.deeplabv3.classifier = DeepLabHead(2048, num_classes)  # 用自定义的 DeepLabHead 替换原有的分类器

    def forward(self, x):  # 定义前向传播方法
        return self.deeplabv3(x)  # 调用 deeplabv3 模型的前向传播


class DeepLabHead(nn.Sequential):  # 定义 DeepLabHead 类,继承自 nn.Sequential
    def __init__(self, in_channels, num_classes):  # 初始化方法,参数包括输入通道数和类别数
        super(DeepLabHead, self).__init__(
            ASPP(in_channels, 256),  # 添加 ASPP 模块
            nn.Conv2d(256, 256, 3, padding=1, bias=False),  # 添加卷积层
            nn.BatchNorm2d(256),  # 添加批量归一化层
            nn.ReLU(),  # 添加 ReLU 激活函数
            nn.Conv2d(256, num_classes, 1)  # 添加最后的卷积层,用于类别预测
        )


class ASPP(nn.Module):  # 定义 ASPP(空洞空间金字塔池化)类,继承自 nn.Module
    def __init__(self, in_channels, out_channels, atrous_rates=None):  # 初始化方法,参数包括输入通道数、输出通道数和空洞率列表
        super(ASPP, self).__init__()
        if atrous_rates is None:  # 如果没有提供空洞率列表,则使用默认值
            atrous_rates = [6, 12, 18]

        layers = []  # 创建一个空列表,用于存放 ASPP 模块的层
        # 添加一个卷积层、批量归一化层和 ReLU 激活函数
        layers.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU
        ))
        for rate in atrous_rates:  # 遍历空洞率列表
            layers.append(ASPPConv(in_channels, out_channels, rate))  # 添加 ASPPConv 层,使用当前空洞率

        self.convs = nn.ModuleList(layers)  # 将 layers 列表转换为 ModuleList
        self.global_pooling = nn.Sequential(  # 定义全局平均池化层
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.out_conv = nn.Sequential(  # 定义输出卷积层
            nn.Conv2d(out_channels * (2 + len(atrous_rates)), out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):  # 定义 ASPP 类的前向传播方法
        x_pool = self.global_pooling(x)  # 对输入 x 进行全局平均池化
        x_pool = F.interpolate(x_pool, size=x.shape[2:], mode='bilinear', align_corners=False)  # 将池化结果上采样到原始尺寸
        x_aspp = [x_pool] + [conv(x) for conv in self.convs]  # 对输入 x 应用 ASPPConv 层
        x = torch.cat(x_aspp, dim=1)  # 将上采样的全局池化结果和 ASPPConv 层的结果沿通道维度拼接
        return self.out_conv(x)  # 应用输出卷积层


class ASPPConv(nn.Sequential):  # 定义 ASPPConv 类,继承自 nn.Sequential
    def __init__(self, in_channels, out_channels, dilation):  # 初始化方法,参数包括输入通道数、输出通道数和空洞率
        super(ASPPConv, self).__init__(
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),  # 添加带空洞的卷积层
            nn.BatchNorm2d(out_channels),  # 添加批量归一化层
            nn.ReLU()  # 添加 ReLU 激活函数
        )


if __name__ == '__main__':
    model = DeepLabV3Plus(num_classes=21)  # 创建一个 DeepLabV3Plus 模型实例,指定 21 个类别
    print(model)  # 输出模型结构
  1. 输入图像:DeepLab v3+接受一张RGB图像作为输入。输入图像的尺寸可能会被调整,以适应模型的要求。
  2. 骨干网络(Backbone):输入图像首先传入骨干网络。骨干网络是一个卷积神经网络,负责从图像中提取特征。常见的骨干网络有ResNet、MobileNet和Xception。骨干网络输出一个特征图(Feature Map)。
  3. 空洞空间金字塔池化(Atrous Spatial Pyramid Pooling,ASPP):特征图经过ASPP模块,该模块使用不同采样率的空洞卷积并行处理特征图。这种方法提高了模型在不同尺度的物体上的性能。然后,所有不同采样率的输出被连接在一起,生成一个更丰富的特征表示。
  4. 编码器-解码器结构:DeepLab v3+在原始DeepLab v3基础上加入了编码器-解码器结构,以获得更精确的分割结果。编码器部分可以被认为是骨干网络和ASPP模块的组合。
  5. 解码器部分的任务是将编码器输出的粗糙特征图上采样(Upsample)到原始输入图像的分辨率。解码器首先对编码器输出执行双线性上采样。然后,解码器从骨干网络中获取低级特征(低级特征包含更多的空间信息),与上采样的特征图连接在一起。最后,通过一个1x1卷积层进行通道维度的降维,再进行上采样,得到与输入图像分辨率相同的输出特征图。
  6. 输出分割结果:输出特征图的通道数等于预定义的类别数量。对于每个像素,模型计算每个类别的概率。通常通过在输出特征图上应用Softmax激活函数来实现。每个像素被分配到具有最高概率的类别,形成最终的分割结果。