实现步骤
在深度学习框架下,搭建一个网络进行训练和测试很方便,主要包括四个步骤:
1.准备数据,将数据集中的数据整理成程序代码可识别读取的形式
2.搭建网络,利用PyTorch提供的API搭建设计的网络或者自定义
3.训练网络,把1中准备好的数据送入2中搭建的网络中进行训练,获得网络各节点权值参数
4.测试网络,导入3中获取的参数,并输入网络一个数据,然后评估网络的输出结果
一.准备数据
一.1:总目标:在准备数据阶段,需要做的就是把训练集中所有的数据整理成[输入, 自定义输出格式]的形式,
在Dogs vs. Cats中,输入为一张张猫狗图片(input),给定输出是对应的猫或者狗信息 (label)。
训练时候(train 代码运行)调用这个类方法获取数据
一.2:主要步骤:数据的准备过程为自定义一个数据集类,需要继承PyTorch的data.Dataset父类,
需要写三个函数 构造方法:__init__(self,) 必须要重载两个父类的__getitem__()和__len__(self)
2.1 __init__(self,) 作用是写出 任何需要的初始化,有两个模块 一个用于训练读取,另一个是用于测试读取
如果是用于训练集读取图片 需要 提取图片的路径和定义标签
如果是用于测试集读取图片 需要 提取图片路径就行
2.2 __getitem__(self, ) 作用:获取数据集中数据内容
2.3 __len__(self) 作用:返回数据集大小
一.3:细节知识
3.1:因为默认的PyTorch数据集读取就是通过__getitem__()方法来实现的,
在这个方法中需要返回PyTorch的Tensor形式数据。
需要定义一个转换关系(需要在__init__(self,)方法中定义声明),用于将图像数据转换成PyTorch的Tensor形式
二.搭建网络
二.1:总目标:构造CNN网络
二.2: 主要步骤:新建的网络 (Net类) 必须要继承PyTorch的nn.Module父类,
需要写两个函数 构造方法:__init__(self,) 必须要重载一个父类的forward()方法
1. __init__(self,)中新建CNN中的各个计算层,包括卷积层,池化层,全连层和激活函数等。
2. forward()方法,该方法就是网络运算中的前向计算
二.3:细节知识:需要了解CNN基本框架
3.1:导入的包,函数作用
import torch.nn as nn: 在PyTorch框架下编写是神经网络,所有层结构和损失函数都是来自torch.nn
import torch.nn.functional as F: 调用卷积,池化函数 除了最基本的图片转成Tensor形式,还有一些图片预处理(图像增强方法)可以提高模型网络的准确率和泛化能力,
都在torchvision.transforms函数里面 有很多对图片预处理方法
1.transforms.ToTensor() # 将图像数据转换成Tensor形式
2.transforms.Resize(IMAGE_SIZE) # 将图像按比例缩放至合适尺寸(IMAGE_SIZE)
3.transforms.CenterCrop((IMAGE_SIZE, IMAGE_SIZE)) # 从图像中心裁剪合适大小的图像
3.2:标签:label采用one-hot编码,[1,0]表示猫,[0,1]表示狗,任何情况只有一个位置为"1",(**学习**)
在采用CrossEntropyLoss()计算Loss情况下,label只需要输入"1"的索引(找“1”的位置),即猫应输入0,狗应输入1
3.3:导入的包,函数作用
1.import os:含有操作系统相关的调用和操作函数,简单说就是文件操作,创建读取文件夹等
os.mkdir() #mkdir()函数——创建一个新文件夹
os.listdir(dir): #listdir()函数——返回指定目录下的所有目录和文件
2.import torch: 是PyTorch包
3.import torch.utils.data as data:实现自由的数据读取
torch.utils.data主要包括以下三个类但是pytorch读取训练集需要使用到2个类:
(1)torch.utils.data.Dataset:作用: (1) 创建数据集,有__getitem__(self, index)函数来根据索引序号获取图片和标签, 有__len__(self)函数来获取数据集的长度.
(2)torch.utils.data.DataLoader: 作用于第三部分训练网络train.py代码,此部分不需要
4.from PIL import Image:图像处理标准库
5.import numpy as np:图片变成矩阵数据类型(用于__getitem__方法里 测试部分 代码中)
6.import torchvision.transforms as transforms:是图片预处理
三.训练网络
三.1:总目标:将图片传入模型,得到训练后的参数保存
三.2: 主要步骤:
#实例化一个网络: model = Net()
#网络设定为训练模式: model.train()
#实例化一个优化器: optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#定义loss计算方法: criterion = torch.nn.CrossEntropyLoss()
#读取数据集中数据进行训练: for epoch in range(nepoch):
#训练所有数据后,保存网络的参数: torch.save(model.state_dict(), '{0}/model.pth'.format(model_cp))
三.3:细节知识:
3.1:在训练网络前,需要了解训练过程中涉及到的几个概念,epoch,batch size和iteration。
从通俗的角度出发,网络学习可以理解成给网络一个输入(image)一个输出(label),然后让调整网络的参数(参数优化),
使其内部输出值(网络的前向计算)和给定的输出值差异(损失Loss)最小,显然每给定一个输入都可以完成一次参数调整。
但如果对每个输入都做一次调整,参数调整的波动会比较大,
一些包含噪声的输入会产生错误的参数调整,网络难以取得较好效果
因此,就会想到给网络输入一些训练数据,累加它们的输出,再与给定的label比较计算Loss,
最后做一次调整,而这一些的大小就是batch size。
在Dogs vs. Cats训练集中,一共有25000个图片,都可以拿来做训练,
如果全部图片都训练过了,就认为完成了1个epoch,即epoch表示整个数据集进了多少次重复训练,
如果batch size设为100,那么网络参数总共调整的参数为25000/100=250次,
而250就是指的是iteration,也即1个epoch中网络参数调整(迭代)的次数。
batch size也不是越大越好,太大了会使网络参数调整缓慢,收敛慢,并且资源消耗也会提高。
3.2:在代码中,DataLoader()的作用是对定义好的数据集类做一次封装,由PyTorch内部执行,有以下几个作用:
可以定义为打乱数据集分布,使各个类型的样本均匀地参与网络训练;
可以设定多线程数据读取(指数据载入内存),提高训练效率,因为训练过程中文件的读取是比较耗时的
可以一次获得batch size大小的数据,并且是Tensor形式,比如读取4个训练数据,需要调用__getitem__()4次,
但是封装好后,在for img, label in dataloader:中,一次img可以获得4个(假设batch size为4)数据,
这也就是为什么代码中img的size是[16×3×200×200],而直接调用__getitem__()是[3×200×200]的原因,
因为DataLoader封装过程中加入了一维数据个数
当然,不用DataLoader封装也是可以的(测试代码中就没有),只是处理起来比较麻烦
3.3:需要注意的是代码loss = criterion(out, label.squeeze())中的.squeeze()方法,
因为从dataloader中获取的label是一个[batch size ×1]的Tensor,
而实际输入的应是1维Tensor,所以需要做一个维度变换。
3.4:模型的保存于加载
在PYTorch里面使用torch.save来保存模型的结构和参数,有两种保存方式:
(1)保存整个模型的结构信息和参数信息,保存的对象是模型model;
(2)保存模型的参数,保存的对象是模型的状态model.state._dict()。
可以这样保存,save的第一个参数是保存对象,第二个参数是保存路径及名称:
(1):torch.save(model,'./model.pth')
(2):torch.save(model.state dict(),'./model_state.pth')
加载模型有两种方式对应于保存模型的方式
(1)加载完整的模型结构和参数信息,使用load_mode1=torch.load('model.pth')
在网络较大的时候加载的时间比较长,同时存储空间也比较大;
(2)加载模型参数信息,需要先导入模型的结构,
然后通过model.1oad_state_dic(torch.load('model_state.pth')) 来导入。
3.5:导入包,函数:
from getdata import DogsVSCatsDataset as DVCD: 从第一部写的代码getdata.py导入DogsVSCatsDataset方法读取数据
from torch.utils.data import DataLoader as DataLoader:用DataLoader()方法定义好的数据集类做一次封装
from network import Net:导入第二部写的代码network.py
import torch:
from torch.autograd import Variable:导入torch.autograd里面的Variable()使的张量变成变量
import torch.nn as nn:导入torch里面的nn,定义损失函数loss计算方法
四.测试网络
四.1:总目标:将训练好的网络参数导入新定义的网络中,然后输入图像完成猫狗分类任务
四.2: 主要步骤:# setting model 建立网络模型
# get data 得到数据
# calculation 计算概率
# pring results 显示结果
四.3:细节知识:
3.1:导入包,函数:
import torch.nn.functional as F:调用softmax(out, dim=1)
import matplotlib.pyplot as plt:导入matplotlib是Python的一个绘图库,是 Python中最常用的可视化工具之一
plt.figure():创建一个绘画对象
plt.suptitle():设置标题
plt.imshow():绘制图片,imshow()接收一张图像,只是画出该图,并不会立刻显示出来。
plt.show():绘画完显示绘画的对象,所有画完后使用plt.show()才能进行结果的显示。
import random:导入random里面的random()方法返回随机生成的一个实数