def foward_student_train(self, sup_data, unsup_data):
'''forward student
'''
# 合并输入data
student_data, sup_data_length = \
self.combine_student_data(sup_data, unsup_data)
# 前向传播
img_feats, _ = self.student.extract_feat(
points=None,
img=student_data['img_inputs'],
img_metas=student_data['img_metas']) # bev_feats, None
student_info = self.student.pts_bbox_head(img_feats)
# 分开预测结果datat
sup_info, unsup_info = self.split_student_data(
student_info, sup_data_length)
# 计算有监督部分的loss
loss_inputs = [sup_data['gt_bboxes_3d'],
sup_data['gt_labels_3d'],
sup_info]
sup_loss = self.student.pts_bbox_head.loss(*loss_inputs)
return sup_loss, unsup_info
# 合并输入data
def combine_student_data(self, sup_data, unsup_data):
'''combine sup and unsup data for student model
'''
assert isinstance(sup_data, dict) and \
isinstance(unsup_data, dict)
new_student_data = deepcopy(sup_data)
keys = sup_data.keys()
for key in keys:
if key == 'img_inputs':
new_student_data[key] = self.combine_imgs(
sup_data[key], unsup_data[key])
else:
new_student_data[key] = sup_data[key] + \
unsup_data[key]
return new_student_data, (len(sup_data['img_metas']),
len(unsup_data['img_metas']))
# 分开预测结果data
def split_student_data(self, student_data, student_data_length):
sup_data, unsup_data = [], []
for idx, data in enumerate(student_data): # len(list) = 6
sup_data.append([{}])
unsup_data.append([{}])
# 默认len(list) = 1
for key in data[0].keys():
sup_data[idx][0][key] = data[0][key][:student_data_length[0], ...].clone()
# detach upservised data gradient
unsup_data[idx][0][key] = data[0][key][student_data_length[0]:, ...].clone()
return tuple(sup_data), tuple(unsup_data)