深度学习好比炼丹,框架就是丹炉,网络结构及算法就是单方,而数据集则是原材料。现在世面上很多炼丹手册都是针对单一数据集进行炼丹,有了这些手册我们就能够很容易进行炼丹,但为了练好丹,我们常常收集各种公开的数据集,并构建私有数据集,此时,便会遇到如何更好的使用多个数据进行练丹的问题。

本文将使用pytorch这个丹炉,介绍如何联合读取多个原材料,而不是从新制作原材料和标签。

1、Pytorch的ConcatDataset介绍



class ConcatDataset(Dataset):
    """
    Dataset to concatenate multiple datasets.
    Purpose: useful to assemble different existing datasets, possibly
    large-scale datasets as the concatenation operation is done in an
    on-the-fly manner.

    Arguments:
        datasets (sequence): List of datasets to be concatenated
    """

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets):
        super(ConcatDataset, self).__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        self.datasets = list(datasets)
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes



首先,ConcatDataset继承自Dataset类。

其次,ConcatDataset的构造函数要求一个列表L作为输入,其包含若干个数据集的。构造函数会计算出一个cumulative size列表,里面存放了”把L中的第n个数据集加上后一共有多少个样本“的序列。

然后,ConcatDataset重写了__len__方法,返回cumulative_size[-1],也就是若干个数据集的总样本数;

最后,重写了__getitem__,当给定索引 idx 的时候,会计算出该idx对应那个数据集及在这个数据集中的位置,这样就可以访问这个数据了。


2、多个数据集联合读取示例

假设我们需要读取MNIST、CIFAR10和CIFAR100三个数据集。

首先,这三个数据集在torchvision中已经实现,调用方式如下:



mnist_data = MNIST('./data', train=True, download=True)
cifar10_data = CIFAR10('./data', train=True, download=True)
cifar100_data = CIFAR100('./data', train=True, download=True)



如果是其他数据集也要先实现读取;

其次,定义一个数据种类和其访问接口的字典:



_DATASETS = {
    'MNIST': MNIST,
    'CIFAR10': CIFAR10,
    'CIFAR100': CIFAR100,
}



然后,定义一个数据信息类,存放数据地址等信息:



class DatasetCatalog:
    DATASETS = {
        'MNIST': {
            "root": "./data",
        },
        'CIFAR10': {
            "root": "./data",
        },
        'CIFAR100': {
            "root": "./data",
        }
    }

    @staticmethod
    def get(name):
        if "MNIST" in name:
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                root=attrs["root"],
            )
            return dict(factory="MNIST", args=args)
        elif "CIFAR10" in name:
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                root=attrs["root"],
            )
            return dict(factory="CIFAR10", args=args)
        elif "CIFAR100" in name:
            attrs = DatasetCatalog.DATASETS[name]
            args = dict(
                root=attrs["root"],
            )
            return dict(factory="CIFAR100", args=args)

        raise RuntimeError("Dataset not available: {}".format(name))



最后,定义一个制作数据集的函数,通过dataset_list指定需要加载的数据集名称,对于train模式,会返回合并后的数据集,对于val模式,返回各自的val数据集



def make_dataset(dataset_list, train=True, transform=None, target_transform=None, download=True):
    assert len(dataset_list) > 0
    data_sets = []
    for dataset_name in dataset_list:
        catalog = DatasetCatalog.get(dataset_name)
        args = catalog['args']
        factory = _DATASETS[catalog['factory']]
        args['train'] = train
        args['transform'] = transform
        args['target_transform'] = target_transform
        args['download'] = download
        if factory == MNIST:
            data_set = factory(**args)
        elif factory == CIFAR10:
            data_set = factory(**args)
        elif factory == CIFAR100:
            data_set = factory(**args)
        data_sets.append(data_set)

    if not train:
        return data_sets
    data_set = data_sets[0]
    if len(data_sets) > 1:
        data_set = ConcatDataset(data_sets)
    return data_set



具体看一下如何调用吧:



if __name__ == "__main__":
    dataset_list = ["MNIST", "CIFAR10", "CIFAR100"]
    concat_data = make_dataset(dataset_list, train=True, download=True)

    for i, (data, target) in enumerate(concat_data):
        print(np.array(data).shape)
        print(target)



获取了concat_data后,就可以通过dataloader来定义loader了。

对于其他数据集或私有数据集,可以改一改,能够实现任何想要的输出。

以上就是本文所有的内容,结合调试代码能够更快的学习哦。