Pytorch 快速详解如何构建自己的Dataset完成数据预处理(附详细过程)
Pytorch(五) 使用DataSet和DataLoader数据加载
在这篇文章中我已经简单的介绍了Dataset和DataLoader的简单用法,但是大多数实际情况中数据集的存储都没有那么简单,所以写了本文来记录一下如何自定义DataSet
介绍
在实际的案例当中,如图像分类等任务来说,我们需要训练的数据集往往是存储在一个文件夹中的,而数据集的存储格式都是类似的
以蚂蚁和蜜蜂图片数据集 hymenoptera_data
来举例
一般的数据集都会分为两个文件夹
-
train
训练集 -
val
测试集
打开训练集之后的数据存储又分为两种情况
情况1
对于图像分类来说, 肯定需要一个 label和一个img
有些数据集喜欢把它们分开成两个文件夹
img
文件夹 中存放的是图片
label
文件夹中存放的是标签,通常以txt
文件来存储,文件名和图片名相同,而文件的内容代表了图片的标签
情况2
对于一些简单的数据集来说,可能不会把label
和img
分开存放
比如情况1中提到的蚂蚁蜜蜂数据集
ants
目录下的全是蚂蚁的图片
bees
文件夹下全是蜜蜂的图片
这里的文件夹名就代表了图片的label
不过常用的情况 是把图片的label
包含在了图片的命名当中
如下图
自定义Dataset
Dataset的主要作用就是提供一种方式来获取数据和其label
自定义的Dataset需要满足如下两个功能
- 如何获取每一个数据和其label
- 告诉我们数据集一共有多少个数据
导入库
from PIL import Image
from torch.utils.data import Dataset
import os
如果没有下载相应的库就用 pip
下载一下
TensorDataset回顾
在我之前的文章中提到,对于简单的数据可以用TensorDataset
来包装
而通过for循环也可以遍历取出Dataset的中的数据和 label
它的原理就是内置了一个方法可以通过 index
来获取到相应的数据
下面就写个小案例来回顾一下TensorDataset
的使用
构建数据x和标签y
x = torch.tensor(
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3],
[4, 5, 6], [7, 8, 9]])
y = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])`
from torch.utils.data import TensorDataset
train_dataset = TensorDataset(x, y)
print(train_dataset[0])
显然,如上图,train_dataset
中数据的存放格式也是一个数据加一个标签的元组形式,并且可以通过 index
来获取
当然在实际问题中也可以之间用for
循环来遍历
Dataset自定义实现
那么对于我们的图像数据来说,要想达到遍历取数据和label的效果,需要我们自定义Dataset
在这里我们需要达到的目标是 通过 index
可以获取到一个 包含图像数据和标签的元组
并且要知道数据集的中长度
因此Dataset的初步模板就出来了, 如下
class MyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
pass
def __len__(self):
pass
其中 getitem
的作用是通过index
来返回一个包含数据和标签的元组len
的作用是返回数据集的长度init
很明显是类构造器
下面就来一步一步的实现它
获取图片并显示
PIL
的Image
可以帮助我们解析并查看一张图片
本案例的猫狗数据集我放在了D:\Source\Datasets\cat_and_dog
简单的读取一个图片来看看
from PIL import Image
img_path = r'D:\Source\Datasets\cat_and_dog\train\cat.10.jpg'
img = Image.open(img_path)
img.show()
通过 Image.open(img_path)
方法读取到的是一个PIL.JpegImagePlugin.JpegImageFile
对象,它包含了很多东西,我们就把它当做图像的数据
显然,读取图像需要图像的完整路径,那么思路很明显,我们可以把图像的路径存成一个列表,然后就可以通过index来获取列表中的值,(这里取出的是图像的完整路径),然后再通过Image把它解析成数据,取出其标签,返回
完成getitem方法
这一步其实不难,重点在理解其思路
下图就达到了返回img图像列表的操作,最后只需要把图像和train_path拼接起来就可以获取到图像的完整路径
完整代码
代码如下
# -*- coding: utf-8 -*-
# @Time : 2021/1/31 11:01
# @Author : Tong Tianyu
# @File : demo.py
from PIL import Image
from torch.utils.data import Dataset
import os
class MyDataset(Dataset):
def __init__(self, data_path):
self.data_path = data_path
self.img_list = os.listdir(self.data_path)
def __getitem__(self, index):
img_title = self.img_list[index]
img_label = img_title.split('.')[0]
img_path = os.path.join(self.data_path, img_title)
img = Image.open(img_path)
return img, img_label
def __len__(self):
return len(self.img_list)
train_path = r'D:\Source\Datasets\cat_and_dog\train'
train_dataset = MyDataset(train_path)
效果如下, 完成目标