data文件通常用来存放数据集预处理部分。我通过网上查找资料进行归纳总结,方便日后使用。

pytroch制作dataset和dataloader

1.dataset制作

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
"""
构建Dataset子类,
pytorch读取图片,主要是通过Dataset类,Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于c++中的虚基类。

"""

class MyDataset(Dataset): # 继承Dataset类
    def __init__(self, txt_path, transform=None, target_transform=None): # 定义txt_path参数
        fh = open(txt_path, 'r') # 读取txt文件
        imgs = []  # 定义imgs的列表
        for line in fh:
            line = line.rstrip() # 默认删除的是空白符('\n', '\r', '\t', ' ')
            words = line.split() # 默认以空格、换行(\n)、制表符(\t)进行分割,大多是"\"
            imgs.append((words[0], int(words[1]))) # 存放进imgs列表中

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index] # fn代表图片的路径,label代表标签
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1   参考:

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)   # 返回图片的长度

自己构建dataset类一定要包含 getitem (返回一个字典包含图片)和 len (返回数据集长度), getitem 通常包含对图片进行裁剪、旋转、放大搜小等操作,这些操作可以利用torchvision.transforms来对图像进行变换。

transforms.Compose([transforms.RandomResizedCrop(224),
 		    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

如果transforms.Compose()没有自己想要的图像变换操作,可以通过Pytorch中的transforms.Lambda()自定义图像变换操作。

2.dataloader使用

self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batch_size,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.num_threads),
            drop_last=True)

作用:torch.utils.data.DataLoader 主要是对数据进行 batch 的划分。

数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。

在训练模型时使用到此函数,用来 把训练数据分成多个小组 ,此函数 每次抛出一组数据 。直至把所有的数据都抛出。就是做一个数据的初始化。

好处:

使用DataLoader的好处是,可以快速的迭代数据。

用于生成迭代数据非常方便。

注意:

除此之外,特别要注意的是输入进函数的数据一定得是可迭代的。如果是自定的数据集的话可以在定义类中用def__len__、def__getitem__定义。