本节主要介绍VoxelNet的训练主体部分,其余部分请参考这里 其实目前大多数3D目标检测算法的网络结构和数据处理,特别是基于KITTI的三维目标检测,都可以参考这些处理方式,具有一定的通用性。

1 代码结构

  1. 加载预训练模型,没有的话使用cycleganpytorch代码讲解 pytorch alexnet代码_深度学习方式初始网络参数
  2. 各种超参数定义引入
  3. 构建TensorBoard,方便模型训练过程的可视化
  4. 搭建VoxelNet
  5. 设定好cuda和优化器
  6. 开始训练
pre_model = args.ckpt # 加载预训练模型
    # 各种超参数定义引入
    cfg.pos_threshold = hyper['pos']
    cfg.neg_threshold = hyper['neg']
    model_name = "model_%d"%(args.index+1)
	# 构建TensorBoard,方便模型训练过程的可视化
    writer = SummaryWriter('runs/%s'%(model_name[:-4]))
	# 搭建VoxelNet
    net = VoxelNet()
    # 设定好cuda和优化器
    net.to(cfg.device)
    optimizer = optim.SGD(net.parameters(), lr=hyper['lr'], momentum = hyper['momentum'], weight_decay=hyper['weight_decay'])

    if pre_model is not None and os.path.exists(os.path.join('./model',pre_model)) :
        ckpt = torch.load(os.path.join('./model',pre_model), map_location=cfg.device)
        net.load_state_dict(ckpt['model_state_dict'])
        cfg.last_epoch = ckpt['epoch']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    else :
    	# 使用Xavier方式初始网络参数
        net.apply(weights_init)     
    train(net, model_name, hyper, cfg, writer, optimizer)
    writer.close()

2 具体训练过程

  1. 加载KITTI数据集
  2. 网络设为cycleganpytorch代码讲解 pytorch alexnet代码_pytorch_02模式,可以不断更新权重
  3. 构建cycleganpytorch代码讲解 pytorch alexnet代码_python_03
  4. 记录cycleganpytorch代码讲解 pytorch alexnet代码_3d_04网络预测分数cycleganpytorch代码讲解 pytorch alexnet代码_cycleganpytorch代码讲解_05和3D检测框偏移cycleganpytorch代码讲解 pytorch alexnet代码_cycleganpytorch代码讲解_05总体cycleganpytorch代码讲解 pytorch alexnet代码_cycleganpytorch代码讲解_05
  5. 保存每个Batch的正负预测框的信息
  6. 反向传播,不断迭代
def train(net, model_name, hyper, cfg, writer, optimizer):
	# 加载KITTI数据集
    dataset=KittiDataset(cfg=cfg,root='/media/jilinlee//KITTI/ponitnet_data/KITTI',set='train')
    data_loader = data.DataLoader(dataset, batch_size=cfg.N, num_workers=4, shuffle=True,pin_memory=False)
    # 网络设为$train$模式,可以不断更新权重
    net.train()
    # 构建VoxelLoss
    criterion = VoxelLoss(alpha=hyper['alpha'], beta=hyper['beta'], gamma=hyper['gamma'])
    # 记录RPN网络预测分数loss和3D检测框偏移loss和总体loss
    running_loss = 0.0
    running_reg_loss = 0.0
    running_conf_loss = 0.0
    epoch_size = len(dataset) // cfg.N
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[round(args.epoch*x) for x in (0.7, 0.9)], gamma=0.1)
    scheduler.last_epoch = cfg.last_epoch + 1
    optimizer.zero_grad()
    epoch = cfg.last_epoch
    while epoch < args.epoch :
        iteration = 0
        for voxel_features, voxel_coords, gt_box3d_corner, gt_box3d, images, calibs, ids in data_loader:
        	#体素特征 
            voxel_features = torch.tensor(voxel_features).to(cfg.device)
            # 正样本框
            pos_equal_one = []
            # 负样本框
            neg_equal_one = []
            # 正样本框相对于gt的偏移量
            targets = []
            with torch.no_grad():
                for i in range(len(gt_box3d)):
                    pos_equal_one_, neg_equal_one_, targets_ = dataset.cal_target(gt_box3d_corner[i], gt_box3d[i], cfg)
                    pos_equal_one.append(pos_equal_one_)
                    neg_equal_one.append(neg_equal_one_)
                    targets.append(targets_)
            pos_equal_one = torch.stack(pos_equal_one, dim=0)
            neg_equal_one = torch.stack(neg_equal_one, dim=0)
            targets = torch.stack(targets, dim=0)
            # 体素特征和体素的对应网格坐标一起送入网络
            score, reg = net(voxel_features, voxel_coords)
            conf_loss, reg_loss, _, _, _ = criterion(reg, score, pos_equal_one, neg_equal_one, targets)
            loss = hyper['lambda'] * conf_loss + reg_loss
            running_conf_loss += conf_loss.item()
            running_reg_loss += reg_loss.item()
            running_loss += (reg_loss.item() + conf_loss.item())
            # 反向传播,不断迭代
            loss.backward()
        scheduler.step()
        epoch += 1

其实不太好理解的部分是这里

  1. cycleganpytorch代码讲解 pytorch alexnet代码_python_08点云数据转为体素数据,这个是在cycleganpytorch代码讲解 pytorch alexnet代码_pytorch_09数据集导入部分完成的,所以这里直接获得体素特征
  2. 体素特征有7个维度cycleganpytorch代码讲解 pytorch alexnet代码_深度学习_10
  3. 而正负样本是根据点云的鸟瞰图中cycleganpytorch代码讲解 pytorch alexnet代码_cycleganpytorch代码讲解_11cycleganpytorch代码讲解 pytorch alexnet代码_3d_12cycleganpytorch代码讲解 pytorch alexnet代码_深度学习_13确定的,有个阈值来划分,因为作者假设在鸟瞰图上车辆的高度和整个cycleganpytorch代码讲解 pytorch alexnet代码_pytorch_14的扫描高度相比可以忽略,这样就成了俯视图。这个正负样本在后续的cycleganpytorch代码讲解 pytorch alexnet代码_3d_04网络部分会用到,尤其是计算cycleganpytorch代码讲解 pytorch alexnet代码_3d_16
  4. 然后就是体素坐标cycleganpytorch代码讲解 pytorch alexnet代码_深度学习_17的问题,虽然由于整个点云空间被划分为体素空间,但是提取体素特征的VoxelNet却是通过全连接层一个个的处理体素中的点云信息,这样丢失了点云对应的位置信息,所以需要加入划分时的体素坐标帮助定位。
for voxel_features, voxel_coords, gt_box3d_corner, gt_box3d, images, calibs, ids in data_loader:
        	#体素特征 
            voxel_features = torch.tensor(voxel_features).to(cfg.device)
            # 正样本框
            pos_equal_one = []
            # 负样本框
            neg_equal_one = []
            # 正样本框相对于gt的偏移量
            targets = []
            with torch.no_grad():
                for i in range(len(gt_box3d)):
                    pos_equal_one_, neg_equal_one_, targets_ = dataset.cal_target(gt_box3d_corner[i], gt_box3d[i], cfg)
                    pos_equal_one.append(pos_equal_one_)
                    neg_equal_one.append(neg_equal_one_)
                    targets.append(targets_)
            pos_equal_one = torch.stack(pos_equal_one, dim=0)
            neg_equal_one = torch.stack(neg_equal_one, dim=0)
            targets = torch.stack(targets, dim=0)
            # 体素特征和体素的对应网格坐标一起送入网络
            score, reg = net(voxel_features, voxel_coords)
            conf_loss, reg_loss, _, _, _ = criterion(reg, score, pos_equal_one, neg_equal_one, targets)
            loss = hyper['lambda'] * conf_loss + reg_loss

cycleganpytorch代码讲解 pytorch alexnet代码_深度学习_18


如图,粉色框体就是个体素,这里面包含了一定数量的点云,并将每个点云处理成了更加高维、可以被3D卷积处理的特征向量cycleganpytorch代码讲解 pytorch alexnet代码_3d_19,然后经过全连接网络,代码在下面,提取出点云特征,再经过最大值筛选(不是cycleganpytorch代码讲解 pytorch alexnet代码_pytorch_20这个方法,因为不是图像这类二维特征图)只保留点云可以表征cycleganpytorch代码讲解 pytorch alexnet代码_3d_21信息的数据,再和点云特征拼接起来生成一个既有点云高维特征又有低维形态特征的向量。

所以,需要加入体素位置信息,帮助后续网络可以对cycleganpytorch代码讲解 pytorch alexnet代码_3d_22进行具体定位。

class FCN(nn.Module):

    def __init__(self,cin,cout): # 网络模块的参数,在init这里传播
        super(FCN, self).__init__()
        self.linear = nn.Linear(cin,cout)
        self.BN = nn.BatchNorm2d(cout)
        self.ReLU= nn.ReLU(inplace=True)

    def forward(self,x): # 网络要处理的特征图,在forWord这里
        batch, voxelnums,_ = x.shape
        x = self.linear(x.view(batch*voxelnums,-1))
        x = self.BN(x)
        x = self.ReLU(x)
        return x.view(batch,voxelnums,-1)