ResNet残差网络Pytorch实现——cifar10数据集训练


上一篇:​​【对花的种类进行批量数据预测】​​ ✌✌✌✌ ​​【目录】​​ ✌✌✌✌ 下一篇:​​【cifar10数据集的预测】​​


大学生一枚,最近在学习神经网络,写这篇文章只是记录自己的学习历程,本文参考了​​Github上fengdu78老师的文章​​进行学习


✌ 使用ResNet进行对cifar10数据集进行训练

import torchvision
import torch
from torchvision import transforms
import os
import json

import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasets
from tqdm import tqdm

# 加载运算设备
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 数据处理
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# 每个批次的数据大小
batch_size=100

# 加载训练数据,不需要下载
train_dataset = torchvision.datasets.CIFAR10(root='./cifar10',
train=True,
download=False,
transform=data_transform)

# 训练数据的加载器
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True)


# 训练数据大小
train_num=len(train_dataset)

print('using {} images for training.'.format(train_num))

# 预测结果与真实分类的映射
cifar10_list=train_dataset.class_to_idx
cla_dict=dict((value,key) for key,value in cifar10_list.items())
json_str=json.dumps(cla_dict,indent=10)
with open('class_indices.json','w') as json_file:
json_file.write(json_str)

# 构建网络
net=resnet34()

# 加载模型参数
model_weight_path='./resnet34-pre.pth'
net.load_state_dict(torch.load(model_weight_path,map_location=device))

# 将每个参数置为False,反向传播时不会进行梯度更新
for param in net.parameters():
param.requires_grad=False

# 修改全连接层
in_channel=net.fc.in_features
net.fc=nn.Linear(in_channel,10)
net.to(device)

# 交叉熵损失函数
loss_function=nn.CrossEntropyLoss()

# 获得需要训练的参数
params=[p for p in net.parameters() if p.requires_grad]

# 优化器
optimizer=optim.Adam(params,lr=0.0001)

epochs=1
loss_sum=999
save_path='./resNet34_cifar10.pth'
train_steps=len(train_loader)

# 开始训练,所有数据只训练1次
for epoch in range(epochs):
net.train()
running_loss=0
train_bar=tqdm(train_loader)

# 训练集总共50000张图片,我设置的每批数据是100,所以对应是500*100
# 循环500次,每次训练的数据为100张
for data in train_bar:
images,labels=data
optimizer.zero_grad()
output=net(images.to(device))
loss=loss_function(output,labels.to(device))
loss.backward()
optimizer.step()

running_loss+=loss.item()

train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# 保存最好的模型参数
if running_loss/train_steps<loss_sum:
loss_sum=running_loss/train_steps
torch.save(net.state_dict(),save_path)