深度学习好比炼丹,框架就是丹炉,网络结构及算法就是单方,而数据集则是原材料。现在世面上很多炼丹手册都是针对单一数据集进行炼丹,有了这些手册我们就能够很容易进行炼丹,但为了练好丹,我们常常收集各种公开的数据集,并构建私有数据集,此时,便会遇到如何更好的使用多个数据进行练丹的问题。
本文将使用pytorch这个丹炉,介绍如何联合读取多个原材料,而不是从新制作原材料和标签。
1、Pytorch的ConcatDataset介绍
class ConcatDataset(Dataset):
"""
Dataset to concatenate multiple datasets.
Purpose: useful to assemble different existing datasets, possibly
large-scale datasets as the concatenation operation is done in an
on-the-fly manner.
Arguments:
datasets (sequence): List of datasets to be concatenated
"""
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
首先,ConcatDataset继承自Dataset类。
其次,ConcatDataset的构造函数要求一个列表L作为输入,其包含若干个数据集的。构造函数会计算出一个cumulative size列表,里面存放了”把L中的第n个数据集加上后一共有多少个样本“的序列。
然后,ConcatDataset重写了__len__方法,返回cumulative_size[-1],也就是若干个数据集的总样本数;
最后,重写了__getitem__,当给定索引 idx 的时候,会计算出该idx对应那个数据集及在这个数据集中的位置,这样就可以访问这个数据了。
2、多个数据集联合读取示例
假设我们需要读取MNIST、CIFAR10和CIFAR100三个数据集。
首先,这三个数据集在torchvision中已经实现,调用方式如下:
mnist_data = MNIST('./data', train=True, download=True)
cifar10_data = CIFAR10('./data', train=True, download=True)
cifar100_data = CIFAR100('./data', train=True, download=True)
如果是其他数据集也要先实现读取;
其次,定义一个数据种类和其访问接口的字典:
_DATASETS = {
'MNIST': MNIST,
'CIFAR10': CIFAR10,
'CIFAR100': CIFAR100,
}
然后,定义一个数据信息类,存放数据地址等信息:
class DatasetCatalog:
DATASETS = {
'MNIST': {
"root": "./data",
},
'CIFAR10': {
"root": "./data",
},
'CIFAR100': {
"root": "./data",
}
}
@staticmethod
def get(name):
if "MNIST" in name:
attrs = DatasetCatalog.DATASETS[name]
args = dict(
root=attrs["root"],
)
return dict(factory="MNIST", args=args)
elif "CIFAR10" in name:
attrs = DatasetCatalog.DATASETS[name]
args = dict(
root=attrs["root"],
)
return dict(factory="CIFAR10", args=args)
elif "CIFAR100" in name:
attrs = DatasetCatalog.DATASETS[name]
args = dict(
root=attrs["root"],
)
return dict(factory="CIFAR100", args=args)
raise RuntimeError("Dataset not available: {}".format(name))
最后,定义一个制作数据集的函数,通过dataset_list指定需要加载的数据集名称,对于train模式,会返回合并后的数据集,对于val模式,返回各自的val数据集
def make_dataset(dataset_list, train=True, transform=None, target_transform=None, download=True):
assert len(dataset_list) > 0
data_sets = []
for dataset_name in dataset_list:
catalog = DatasetCatalog.get(dataset_name)
args = catalog['args']
factory = _DATASETS[catalog['factory']]
args['train'] = train
args['transform'] = transform
args['target_transform'] = target_transform
args['download'] = download
if factory == MNIST:
data_set = factory(**args)
elif factory == CIFAR10:
data_set = factory(**args)
elif factory == CIFAR100:
data_set = factory(**args)
data_sets.append(data_set)
if not train:
return data_sets
data_set = data_sets[0]
if len(data_sets) > 1:
data_set = ConcatDataset(data_sets)
return data_set
具体看一下如何调用吧:
if __name__ == "__main__":
dataset_list = ["MNIST", "CIFAR10", "CIFAR100"]
concat_data = make_dataset(dataset_list, train=True, download=True)
for i, (data, target) in enumerate(concat_data):
print(np.array(data).shape)
print(target)
获取了concat_data后,就可以通过dataloader来定义loader了。
对于其他数据集或私有数据集,可以改一改,能够实现任何想要的输出。
以上就是本文所有的内容,结合调试代码能够更快的学习哦。