使用pytorch训练图像分类模型需要加载数据集,关于train_set,train_loader的写法介绍如下。

首先参考MNIST数据集的train_set,train_loader的写法。

# 训练集
trainset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch', # 选择数据的根目录
train=True,
download=True, # 不从网络上download图片
transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 测试集
testset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch', # 选择数据的根目录
train=False,
download=True, # 不从网络上download图片
transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)

MNIST直接使用的是torchvision.dataset.MNIST类加载数据集,train_loader的使用都是一样的。所以核心问题是如何书写这个Dataset类。

Dataset的核心是需要返回包含所有数据集的列表,以及每个数据集对应的标签。

我的数据集存放方式为:train_print目录下有10个文件夹,里面有10个类别的数据集。

卷积神经网络-加载数据集_深度学习

比如打开‘01’文件夹如下所示。

卷积神经网络-加载数据集_cnn_02

我采用的思路是,将这10个文件夹所有图片的绝对路径存放到一个dataset-text.txt文件中。下一步遍历该dataset-text.txt文件。这个文件长这样。

卷积神经网络-加载数据集_pytorch_03


很容易能够发现标签存放在倒数第二个位置啦。每张图片的标签就存放在文件绝对路径中,也就是文件夹‘01’的文件名。

因此,有了这个txt文件,我们遍历的时候,读取一行路径,​​将该路径存放到images中,把倒数第二个位置上的标签存放到labels中​​,核心工作就完成了。Dataset类的编写没有固定的,但核心都是不管你通过什么方法,把数据集和对应的标签存放到images,labels中。如果你仔细看到这里还不明白,就私信我。因为说实话这个问题我也困扰了很久,因为我也是入门。

class MyDataset(Dataset):
def __init__(self, dataset_path, num_class, transforms=None):
super(MyDataset,self).__init__()
images = []
labels = []
txt_path = self.dataset2txt(dataset_path,num_class)


with open(txt_path, 'r') as f:
for line in f:
if int(line.split('/')[-2]) > args.num_class:
break
line = line.strip('\n')
images.append(line)
labels.append(int(line.split('/')[-2]))
self.images = images
self.labels = labels
self.transforms = transforms

def __getitem__(self, index):
image = Image.open(self.images[index])
label = self.labels[index]

if self.transforms is not None:
image = self.transforms(image)

return image, label

def __len__(self):
return len(self.labels)


def dataset2txt(self,dataset_path, class_num=None):
'''
transform dataset into a txt file which contain every Image
:param In_path: path of dataset
:param num_class: classes
:return:path of txt file
'''

# 1.创建文件
# 一下两行代码目的是与数据集同级目录下新建dataset-text.txt文件
txt_path = os.path.abspath(os.path.dirname(dataset_path))
txt_path = txt_path + '/dataset-text.txt'
# 删除已经存在的文件,要保证每次操作的文件是一个空的txt文件
if os.path.exists(txt_path):
os.remove(txt_path)

f = open(txt_path, 'w')
f.close()
# 2.写入文件
# 打开数据集,将主目录下所有文件夹放入list中
dirs = os.listdir(dataset_path)
# 将文件夹按从小到大排序,文件夹的名字是按照数字命名的,01,02,03...
dirs.sort()

# 打开第二级每个文件夹,将并将每个文件的绝对路径写入到上面新建的txt文件
for i, dir in enumerate(dirs):
file = os.path.abspath(dataset_path) + '/' + dirs[i]
DIRLIST = os.listdir(file)
for j, d in enumerate(DIRLIST):
content = file + '/' + d + '\n'
# 每次执行前一定确保要写入的文件是空的
with open(txt_path, 'a') as f:
f.write(content)

return txt_path

剩下的就是书写train_set,train_loader.

train_set = MyDataset(dataset_path=args.dataset, num_class=args.num_class, transforms=transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)