Pytorch 实现自己的残差网络图片分类器
本文主要讨论网络的代码实现,不对其原理进行深究。
一、项目模块包导入。
import sys
import getopt
import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.models as models
torch和torchvision两个包可能无法运用pip或者conda进行下载,需要从其他资源网站查找下载。
二、模型导入处理:
model = models.resnet18(pretrained=True)
在models包中封装了当前比较流行的大多数神经网络模型,参数pretrained=True表示导入模型的预训练模型,False表示仅导入网络结构而不包含参数。
fc_features = model.fc.in_features
model.fc = nn.Linear(fc_features, 102)
获取网络的参数信息,这里我们选取了网络的全连接层的参数信息,将其分类结果的数量设置为我们的问题标签数量,这里我们对102种花进行分类,所以设置的数值为102。
三、网络训练参数设置
inputfile = ''
savefile = ''
testfile=''
try:
opts, args = getopt.getopt(argv,'hi:t:s:',["ifile =","tfile =""sfile ="])
except getopt.GetoptError:
print("train.py -i <input_data_file> -t <test_data_file> -s <save_model_file>")
sys.exit(2)
for opt,arg in opts:
if opt =="-h":
print("train.py -i <input_data_file> -t <test_data_file> -s <save_model_file>")
sys.exit()
elif opt in ("-i","--ifile"):
inputfile = arg
elif opt in("-s","--sfile"):
savefile = arg
elif opt in("-t","--tfile"):
testfile = arg
print("the input datasets is :",inputfile)
print("the test datasets is :",testfile)
print("the save model file is :",savefile)
# parameters
arch = 'resnet18'
lr = 0.05
momentum = 0.9
weight_decay = 1e-4
resume = ''
epochs = 30
start_epoch = 0
evaluate = 0
best_prec1 = 0
print_freq = 10
inputfile,testfile,savefile分别对应训练数据集根目录,测试数据集根目录和模型保存目录,定义在main函数中通过cmd运行命令输入。lr代表learning rate表示学习率,学习率大,网络收敛速度快,但可能收敛效果不理想,后期网络效果起伏大。学习率小,收敛速度慢,后期网络效果起伏较小。epochs表示学习周期,训练所有的图片一次及为一个周期。有些参数是官网demo中的参数,这里可能不会用到,如果有其他需要,可以在本文的基础上阅读此链接中的demo代码:pytorch ResNet
四、图片数据导入
# data preparing
train_dir = inputfile
valid_dir = testfile
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
batch_size=10
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
# TODO: Load the datasets with ImageFolder
train_datasets = dsets.ImageFolder(train_dir,
data_transforms)
# TODO: Using the image datasets and the trainforms, define the dataloaders
trainloader = torch.utils.data.DataLoader(dataset = train_datasets,
batch_size=batch_size,
shuffle=True,
drop_last=False,
pin_memory=True)
valid_datasets = dsets.ImageFolder(valid_dir,
data_transforms)
validloader = torch.utils.data.DataLoader(dataset = valid_datasets,
batch_size=batch_size,
shuffle=True,
drop_last=False,
pin_memory=True)
batch_size是指一次性放入内存中训练的图片数量,如果放入数量太多,可能会引起程序奔溃。数量太少训练的速度会受到影响。这段代码中将训练路径中的图片路径和标签读入整合成DataLoader类型可以更加方便的使用模型进行训练,如果有需要使用自己的具体图片,其格式可以参考本文后的对单张图片进行格式更改板块的内容。
data_transforms包装了对于图片的一系列操作,Resize将图片更改为指定大小,Centercrop截取中心部位指定大小的图片,Totensro将图片转化为tensor格式,normalize将图片归一化。这里我将图片更改为大小为224*224的大小。
trainloader的使用中有一点需要注意,pytorch可以直接通过ImageFolder读取数据文件根目录,并通过其中的子目录将子目录中的图片归为一类,但是会为其自动分配一个数据标签,在后面的训练和预测过程中所使用的也是其自主生成的标签。
五、设置损失函数和分类器
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr,
momentum=momentum,
weight_decay=weight_decay)
cudnn.benchmark = True
六、训练网络
def train(train_loader, model, criterion, optimizer, epoch):
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
print_freq = 10
# switch to train mode
model.train()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
target = target.cuda(non_blocking=True)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
if i % print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.avg:.3f}\t'
'Prec@5 {top5.avg:.3f}'.format(
epoch, i, len(train_loader),
loss=losses, top1=top1, top5=top5))
# training
for epoch in range(start_epoch, epochs):
# train for one epoch
train(trainloader, model, criterion, optimizer, epoch)
prec1 = validate(validloader, model, criterion)
可输出每个周期的训练效果,返回训练结果。
七、测试网络
def validate(val_loader, model, criterion):
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
print_freq = 10
# switch to evaluate mode
model.eval()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
target = target.cuda(non_blocking=True)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# measure elapsed time
if i % print_freq == 0:
print('Test: [{0}/{1}]\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.avg:.3f}\t'
'Prec@5 {top5.avg:.3f}'.format(
i, len(val_loader), loss=losses,
top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
调用测试集对网络效果进行测试,显示top1和top5的平均正确率,返回top1平均正确率。
八、保存模型
def save_checkpoint(model, filename):
torch.save(model, filename)
因为每次训练模型的耗时较长,最好的办法是每次训练结束后把模型保存下来。
九、读取模型
new_model = torch.load(inputfile)
inputfile是模型的存放地址。
到此我们的网络就基本搭建完成了。下面我们来看看如何用一张图片进行预测。
一、包导入
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import glob, os
import sys
import getopt
import torch
二、参数输入
def main(argv):
inputfile = ''
topk = 1
try:
opts, args = getopt.getopt(argv,'hi:p:t:',["imfile =","ipfile =","topk ="])
except getopt.GetoptError:
print("train.py -i <input_model_file> -p <imput_picture_file> -t <top k most likely classes>")
sys.exit(2)
for opt,arg in opts:
if opt =="-h":
print("train.py -i <input_model_file> -p <imput_picture_file> -t <top k most likely classes>")
sys.exit()
elif opt in ("-i","--imfile"):
inputfile = arg
elif opt in("-p","--ipfile"):
picfile = arg
elif opt in("-t","--topk"):
topk = int(arg)
print("the input model is :",inputfile)
print("the input picture file is :",picfile)
print("the top k is", str(topk))
三、模型导入
new_model = torch.load(inputfile)
四、图片预处理
def process_image(image):
''' Scales, crops, and normalizes a PIL image for a PyTorch model,
returns an Numpy array
'''
im = Image.open(image)
im = im.resize((256,256))
im = im.crop((16,16,240,240))
# TODO: Process a PIL image for use in a PyTorch modela
np_image = np.array(im)
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
m = np.mean(np_image)
mx = np.max(np_image)
mn = np.min(np_image)
np_image = (np_image - mn) / (mx - mn)
np_image = (np_image - mean) / std
np_image = np_image.transpose(2,0,1)
return torch.from_numpy(np_image)
因为pytorch中的图片格式为颜色通道为1号位,而我们读取的图片矩阵的颜色通道为3号位,所以我们这里需要使用
np_image = np_image.transpose(2,0,1)
转换一下,将3号位的数据移到1号位,保持另外两位顺序不变。
五、图片增维
def predict(image_path, model, topk):
''' Predict the class (or classes) of an image using a trained deep learning model.'''
myinput = process_image(image_path)
myinput = torch.unsqueeze(myinput,0)
myinput = myinput.float()
output = model(myinput)
return output.topk(topk)
还记得我们在训练图片的时候使用的batch_size参数么,我们使用mode进行训练时,使用的其实不是单张图片,而是一个图片集,所以输入应为4维矩阵,这里我们使用1张图片进行训练,所以说需要使用
myinput = torch.unsqueeze(myinput,0)
增加图片数量维度。
六、图片预测
plt.subplot(311)
im = Image.open(picfile)
plt.imshow(im)
label = picfile.split("\\")[1]
plt.title(cat_to_name[str(label)])
plt.subplot(313)
probs, classes = predict(picfile, new_model,topk)
classes = classes.tolist()[0]
probs = probs.tolist()[0]
flowers = []
problity = []
num = []
i = 1
for flowerclass in classes:
flowers.append(cat_to_name[listtarget[flowerclass]])
num.append(i)
i += 1
for prob in probs:
problity.append(prob)
plt.bar(num,problity,facecolor = 'blue', edgecolor = 'white')#调用plot函数,这并不会立刻显示函数图像
temp = zip(num,problity)
i = 0
for x, y in temp:
plt.text(x ,y + 2,flowers[i], ha = 'center', va = 'bottom')
i += 1
# 去除坐标轴
plt.ylabel("possibility")
plt.xlabel("varieties")
plt.xticks(())
plt.show()#调用show函数显示函数图像
最后输出预测结果,检查是否正确。这里特别提一下,我所使用的数据集是一个包含102种花的图片数据集,存放在根目录下按照1~102编号102个文件夹中,但是文件编号和标签输出并不对应,大家要注意这个地方。
项目所涉及的train.py和predict.py文件可以在这里下载程序完整代码
或者我在我的百度云盘中进行下载:https://pan.baidu.com/s/13OMVbIMij1SGVM8jRfk7xA 提取码:k998
项目所需的数据文件并没有上传,如果对于数据导入或者标签输出方面有疑问的小伙伴可以在评论区提出,文中有说的不对的地方也请大家指正。
另外提一下,本文中的地址输入是使用命令行输入参数的形式,所以会有同学发现最后找不到地址。