目录
- 1.Dataset基类
- 2.用DataLoader类实现自定义数据集
- 2.1DataLoader类的定义
- 3.DataLoader类中的多采样器子类
- 4.Torchtext工具与内置数据集
- 4.1Torchtext的内部结构
- 4.2安装Torchtext库
- 4.3查看Torchtext库的内置数据集
- 4.4安装Torchtext库的调用模块
- 4.5Torchtext库的内置预训练词向量
1.Dataset基类
在PyTorch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法。
- __len__方法,能够实现通过全局的len()方法获取其中的元素个数
- __getitem__方法,能够通过传入索引的方式获取数值,例如通过dataset[i]获取其中的第i条数据
2.用DataLoader类实现自定义数据集
在PyTorch中,用torch.utils.data.DataLoader类可以构建带有批次的数据集。
2.1DataLoader类的定义
class DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0,
collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None)
具体参数解读如下:
- dataset:待加载的数据
- batch_size:每批次加载的样本数量,默认是1
- shuffle:是否把样本的顺序打乱。默认是False,表示不打乱样本的顺序
- sampler:接收一个采样器对象,用于按照指定的样本提取策略从数据集中提取样本。如果指定,则忽略shuffle参数
- num_workers:设置加载数据的额外进程数量。默认是0,表示不额外启动进程来加载数据,直接使用主进程对数据进行加载
- collate_fn:接收一个自定义函数。当该参数不为None时,系统会先从数据集中取出数据,然后将数据传入collate_fn中,由collate_fn参数所指向的函数对数据进行二次加工。collate_fn常用于测试和训练场景中对同一个数据集进行数据提取
- pin_memory:在数据返回前,是否将数据复制到CUDA内存中。默认值为False
- drop_last:是否丢弃最后数据,默认值是False,表示不丢弃。在样本总数不能被batch size整除的情况下,如果该值为True,则丢弃最后一个满足一个批次数量的数据;如果该值为False,则将最后不足一个批次数量的数据返回
- timeout:读取数据的超时时间,默认值为0。当超时设置时间还没读到数据是,系统就会报错
- worker_init_fn:每个子进程的初始化函数,在加载数据之前运行
- multiprocessing_context:多进程处理的配置参数
3.DataLoader类中的多采样器子类
DataLoader类是一个非常强大的数据集处理类。它几乎可以覆盖数据集的任何使用场景,在PyTorch程序中也非常常用。其中与DataLoader类配套的还有采样器sampler类,该类又派生了多个采样器子类,同时支持自定义采样器类。其中,内置的采样器子类有如下几种
- SequentialSampler:按照原有的样本顺序进行采样
- RandomSampler:按照随机顺序进行采样,可以设置是否重复采样
- SubsetRandomSampler:按照指定的集合或索引列表进行随机顺序采样
- WeightedRandomSampler:按照指定的概率进行随机顺序采样
- BatchSampler:按照指定的批次索引进行采样
4.Torchtext工具与内置数据集
4.1Torchtext的内部结构
Torchtext对数据的处理可以概括为Field、Dataset和迭代器这3部分
- Field:要如何处理某个字段
- Dataset:定义数据源信息
- 迭代器:返回模型所需要的处理后的数据。主要分为以下3种
- Iterator:标准迭代器
- BucketIerator:相比于标准迭代器,它会将类似长度的样本当作一批来处理。由于在文本处理中经常需要将“每一批样本长度”补齐为“当前批次中最长序列的长度”,因此当样本长度差别较大时,使用BucketIterator可以带来填充效率的提高。除此之外,还可以在Field中通过fix_length参数来对样本进行截断补齐操作
- BPTTIterator:基于BPTT(基于时间的反向传播算法)的迭代器,一般用于语言模型中。
4.2安装Torchtext库
pip install torchtext
4.3查看Torchtext库的内置数据集
在安装好Torchtext库后,可以在如下的路径中查看Torchtext库的内置数据集
本地pip安装包路径\Lib\site-packages\torchtext\datases\__init__.py
4.4安装Torchtext库的调用模块
在使用Torchtext库过程中,如果要间接使用其他的文本处理库,则还需要额外下载,例如,使用字段处理库的代码如下:
from torchtext import data
TEXT = data.Field(tokenize='spacy')
在调用Torchtext库的data.Field()函数时,可以向tokenize参数传入“revtok”、“subword”、“spacy”、“moses”字符串,表示分别使用revtok、NLTK、en模块的SpaCy库、sacremoses库进行字段处理,这些库都需要单独安装。
4.5Torchtext库的内置预训练词向量
Torchtext库中内置了若干个预训练词向量,可以在模型中直接用来对本地的权重进行初始化
charngram.100d、fasttext.en.300d、fasttext.simple.300d、glove.42B.300d、glove.840B.300d、glove.twitter.27B.25d、glove.twitter.27B.50d、glove.twitter.27B.100d、glove.twitter.27B.200d、glove.6B.50d、glove.6B.100d、glove.6B.200d、glove.6B.300d
这些词向量,前部分的名称表明其在训练时所用的模型,后部分都是“数字+d”的形式,代表将词映射成词向量的维度。这种本来就带有语义的词向量,可以大大加快模型的训练速度。