PyTorch搭建DenseNet

介绍

在深度学习领域,DenseNet是一种非常流行的卷积神经网络架构。它通过直接连接不同层的特征图来加强信息的传递和重用,使得网络更加紧凑且易于训练。本文将介绍如何使用PyTorch框架搭建一个DenseNet网络,并提供相应的代码示例。

导入相关库

首先,我们需要导入PyTorch和其他必要的库。

import torch
import torch.nn as nn
import torch.optim as optim

构建基本模块

我们将首先构建DenseNet的基本模块——DenseBlock和TransitionBlock。

DenseBlock

DenseBlock是DenseNet的核心部分,由多个密集连接的卷积层组成。每个卷积层的输入是前面所有层的输出的连结。

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            in_channels += i * growth_rate
            self.layers.append(self._make_dense_layer(in_channels, growth_rate))
    
    def _make_dense_layer(self, in_channels, growth_rate):
        return nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False)
        )
    
    def forward(self, x):
        for layer in self.layers:
            out = layer(x)
            x = torch.cat((x, out), dim=1)
        return x

TransitionBlock

TransitionBlock用于将DenseBlock的输出通道数减少,并进行下采样。

class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionBlock, self).__init__()
        self.transition = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
    
    def forward(self, x):
        return self.transition(x)

构建DenseNet

我们将使用上述基本模块构建整个DenseNet网络。

class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=10):
        super(DenseNet, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        num_channels = 64
        self.dense_blocks = nn.ModuleList()
        self.transition_blocks = nn.ModuleList()
        
        for i, num_layers in enumerate(block_config):
            block = DenseBlock(num_channels, growth_rate, num_layers)
            self.dense_blocks.append(block)
            num_channels += num_layers * growth_rate
            
            if i != len(block_config) - 1:
                transition = TransitionBlock(num_channels, num_channels // 2)
                self.transition_blocks.append(transition)
                num_channels = num_channels // 2
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(num_channels, num_classes)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        for dense_block, transition_block in zip(self.dense_blocks, self.transition_blocks):
            x = dense_block(x)
            x = transition_block(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

流程图

下面是整个DenseNet网络的流程图:

flowchart TD
    A[输入] --> B[卷积层]
    B --> C[Batch Normalization]
    C --> D[ReLU]
    D --> E[最大池化]
    E --> F[密集块1]
    F --> G[过渡块1]
    G --> H[密集块2]
    H --> I[过渡块2]
    I --> J[密集块3]