pytorch构建自己的图像数据集
- 数据准备
- 获取数据
- 重写Dataset类
- 数据载入
- 代码
数据准备
Pytorch读取和载入数据有专门的Dataset
和Dateloader
类,但是当我们想读取自己的数据集时,Dataset
类就不能用了,因此这篇博客教大家如何创建自己的数据集。在开始工作之前需要准备好自己的图像数据集,这里使用cifar10数据集为例,cifar10是一个十分类的公开数据集,拥有6w张32*32的图像,该数据集结构如下:
|-cifar10
|-----train
|---------airplane
|---------automobile
|---------…
|-----test
|---------airplane
|---------automobile
|---------…
你可以把自己的数据集也按照这样分类,数据集准备好后,我们就可以进行下一步处理了。总体的思路是把各个图像的路径和其对应的标签构成一个列表,这样就可以利用pytorch自带的Dataloder
类进行读取。
获取数据
利用glob包可以很方便的获取我们想要的路径,首先我们需要导入相应的包:
import glob
分别读取训练集和测试集的路径:
train_imgs_path = glob.glob(r'F:\jupyterFile\pycharmcode\cifar\train\*\*.png')
test_imgs_path = glob.glob(r'F:\jupyterFile\pycharmcode\cifar\test\*\*.png')#路径前的r为转义标志,否则路径可能报错
这里大家可以print
一下看看输出的是什么。路径获取后,我们就要给每一个路径赋予一个标签(0,1,2,…),首先设定一下有哪些类,注意类的名称一定要和文件夹的名称对应,cifar有10类:
species = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
接着创建两个空列表,把标签按照路径的顺序进行排列:
train_labels = []
test_labels = [] # 用于生成标签
# 对所有图片路径进行迭代
# 为训练集添加标签
for img in train_imgs_path:
# 区分出每个img,应该属于什么类别
for i, c in enumerate(species):
if c in img:
train_labels.append(i)
# 为测试集添加标签
for img in test_imgs_path:
for i, c in enumerate(species):
if c in img:
test_labels.append(i) # 为对应的数据集增加标签
实现的过程很简单,就是使用之前设定的类别名称在路径里搜索,从而为对应的类别赋予标签。到此我们已经拥有了训练数据和测试数据的路径及其对应的标签,下一步就是调用pytorch Dataset
类了。
重写Dataset类
这里我们需要重写pytorch自带的Dataset
类,便于读取我们自己的数据,同时还需要导入相应的库,代码在最后,下面是重写后的Dataset
类:
class Mydatasetpro(torch.utils.data.Dataset):
# 初始化函数,得到数据
def __init__(self, data_root, data_label, transform):
self.data = data_root
self.label = data_label
self.transforms = transform
# index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
pil_img = Image.open(data)
data = self.transforms(pil_img)
return data, labels
# 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
def __len__(self):
return len(self.data)
数据载入
重写之后,我们就可以利用pytorch自带的DataLoader
进行数据的读取和载入了:
batchsz=32 #这里设置批次数量
data_train = Mydatasetpro(train_imgs_path, train_labels, transform) # 训练数据读取
data_train_loader = DataLoader(data_train, batch_size=batchsz, shuffle=True) # 训练数据载入,训练时数据标签要打乱
data_test = Mydatasetpro(test_imgs_path, test_labels, transform) # 测试数据读取
data_test_loader = DataLoader(data_test, batch_size=batchsz, shuffle=False) # 测试数据载入
我们得到的data_train_loader
和data_test_loader
就是最终处理好的数据,可以直接输入后续的模型。
代码
最后放一下整体的代码如下,在代码里我把相应的过程用函数data_get
封装了。有问题可以留言哈!
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
import glob
from PIL import Image
import matplotlib.pyplot as plt
# 重构Dataset类
class Mydatasetpro(torch.utils.data.Dataset):
# 初始化函数,得到数据
def __init__(self, data_root, data_label, transform):
self.data = data_root
self.label = data_label
self.transforms = transform
# index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
pil_img = Image.open(data)
data = self.transforms(pil_img)
return data, labels
# 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
def __len__(self):
return len(self.data)
def data_get(train_imgs_path,test_imgs_path,species,batchsz,transform):
train_labels = []
test_labels = [] # 用于生成标签
# 对所有图片路径进行迭代
for img in train_imgs_path:
# 区分出每个img,应该属于什么类别
for i, c in enumerate(species):
if c in img:
train_labels.append(i)
for img in test_imgs_path:
# 区分出每个img,应该属于什么类别
for i, c in enumerate(species):
if c in img:
test_labels.append(i) # 为对应的数据集增加标签
data_train = Mydatasetpro(train_imgs_path, train_labels, transform) # 训练数据读取
data_train_loader = DataLoader(data_train, batch_size=batchsz, shuffle=True) # 训练数据载入,训练时数据标签要打乱
data_test = Mydatasetpro(test_imgs_path, test_labels, transform) # 测试数据读取
data_test_loader = DataLoader(data_test, batch_size=batchsz, shuffle=False) # 测试数据载入
return data_train_loader,data_test_loader
if __name__ == '__main__':
# 训练集和测试集数据路径
train_imgs_path = glob.glob(r'F:\jupyterFile\pycharmcode\cifar\train\*\*.png')
test_imgs_path = glob.glob(r'F:\jupyterFile\pycharmcode\cifar\test\*\*.png')
# 输入类别
species = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 图像变换
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
batchsz = 128 # 定义批处理量
data_train_loader, data_test_loader=data_get(train_imgs_path,test_imgs_path,species,batchsz,transform)
imgs_batch, labels_batch = next(iter(data_train_loader)) # 迭代方法获取批数据
print(imgs_batch.shape)
# 测试下数据集里的图片
plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs_batch[:32], labels_batch[:8])):
img = img.permute(1, 2, 0).numpy()
plt.subplot(2, 4, i+1)
plt.xlabel(species[label.numpy()])
plt.imshow(img)
plt.show()