pytorch 分类图片构建 datasets

  • 方法一 torchvision.datasets
  • 方法二 torchvision.datasets.ImageFolder
  • 3. torch.utils.data.DataLoader

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

假设已知大小统一的图片,将其按数据集用途、类别分文件夹存放,一种类型的文件夹下面只存放一种类型的图片
文件夹树形结构如下:

|— image
 | |— train
 | | |— label1
 | | | |— img.png
 | | | | …
 | | |— label2
 | | | …
 | | |— labeln
 | |— valid
 | | |— label1
 | | |— label2
 | | | …
 | | |— labeln
 | |— test
 | | |— label1
 | | |— label2
 | | | …
 | | |— labeln

方法一 torchvision.datasets

  1. 首先将各用途下的图片按
    绝对路径 空格 label
    写入 .txt
  2. pytorch 内置了 Dataset 类,为了构建自己数据源的 dataset, 需要用户定义,必须包含以下内容:
    __ init__:初始化函数用于对数据集的某些参数进行定义。
    __ len__:用于返回数据集的长度。
    __ getitem__:用于对数据集进行迭代。
# 创建数据集
from torchvision import datasets,transforms
import torch
import cv2
class MyDataset(torch.utils.data.Dataset):  # 继承的torch.utils.data.Dataset
    def __init__(self, datalist, transform=None, target_transform=None):  # 初始化一些需要传入的参数
        super(MyDataset, self).__init__()
        
        self.datalist = datalist
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):  
        return len(self.datalist)
 
    def __getitem__(self, index):
        # 这个方法是必须要有的,用于按照索引读取每个元素的具体内容
        img_file_fullpath, label = self.datalist[index] 
        # 读取图片信息
        image_pre_handle = ImagePrehandle(img_file_fullpath)
        _,ori_img = image_pre_handle.load_img(img_file_fullpath) 
        img = cv2.cvtColor(ori_img,cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(img)  # 是否进行transform
        if self.target_transform:
            label = self.target_transform(label)
        return img, label  # return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容

调用

#  define custom transform function
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])
target_transform = Lambda(lambda y: torch.zeros(4, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
train_data = MyDataset(datalist, transform=transform,target_transform=target_transform)

attention:
image_pre_handle.load_img 中使用cv2.imread 读取图片,transform正常使用是因为先.ToTensor() 了,否则可能需要使用PIL.Image.open 读取图片。

方法二 torchvision.datasets.ImageFolder

一个通用的数据加载器,运行命令为:

from torchvision.datasets import ImageFolder
ImageFolder(root="root folder path", [transform, target_transform])

root : 指定图片存储的路径,‘./image/train’
transform: 一个函数,原始图片作为输入,返回一个转换后的图片。
target_transform - 一个函数,输入为target,输出对其的转换。例子,输入的是图片标注的string,输出为word的索引。
有以下成员变量:
self.classes - 用一个list保存 类名
self.class_to_idx - 类名对应的 索引
self.imgs - 保存(img-path, class) tuple的list
即后面可以通过查看返回的数据集对象来查看相应的值,

A generic data loader where the images are arranged in this way:

root/dog/xxx.png
    root/dog/xxy.png
    root/dog/[...]/xxz.png

    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/[...]/asd932_.png

Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, transforms.RandomCrop target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)

Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples

3. torch.utils.data.DataLoader

Dataloader 就是一个数据加载器,可以在训练的时候,把训练数据分成多个小组,每次抛出一组数据,直至把所有的数据都抛出。

dataloaders = Data.DataLoader(dataset, # 需要导入的数据集
batch_size=1, # 一批多少个数据
shuffle=False, # 是否打乱顺序
sampler=None, # 取样器,从dataset中抽取数据的策略(使用时shuffle必须为False)
num_workers=0, # 使用多线程读取数据,0表示不使用多线程(Windows系统)
pin_memory=False, # 将tensor放到cuda
drop_last=False) # dataset中最后不足一个batch的数据是否丢掉

Args

annotation

dataset (Dataset)

dataset from which to load the data.

batch_size (int, optional)

how many samples per batch to load (default: 1).

一批多少个数据

shuffle (bool, optional)

set to True to have the data reshuffled at every epoch (default: False).

是否打乱顺序

sampler (Sampler or Iterable, optional)

defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, :attr:shuffle must not be specified.

取样器,从dataset中抽取数据的策略(使用时shuffle必须为False)

batch_sampler (Sampler or Iterable, optional)

like :attr:sampler, but returns a batch of indices at a time. Mutually exclusive with :attr:batch_size, :attr:shuffle, :attr:sampler, and :attr:drop_last.

num_workers (int, optional)

how many subprocesses to use for data, loading. 0 means that the data will be loaded in the main process. (default: 0)

使用多线程读取数据,0表示不使用多线程(Windows系统)

collate_fn (callable, optional)

merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

pin_memory (bool, optional)

If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:collate_fn returns a batch that is a custom type, see the example below.

将tensor放到cuda

drop_last (bool, optional)

set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

dataset中最后不足一个batch的数据是否丢掉

timeout (numeric, optional): if positive, the timeout value for collecting a batch
        from workers. Should always be non-negative. (default: ``0``)
    worker_init_fn (callable, optional): If not ``None``, this will be called on each
        worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
        input, after seeding and before data loading. (default: ``None``)
    prefetch_factor (int, optional, keyword-only arg): Number of samples loaded
        in advance by each worker. ``2`` means there will be a total of
        2 * num_workers samples prefetched across all workers. (default: ``2``)
    persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
        the worker processes after a dataset has been consumed once. This allows to
        maintain the workers `Dataset` instances alive. (default: ``False``)