pytorch 分类图片构建 datasets
- 方法一 torchvision.datasets
- 方法二 torchvision.datasets.ImageFolder
- 3. torch.utils.data.DataLoader
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
假设已知大小统一的图片,将其按数据集用途、类别分文件夹存放,一种类型的文件夹下面只存放一种类型的图片
文件夹树形结构如下:
|— image
| |— train
| | |— label1
| | | |— img.png
| | | | …
| | |— label2
| | | …
| | |— labeln
| |— valid
| | |— label1
| | |— label2
| | | …
| | |— labeln
| |— test
| | |— label1
| | |— label2
| | | …
| | |— labeln
方法一 torchvision.datasets
- 首先将各用途下的图片按
绝对路径 空格 label
写入 .txt - pytorch 内置了 Dataset 类,为了构建自己数据源的 dataset, 需要用户定义,必须包含以下内容:
__ init__:初始化函数用于对数据集的某些参数进行定义。
__ len__:用于返回数据集的长度。
__ getitem__:用于对数据集进行迭代。
# 创建数据集
from torchvision import datasets,transforms
import torch
import cv2
class MyDataset(torch.utils.data.Dataset): # 继承的torch.utils.data.Dataset
def __init__(self, datalist, transform=None, target_transform=None): # 初始化一些需要传入的参数
super(MyDataset, self).__init__()
self.datalist = datalist
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.datalist)
def __getitem__(self, index):
# 这个方法是必须要有的,用于按照索引读取每个元素的具体内容
img_file_fullpath, label = self.datalist[index]
# 读取图片信息
image_pre_handle = ImagePrehandle(img_file_fullpath)
_,ori_img = image_pre_handle.load_img(img_file_fullpath)
img = cv2.cvtColor(ori_img,cv2.COLOR_BGR2RGB)
if self.transform:
img = self.transform(img) # 是否进行transform
if self.target_transform:
label = self.target_transform(label)
return img, label # return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
调用
# define custom transform function
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])
target_transform = Lambda(lambda y: torch.zeros(4, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
train_data = MyDataset(datalist, transform=transform,target_transform=target_transform)
attention:
image_pre_handle.load_img 中使用cv2.imread 读取图片,transform正常使用是因为先.ToTensor() 了,否则可能需要使用PIL.Image.open 读取图片。
方法二 torchvision.datasets.ImageFolder
一个通用的数据加载器,运行命令为:
from torchvision.datasets import ImageFolder
ImageFolder(root="root folder path", [transform, target_transform])
root : 指定图片存储的路径,‘./image/train’
transform: 一个函数,原始图片作为输入,返回一个转换后的图片。
target_transform - 一个函数,输入为target,输出对其的转换。例子,输入的是图片标注的string,输出为word的索引。
有以下成员变量:
self.classes - 用一个list保存 类名
self.class_to_idx - 类名对应的 索引
self.imgs - 保存(img-path, class) tuple的list
即后面可以通过查看返回的数据集对象来查看相应的值,
A generic data loader where the images are arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
3. torch.utils.data.DataLoader
Dataloader 就是一个数据加载器,可以在训练的时候,把训练数据分成多个小组,每次抛出一组数据,直至把所有的数据都抛出。
dataloaders = Data.DataLoader(dataset, # 需要导入的数据集
batch_size=1, # 一批多少个数据
shuffle=False, # 是否打乱顺序
sampler=None, # 取样器,从dataset中抽取数据的策略(使用时shuffle必须为False)
num_workers=0, # 使用多线程读取数据,0表示不使用多线程(Windows系统)
pin_memory=False, # 将tensor放到cuda
drop_last=False) # dataset中最后不足一个batch的数据是否丢掉
Args | annotation |
dataset (Dataset) | dataset from which to load the data. |
batch_size (int, optional) | how many samples per batch to load (default: 一批多少个数据 |
shuffle (bool, optional) | set to 是否打乱顺序 |
sampler (Sampler or Iterable, optional) | defines the strategy to draw samples from the dataset. Can be any 取样器,从dataset中抽取数据的策略(使用时shuffle必须为False) |
batch_sampler (Sampler or Iterable, optional) | like :attr: |
num_workers (int, optional) | how many subprocesses to use for data, loading. 使用多线程读取数据,0表示不使用多线程(Windows系统) |
collate_fn (callable, optional) | merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. |
pin_memory (bool, optional) | If 将tensor放到cuda |
drop_last (bool, optional) | set to dataset中最后不足一个batch的数据是否丢掉 |
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: ``0``)
worker_init_fn (callable, optional): If not ``None``, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: ``None``)
prefetch_factor (int, optional, keyword-only arg): Number of samples loaded
in advance by each worker. ``2`` means there will be a total of
2 * num_workers samples prefetched across all workers. (default: ``2``)
persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
the worker processes after a dataset has been consumed once. This allows to
maintain the workers `Dataset` instances alive. (default: ``False``)