数据读取Dataset与Dataloader
- 前言
- 官方通用的数据加载器
- 文件目录存储格式
- 主要函数
- 所有代码
- 代码部分讲解
- 官方通用的数据加载器收获
- 图片数据集(标签在图片名称上)
- 构建自己的Dataset(重要)
- data列表构建
- 总结
- 待续
前言
在pytorch学习这一块总是断断续续,完成大作业所写的代码再次回首已经完全看不懂了。所以我决定把学习过程中遇到的一些问题和知识总结出来,希望能取得一些进步吧。本人完全菜鸟,写这些笔记的主要目的是督促自己坚持学习下去,笔记中可能出现比较夸张的错误,恳请各位大佬谅解。
在数据集读取学习过程中遇到了很多很多很多困难,目前也只是对图片数据集(标签信息在图片名称上)读取稍微明白了一些,关于txt文件、CSV文件等。尤其是mat文件还是不是很明白怎么去处理。希望这篇笔记未来能把处理这些数据集的代码都写出来。
官方通用的数据加载器
以花卉分类为例
文件目录存储格式
主要函数
将图片数据集存储成指定上述存储格式后调用后面的函数即可实现
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
所有代码
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
代码部分讲解
- 创建列表
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
将文件夹目录下的所有文件夹转换成字典,包含文件夹名称和对应的数字标签0,1,2,3,4
flower_list = train_dataset.class_to_idx
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
颠倒key和val
cla_dict = dict((val, key) for key, val in flower_list.items())
# { 0:'daisy', 1:'dandelion', 2:'roses', 3:'sunflower', 4:'tulips'}
剩下的代码就是将其写入json文件
- 使用
json_path = './class_indices.json'
json_file = open(json_path, "r")
class_indict = json.load(json_file)
print("class: {}".format(class_indict[str(i)]))#i=0,1,2,3,4
如i=0时,由{ 0:‘daisy’, 1:‘dandelion’, 2:‘roses’, 3:‘sunflower’, 4:‘tulips’}可知应该输出daisy
print("class: {}".format(class_indict[str(0)]))
官方通用的数据加载器收获
官方官方通用的数据加载器中可以利用class_to_idx获取字典格式最终存储成json文件
def find_classes(path):
classes = [d.name for d in os.scandir(path) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return class_to_idx
使用
image_path = os.path.join(os.getcwd(), "imagedata") # data set path
bear_list = find_classes(image_path)
print(bear_list)
结果
{'ballone': 0, 'ballthree': 1, 'balltwo': 2, 'innerone': 3, 'innerthree': 4, 'innertwo': 5, 'normal': 6, 'outerone': 7, 'outerthree': 8, 'outertwo': 9}
最后见上一节代码部分讲解即可保存json文件,方便之后预测使用。
图片数据集(标签在图片名称上)
这里使用完成大作业的代码进行学习,代码为轴承故障10分类问题,数据集为mat文件经过一系列操作转换为小波时频图jpg文件(转换过程的代码有时间争取我也总结一下)。图片文件存储地址以及图片如下
构建自己的Dataset(重要)
class MyDataset(Dataset):
def __init__(self, data, transform, loder):
self.data = data
self.transform = transform
self.loader = loder
def __getitem__(self, item):
img, label = self.data[item]
img = self.loader(img)
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data)
编程学的不好,术语不太会说,我就挑着代码里我不懂的地方简单说一下。
- data: data是一个列表,格式为 [图片的地址(对应img),图片的标签(对应label)]。data列表具体怎么构建的下面会介绍。
- loder: 读取图片函数,这里调用的自己编写的Myloader函数
def Myloader(path):
return Image.open(path).convert('RGB')
- Transform: 对数据进行预处理
具体使用如下
train = MyDataset(train_data, transform=transform_train, loder=Myloader)
test = MyDataset(test_data, transform=transform_test, loder=Myloader)
- __ getitem__
在DataLoader 送入torch中进行训练时,会自动调用数据集类的__getitem__()方法
train = MyDataset(train_data, transform=transform_train, loder=Myloader)
train_data = DataLoader(dataset=train, batch_size=10, shuffle=True, num_workers=0)
#截取DataLoader中的一段函数 def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,......
Dataset[T_co]从而调用类中的方法__getitem__,从而return img, label(见上面代码),从而得到每张照片(根据T_co的值)的img和label便于后续的训练代码
- getitem返回方式两种都可以(如下)
import torch
from torch.utils.data import Dataset,DataLoader
class MyDataset1(Dataset):
def __init__(self):
self.data = torch.tensor([[1,2,3],[2,3,4],[4,5,6]])
self.lable = torch.LongTensor([1,1,0,0])
def __getitem__(self,index):
data = (self.data[index],self.lable[index])
return data
def __len__(self):
return len(self.data)
class MyDataset2(Dataset):
def __init__(self):
self.data = torch.tensor([[1,2,3],[2,3,4],[4,5,6]])
self.lable = torch.LongTensor([1,1,0,0])
def __getitem__(self,index):
return self.data[index], self.lable[index]
def __len__(self):
return len(self.data)
mydataset1=MyDataset1()
mydataset2=MyDataset2()
mydataloder1 = DataLoader(dataset=mydataset1,batch_size=1)
mydataloder2 = DataLoader(dataset=mydataset2,batch_size=1)
for i,(data,label) in enumerate(mydataloder1):
print(data,label)
for i,(data,label) in enumerate(mydataloder2):
print(data,label)
- 注意!train中for使用的不同
for batch_idx, (data, target) in enumerate(train_loader):
for step, data in enumerate(train_bar):
images, labels = data
data列表构建
我的理解主要是干了这么一个事情,就是要把数据处理一下,创建了一个列表,里面装着图片和他的标签,所以步骤就是从文件中读取图片1转换成想要的格式,再读取图片1的标签,然后打包存进去。data里面就是[图片1,图片1的标签],[图片2,图片2的标签]…然后后续再索引需要的就可以了。现在的目的首先就是如何创建出data列表
# 得到一个包含路径与标签的列表
#标签通过find_label函数获取
#lens代表数据长度,也就是data中有几个数据,比如有2个就是data=[[图片1,图片1标签],[图片2,图片2标签]]
def init_process(path, lens):
data = []
name = find_label(path)
for i in range(lens[0], lens[1]):
data.append([path % i, name])
return data
# 将图片名称中的 字母标签 转换成 0,1,2,3,4,5,6,7,8,9标签
#举例
#这里str输入的是图片路径,如path1 = r'C:\Users\lenovo\Desktop\Modern Signal Processing\xiaobo_CNN\imagedata\normal\normal_%d.jpg'
#可以看出图片的英文标签是normal,经过函数find_label转换,标签变为0
def find_label(str):
first, last = 0, 0
for i in range(len(str) - 1, -1, -1):
if str[i] == '%' and str[i - 1] == '_':
last = i - 1
if (str[i] == 'n' or str[i] == 'b' or str[i] == 'i' or str[i] == 'o') and str[i - 1] == '\\':
first = i
break
name = str[first:last]
#print(name)
if name == 'normal':
return 0
elif name == 'ballone':
return 1
elif name == 'balltwo':
return 2
elif name == 'ballthree':
return 3
elif name == 'innerone':
return 4
elif name == 'innertwo':
return 5
elif name == 'innerthree':
return 6
elif name == 'outerone':
return 7
elif name == 'outertwo':
return 8
elif name == 'outerthree':
return 9
具体实现过程
#data1列表中包含了200张图片以及其对应的标签数据
#1.首先转换成data=[[图片1,图片1标签],[图片2,图片2标签]....]的格式,以data1举例
path1 =
data1 = init_process(path1, [1,200])
#2.拼接所有的data形成训练数据集
train_data = data1[1:140] + data2[1:140]+ data3[1:140]+ data4[1:140]+ data5[1:140]+ data6[1:140]+ data7[1:140]+ data8[1:140]+ data9[1:140]+ data10[1:140]
#3.由此可以创建出data列表,作为Dataset的输入,使用MyDataset函数
train = MyDataset(train_data, transform=transform_train, loder=Myloader)
#4.使用DataLoader函数
train_data = DataLoader(dataset=train, batch_size=10, shuffle=True, num_workers=0)
总结
放上整个代码(路径就不放了,存的乱七八糟)
def Myloader(path):
return Image.open(path).convert('RGB')
def init_process(path, lens):
data = []
name = find_label(path)
for i in range(lens[0], lens[1]):
data.append([path % i, name])
return data
class MyDataset(Dataset):
def __init__(self, data, transform, loder):
self.data = data
self.transform = transform
self.loader = loder
def __getitem__(self, item):
img, label = self.data[item]
img = self.loader(img)
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data)
def find_label(str):
first, last = 0, 0
for i in range(len(str) - 1, -1, -1):
if str[i] == '%' and str[i - 1] == '_':
last = i - 1
if (str[i] == 'n' or str[i] == 'b' or str[i] == 'i' or str[i] == 'o') and str[i - 1] == '\\':
first = i
break
name = str[first:last]
#print(name)
if name == 'normal':
return 0
elif name == 'ballone':
return 1
elif name == 'balltwo':
return 2
elif name == 'ballthree':
return 3
elif name == 'innerone':
return 4
elif name == 'innertwo':
return 5
elif name == 'innerthree':
return 6
elif name == 'outerone':
return 7
elif name == 'outertwo':
return 8
elif name == 'outerthree':
return 9
def load_data():
transform_train = transforms.Compose([
#transforms.RandomResizedCrop(224),#对图片尺寸做一个缩放切割
#transforms.RandomHorizontalFlip(),#图像一半的概率翻转,一半的概率不翻转
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
path1 =
data1 = init_process(path1, [1,200])
path2 =
data2 = init_process(path2, [1, 200])
path3 =
data3 = init_process(path3, [1, 200])
path4 =
data4 = init_process(path4, [1, 200])
path5 =
data5 = init_process(path5, [1, 200])
path6 =
data6 = init_process(path6, [1, 200])
path7 =
data7 = init_process(path7, [1, 200])
path8 =
data8 = init_process(path8, [1, 200])
path9 =
data9 = init_process(path9, [1, 200])
path10 =
data10 = init_process(path10, [1, 200])
# 800个训练
train_data = data1[1:140] + data2[1:140]+ data3[1:140]+ data4[1:140]+ data5[1:140]+ data6[1:140]+ data7[1:140]+ data8[1:140]+ data9[1:140]+ data10[1:140]
train = MyDataset(train_data, transform=transform_train, loder=Myloader)
# 200个测试
test_data = data1[141:180] + data2[161:180]+ data3[161:180]+ data4[161:180]+ data5[161:180]+ data6[161:180]+ data7[161:180]+ data8[161:180]+ data9[161:180]+ data10[161:180]
test = MyDataset(test_data, transform=transform_test, loder=Myloader)
train_data = DataLoader(dataset=train, batch_size=10, shuffle=True, num_workers=0)
test_data = DataLoader(dataset=test, batch_size=1, shuffle=True, num_workers=0)
return train_data, test_data
待续
mat文件 txt文件 CSV文件