文章目录
运行环境安装 Anaconda | python ==3.6.6
一、torchvision 图像数据读取 [0, 1]
import torchvision.transforms as transforms
transforms 模块提供了一般的图像转换操作类。
class torchvision.transforms.ToTensor
功能:
把shape=(H x W x C) 的像素值为 [0, 255] 的 PIL.Image 和 numpy.ndarray
转换成shape=(C x H x W)的像素值范围为[0.0, 1.0]
的 torch.FloatTensor。
class torchvision.transforms.Normalize(mean, std)
功能:
此转换类作用于torch.*Tensor。给定均值(R, G, B)和标准差(R, G, B),用公式channel = (channel - mean) / std进行规范化。
transforms.Compose 归一化到 [-1.0, 1.0 ]
二、torchvision 的 Transform
在深度学习时关于图像的数据读取:由于Tensorflow不支持与numpy的无缝切换,导致难以使用现成的pandas等格式化数据读取工具,造成了很多不必要的麻烦,而pytorch解决了这个问题。
pytorch自定义读取数据和进行Transform的部分请见文档:
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
但是按照文档中所描述所完成的自定义Dataset只能够使用自定义的Transform步骤,而torchvision包中已经给我们提供了很多图像transform步骤的实现,为了使用这些已经实现的Transform步骤,我们可以使用如下方法定义Dataset:
三、读取图像数据类
3.1 class torchvision.datasets.ImageFolder 默认读取图像数据方法:
-
__init__
( 初始化)
-
classes, class_to_idx = find_classes(root)
:得到分类的类别名(classes)和类别名与数字类别的映射关系字典(class_to_idx)
其中 classes (list): List of the class names.
其中 class_to_idx (dict): Dict with items (class_name, class_index). -
imgs = make_dataset(root, class_to_idx)
得到imags列表。
其中 imgs (list): List of (image path, class_index) tuples
每个值是一个tuple,每个tuple包含两个元素:图像路径和标签
-
__getitem__
(图像获取)
-
path, target = self.imgs[index]
获取图像(路径,标签) -
img = self.loader(path)
数据读取。 -
img = self.transform(img)
数据、标签 转换成 tensor -
target = self.target_transform(target)
-
__len__
( 数据集数量)
-
return len(self.imgs)
图像获取 __getitem__
3.2 自定义数据读取方法
PyTorch中和数据读取相关的类都要继承一个基类:torch.utils.data.Dataset。
故需要改写其中的 __init__、__len__、__getitem__
-
__init__()
初始化传入参数:
- img_path 里面为所有图像数据(包括训练和测试)
txt_path 里面有 train.txt和val.txt两个文件:txt文件中每行都是图像路径,tab键,标签。 - 其中 self.img_name 和 self.img_label 的读取方式就跟你数据的存放方式有关(需要调整的地方)
-
__getitem__()
依然采用default_loader方法来读取图像。 -
Transform
中将每张图像都封装成 Tensor
https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/#torchutilsdata 鸣谢