我们使用一个很经典的数据集Cifar10,而该数据集可以直接通过Pytorch内置函数获取到。
一、导入所需的库
import torch ## pytorch
import torchvision ## 迁移学习模型和许多其他视觉相关类
from torch import nn ## Pytorch核心神经网络模型类
from torch import optim ## 包含几个Pytorch优化器类
import torch.nn.functional as F ## 包含几个Pytorch提供的实用函数
from torchvision import datasets, transforms, models ## 对于数据集和变换的一些计算机视觉相关类
from torch.utils.data import * ## 包含几个数据操作的实用函数
from PIL import Image
import numpy as np
二、创建CIFAR10 Pytorch数据集
- 从torchvision下载CIFAR10训练集和测试集;
- 首先设置train=True,表明我们下载训练集。然后设为False来下载测试集;
- 设置download=True,由于我们是第一次获取这个数据集。因此,它将首先从CIFAR10类中预先指定的URL下载。
- 在首次运行这个cell,成功的下载数据集后,应该改变为False来避免每次下载;
- 以下操作的结果将是两个数据集对象,分别表示CIFAR10训练集和测试集。
train_dataset = datasets.CIFAR10('Cifar10', train=True,
download=True)
test_dataset = datasets.CIFAR10('Cifar10', train=False,
download=True)
这里有两个来自torchvision.datasets.cifar的数据集对象。这是Pytorch的Dataset类的一个子类,Dataset是表示任何数据集的主类。这个特殊的类表示存储在其内部数据结构中的CIFAR10数据。稍后,这些对象将被传递给一个Pytorch Dataloader对象(稍后解释)来处理这些图像。
我们可以验证两个数据集的长度(图像的数量)
len(train_dataset),len(test_dataset)
(50000, 10000)
如上所示,我们分别有50000张图片的训练集和10000张图片的测试集。
二、张量(Tensors)快速介绍
张量是一种表示单个类型(整数或浮点数等)的n维数据对象的通用方式。例如:
- 一个单值(整型或者浮点值)是一个0维张量;
- 一个有N个元素的数组是一维张量;
- 一个有M行N列的矩阵是一个二维张量(MxN);
- 用三个矩阵表示的三个RGB(红,绿,蓝)颜色通道的MxN图像是一个三维张量(3xMxN); 图像张量包含在dataset对象中的字段train_data。让我们来看看代表一个图像张量的形状。
train_dataset.data[0].shape
(32, 32, 3)
说明我们的图片大小为3通道32x32,让我们用matplotlib.plyplot模块查看图片
%matplotlib inline
import matplotlib.pyplot as plt
plt.imshow(train_dataset.data[100])
这似乎是艘船,由于分辨率低(32x32),图片非常模糊