使用pytorch对图像处理时,需要将自己的图像数据转化为pytorch框架可以理解的DataSet,此时即需要创建自己的数据集,下面总结如何创建自定义的数据集

一、将图像整理为txt文件,txt文件每行的内容包括:图像的路径 和 图像分类标签

本例中图像是按照分类存放到其对应的子文件夹中的

import os
save_path = './data/txt' #保存的路径
dir_path = './data/piture/'#存放数据集的位置
#训练集和测试集的比例9:1
def generate_txt():
    if not os.path.exists(save_path):
        os.makedirs('./txt')
    train_txt = open('./txt/train.txt','a') #以追加方式打开文件
    test_txt = open('./txt/test.txt','a')

    #遍历每个子文件夹
    for sub_dir in os.listdir(dir_path):
        i = 1 #计数
        for file_name in os.listdir(sub_dir): #得到每张图像的名称
            img_path = dir_path + sub_dir + '/' + file_name
            img_label = int(sub_dir)
            if i%10 == 5:
                test_txt.write(img_path+' '+str(img_label)+'\n')
            else:
                train_txt.write(img_path+' '+str(img_label)+'\n')
            i = i + 1
    train_txt.close()
    test_txt.close()
generate_txt()

二、创建自定义的DataSet

from torchvision import transforms,utils
from torch.utils.data import Dataset,DataLoader
from PIL import Image
import numpy as np
import torch.optim as optim
import os
from torchvision import models

#设置学习率
learning_rate = 0.0001

#root=
#定义读取图片的格式:使用PIL的Image类的open方法读取图像内容,转换为RGB格式
def default_loader(path):
    return Image.open(path).convert('RGB')

#创建自己的数据集
class MyDataset(Dataset): #自己的数据集都需要继承Dataset类
    def __init__(self,txt,transform=None,target_transform=None,loader=default_loader):
        #对继承自父类的属性初始化
        super(MyDataset,self).__init__()
        #打开传入的txt文件
        img_info = []
        with open(txt,'r') as fh:
            for line in fh:
                line = line.rstrip('\n')
                line = line.strip()
                info = line.split()
                img_info.append((info[0],int(info[1])))#每张图片的路径和标签
        self.imgs = img_info  #[(path1,label1),(path2,label2),...]
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader #读取图像方法
    def __getitem__(self,index):
        '''对数据进行预处理:这里是加载图像转化为RGB格式,再按照索引读取每个元素的具体内容'''
        path,label = self.imgs[index] #(path,label)
        img = self.loader(path) #读取图像的内容
        img = self.transform(img) #是否将数据标签转换为Tensor
        return img,label,path #这里返回什么,在调用DataLoader时,就可以获得哪些内容
    def __len__(self):
        return len(self.imgs) #数据集的长度,多少张图片

#定义对象
train_data = MyDataset(txt='txt/train.txt',transform=transforms.ToTensor())#将图像内容转换为向量
print(train_data.__getitem__(10)) 
#将所有的数据都加载到了train_data中,使用__getitem__()获取对应索引下的图像  内容和标签

三、使用DataLoader设置访问的batch_size以及是否shuffle等选项

#*************定义Datasets只是对读入的数据进行了索引,需要使用DataLoader类进行进一步处理***
#实现batch_size:分批次读取
#shuffle=True:对数据进行随机读取
#train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True ,num_workers=4)
#test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False,num_workers=4)

#图像的初始化操作
#Compose()串联多个图片变换的操作
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((227,227)),
    transforms.ToTensor(),
])
#train_data = MyDataset(txt='txt/train.txt',transform=transforms.ToTensor())
#即得到50个batch
train_loader = DataLoader(dataset=train_data,batch_size=500,shuffle=True)
print(len(train_loader))#可以得到共多少个batch

四、使用DataLoader访问数据

因为在处理数据时,将输入转化为tensor, 所以对应输出的类型

for i,(input,label,path) in enumerate(testLoader):
        #得到原本的数据样本
        print(type(input)) #tensor
        print(type(label)) #tuple
        print(type(path))  #tuple