p y t o r c h 中 的 D a t a L o a d e r 与 D a t a S e t pytorch中的DataLoader与DataSet pytorch中的DataLoader与DataSet


class torch.utils.data.Dataset

决定数据从哪读取,如何读取,进行何种预处理

表示Dataset的抽象类。
所有子类应该override __len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

LoadDataset(data_dir=major_config.train_image, transform=train_transform)
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import major_config
random.seed(1)

# 类别对应表
dict_label = major_config.dict_label

# 返回所有图片路径和标签
def get_img_label(data_dir):
    img_label_list = list()
    for root, dirs, _ in os.walk(data_dir):
        # 遍历类别
        for sub_dir in dirs:
            img_names = os.listdir(os.path.join(root, sub_dir))
            # img_names = list(filter(lambda x: x.endswith('.png'), img_names))   # 如果改了图片格式,这里需要修改
            # 遍历图片
            for i in range(len(img_names)):
                img_name = img_names[i]
                path_img = os.path.join(root, sub_dir, img_name)
                label = dict_label[sub_dir]
                img_label_list.append((path_img, int(label)))
    return img_label_list

# 主要是用来接受索引返回样本用的
class LoadDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        # 获取所有图片的路径、label , 和 确定预处理操作
        self.img_label_list = get_img_label(data_dir)  # img_label_list,在DataLoader中通过index读取样本
        self.transform = transform

    #接受一个索引,返回一个样本 ---  img, label
    def __getitem__(self, index):
        path_img, label = self.img_label_list[index]
        img = Image.open(path_img).convert('RGB')     # 0~255
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等
        return img, label

    def __len__(self):
        return len(self.img_label_list)




__getitem__的主要作用

主要是用来接受索引返回样本用的。(Sample :Index 生成索引)

__len__的主要作用

__getitem__接受索引的范围就是__len__里确定的范围


class torch.utils.data.DataLoader

class torch.utils.data.DataLoader(  dataset, 
									batch_size=1,
									shuffle=False,
									sampler=None, 
									num_workers=0, 
									collate_fn=<function default_collate>,
									pin_memory=False, 
									drop_last=False
)
  • dataset (Dataset) – 加载数据的数据集。
  • batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
  • shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
  • sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
  • drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)。即当样本数不能被batchsize整除时,是否舍弃最后一批数据

_SingleProcessDataLoaderIter

def _next_data(self):

_sampler_iter

sampler

pytorch中的DataLoader与DataSet_数据集

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)
    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data