主要内容
- 1 Dataset
- 2 Sampler
- 3 DataLoader
- 3.1 三者关系 (Dataset, Sampler, Dataloader)
- 3.2 批处理
- 3.2.1 自动批处理(默认)
- 3.2.3 collate_fn
- 3.3 多进程处理 (multi-process)
- 4 预取 (prefetch)
- 5 代码详解
本篇博文主要用来记录参考链接中的所学重要知识,梳理清楚。
1 Dataset
Dataset 负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。
- Map-style sataset
torch.utils.data.Dataset
它是一种通过实现 getitem() 和 len() 来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。访问时,这样的数据集用 dataset[idx] 访问 idx 对应的数据。 - Iterable-style dataset
- 其他Dataset
torch.utils.data.TensorDataset
: 用于获取封装成 tensor 的数据集,每一个样本都通过索引张量来获得。
2 Sampler
torch.utils.data.Sampler
负责提供一种遍历数据集所有元素索引的方式。
特别地,len() 方法不是必要的,但是当 DataLoader 需要计算 len() 的时候必须定义,这点在其源码中也有注释加以体现。
同样,PyTorch 也在此基础上提供了其他类型的 Sampler 子类
torch.utils.data.SequentialSampler
: 顺序采样样本,始终按照同一个顺序torch.utils.data.RandomSampler
: 可指定有无放回地,进行随机采样样本元素torch.utils.data.BatchSampler
: 在一个batch中封装一个其他的采样器, 返回一个 batch 大小的 index 索引
3 DataLoader
torch.utils.data.DataLoader
是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和 Iterable-style Dataset,支持单进程/多进程,还可以设置 loading order, batch size, pin memory 等加载参数。其接口定义如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
对于每个参数的含义,以下给出一个表格进行对应介绍:
attribute | meaning | default value | type |
dataset | 加载数据的数据集 | Dataset | |
batch_size | 每个 batch 加载多少个样本 | 1 | int |
shuffle | 设置为 True 时,调用 RandomSampler 进行随机索引 | False | bool |
sampler | 定义从数据集中提取样本的策略。如果指定了, shuffle 参数必须为 False,(否则会和 RandomSampler 互斥) | None | Sampler, Iterable |
batch_sampler | 和 sampler 类似,但是一般传入 BatchSampler,每次返回一个 batch 大小的索引。其和 batch_size, shuffle 等参数是互斥的 | None | Sampler, Iterable |
num_workers | 要用于数据加载的子进程数,0 表示将在主进程中加载数据 | 0 | int |
collate_fn | 在将 Map-style dataset 取出的数据整合成 batch 时使用,合并样本列表以形成一个 batch | None | callable |
pin_memory | 如果为 True,则 DataLoader 在将张量返回之前将其复制到 CUDA 固定的内存中 | False | bool |
drop_last | 设置为 True 删除最后一个不完整的批次,如果该数据集大小不能被该批次大小整除。如果 False 并且数据集的大小不能被批次大小整除,那么最后一批将较小 | False | bool |
timeout | 如果为正,则为从 worker 收集 batch 的超时值,应始终为非负数。超过这个时间还没读取到数据的话就会报错 | 0 | numeric |
worker_init_fn | 如果不为 None,它将会被每个 worker 子进程调用,以 worker id ([0, num_workers - 1] 内的整形) 为输入 | None | callable |
prefetch_facto | 每个 worker 提前加载 的 sample 数量 | 2 | int |
persistent_workers | 如果为 True,dataloader 将不会终止 worker 进程,直到 dataset 迭代完成 | False | bool |
3.1 三者关系 (Dataset, Sampler, Dataloader)
- 设置 Dataset,将数据 data source 包装成 Dataset 类,暴露提取接口。
- 设置 Sampler,决定采样方式。我们是能从 Dataset 中提取元素了,还是需要设置 Sampler 告诉程序提取 Dataset 的策略。
- 将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置 shuffle, batch_size 等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。
总结来说,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。
3.2 批处理
3.2.1 自动批处理(默认)
DataLoader 支持通过参数batch_size, drop_last, batch_sampler,自动地把取出的数据整理 (collate) 成批次样本 (batch)
batch_size 和 drop_last 参数用于指定 DataLoader 如何获取 dataset 的 key。特别地,对于 map-style 类型的 dataset,用户可以选择指定 batch_sample参数,一次就生成一个 keys list
在使用 sampler 产生的 indices 获取采样到的数据时,DataLoader 使用 collate_fn 参数将样本列表整理成 batch。抽象这个过程,其表示方式大致如下
# For Map-style
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
# For Iterable-style
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
3.2.3 collate_fn
当关闭自动批处理 (automatic batching) 时,collate_fn 作用于单个数据样本,只是在 PyTorch 张量中转换 NumPy 数组。
当开启自动批处理 (automatic batching) 时,collate_fn 作用于数据样本列表,将输入样本整理为一个 batch,一般做下面 3 件事情
- 添加新的批次维度(一般是第一维)
- 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量
- 它保留数据结构,例如,如果每个样本都是 dict,则输出具有相同键集但批处理过的张量作为值的字典(或list,当不能转换的时候)。list, tuples, namedtuples 同样适用
自定义 collate_fn 可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度,添加对自定义数据类型的支持等。
3.3 多进程处理 (multi-process)
为了避免在加载数据时阻塞计算代码,PyTorch 提供了一个简单的开关,只需将参数设置 num_workers 为正整数即可执行多进程数据加载,设置为 0 时执行单线程数据加载。
4 预取 (prefetch)
DataLoader 通过指定 prefetch_factor (默认为 2)来进行数据的预取。
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
...
self._reset(loader, first_iter=True)
def _reset(self, loader, first_iter=False):
...
# prime the prefetch loop
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()
通过源码可以看到,prefetch 功能仅适用于 多进程 加载中(下面会由多进程 dataloader 的代码分析)
5 代码详解
来看看具体的代码调用流程:
for data, label in train_loader:
......
for 循环会调用 dataloader 的 iter(self) 方法,以此获得迭代器来遍历 dataset
class DataLoader(Generic[T_co]):
...
def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
在 iter(self) 方法中,dataloader
调用了self._get_iterator()
方法,根据 num_worker
获得迭代器,并指示进行单进程还是多进程
class DataLoader(Generic[T_co]):
...
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
为了描述清晰,我们只考虑单进程的代码。下面是 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter)
,以及其父类 class _BaseDataLoaderIter(object):
的重点代码片段:
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
# 初始化赋值一些 DataLoader 参数,
# 以及用户输入合法性进行校验
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._index_sampler = loader._index_sampler
...
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 重点代码行,通过此获取数据
self._num_yielded += 1
...
return data
next = __next__ # Python 2 compatibility
def __len__(self) -> int:
return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)
def __getstate__(self):
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
_BaseDataLoaderIter
是所有 DataLoaderIter
的父类。dataloader获得了迭代器之后,for 循环需要调用 __next__()
来获得下一个对象,从而实现遍历。通过 __next__
方法调用 _next_data()
获取数据
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
从 _SingleProcessDataLoaderIter 的初始化参数可以看到,其在父类 _BaseDataLoaderIter 的基础上定义了 _dataset_fetcher, 并传入 _dataset, _auto_collation, _collate_fn 等参数,用于定义获取数据的方式。其具体实现会在稍后解释。
在 _next_data() 被调用后,其需要 next_index() 获取 index,并通过获得的 index 传入 _dataset_fetcher 中获取对应样本
class DataLoader(Generic[T_co]):
...
@property
def _auto_collation(self):
return self.batch_sampler is not None
@property
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
class _BaseDataLoaderIter(object):
...
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
...
def _next_index(self):
# sampler_iter 来自于 index_sampler
return next(self._sampler_iter) # may raise StopIteration
从这里看出,dataloader 提供了 sampler (可以是batch_sampler 或者是其他 sampler 子类),然后 _SingleProcessDataLoaderIter
迭代sampler获得索引
下面我们来看看 fetcher,fetcher 需要 index 来获取元素,并同时支持 Map-style dataset(对应 _MapDatasetFetcher
)和 Iterable-style dataset(对应 _IterableDatasetFetcher
),使其在Dataloader内能使用相同的接口 fetch,代码更加简洁。
- 对于 Map-style:直接输入索引 index,作为 map 的 key,获得对应的样本(即 value)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 有batch_sampler,_auto_collation就为True,
# 就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
- 对于 Iterable-style: init 方法内设置了 dataset 初始的迭代器,fetch 方法内获取元素,index 其实已经没有多大作用了
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 对于batch_sampler(即auto_collation==True)
# 直接使用往后遍历并提取len(possibly_batched_index)个样本(即1个batch的样本)
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
# 对于sampler,直接往后遍历并提取1个样本
data = next(self.dataset_iter)
return self.collate_fn(data)
最后,我们通过索引传入 fetcher,fetch 得到想要的样本。
因此,整个过程调用关系总结 如下:
loader.__iter__
--> self._get_iterator()
--> class _SingleProcessDataLoaderIter
--> class _BaseDataLoaderIter
--> __next__()
--> self._next_data()
--> self._next_index()
-->next(self._sampler_iter)
即 next(iter(self._index_sampler))
--> 获得 index --> self._dataset_fetcher.fetch(index)
--> 获得 data
参考链接:PyTorch 源码解读之 torch.utils.data:解析数据处理全流程
问题记录:
1.pytorch 中的Dataset这个类为什么可以调用__getitem__?特殊方法名称