Pytorch 快速详解如何构建自己的Dataset完成数据预处理(附详细过程)


Pytorch(五) 使用DataSet和DataLoader数据加载

在这篇文章中我已经简单的介绍了Dataset和DataLoader的简单用法,但是大多数实际情况中数据集的存储都没有那么简单,所以写了本文来记录一下如何自定义DataSet


介绍

在实际的案例当中,如图像分类等任务来说,我们需要训练的数据集往往是存储在一个文件夹中的,而数据集的存储格式都是类似的

以蚂蚁和蜜蜂图片数据集 hymenoptera_data 来举例

pytorch数据并行 推理 pytorch合并dataset_数据集


一般的数据集都会分为两个文件夹

  • train 训练集
  • val 测试集

打开训练集之后的数据存储又分为两种情况

情况1

对于图像分类来说, 肯定需要一个 label和一个img

有些数据集喜欢把它们分开成两个文件夹

pytorch数据并行 推理 pytorch合并dataset_机器学习_02


img文件夹 中存放的是图片

pytorch数据并行 推理 pytorch合并dataset_python_03


label文件夹中存放的是标签,通常以txt文件来存储,文件名和图片名相同,而文件的内容代表了图片的标签

pytorch数据并行 推理 pytorch合并dataset_数据集_04

情况2

对于一些简单的数据集来说,可能不会把labelimg分开存放

比如情况1中提到的蚂蚁蜜蜂数据集

ants目录下的全是蚂蚁的图片

bees文件夹下全是蜜蜂的图片

这里的文件夹名就代表了图片的label

不过常用的情况 是把图片的label包含在了图片的命名当中

如下图

pytorch数据并行 推理 pytorch合并dataset_pytorch数据并行 推理_05

自定义Dataset

Dataset的主要作用就是提供一种方式来获取数据和其label
自定义的Dataset需要满足如下两个功能

  • 如何获取每一个数据和其label
  • 告诉我们数据集一共有多少个数据

导入库

from PIL import Image
from torch.utils.data import Dataset
import os

如果没有下载相应的库就用 pip下载一下

TensorDataset回顾

在我之前的文章中提到,对于简单的数据可以用TensorDataset来包装
而通过for循环也可以遍历取出Dataset的中的数据和 label 它的原理就是内置了一个方法可以通过 index 来获取到相应的数据
下面就写个小案例来回顾一下TensorDataset的使用

构建数据x和标签y

x = torch.tensor(
    [[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3],
     [4, 5, 6], [7, 8, 9]])
y = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])`

pytorch数据并行 推理 pytorch合并dataset_pytorch数据并行 推理_06

from torch.utils.data import TensorDataset
train_dataset = TensorDataset(x, y)
print(train_dataset[0])

pytorch数据并行 推理 pytorch合并dataset_pytorch数据并行 推理_07


显然,如上图,train_dataset中数据的存放格式也是一个数据加一个标签的元组形式,并且可以通过 index来获取

当然在实际问题中也可以之间用for循环来遍历

Dataset自定义实现

那么对于我们的图像数据来说,要想达到遍历取数据和label的效果,需要我们自定义Dataset 在这里我们需要达到的目标是 通过 index 可以获取到一个 包含图像数据和标签的元组
并且要知道数据集的中长度
因此Dataset的初步模板就出来了, 如下

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        pass

    def __len__(self):
        pass

其中 getitem的作用是通过index来返回一个包含数据和标签的元组
len的作用是返回数据集的长度
init很明显是类构造器
下面就来一步一步的实现它

获取图片并显示

PILImage可以帮助我们解析并查看一张图片
本案例的猫狗数据集我放在了D:\Source\Datasets\cat_and_dog 简单的读取一个图片来看看

from PIL import Image
img_path = r'D:\Source\Datasets\cat_and_dog\train\cat.10.jpg'
img = Image.open(img_path)
img.show()

pytorch数据并行 推理 pytorch合并dataset_深度学习_08


通过 Image.open(img_path)方法读取到的是一个PIL.JpegImagePlugin.JpegImageFile对象,它包含了很多东西,我们就把它当做图像的数据

pytorch数据并行 推理 pytorch合并dataset_python_09


显然,读取图像需要图像的完整路径,那么思路很明显,我们可以把图像的路径存成一个列表,然后就可以通过index来获取列表中的值,(这里取出的是图像的完整路径),然后再通过Image把它解析成数据,取出其标签,返回

完成getitem方法

这一步其实不难,重点在理解其思路

下图就达到了返回img图像列表的操作,最后只需要把图像和train_path拼接起来就可以获取到图像的完整路径

pytorch数据并行 推理 pytorch合并dataset_数据集_10

完整代码

代码如下

# -*- coding: utf-8 -*-
# @Time    : 2021/1/31 11:01
# @Author  : Tong Tianyu
# @File    : demo.py
from PIL import Image
from torch.utils.data import Dataset
import os


class MyDataset(Dataset):
    def __init__(self, data_path):
        self.data_path = data_path
        self.img_list = os.listdir(self.data_path)

    def __getitem__(self, index):
        img_title = self.img_list[index]
        img_label = img_title.split('.')[0]
        img_path = os.path.join(self.data_path, img_title)
        img = Image.open(img_path)
        return img, img_label

    def __len__(self):
        return len(self.img_list)


train_path = r'D:\Source\Datasets\cat_and_dog\train'
train_dataset = MyDataset(train_path)

效果如下, 完成目标

pytorch数据并行 推理 pytorch合并dataset_深度学习_11