1.下载数据集

使用torchvision.datasets来下载数据集

  • root 用来指定下载后保存的位置(如果已经存在则不会下载)
  • download表示是否要下载
  • train 表示获取训练数据集或测试数据集
  • transform代表对图像的操作, 这里仅仅使用了ToTensor()把图像数据转换为Tensor类型
    其格式为(少量图像分类数据集下载 图像分类的数据集_少量图像分类数据集下载)


书本原话: 
注意:由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括
transforms.ToTensor() 在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错
但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成
uint8,避免不必要的bug。
import torchvision
import torchvision.transforms as transforms
mnist_train = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=True, download=True,
                                                transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=False, download=True,
                                               transform=transforms.ToTensor())

查看一下读取的结果

少量图像分类数据集下载 图像分类的数据集_少量图像分类数据集下载_02

2.查看数据集结构

对训练集切片查看一下数据类型和标签类型

少量图像分类数据集下载 图像分类的数据集_深度学习_03


这里的标签已经转换为数值型数据来存储

所以我们可以编写一个函数将其转换为 图像数据集原本对应的标签

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

少量图像分类数据集下载 图像分类的数据集_数据集_04

3.查看图片与标签

先提取出其中的一张图片与标签来查看

img, label = mnist_train[0]
title = get_fashion_mnist_labels([label])[0] # 获取标签
plt.imshow(img.view((28,28)).numpy())	# 数据格式转换
plt.title(title)	# 设置标题
plt.savefig('test.jpg')	# 存储图片

少量图像分类数据集下载 图像分类的数据集_人工智能_05


查看多个图片和标签(以前十张为例)

import matplotlib.pyplot as plt
def show_fashion_mnist(images, labels):
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
plt.show()

少量图像分类数据集下载 图像分类的数据集_人工智能_06

4.按小批次读取数据集

使用DataLoader 它可以允许多线程来加速数据读取

具体的可以看下面链接中的文章,有对DataLoaderDataset的详细介绍

from torch.utils.data import DataLoader
import sys
batch_size = 256
if sys.platform.startswith('win'):
    # 0表示不用额外的进程来加速读取数据
    num_workers = 0
else:
    num_workers = 4
train_iter = DataLoader(mnist_train,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers)
test_iter = DataLoader(mnist_test,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=num_workers)

DataLoader是个可遍历的对象

start = time()
for X, y in train_iter:
	continue
print('%.2f sec' % (time() - start))

可以通过上述代码来查看读取一遍训练集需要的时间