大家了解了数据集格式后,接下来我会给大家介绍ResNeXt的数据预处理工作是怎么进行的,我在代码部分的关键部分都做了详细的注释,大家一定要看代码。
from torchvision import transforms, datasets
import os
import torch
from PIL import Image
import scipy.io as scio
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
def ImageNetData(args):
# data_transform, pay attention that the input of Normalize() is Tensor and the input of RandomResizedCrop() or RandomHorizontalFlip() is PIL Image
data_transforms = {
'train': transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {}
#image_datasets['train'] = datasets.ImageFolder(os.path.join(args.data_dir, 'ILSVRC2012_img_train'), data_transforms['train'])
#参数解释: 训练集图片路径,文件夹与类别名的映射文件,设置对图片进行的处理
image_datasets['train'] = ImageNetTrainDataSet(os.path.join(args.data_dir, 'ILSVRC2012_img_train'),
os.path.join(args.data_dir, 'ILSVRC2012_devkit_t12', 'data', 'meta.mat'),
data_transforms['train'])
#参数解释: 验证集图片路径,图片与类别的映射文件, 设置对图片进行的处理
image_datasets['val'] = ImageNetValDataSet(os.path.join(args.data_dir, 'ILSVRC2012_img_val'),
os.path.join(args.data_dir, 'ILSVRC2012_devkit_t12', 'data','ILSVRC2012_validation_ground_truth.txt'),
data_transforms['val'])
# wrap your data and label into Tensor
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} #返回一个字典!
return dataloders, dataset_sizes
class ImageNetTrainDataSet(torch.utils.data.Dataset):
def __init__(self, root_dir, img_label, data_transforms):
label_array = scio.loadmat(img_label)['synsets']#读取映射文件中的synsets部分
label_dic = {}
for i in range(1000):
label_dic[label_array[i][0][1][0]] = i#label_array[i][0][1][0]:图像文件夹编号(相当于读入1000个文件夹),和对应的类别,因为共1000个类别
self.img_path = os.listdir(root_dir)#遍历训练集的文件夹(类别)数
self.data_transforms = data_transforms
self.label_dic = label_dic #文件夹和对应的类别组成的字典
self.root_dir = root_dir
self.imgs = self._make_dataset()#这里要用self.label_dict
def __len__(self):
return len(self.imgs)
def __getitem__(self, item): #Python的魔法方法__getitem__ 可以让对象实现迭代功能
data, label = self.imgs[item]
img = Image.open(data).convert('RGB')
if self.data_transforms is not None:
try:
img = self.data_transforms(img)
except:
print("Cannot transform image: {}".format(self.img_path[item]))
return img, label
def _make_dataset(self):
class_to_idx = self.label_dic# 文件夹和类别所对应的的类别
images = []
dir = os.path.expanduser(self.root_dir)
for target in sorted(os.listdir(dir)):#target是每一类图像文件夹的名称
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):#fnames 是 该类别文件夹下的所有图片
for fname in sorted(fnames):
if self._is_image_file(fname):
path = os.path.join(root, fname)#每一张图片的路径
item = (path, class_to_idx[target])#每一张图片的路径和它所对应的类别
images.append(item)#加入images
return images
def _is_image_file(self, filename):
"""Checks if a file is an image.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
class ImageNetValDataSet(torch.utils.data.Dataset):
def __init__(self, img_path, img_label, data_transforms):
self.data_transforms = data_transforms
img_names = os.listdir(img_path)#获取验证集中所有图片的名称组成img_names(list类型)
img_names.sort()#对list类型的数据进行排序
self.img_path = [os.path.join(img_path, img_name) for img_name in img_names]
with open(img_label,"r") as input_file:
lines = input_file.readlines()
self.img_label = [(int(line)-1) for line in lines] #获取label,[1,val_lengths]
def __len__(self):
return len(self.img_path)
def __getitem__(self, item): #Python的魔法方法__getitem__ 可以让对象实现迭代功能
img = Image.open(self.img_path[item]).convert('RGB')
label = self.img_label[item]
if self.data_transforms is not None:
try:
img = self.data_transforms(img)
except:
print("Cannot transform image: {}".format(self.img_path[item]))
return img, label #返回一个tuple数据类型。
这里大家需要了解的是python中的 __getitem__方法的用法。
另外大家疑惑最多的应该是这部分:
label_array = scio.loadmat(img_label)['synsets']#读取映射文件中的synsets部分,这里保存的最重要的信息就是类别和文件夹的对应关系
label_dic = {}
for i in range(1000):
label_dic[label_array[i][0][1][0]] = i#label_array[i][0][1][0]:图像文件夹编号(相当于读入1000个文件夹),和对应的类别,因为共1000个类别
其实,这是由.mat文件中的数据类型所决定的,因为 scio.loadmat(img_label) 读出来的是字典型数据,因此我们需要得到 'synsets' 所对应的内容。为了方便大家理解们这里大家们可以将label_array打印出来,查看他的属性(剧透:尺寸是[1860,1]),对应的代码:
import scipy.io as scio
path = './/ImageNet//ILSVRC2012_devkit_t12//data//meta.mat'
result = scio.loadmat(path)
print(type(result))
for i in range(2000):
print(i)
print(result['synsets'][i][0][1][0])#是为了获取图片文件夹编号,典型的对数据做切片
到这里,应该就没有问题了。