p y t o r c h 中 的 D a t a L o a d e r 与 D a t a S e t pytorch中的DataLoader与DataSet pytorch中的DataLoader与DataSet
class torch.utils.data.Dataset
决定数据从哪读取,如何读取,进行何种预处理
表示Dataset的抽象类。
所有子类应该override __len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
LoadDataset(data_dir=major_config.train_image, transform=train_transform)
import os import random from PIL import Image from torch.utils.data import Dataset import major_config random.seed(1) # 类别对应表 dict_label = major_config.dict_label # 返回所有图片路径和标签 def get_img_label(data_dir): img_label_list = list() for root, dirs, _ in os.walk(data_dir): # 遍历类别 for sub_dir in dirs: img_names = os.listdir(os.path.join(root, sub_dir)) # img_names = list(filter(lambda x: x.endswith('.png'), img_names)) # 如果改了图片格式,这里需要修改 # 遍历图片 for i in range(len(img_names)): img_name = img_names[i] path_img = os.path.join(root, sub_dir, img_name) label = dict_label[sub_dir] img_label_list.append((path_img, int(label))) return img_label_list # 主要是用来接受索引返回样本用的 class LoadDataset(Dataset): def __init__(self, data_dir, transform=None): # 获取所有图片的路径、label , 和 确定预处理操作 self.img_label_list = get_img_label(data_dir) # img_label_list,在DataLoader中通过index读取样本 self.transform = transform #接受一个索引,返回一个样本 --- img, label def __getitem__(self, index): path_img, label = self.img_label_list[index] img = Image.open(path_img).convert('RGB') # 0~255 if self.transform is not None: img = self.transform(img) # 在这里做transform,转为tensor等等 return img, label def __len__(self): return len(self.img_label_list)
__getitem__的主要作用
主要是用来接受索引返回样本用的。(Sample :Index 生成索引)
__len__的主要作用
__getitem__接受索引的范围就是__len__里确定的范围
class torch.utils.data.DataLoader
class torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False )
- dataset (Dataset) – 加载数据的数据集。
- batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
- shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
- sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
- num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
- drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)。即当样本数不能被batchsize整除时,是否舍弃最后一批数据
_SingleProcessDataLoaderIter
def _next_data(self):
_sampler_iter
sampler
def fetch(self, possibly_batched_index): if self.auto_collation: data = [self.dataset[idx] for idx in possibly_batched_index] else: data = self.dataset[possibly_batched_index] return self.collate_fn(data)
def _next_data(self): index = self._next_index() # may raise StopIteration data = self._dataset_fetcher.fetch(index) # may raise StopIteration if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data