环境:win10+python3.9+pytorch1.8.2

开源工程链接:https://github.com/kuangliu/pytorch-cifar


1、准备cifar-10的数据:

链接: https://pan.baidu.com/s/1nJOtE2QV4AAA34cnOYU8uQ

提取码:pni8

cifar-10图像分类模型训练(pytorch)_cifar-10

cifar-10图像分类模型训练(pytorch)_pytorch_02

cifar-10图像分类模型训练(pytorch)_图像分类_03


2、配置好训练配置:

cifar-10图像分类模型训练(pytorch)_onnx模型_04

cifar-10图像分类模型训练(pytorch)_cifar-10_05

cifar-10图像分类模型训练(pytorch)_onnx模型_06

cifar-10图像分类模型训练(pytorch)_onnx模型_07

cifar-10图像分类模型训练(pytorch)_onnx模型_08

cifar-10图像分类模型训练(pytorch)_pytorch_09

'''Train CIFAR10 with PyTorch.'''
'''https://github.com/kuangliu/pytorch-cifar'''
'''https://blog.csdn.net/xu_fu_yong/article/details/92848502?utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-9.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-9.control'''

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
import torch.optim.lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import os
import cv2 as cv
import numpy as np

# 参数说明
max_epoch = 100  # 迭代次数
test_epoch = 5
display = 100
train_batch_size = 64
val_batch_size = 32

# 训练图片数据路径
train_data_dir = './data/cifar-10/train/'
# 测试图片数据路径
test_data_dir = './data/cifar-10/test/'
# 模型保存
save_model_path = './torch_cifar-10_Lenet.pth'
# 模型最优保存
save_best_model_path = './torch_cifar-10_Lenet_best.pth'


class MyDataset_Cifar10(Dataset):
    def __init__(self, image_dir):
        self.root_dir = image_dir
        self.name_list = os.listdir(image_dir)
        self.label_list = []
        for i in self.name_list:
            id, name = i.split('_')
            id = id[:]
            self.label_list.append(id)

    def __len__(self):
        return len(self.name_list)

    def __getitem__(self, item):
        img = cv.imread(self.root_dir + self.name_list[item])
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
        img = cv.resize(img, (32, 32))
        img = img.transpose((2, 0, 1))
        img = torch.tensor(img)
        label = self.label_list[item]
        label = torch.tensor(int(label))
        return img, label


#net = AlexNet()
#net = VGG('VGG11')
#net = ResNet18()
net = LeNet()
#net = MyNet32()
#net = MyNet()
net.cuda()

# best_model = net.state_dict()
# best_acc = 0.0

cross_entropy_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)

# 加载数据
train_set = MyDataset_Cifar10(train_data_dir)
val_set = MyDataset_Cifar10(test_data_dir)
trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
valloader = torch.utils.data.DataLoader(val_set, batch_size=val_batch_size, shuffle=False)


def train(start_epoch):
    best_model = net.state_dict()
    best_acc = 0.0
    for e in range(start_epoch + 1, max_epoch):
        print('Epoch {}/{}'.format(e, max_epoch))
        print('-' * 10)

        net.train()
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs = Variable(inputs.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = cross_entropy_loss(outputs, labels)

            loss.backward()
            optimizer.step()
            #scheduler.step()

            if i % display == 0:
                print('{} train loss:{} learning rate:{}'.
                      format(i * train_batch_size, loss.item(), optimizer.param_groups[0]['lr']))

        if e % test_epoch == 0:
            print('testing...')
            net.eval()
            acc = 0
            with torch.no_grad():
                for i, data in enumerate(valloader, 0):
                    inputs, labels = data
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())

                    outputs = net(inputs)
                    _, preds = torch.max(outputs.data, 1)
                    acc += torch.sum(preds == labels.data)

            acc = acc.item() / len(val_set)
            print('val acc:{}'.format(acc))

            if acc > best_acc:
                best_acc = acc
                best_model = net.state_dict()

        state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': e}
        torch.save(state, save_model_path)
        torch.save(best_model, save_best_model_path)
        scheduler.step()


def test():
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    img = cv.imread('D:/cat.jpg')
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
    img = cv.resize(img, (32, 32))
    img = img.transpose((2, 0, 1))
    img = torch.tensor(img)
    img = img.unsqueeze(0)
    img = img.cuda()

    # net = AlexNet()
    # net.cuda()
    checkpoint = torch.load(save_model_path)
    net.load_state_dict(checkpoint['model'])
    net.eval()

    outputs = net(img)
    _, preds = torch.max(outputs.data, 1)
    print(classes[preds.item()])


if __name__ == '__main__':

    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(save_model_path):
        checkpoint = torch.load(save_model_path)
        net.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')

    train(start_epoch)

    #test()


3、训练好模型后,转换成onnx模型:

cifar-10图像分类模型训练(pytorch)_图像分类_10

import torch
from models import *

model = LeNet()
checkpoint = torch.load("./torch_cifar-10_Lenet_best.pth")
model.load_state_dict(checkpoint)
model.eval()

batch_size = 1
dummy_input = torch.randn(batch_size, 3, 32, 32)

input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: 'batch'},"output": {0: 'batch'}}
torch.onnx.export(model, dummy_input, "./torch_cifar-10_Lenet_best.onnx",
                  verbose=True, input_names=input_names, output_names=output_names,
                  dynamic_axes=dynamic_axes,
                  opset_version=11)