Pytorch中torchvision包transforms模块应用小案例

Pytorch提供了torchvision这样一个视觉工具包,提供了很多视觉图像处理的工具,其中transforms模块主要提供了PIL Image对象和Tensor对象的常用操作,其中最核心的三个操作分别是:
(1)ToTensor:将PIL Image对象转换成Tensor,同时会自动将[0,255]归一化至[0,1]。
(2)ToPILImage:将Tensor对象转换成PIL Image对象。
(3)Compose:如果需要对图片数据集进行多个操作,可通过Compose将这些操作汇集起来,类似于nn.Sequential的原理,调用方法也相同,但注意Compose处理的数据格式是PIL Image格式。

一、程序运行所需的数据集下载

pytorch transformer 实战 pytorch transformer包_猫狗分类

二、案例要求和实现流程

将数据集中的图片进行猫狗分类,根据文件名前缀判断是猫还是狗,然后输出分类结果,并对图像进行自定义的操作transforms并展示。具体实现流程按下面代码的括号中的的顺序和注释依次进行理解,注意:文件路径因为每个用户保存的路径不同,所以要做相应的修改,否则会报错。

三、代码

import os
from PIL import Image
import numpy as np
from torch.utils import data
from torchvision import transforms as T

transform = T.Compose([ # 等同于sequential,调用方式也一致,此transforms输入数据类型是PIL Image,输出数据类型是tensor (2)
    T.Resize(224), # 缩放图片,保持长宽比不变,最短边为224像素
    T.CenterCrop(224), # 从图片中间切出224*224的图片
    T.ToTensor(), # 将图片(Image)转换成Tensor,归一化[0,1]
    T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) # 标准化[-1,1],处理后格式依旧为tensor格式
])

class DogCat(data.Dataset): # 继承data.Dataset类 (3)
    def __init__(self,root,transforms = None): # 初始化没有传递参数时默认transforms为无 (4)
        imgs = os.listdir(root) # 根据root路径搜寻出文件夹中的所有文件并得到文件名列表 (5)
        self.imgs = [os.path.join(root,img) for img in imgs] # 将文件的路径和文件名拼接 (6)
        # print(self.imgs)
        self.transform = transforms # 赋值,self.变量属于类变量,类内形式通用(7)

    def __getitem__(self, index): #继承data.Dataset类的DogCat类的实例对象可以自行用数标(index)启动(9)
        # print(index)
        # print(self.imgs[index])
        img_path = self.imgs[index] # 获取文件的路径 (10)
        lable = 1 if 'dog' in img_path.split('/')[-1] else 0 # 判断文件名前缀判断狗猫 (11)
        data = Image.open(img_path) # 通过文件路径获取图片的信息转换成PIL Image对象类型 (12)
        if self.transform: # 判断图像是否需要进行操作转换 (13)
            data = self.transform(data) # 图像需要转换,并且返回Tensor数据类型 (14)
            # print(data.type())
        return data,lable # 返回图像tensor数据和狗猫分类标记 (15)

    def __len__(self): # 图片文件的数量
        return len(self.imgs)

dataset = DogCat('E:\pythonProjecttest\dogcat/',transforms = transform) # 注意路径根据自己存放数据集的进行修改,构建类实例 (1)
# img,lable = dataset[0] # 只是启动dataset.__getitem__(self, 0),不是所有的dataset,可以省略
for img,lable in dataset: # 启动剩下的dataset.__getitem__(self, index),包含后续所有的dataset (8)(16)
    img = T.ToPILImage()(img) # 将图像tensor数据转换成PIL Image数据类型 (17)
    img.show() # 图像展示 (18)
    print(lable) # 输出狗猫分类结果

四、部分结果展示

D:\Anaconda\python.exe E:/pythonProjecttest/Dogcatclassification.py
0
0
0
0
1
1
1
1

Process finished with exit code 0

pytorch transformer 实战 pytorch transformer包_torchvision_02


pytorch transformer 实战 pytorch transformer包_pytorch_03


pytorch transformer 实战 pytorch transformer包_torchvision_04


pytorch transformer 实战 pytorch transformer包_torchvision_05