PyTorch DataLoader 源代码 - 调试阶段

在本集中,我们将继续上集数据标准化的地方。只是这一次,我们将要调试代码,而不是编写代码,尤其是要调试PyTorch源代码,以查看规范化数据集时到底发生了什么。

stm32上安装python stm32 pytorch_机器学习

调试PyTorch源代码的简短程序

在我们开始调试之前,我们只想给我们快速概述一下我们编写的程序,这将使我们能够逐步看到数据集的归一化,并看到它在hood和PyTorch下面到底是如何完成的。

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

from torch.utils.data import DataLoader

正如我们在上一集所讨论的,我们有平均值和标准差值。现在,我们不需要计算这些,我们只需要把这些值拉出来,然后硬编码到程序里。

mean = 0.2860347330570221
std = 0.3530242443084717

我们不想费尽周折地重新计算这些值,所以我们在这里对它们进行了硬涂层。我们有了平均值和标准差,我们知道我们需要这两个值来对数据集的每个成员或每个像素进行归一化。

接下来,我们使用FashionMNIST类构造函数初始化我们的训练集。这里需要注意的关键点是transforms

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
        , transforms.Normalize(mean, std)
    ])
)

transforms第一个配置是将pill image转换为张量,然后第二个配置是归一化变换,它要对我们的数据进行归一化。我们的目标是在源代码中验证这个特殊的变换是如何工作的。

最后,我们创建一个DataLoader并使用它。

loader = DataLoader(train_set, batch_size=1)
image, label = next(iter(loader))

调试PyTorch源代码

好了,现在我们准备好实际调试了。为了调试,我们将继续前进,并确保我们已经选择了我的python运行配置,然后我们将点击,开始调试。
使用此链接可以访问PyTorch DataLoader类的当前源代码。本讨论假定使用PyTorch版本1.5.0

The Sampler: To Shuffle Or Not

采样器是获取索引值的对象,它将用于从底层数据集中获取实际值。我们可以看到,有两个特殊的采样器是相关的,随机采样器和顺序采样器。如果shuffle值为,则true采样器将为随机采样器,否则为连续采样器。

批次大小的使用方式

我们发现采样器用于在以下代码中收集索引值:

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch

在这里,我们可以看到batch_size参数在起作用,因为它限制了所收集索引值的数量。
请注意,yield此处的关键字注意,使得这个迭代器变成了所谓的generator
获取索引值后,它们将通过以下方式用于获取数据:

def fetch(self, possibly_batched_index):
    if self.auto_collation:
        data = [self.dataset[idx] for idx in possibly_batched_index]
    else:
        data = self.dataset[possibly_batched_index]
    return self.collate_fn(data)

就像这样,从底层数据集中提取每个样本的工作。

data = [self.dataset[idx] for idx in possibly_batched_index]

这种语法或符号被称为列表理解。
这将返回一个数据元素的列表,然后使用 collate_fn()方法提取并放入一个单一的批量张量。

标准化数据集

最后,我们发现返回给批处理的每个元素都使用normalize()功能性api 的功能进行了规范化。

def normalize(tensor, mean, std, inplace=False):
    """Normalize a tensor image with mean and standard deviation.
    tensor.sub_(mean).div_(std)
    return tensor

请注意,数据集类调用了一个转换,然后调用了功能性api。我们还遇到了一些错误的设计,需要进行一些修改才能保持一致。

请注意,此处使用“ 黑客 ”一词是指我们看到代码在进行不必要的转换。
在本集中,我们调试了PyTorch DataLoader,以查看如何从PyTorch数据集中提取数据并对其进行规范化。我们看到了几个构造函数参数的影响,并看到了如何构建批处理。