使用pytorch对图像处理时,需要将自己的图像数据转化为pytorch框架可以理解的DataSet,此时即需要创建自己的数据集,下面总结如何创建自定义的数据集
一、将图像整理为txt文件,txt文件每行的内容包括:图像的路径 和 图像分类标签
本例中图像是按照分类存放到其对应的子文件夹中的
import os
save_path = './data/txt' #保存的路径
dir_path = './data/piture/'#存放数据集的位置
#训练集和测试集的比例9:1
def generate_txt():
if not os.path.exists(save_path):
os.makedirs('./txt')
train_txt = open('./txt/train.txt','a') #以追加方式打开文件
test_txt = open('./txt/test.txt','a')
#遍历每个子文件夹
for sub_dir in os.listdir(dir_path):
i = 1 #计数
for file_name in os.listdir(sub_dir): #得到每张图像的名称
img_path = dir_path + sub_dir + '/' + file_name
img_label = int(sub_dir)
if i%10 == 5:
test_txt.write(img_path+' '+str(img_label)+'\n')
else:
train_txt.write(img_path+' '+str(img_label)+'\n')
i = i + 1
train_txt.close()
test_txt.close()
generate_txt()
二、创建自定义的DataSet
from torchvision import transforms,utils
from torch.utils.data import Dataset,DataLoader
from PIL import Image
import numpy as np
import torch.optim as optim
import os
from torchvision import models
#设置学习率
learning_rate = 0.0001
#root=
#定义读取图片的格式:使用PIL的Image类的open方法读取图像内容,转换为RGB格式
def default_loader(path):
return Image.open(path).convert('RGB')
#创建自己的数据集
class MyDataset(Dataset): #自己的数据集都需要继承Dataset类
def __init__(self,txt,transform=None,target_transform=None,loader=default_loader):
#对继承自父类的属性初始化
super(MyDataset,self).__init__()
#打开传入的txt文件
img_info = []
with open(txt,'r') as fh:
for line in fh:
line = line.rstrip('\n')
line = line.strip()
info = line.split()
img_info.append((info[0],int(info[1])))#每张图片的路径和标签
self.imgs = img_info #[(path1,label1),(path2,label2),...]
self.transform = transform
self.target_transform = target_transform
self.loader = loader #读取图像方法
def __getitem__(self,index):
'''对数据进行预处理:这里是加载图像转化为RGB格式,再按照索引读取每个元素的具体内容'''
path,label = self.imgs[index] #(path,label)
img = self.loader(path) #读取图像的内容
img = self.transform(img) #是否将数据标签转换为Tensor
return img,label,path #这里返回什么,在调用DataLoader时,就可以获得哪些内容
def __len__(self):
return len(self.imgs) #数据集的长度,多少张图片
#定义对象
train_data = MyDataset(txt='txt/train.txt',transform=transforms.ToTensor())#将图像内容转换为向量
print(train_data.__getitem__(10))
#将所有的数据都加载到了train_data中,使用__getitem__()获取对应索引下的图像 内容和标签
三、使用DataLoader设置访问的batch_size以及是否shuffle等选项
#*************定义Datasets只是对读入的数据进行了索引,需要使用DataLoader类进行进一步处理***
#实现batch_size:分批次读取
#shuffle=True:对数据进行随机读取
#train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True ,num_workers=4)
#test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False,num_workers=4)
#图像的初始化操作
#Compose()串联多个图片变换的操作
train_transforms = transforms.Compose([
transforms.RandomResizedCrop((227,227)),
transforms.ToTensor(),
])
#train_data = MyDataset(txt='txt/train.txt',transform=transforms.ToTensor())
#即得到50个batch
train_loader = DataLoader(dataset=train_data,batch_size=500,shuffle=True)
print(len(train_loader))#可以得到共多少个batch
四、使用DataLoader访问数据
因为在处理数据时,将输入转化为tensor, 所以对应输出的类型
for i,(input,label,path) in enumerate(testLoader):
#得到原本的数据样本
print(type(input)) #tensor
print(type(label)) #tuple
print(type(path)) #tuple