1.下载数据集
使用torchvision.datasets
来下载数据集
-
root
用来指定下载后保存的位置(如果已经存在则不会下载) -
download
表示是否要下载 -
train
表示获取训练数据集或测试数据集 -
transform
代表对图像的操作, 这里仅仅使用了ToTensor()
把图像数据转换为Tensor
类型
其格式为()
书本原话:
注意:由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括
transforms.ToTensor() 在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错
但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成
uint8,避免不必要的bug。
import torchvision
import torchvision.transforms as transforms
mnist_train = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=False, download=True,
transform=transforms.ToTensor())
查看一下读取的结果
2.查看数据集结构
对训练集切片查看一下数据类型和标签类型
这里的标签已经转换为数值型数据来存储
所以我们可以编写一个函数将其转换为 图像数据集原本对应的标签
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
3.查看图片与标签
先提取出其中的一张图片与标签来查看
img, label = mnist_train[0]
title = get_fashion_mnist_labels([label])[0] # 获取标签
plt.imshow(img.view((28,28)).numpy()) # 数据格式转换
plt.title(title) # 设置标题
plt.savefig('test.jpg') # 存储图片
查看多个图片和标签(以前十张为例)
import matplotlib.pyplot as plt
def show_fashion_mnist(images, labels):
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
plt.show()
4.按小批次读取数据集
使用DataLoader
它可以允许多线程来加速数据读取
具体的可以看下面链接中的文章,有对DataLoader
和Dataset
的详细介绍
from torch.utils.data import DataLoader
import sys
batch_size = 256
if sys.platform.startswith('win'):
# 0表示不用额外的进程来加速读取数据
num_workers = 0
else:
num_workers = 4
train_iter = DataLoader(mnist_train,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
test_iter = DataLoader(mnist_test,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
DataLoader
是个可遍历的对象
start = time()
for X, y in train_iter:
continue
print('%.2f sec' % (time() - start))
可以通过上述代码来查看读取一遍训练集需要的时间