环境:win10+python3.9+pytorch1.8.2
开源工程链接:https://github.com/kuangliu/pytorch-cifar
1、准备cifar-10的数据:
链接: https://pan.baidu.com/s/1nJOtE2QV4AAA34cnOYU8uQ
提取码:pni8
2、配置好训练配置:
'''Train CIFAR10 with PyTorch.'''
'''https://github.com/kuangliu/pytorch-cifar'''
'''https://blog.csdn.net/xu_fu_yong/article/details/92848502?utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-9.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-9.control'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from models import *
import torch.optim.lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import os
import cv2 as cv
import numpy as np
# 参数说明
max_epoch = 100 # 迭代次数
test_epoch = 5
display = 100
train_batch_size = 64
val_batch_size = 32
# 训练图片数据路径
train_data_dir = './data/cifar-10/train/'
# 测试图片数据路径
test_data_dir = './data/cifar-10/test/'
# 模型保存
save_model_path = './torch_cifar-10_Lenet.pth'
# 模型最优保存
save_best_model_path = './torch_cifar-10_Lenet_best.pth'
class MyDataset_Cifar10(Dataset):
def __init__(self, image_dir):
self.root_dir = image_dir
self.name_list = os.listdir(image_dir)
self.label_list = []
for i in self.name_list:
id, name = i.split('_')
id = id[:]
self.label_list.append(id)
def __len__(self):
return len(self.name_list)
def __getitem__(self, item):
img = cv.imread(self.root_dir + self.name_list[item])
img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
img = cv.resize(img, (32, 32))
img = img.transpose((2, 0, 1))
img = torch.tensor(img)
label = self.label_list[item]
label = torch.tensor(int(label))
return img, label
#net = AlexNet()
#net = VGG('VGG11')
#net = ResNet18()
net = LeNet()
#net = MyNet32()
#net = MyNet()
net.cuda()
# best_model = net.state_dict()
# best_acc = 0.0
cross_entropy_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)
# 加载数据
train_set = MyDataset_Cifar10(train_data_dir)
val_set = MyDataset_Cifar10(test_data_dir)
trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
valloader = torch.utils.data.DataLoader(val_set, batch_size=val_batch_size, shuffle=False)
def train(start_epoch):
best_model = net.state_dict()
best_acc = 0.0
for e in range(start_epoch + 1, max_epoch):
print('Epoch {}/{}'.format(e, max_epoch))
print('-' * 10)
net.train()
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
optimizer.zero_grad()
outputs = net(inputs)
loss = cross_entropy_loss(outputs, labels)
loss.backward()
optimizer.step()
#scheduler.step()
if i % display == 0:
print('{} train loss:{} learning rate:{}'.
format(i * train_batch_size, loss.item(), optimizer.param_groups[0]['lr']))
if e % test_epoch == 0:
print('testing...')
net.eval()
acc = 0
with torch.no_grad():
for i, data in enumerate(valloader, 0):
inputs, labels = data
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
outputs = net(inputs)
_, preds = torch.max(outputs.data, 1)
acc += torch.sum(preds == labels.data)
acc = acc.item() / len(val_set)
print('val acc:{}'.format(acc))
if acc > best_acc:
best_acc = acc
best_model = net.state_dict()
state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': e}
torch.save(state, save_model_path)
torch.save(best_model, save_best_model_path)
scheduler.step()
def test():
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
img = cv.imread('D:/cat.jpg')
img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
img = cv.resize(img, (32, 32))
img = img.transpose((2, 0, 1))
img = torch.tensor(img)
img = img.unsqueeze(0)
img = img.cuda()
# net = AlexNet()
# net.cuda()
checkpoint = torch.load(save_model_path)
net.load_state_dict(checkpoint['model'])
net.eval()
outputs = net(img)
_, preds = torch.max(outputs.data, 1)
print(classes[preds.item()])
if __name__ == '__main__':
# 如果有保存的模型,则加载模型,并在其基础上继续训练
if os.path.exists(save_model_path):
checkpoint = torch.load(save_model_path)
net.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
print('加载 epoch {} 成功!'.format(start_epoch))
else:
start_epoch = 0
print('无保存模型,将从头开始训练!')
train(start_epoch)
#test()
3、训练好模型后,转换成onnx模型:
import torch
from models import *
model = LeNet()
checkpoint = torch.load("./torch_cifar-10_Lenet_best.pth")
model.load_state_dict(checkpoint)
model.eval()
batch_size = 1
dummy_input = torch.randn(batch_size, 3, 32, 32)
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: 'batch'},"output": {0: 'batch'}}
torch.onnx.export(model, dummy_input, "./torch_cifar-10_Lenet_best.onnx",
verbose=True, input_names=input_names, output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11)