简述

Pytorch自定义数据集方法,应该是用pytorch做算法的最基本的东西。
往往网络上给的demo都是基于torch自带的MNIST的相关类。所以,为了解决使用其他的数据集,在查阅了torch关于MNIST数据集的源码之后,很容易就可以推广到了我们自己需要的代码上。


文章目录

  • 简述
  • QuickStart
  • 补充说明
  • 数据预处理
  • np.ndarray
  • PIL.Image
  • 关于图片
  • 一个导入图片的demo


具体操作如下:

QuickStart

  1. 需要导入一些包。
from torch.utils.data import Dataset, DataLoader
  1. 再自定义一个用于当训练集合的类。
class TrainSet(Dataset):
    def __init__(self, X, Y):
        # 定义好 image 的路径
        self.X, self.Y = X, Y

    def __getitem__(self, index):
        return self.X[index], self.Y[index]

    def __len__(self):
        return len(self.X)
  1. 创建一个DataLoader
  • 其中X_tensorY_tensor分别是两个torch.Tensor的类型变量。(例如:下面以乱生成的为例(分别的shape是(10, 3, 14, 14)),一个是(10)
X_tensor = torch.ones((10, 3, 14, 14))
Y_tensor = torch.ones((10,))
mydataset = TrainSet(X_tensor, Y_tensor)
train_loader = DataLoader(mydataset, batch_size=10, shuffle=True)
  • 如果数据是numpy(如果不是numpy就转成numpy格式,比如下面是个df
# np_x, np_y  自己定义即可
X_tensor = torch.from_numpy(np_x)
Y_tensor = torch.from_numpy(np_y)

mydataset = TrainSet(X_tensor, Y_tensor)
train_loader = DataLoader(mydataset, batch_size=10, shuffle=True)

train_loader 使用的方式也非常简单:

for step, (x, y) in enumerate(train_loader):

这里的x,y就是每个batch所处理的数据。

  • batch_size表示用将原数据拆分之后,每batch_size个数据作为一组数据被调用。shuffle表示数据是否被洗牌(即刷新顺序,避免训练的时候多次调用结果都遇到同一batch,从而避免误差)

但套路基本差不多。(基本到这,就可以解决很多数据类型的转成pytorch可以用的数据集了。


补充说明

数据预处理

之后,假设你的训练集合为[X,Y],其中X是训练数据,Y是对应的数据的标签。

首先,需要知道的是,torch能处理的数据只能是torch.Tensor,所以有必要将其他数据转换为torch.Tensor

常见的有几种数据:

  • np.ndarray
  • PIL.Image

如果是图片数据,其实也有多种情况,根据数据维度不同,有些是二维图,有些是三维图(通俗来讲,就是黑白图和彩图)。

所以,我先按照数据类型的模式将一遍,再补充关于图片的处理。

np.ndarray

np.ndarray是非常常见的格式,转成Tensor也非常简单。

torch.Tensor(array)

或者是

torch.from_numpy(array)

这样代码的返回格式就是一个Tensor

PIL.Image

import torchvision.transforms as transforms
transforms.ToTensor()(image)

这样代码的返回格式就是一个Tensor

关于图片

  • 彩色的三维图: 上面方法就已经完成了对应的数据处理的步骤
  • 灰白或者是二值的二维图:就需要将数据增加一个维度了(因为往往关于图片,所用到的算法都是包括了卷积的步骤,所以要求增加一个维度)

具体操作如下: 明显,torch.Tensor(X)这样的步骤,其实是重复了上面的将np.ndarray转成torch.Tensor的步骤。同理可以换成上面的关于PIL.Image的方法

X_tensor = torch.unsqueeze(torch.Tensor(X), 1)
Y_tensor = torch.unsqueeze(torch.Tensor(Y), 1)

一个导入图片的demo

另外,附上一个我常用的读取自定义图片的dataset类

main函数部分是对数据集做测试。

import torch.utils.data as data
import glob
import os
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch

import piexif
import imghdr


class MyDataset(data.Dataset):
    def __init__(self, path, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False):

        if resize != -1:
            transform = transforms.Compose([
                transforms.Resize(resize),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
        else:
            transform = transforms.Compose([
                transforms.ToTensor(),
            ])
        img_format = '*.%s' % img_type

        if remove_exif:
            for name in glob.glob(os.path.join(path, img_format)):
                try:
                    piexif.remove(name)  # 去除exif
                except Exception:
                    continue

        # imghdr.what(img_path) 判断是否为损坏图片
        if Len == -1:
            self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name in
                            glob.glob(os.path.join(path, img_format)) if imghdr.what(name)]
        else:
            self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name in
                            glob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)]
        self.dataset = np.array(self.dataset)
        self.dataset = torch.Tensor(self.dataset)
        self.Train = Train

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]


if __name__ == '__main__':
    path = r'D:\Software\DataSet\faces'
    dataset = MyDataset(path=path, resize=96, Len=10, img_type='jpg')
    print(len(dataset))
    plt.imshow(dataset[0].numpy().transpose(1, 2, 0) * 0.5 + 0.5)
    plt.show()
    print(dataset[0].max(), dataset[0].min())