迭代器
理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data
模块的关键。
在 Dataset
, Sampler
和 DataLoader
这三个类中都会用到 python 抽象类的魔法方法,包括__len__(self)
,__getitem__(self)
和 __iter__(self)
-
__len__(self)
: 定义当被len()
函数调用时的行为,一般返回迭代器中元素的个数 -
__getitem__(self)
: 定义获取容器中指定元素时的行为,相当于self[key]
,即允许类对象拥有索引操作 -
__iter__(self)
: 定义当迭代容器中的元素时的行为
迭代的意思类似于循环,每一次重复的过程被称为一次迭代的过程,而每一次迭代得到的结果会被用来作为下一次迭代的初始值。提供迭代方法的容器称为迭代器,通常接触的迭代器有序列(列表、元组和字符串)还有字典,这些数据结构都支持迭代操作。
实现迭代器的魔法方法有两个:__iter__(self)
和 __next__(self)
一个容器如果是迭代器,那就必须实现 __iter__(self)
魔法方法,这个方法实际上是返回是一个迭代器(通常是迭代器本身)。接下来重点要实现的是 __next__(self)
魔法方法,因为它决定了迭代的规则。
class Fibs:
def __init__(self, n=20):
self.a = 0
self.b = 1
self.n = n
def __iter__(self):
return self
def __next__(self):
self.a, self.b = self.b, self.a + self.b
if self.a > self.n:
raise StopIteration
return self.a
fibs = Fibs()
for each in fibs:
print(each)
# 输出
# 1 1 2 3 5 8 13
一般来说,迭代器满足以下几种特性:
- 迭代器是⼀个对象
- 迭代器可以被 next() 函数调⽤,并返回⼀个值
- 迭代器可以被 iter() 函数调⽤,并返回一个迭代器(可以是自身)
- 连续被 next() 调⽤时依次返回⼀系列的值
- 如果到了迭代的末尾,则抛出 StopIteration 异常
- 迭代器也可以没有末尾,只要被 next() 调⽤,就⼀定会返回⼀个值
- Python 中, next() 内置函数调⽤的是对象的 next() ⽅法
- Python 中, iter() 内置函数调⽤的是对象的 iter() ⽅法
- ⼀个实现了迭代器协议的的对象可以被 for 语句循环迭代直到终⽌
了解了什么是迭代器后,我们就可以开始解读 torch.utils.data
模块
对于 torch.utils.data
而言,重点是其 Dataset
, Sampler
, DataLoader
模块,辅以 collate
, fetch
, pin_memory
等组件对特定功能予以支持。
1 Dataset
Dataset 负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。
Dataset 共有 Map-style datasets 和 Iterable-style datasets 两种:
1.1 Map-style dataset
torch.utils.data.Dataset
它是一种通过实现 __getitem__()
和 __len()__
来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。访问时,这样的数据集用 dataset[idx]
访问 idx
对应的数据。
通常我们使用 Map-style 类型的 dataset 居多,其数据接口定义如下:
class Dataset(Generic[T_co]):
# Generic is an Abstract base class for generic types.
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
PyTorch 中所有定义的 Dataset 都是其子类。
对于一般计算机视觉任务,我们通常会在其中进行一些 resize, crop, flip 等预处理的操作
值得一提的是,PyTorch 源码中并没有提供默认的 __len__()
方法实现,原因是 return NotImplemented
或者 raise NotImplementedError()
之类的默认实现都会存在各自的问题,这点在其源码中也有注释加以体现。
1.2 Iterable-style dataset
torch.utils.data.IterableDataset
它是一种实现 __iter__()
来获取数据的 Dataset,这种类型的数据集特别适用于以下情况:随机读取代价很大甚至不大可能,且 batch size 取决于获取的数据。其接口定义如下:
class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
特别地,当 DataLoader
的 num_workers > 0
时, 每个 worker 都将具有数据对象的不同样本。因此需要独立地对每个副本进行配置,以防止每个 worker 产生的数据不重复。同时,数据加载顺序完全由用户定义的可迭代样式控制。这允许更容易地实现块读取和动态批次大小(例如,通过每次产生一个批次的样本)
1.3 其他 Dataset
除了 Map-style dataset 和 Iterable-style dataset 以外,PyTorch 也在此基础上提供了其他类型的 Dataset 子类
-
torch.utils.data.ConcatDataset
: 用于连接多个ConcatDataset
数据集 -
torch.utils.data.ChainDataset
: 用于连接多个IterableDataset
数据集,在IterableDataset
的__add__()
方法中被调用 -
torch.utils.data.Subset
: 用于获取指定一个索引序列对应的子数据集
class Subset(Dataset[T_co]):
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
-
torch.utils.data.TensorDataset
: 用于获取封装成 tensor 的数据集,每一个样本都通过索引张量来获得。
class TensorDataset(Dataset):
def __init__(self, *tensor):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in tensors
def __len__(self):
return self.tensors[0].size(0)
2 Sampler
torch.utils.data.Sampler
负责提供一种遍历数据集所有元素索引的方式。可支持用户自定义,也可以用 PyTorch 提供的,基类接口定义如下:
lass Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
特别地,__len__()
方法不是必要的,但是当 DataLoader 需要计算 len()
的时候必须定义,这点在其源码中也有注释加以体现。
同样,PyTorch 也在此基础上提供了其他类型的 Sampler 子类
-
torch.utils.data.SequentialSampler
: 顺序采样样本,始终按照同一个顺序 -
torch.utils.data.RandomSampler
: 可指定有无放回地,进行随机采样样本元素 -
torch.utils.data.SubsetRandomSampler
: 无放回地按照给定的索引列表采样样本元素 -
torch.utils.data.WeightedRandomSampler
: 按照给定的概率来采样样本。样本元素来自[0,…,len(weights)-1]
, 给定概率(权重) -
torch.utils.data.BatchSampler
: 在一个batch中封装一个其他的采样器, 返回一个 batch 大小的 index 索引 -
torch.utils.data.DistributedSample
: 将数据加载限制为数据集子集的采样器。与torch.nn.parallel.DistributedDataParallel
结合使用。 在这种情况下,每个进程都可以将DistributedSampler
实例作为DataLoader
采样器传递
3 DataLoader
torch.utils.data.DataLoader
是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和 Iterable-style Dataset,支持单进程/多进程,还可以设置 loading order, batch size, pin memory 等加载参数。其接口定义如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
对于每个参数的含义,以下给出一个表格进行对应介绍:
attribute | meaning | default value | type |
dataset | 加载数据的数据集 | | Dataset |
batch_size | 每个 batch 加载多少个样本 | 1 | int |
shuffle | 设置为 True 时,调用 RandomSampler 进行随机索引 | False | bool |
sampler | 定义从数据集中提取样本的策略 如果指定了, shuffle 参数必须为 False,(否则会和 RandomSampler 互斥) | None | Sampler, Iterable |
batch_sampler | 和 sampler 类似,但是一般传入 BatchSampler,每次返回一个 batch 大小的索引 其和 batch_size, shuffle 等参数是互斥的 | None | Sampler, Iterable |
num_workers | 要用于数据加载的子进程数,0 表示将在主进程中加载数据 | 0 | int |
collate_fn | 在将 Map-style datase t 取出的数据整合成 batch 时使用,合并样本列表以形成一个 batch | None | callable |
pin_memory | 如果为 True,则 DataLoader 在将张量返回之前将其复制到 CUDA 固定的内存中 | False | bool |
drop_last | 设置为 True 删除最后一个不完整的批次,如果该数据集大小不能被该批次大小整除。 如果 False 并且数据集的大小不能被批次大小整除,那么最后一批将较小 | False | bool |
timeout | 如果为正,则为从 worker 收集 batch 的超时值,应始终为非负数 超过这个时间还没读取到数据的话就会报错 | 0 | numeric |
worker_init_fn | 如果不为 None,它将会被每个 worker 子进程调用, 以 worker id ([0, num_workers - 1] 内的整形) 为输入 | None | callable |
prefetch_factor | 每个 worker 提前加载 的 sample 数量 | 2 | int |
persistent_workers | 如果为 True,dataloader 将不会终止 worker 进程,直到 dataset 迭代完成 | False | bool |
从参数定义中,我们可以看到 DataLoader 主要支持以下几个功能
- 支持加载
map-style
和iterable-style
的 dataset,主要涉及到的参数是dataset
- 自定义数据加载顺序,主要涉及到的参数有
shuffle
,sampler
,batch_sampler
,collate_fn
- 自动把数据整理成batch序列,主要涉及到的参数有
batch_size
,batch_sampler
,collate_fn
,drop_last
- 单进程和多进程的数据加载,主要涉及到的参数有
num_workers
,worker_init_fn
- 自动进行锁页内存读取 (memory pinning),主要涉及到的参数
pin_memory
- 支持数据预加载,主要涉及的参数
prefetch_factor
3.1 三者关系 (Dataset, Sampler, Dataloader)
通过以上介绍的三者工作内容不难推出其内在关系:
- 设置 Dataset,将数据 data source 包装成 Dataset 类,暴露提取接口。
- 设置 Sampler,决定采样方式。我们是能从 Dataset 中提取元素了,还是需要设置 Sampler 告诉程序提取 Dataset 的策略。
- 将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置
shuffle
,batch_size
等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。
总结来说,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。