1. 概述

本文主要是参照 B 站 UP 主 霹雳吧啦Wz 的视频学习笔记

整个工程已经上传个人的 github https://github.com/lovewinds13/QYQXDeepLearning ,下载即可直接测试,数据集文件因为比较大,已经删除了,按照下文教程下载即可。

论文下载:Deep Residual Learning for Image Recognition

2. ResNet

ResNet (deep residual network)在 2015 年由微软实验室提出, 斩获当年 ImageNet 竞赛中分类任务第一名, 目标检测第一名。 获得 COCO 数据集中目标检测第一名, 图像分割第一名。

ResNet 的创新点:

(1)超深的网络结构,突破了 1000 层;

(2)提出 residual 模块;

(3)使用 Batch Normalization 加速训练 (丢弃 dropout )。

ResNet 解决的问题:

(1)ResNet 解决了深度神经网络的 “退化” 问题,即对浅层网络逐渐叠加 layers,模型复杂度变高,但是网络性能却快速下降;

(2)BN (Batch Normalization)层的引入基本解决 plain net 的梯度消失和梯度爆炸的问题。

补充:

梯度消失: 若每一层的误差梯度小于 1,反向传播时,网络越深,梯度越趋近于 0;
梯度爆炸: 若每一层的误差梯度大于 1,反向传播时,网路越深,梯度越来越大。

2.1 网络框架

pytorch使用resnet进行上下采样 resnet pytorch_2d

上图是构建了一个18 层和 34 层的 plain net 作为对比,层与层之间只是简单叠加,之后又构建了一个18 层和 34 层的 residual net,在 plain 网络上加入 shortcut,两个网络的参数量与计算量相同,同时跟 VGG19 相比,计算量小很多。

pytorch使用resnet进行上下采样 resnet pytorch_深度学习_02


左图为 plain net,可以看到加深网络,错误率反而变高了;右图是加入了残差结构的结果,恰好与左图相反。

2.2 残差结构

pytorch使用resnet进行上下采样 resnet pytorch_分类_03


残差块 (Residual block)通过 shortcut connectio n实现,通过 shortcut 将 block 的输入输出进行一个简单的叠加,这样不会给网络增加额外的参数和计算量,同时可以大大增加模型的训练速度,提高训练效果,当模型的深度增加时,很好地解决了退化问题。

pytorch使用resnet进行上下采样 resnet pytorch_深度学习_04


上图中左侧残差结构称为 BasicBlock,右侧残差结构称为 Bottleneck,分别对应不同的网络深度。其中,1×1 的卷积核进行降维和升维(特征矩阵深度),同时减少网络参数。

残差结构

网络层数

网络深度

BasicBlock

ResNet18/34

浅层网络

Bottleneck

ResNet50/101/152

深层网络

对比同样深度输入的残差结构参数数量,注:维度 = 256

pytorch使用resnet进行上下采样 resnet pytorch_2d_05

结构

参数

左边结构

3×3×256×256 + 3×3256256 = 1179648

右边结构

1×1×256×64 + 3×3×64×64 + 1×1×64×256 = 69632

CNN参数个数 
= 卷积核尺寸 × 卷积核深度 × 卷积核组数 
= 卷积核尺寸 × 输入特征矩阵深度 × 输出特征矩阵深度

注意: 使用残差网络时,主分支与 shortcut 的输出特征矩阵 shape 必须相同。

2.3 short cut 连接

在 ResNet 网络中,发现有些残差块的 short cut 是实线的,而有些则是虚线的,如下图。

pytorch使用resnet进行上下采样 resnet pytorch_pytorch_06


虚线的 short cut ,增加了 1×1 的卷积核进行了维度处理,特征矩阵在长宽方向降采样,深度方向调整成下一层残差结构所需要的 channel,其中残差结构的最后一个 relu 激活函数在加入了 short cut 数据之后。

(1)ResNet 18/34 残差结构

pytorch使用resnet进行上下采样 resnet pytorch_深度学习_07


(2)ResNet 50/101/152 残差结构

pytorch使用resnet进行上下采样 resnet pytorch_分类_08

2.4 ResNet 网络结构配置

pytorch使用resnet进行上下采样 resnet pytorch_分类_09

ResNet 网络结构的特点:

(1)增加了 short cut 路径,short cut 首尾构成一个 residual block;

(2)residual block 中无池化操作(pool),降采样通过设置卷积(conv)的步长(stride)完成;

(3)conv3_x, conv4_x, conv5_x 的第一层残差结构都是虚线残差结构,调整输入特征矩阵的 shape,stride =2 将特征矩阵的高和宽缩减为原来的一半,同时将深度(channel )调整为下一层残差结构所需要的深度;

(4)最终特征通过 average pool 输出;

(5)卷积后面紧跟 BN 层,舍弃了 dropout。

ResNet 结构容易修改和扩展,通过调整 residual block 结构内的 channel 数量以及堆叠的 block 数量,就可以调整网络的宽度和深度,从而得到具备不同表达能力的网络,而不用过多地担心网络的“退化”问题。所以,只要训练数据足够,逐步加深网络,就可以获得更好的性能表现。

CIFAR-10 数据集的结果对比:

pytorch使用resnet进行上下采样 resnet pytorch_2d_10

2.4 迁移学习

pytorch使用resnet进行上下采样 resnet pytorch_深度学习_11

在已完成训练网络模型的基础上,根据自身需要,修改相关网络层,适应分类网络。通过迁移学习,在设备、资源有限的情况下,极大提高了效率。比如 VGG 参数庞大,自己重新训练过程缓慢,使用迁移学习的方式可以快速完成新模型训练。

CNN 常见迁移学习的方式:

(1)载入权重后训练所有参数;

(2) 载入权重后只训练最后几层参数;

(3)载入权重后在原网络基础上再添加一层全连接层, 仅训练最后一个全连接层

pytorch使用resnet进行上下采样 resnet pytorch_深度学习_12


如上图 VGG 输出 1000 分类,在这个基础上,可以再添加一个全连接层,对这个进行重新训练,直接只用前面卷积等操作提取的特征参数。

3. demo 实现

3.1 数据集

本文使用花分类数据集,下载链接: 花分类数据集——http://download.tensorflow.org/example_images/flower_photos.tgz

pytorch使用resnet进行上下采样 resnet pytorch_pytorch_13


3.2 model.py

"""
ResNet 模型
"""


import torch.nn as nn
import torch


"""
# 定义 BasicBlock 模块
# ResNet18/34的残差结构, 用的是2个3x3大小的卷积
"""
class BasicBlock(nn.Module):
    expansion = 1   # 残差结构中, 判断主分支的卷积核个数是否发生变化,不变则为1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):   # downsample 对应虚线残差结构
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False
                               )
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False
                               )
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None: # 虚线残差结构,需要下采样
            identity = self.downsample(x)   # 捷径分支short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

"""
# 定义 Bottleneck 模块
# ResNet50/101/152的残差结构,用的是1x1+3x3+1x1的卷积
"""
class Bottleneck(nn.Module):
    """
    #   注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    #  但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    #   这么做的好处是能够在top1上提升大概0.5%的准确率。
    #   可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
       """
    expansion = 4   # 残差结构中第三层卷积核个数是第1/2层卷积核个数的4倍
    def __init__(self, in_channel, out_channel, stride=1, downsample=None, groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width, kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width)

        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1
                               )
        self.bn2 = nn.BatchNorm2d(width)

        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)   # 捷径分支short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

"""
# 残差网络结构
"""
class ResNet(nn.Module):
    # block = BasicBlock or Bottleneck
    # blocks_num 为残差结构中 conv2_x~conv5_x 中残差块个数, 一个列表
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True, groups=1, width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64
        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2  =self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    # channel 为残差结构中第1层卷积核个数
    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        # ResNet50/101/152 的残差结构, block.expansion=4
        if stride != 1 or self.in_channel != channel*block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel*block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel*block.expansion)
            )

        layers =[]
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups = self.groups,
                            width_per_group = self.width_per_group,
                            ))
        self.in_channel = channel*block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups = self.groups,
                                width_per_group = self.width_per_group,
                                ))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x

"""
# resnet34 结构
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
"""
def resnet34(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

"""
# resnet50 结构
# https://download.pytorch.org/models/resnet50-19c8e357.pth
"""
def resnet50(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

"""
# resnet101 结构
# https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
"""
def resnet101(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

"""
# resnext50_32x4d 结构
# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
"""
def resnext50_32x4d(num_classes=1000, include_top=True):
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

"""
# resnext101_32x8d 结构
# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
"""
def resnext101_32x8d(num_classes=1000, include_top=True):
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

"""
测试模型
"""
# if __name__ == '__main__':
#     input1 = torch.rand([224, 3, 224, 224])
#     model_x = resnet34(num_classes=5, include_top=True)
#     print(model_x)
    # output = GoogLeNet(input1)

3.3 train.py

3.3.1 导入包

"""
训练(CPU)
"""

import os
import sys
import json
import time
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm   # 显示进度条模块

from model import resnet34

3.3.2 数据集预处理

data_transform = {
        "train": transforms.Compose([
                                    transforms.RandomResizedCrop(224),  # 随机裁剪, 再缩放为 224*224
                                    transforms.RandomHorizontalFlip(),  # 水平随机翻转
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        "val": transforms.Compose([
                                    # transforms.Resize((224, 224)),  # 元组(224, 224)
                                    transforms.Resize(256),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    }

3.3.3 加载数据集

3.3.3.1 读取数据路径
# data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # 读取数据路径
data_root = os.path.abspath(os.path.join(os.getcwd(), "./"))
image_path = os.path.join(data_root, "data_set", "flower_data")
# image_path = data_root + "/data_set/flower_data/"
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

此处相比于 UP 主教程,修改了读取路径。

3.3.3.2 加载训练集
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"]
                                         )
 train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw
                                               )
3.3.3.3 加载验证集
val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                       transform=data_transform["val"]
                                       )
val_num = len(val_dataset)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=4,
                                             shuffle=False,
                                             num_workers=nw
                                             )
3.3.3.4 保存数据索引
flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open("calss_indices.json", 'w') as json_file:
        json_file.write(json_str)

3.3.4 训练过程

net = resnet34()  # 实例化网络
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-333f7ec4.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)   # (5分类)
    # net.to(device)
    net.to("cpu")   # 直接指定 cpu
    loss_function = nn.CrossEntropyLoss()   # 交叉熵损失
    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)     # 优化器(训练参数, 学习率)

    epochs = 10     # 训练轮数
    save_path = "./ResNet34.pth"
    best_accuracy = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        net.train()     # 开启Dropout
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)     # 设置进度条图标
        for step, data in enumerate(train_bar):     # 遍历训练集,
            images, labels = data   # 获取训练集图像和标签
            optimizer.zero_grad()   # 清除历史梯度
            logits = net(images)
            loss = loss_function(logits, labels)   # 计算损失值
            loss.backward()     # 方向传播
            optimizer.step()    # 更新优化器参数
            running_loss += loss.item()
            train_bar.desc = "train epoch [{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                      epochs,
                                                                      loss
                                                                      )
        # 验证
        net.eval()      
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels).sum().item()
        val_accuracy = acc / val_num
        print("[epoch %d ] train_loss: %3f    val_accurancy: %3f" %
              (epoch + 1, running_loss / train_steps, val_accuracy))
        if val_accuracy > best_accuracy:    # 保存准确率最高的
            best_accuracy = val_accuracy
            torch.save(net.state_dict(), save_path)
    print("Finished Training.")

训练过程可视化信息输出:

pytorch使用resnet进行上下采样 resnet pytorch_ide_14

GPU 训练代码: 仅在 CPU 训练的基础上做了数据转换处理。

"""
训练(GPU)
"""
import os
import sys
import json
import time
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import resnet34


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"use device is {device}")

    data_transform = {
        "train": transforms.Compose([
                                    transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        "val": transforms.Compose([
                                    transforms.Resize(256),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    }
    # data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # 读取数据路径
    data_root = os.path.abspath(os.path.join(os.getcwd(), "./"))
    image_path = os.path.join(data_root, "data_set", "flower_data")
    # image_path = data_root + "/data_set/flower_data/"
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"]
                                         )
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open("calss_indices.json", 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # 线程数计算
    nw = 0
    print(f"Using {nw} dataloader workers every process.")

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw
                                               )
    val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                       transform=data_transform["val"]
                                       )
    val_num = len(val_dataset)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=4,
                                             shuffle=False,
                                             num_workers=nw
                                             )
    print(f"Using {train_num} images for training, {val_num} images for validation.")

    # test_data_iter = iter(val_loader)
    # test_image, test_label = next(test_data_iter)

    """ 测试数据集图片"""
    # def imshow(img):
    #     img = img / 2 + 0.5
    #     np_img = img.numpy()
    #     plt.imshow(np.transpose(np_img, (1, 2, 0)))
    #     plt.show()
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    net = resnet34()    # 实例化网络
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-333f7ec4.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)  # (5分类)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    # optimizer = optim.Adam(net.parameters(), lr=0.0001)
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    epochs = 10
    save_path = "./ResNet34_GPU.pth"
    best_accuracy = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device)) # 计算损失函数
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch [{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                      epochs,
                                                                      loss
                                                                      )
        # 验证
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accuracy = acc / val_num
        print("[epoch %d ] train_loss: %3f    val_accurancy: %3f" %
              (epoch + 1, running_loss / train_steps, val_accuracy))
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(net.state_dict(), save_path)
    print("Finished Training.")

if __name__ == '__main__':
    main()

3.3.5 结果预测

"""
预测
"""

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import resnet34


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    image_path = "./daisy01.jpg"
    img = Image.open(image_path)
    plt.imshow(img)
    img = data_transform(img)   # [N, C H, W]
    img = torch.unsqueeze(img, dim=0)   # 维度扩展
    # print(f"img={img}")
    json_path = "./calss_indices.json"
    with open(json_path, 'r') as f:
        class_indict = json.load(f)

    # model = AlexNet(num_classes=5).to(device)   # GPU
    # model = vgg(model_name="vgg16", num_classes=5)  # CPU
    model = resnet34(num_classes=5)
    weights_path = "./ResNet34.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))
    model.eval()
    with torch.no_grad():
        # output = torch.squeeze(model(img.to(device))).cpu()   #GPU
        output = torch.squeeze(model(img))      # 维度压缩
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
        print_res = "class: {}  prob: {:.3}".format(class_indict[str(predict_cla)],
                                                    predict[predict_cla].numpy())
        plt.title(print_res)
        # for i in range(len(predict)):
        #     print("class: {}  prob: {:.3}".format(class_indict[str(predict_cla)],
        #                                             predict[predict_cla].numpy()))
        plt.show()

if __name__ == '__main__':
    main()

预测结果如下:

pytorch使用resnet进行上下采样 resnet pytorch_深度学习_15

可以看到,预测的准确率为 0.988,相比于前面的网络,大幅提升。