pytorch中如何加载数据
一、pytorch数据主要涉及两个类:dataset和dataloader
1、dataset提供一种方式去获取数据及其label:
1)如何获取每一个数据及其label
2)告诉我们总共有多少的数据(神经网络是经常要对一个数据迭代多次,当我们知道有多少数据需要训练的时候,他才知道我们要训练多少次才能把数据迭代完,进行下一次迭代)
2、dataloader为后面的网络提供不同的数据形式

二、Dataset代码实践

1、首先下载一个数据集,链接:https://pan.baidu.com/s/1r1FwiaX1ohmU-PuZunJh1g 提取码:881f,将该数据集放在项目文件夹下面。

这个数据集分为了train和val,分别是训练数据集和验证数据集,训练数据集里面有ants和bees两个文件夹,里面都是蚂蚁和蜜蜂的图片。

2、在pycharm中新建一个python文件

首先在控制台测一下函数,看能否正常打开数据集中的图片:

pytorch文档pdf pytorch教程pdf_pytorch


三、在python控制台一条一条语句测试,同时观察右边变量的变化

pytorch文档pdf pytorch教程pdf_数据_02


下面介绍几个常用函数:

1、listdir(文件路径),括号中的变量为文件路径,该函数的作用是将括号中路径下的文件夹下的文件形成一个列表

pytorch文档pdf pytorch教程pdf_pytorch_03


2、path.join(路径1,路径2),作用是将两个路径变量合并起来。

pytorch文档pdf pytorch教程pdf_数据集_04


具体可以在python控制台一个个代码敲进去,然后观察右边变量的变化,体会每个函数的作用

pytorch文档pdf pytorch教程pdf_深度学习_05


四、加载数据例程

from torch.utils.data import Dataset # 从torch里面的常用工具区utils关于数据的data区导入Dataset工具
 from PIL import Image
 import os
 class MyData(Dataset):# 创建一个类
 def init(self, root_dir, label_dir): # 初始化
 self.root_dir = root_dir # 数据地址
 self.label_dir = label_dir # 标签地址
 self.path = os.path.join(self.root_dir, self.label_dir)# 想获得这个图片的地址,将两个地址合并
 self.img_path = os.listdir(self.path)# 把所有图片的地址生成一个列表
 def getitem(self, idx): # idx作为一个索引去获取图片的地址
 img_name = self.img_path[idx] # 索引第idx张图片
 img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
 img = Image.open(img_item_path)
 label = self.label_dir
 return img, label
 def len(self):
 return len(self.img_path)
root_dir = “dataset/train” # 数据集的相对地址
 ants_label_dir = “ants_img”
 bees_label_dir = “bees_img”
 ants_dataset = MyData(root_dir, ants_label_dir) #实例化一个蚂蚁类
 bees_dataset = MyData(root_dir, bees_label_dir) # 实例化一个蜜蜂类