pytorch中如何加载数据
一、pytorch数据主要涉及两个类:dataset和dataloader
1、dataset提供一种方式去获取数据及其label:
1)如何获取每一个数据及其label
2)告诉我们总共有多少的数据(神经网络是经常要对一个数据迭代多次,当我们知道有多少数据需要训练的时候,他才知道我们要训练多少次才能把数据迭代完,进行下一次迭代)
2、dataloader为后面的网络提供不同的数据形式
二、Dataset代码实践
1、首先下载一个数据集,链接:https://pan.baidu.com/s/1r1FwiaX1ohmU-PuZunJh1g 提取码:881f,将该数据集放在项目文件夹下面。
这个数据集分为了train和val,分别是训练数据集和验证数据集,训练数据集里面有ants和bees两个文件夹,里面都是蚂蚁和蜜蜂的图片。
2、在pycharm中新建一个python文件
首先在控制台测一下函数,看能否正常打开数据集中的图片:
三、在python控制台一条一条语句测试,同时观察右边变量的变化
下面介绍几个常用函数:
1、listdir(文件路径),括号中的变量为文件路径,该函数的作用是将括号中路径下的文件夹下的文件形成一个列表
2、path.join(路径1,路径2),作用是将两个路径变量合并起来。
具体可以在python控制台一个个代码敲进去,然后观察右边变量的变化,体会每个函数的作用
四、加载数据例程
from torch.utils.data import Dataset # 从torch里面的常用工具区utils关于数据的data区导入Dataset工具
from PIL import Image
import os
class MyData(Dataset):# 创建一个类
def init(self, root_dir, label_dir): # 初始化
self.root_dir = root_dir # 数据地址
self.label_dir = label_dir # 标签地址
self.path = os.path.join(self.root_dir, self.label_dir)# 想获得这个图片的地址,将两个地址合并
self.img_path = os.listdir(self.path)# 把所有图片的地址生成一个列表
def getitem(self, idx): # idx作为一个索引去获取图片的地址
img_name = self.img_path[idx] # 索引第idx张图片
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def len(self):
return len(self.img_path)
root_dir = “dataset/train” # 数据集的相对地址
ants_label_dir = “ants_img”
bees_label_dir = “bees_img”
ants_dataset = MyData(root_dir, ants_label_dir) #实例化一个蚂蚁类
bees_dataset = MyData(root_dir, bees_label_dir) # 实例化一个蜜蜂类