文章目录

  • DataLoaderIter && DataLoader
  • Dataset



因为每次和数据打交道,天天可以碰到

torch.utils.data.Dataset, torch.utils.data.DataLoader

我看到的代码 都是一步一步封装,首先定义数据增强的措施,然后把这些措施封装到预处理中(这里用到了torchvision.transforms),定义好预处理后,就应该采样了sample(这里继承了Dataset,一般用RandomSampler,BatchSampler,SequentialSampler),里面封装了预处理,接着进行数据的加载loader(这里继承了Dataloader),这里封装的就是之前定义的sample,最后dataloader封装进了dataloaderIter中,进行逐步迭代。当然还需要对数据集进行定义,编写需要的方法(这里继承了Dataset)。 当程序开始执行的时候,它一步一步倒着去执行,依次遍历,首先 根据得到batch索引,然后根据索引得到数据,接着进行处理,最终得到所需要的数据。

DataLoaderIter && DataLoader

Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
一般先进行定义:

loader_train = LTRLoader('train', dataset_train, 
                            training=True, 
                            batch_size=settings.batch_size,
                            num_workers=settings.num_workers,
                            shuffle=True, drop_last=True, stack_dim=1)

其中dataset_train,就是Dataset类型,也就是定义的sampler,这里最主要的是重写__ getitem__()方法,得到数据。
一般运行代码:
for i,data in enumenrate(loader,1): 就自动跳到了dataloader.pyclass DataLoader()中的 def __ iter 方法,

pytorch如何打印模型参数_python


在这里选择使用单线程还是多线程进行数据的迭代,这里的MultiProcessingDataLoaderIter( Iterates once over the DataLoader’s dataset, as specified by the sampler)继承的是BaseDataLoaderIter,开始初始化,然后Dataloader进行初始化,然后进入
next __()方法 随机生成索引,进而生成batch,最后调用 _get_data() 方法得到data。idx, data = self._get_data(), data = self.data_queue.get(timeout=timeout)
这里用到了
队列

总结一下: 1.调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter
2.反复调用DataLoaderIter 的__next__()来得到batch, 具体操作就是, 多次调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch. 中间还会涉及到shuffle , 以及sample 的方法等,
3当数据读完后, next()抛出一个StopIteration异常, for循环结束, dataloader 失效.

因此dataloader作用:

1.定义了一堆成员变量, 到时候赋给DataLoaderIter,

2.然后有一个__iter__() 函数, 把自己 “装进” DataLoaderIter 里面.

以下是DataLoader()的参数:

pytorch如何打印模型参数_数据_02


其中shuffle参数和sampler参数相关,sampler option is mutually exclusive with shuffle

Dataloader中存在一个默认的collate_fn函数

pytorch如何打印模型参数_python_03


需要根据自己的需求重写collate_fn函数,该函数的作用是将得到的数据整理成一个batch。

代码为:先判断batch数据的类型,然后形成batch

pytorch如何打印模型参数_预处理_04

Dataset

基本的Dataset类如下:有__init__() 、__ getitem__()、__ len__()、__ iter__() 方法

pytorch如何打印模型参数_数据_05


当然你定义的数据集是需要继承此类,并且覆写 _ init_() 、__ getitem__()、_ len_() 方法,甚至自己实现get_frames(),get_sequence_info()等方法

如果你要加载你的数据也需要继承此类,并且覆写 _ init_() 、__ getitem__()、_ len_() 方法,可以调用数据集中自己写的方法


第一个初始化,定义你用于训练的数据集,以什么比例进行sample(多个数据集的情况),每个epoch训练样本的数目,预处理方法等等

第二个是根据索引得到所需要的数据

video_keys=self.videos.keys() 
video = self.videos[video_keys[rand_vid]] #rand_vid为索引
video_ids=video[0]
video_id_keys=video.ids.keys()
rand_trackid_z = np.random.choice(list(range(len(video_id_keys))))
  
  #simafc中经过一系列的路径,文件名,进行随机选择 需要的图片,如果有预处理方法,再得到图片后进行预处理。有时候除了加载所需要的图片,还要加载真值

第三个是返回数据的长度

要想进行下一步的操作,读取正确的数据,并进行一定的处理是很重要的。