简而言之,这个参数就是用来设定dataloader最后输出的batch内容;dataloader一次性从dataset得到batch大小的数据,但这些数据本身是分散的,拿图片举例,比如我们batch为8,则我们得到的是8个[3,256,256](256为图片形状,随便设置的)大小的张量,通过collate_fn这个参数转化为形状为[8,3,256,256]的张量作为dataloader的输出。
一般情况下,这个参数是不用设置的,那什么是不一般的情况呢,比如数据长度不同的时候,最明显的拿NLP里面的句子长度举例,每个batch里面的句子长度不一样,如果使用默认的collate_fn方法,就可能报错,这时候就需要自定义collate_fn参数。
而重写这个参数也很简单,就是自定义一个函数,假设我们这里给这个函数取名字叫做my_collate_fn(batch),注意,只能有一个输入变量batch,batch包含了dataset里面getitem返回的所有值。
拿vqa任务来举例子,比如dataset的getitem每次返回一张图片的数据data,label,以及相应的question,answer,如果我们设定dataloader一次性获得8个大小的batch,则此时传入my_collate_fn的变量为一个list,这个list包含8个[data, label, question, answer],如果要使用默认的collate_fn,则要求,这8个[data, label, question, answer]里面,每一个变量的形状都是相同的,就是说8个data的形状相同,8个label的形状相同,8个question的形状,8个answer的形状也相同,这样才可以用默认的collate_fn参数,否则会报错。如果不满足的话就需要自定义collate_fn参数
举例如下,此时自定义的vqa_collate_fn的输入是list为8的变量,我们单独看其中的每个变量的第三个元素,可以发现,第0和第1个变量的第三个元素长度是不相同的,这时候如果使用默认的collate_fn就会报错,而在这里,我们自定义的函数里面,我们知识把每个变量的第三个元素放在一个list里面直接返回:
补充:collate_fn = operator.itemgetter(0), 表示每个样本独立返回,不用合并成一个tensor