在上一篇博客里分析了Featdepth论文原理和核心源码,也就是模型部分,包括网络结构和损失函数计算:
苹果姐:单目深度估计自监督模型Featdepth解读(上)——论文理解和核心源码分析
本篇博客将介绍Featdepth使用的框架–openMMLab的使用以及作者进行的一些修改和扩展。
Featdepth的源码结构和monodepth2有很大的不同。后者完全是定制化的代码,很适合pytorch入门,前者是使用了商汤的计算机视觉框架OpenMMLab中的基础库mmcv,完全按照mmcv模板写的,在数据读取部分还借鉴了mmdetection的代码,是OpenMMLab中的目标检测库,可以说如果想看懂Featdepth源码结构,必须先学习一下mmcv框架,了解其核心组件Register/Config/Hook/Runner等功能和用法,最好也看看源码。
mmcv工程地址:GitHub - open-mmlab/mmcv: OpenMMLab Computer Vision Foundation
官方文档:Welcome to MMCV’s documentation!
关于OpenMMLab知乎和B站都有博客和视频,我在此只针对Featdepth用到的简要介绍一下。
模型训练部分的代码很短,如下所示:
from __future__ import division
import argparse
from mmcv import Config
from mmcv.runner import load_checkpoint
from mono.datasets.get_dataset import get_dataset
from mono.apis import (train_mono,
init_dist,
get_root_logger,
set_random_seed)
from mono.model.registry import MONO
import torch
def main():
args = parse_args()
print(args.config)
cfg = Config.fromfile(args.config)
cfg.work_dir = args.work_dir
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.gpus = [int(_) for _ in args.gpus.split(',')]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
print('cfg is ', cfg)
# init logger before other steps
logger = get_root_logger(cfg.log_level)
logger.info('Distributed training: {}'.format(distributed))
# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
model_name = cfg.model['name']
model = MONO.module_dict[model_name](cfg.model)
if cfg.resume_from is not None:
load_checkpoint(model, cfg.resume_from, map_location='cpu')
elif cfg.finetune is not None:
print('loading from', cfg.finetune)
checkpoint = torch.load(cfg.finetune, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
train_dataset = get_dataset(cfg.data, training=True)
if cfg.validate:
val_dataset = get_dataset(cfg.data, training=False)
else:
val_dataset = None
train_mono(model,
train_dataset,
val_dataset,
cfg,
distributed=distributed,
validate=cfg.validate,
logger=logger)
首先Config组件用来把各种格式的配置文件读取成Config对象,以便于读取
然后指定了预训练模型地址和GPU、cuda设置、是否使用分布式多卡、日志设置、随机种子等
model_name = cfg.model['name']
model = MONO.module_dict[model_name](cfg.model)
这两句是从已注册的模型字典中取出你想要训练的模型对象。熟悉设计模式的朋友们会发现这是工厂模式的典型用法:通过Register工具将各个模型注册进了工厂的module_dict里,根据配置项中的字符串取出相应的模型类并实例化。具体用法如下:
在from mono.model.registry import MONO这一句中,先是执行了mono/model/init_.py,导入了源码中的四个模型:
from .mono_baseline.net import Baseline
from .mono_autoencoder.net import autoencoder
from .mono_fm.net import mono_fm
from .mono_fm_joint.net import mono_fm_joint
每个net文件被导入时都自动运行了Registry类中的装饰函数:(装饰函数在被修饰的函数或类定义的时候就自动被调用,有关装饰函数的用法不熟悉的话可参见 Python @函数装饰器及用法(超级详细))
@MONO.register_module
class mono_fm_joint(nn.Module):
这里的Registry类是作者自己写的,是源码的简化版,只保留了注册module_dict的功能,也不支持传参。可能因为源码的参数中有个build_func,默认为build_from_cfg,是用来将类实例化的方法,作者省掉了这个方法,所以重新写了一个不需要传参的Registry。MONO是他首先实例化的一个Registry对象,也就是工厂,用来保存module_dict。
import torch
import torch.nn as nn
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not issubclass(module_class, nn.Module):
raise TypeError(
'module must be a child of nn.Module, but got {}'.format(
module_class))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls): # 作为装饰函数
self._register_module(cls)
return cls
MONO = Registry('mono')
后面是预训练模型载入或者断点恢复
然后就是数据读取操作,这里的dataset和monodepth2类似,都是继承自pytorch自带的torch.utils.data.Dataset,进行了相应的扩展。
然后就是train_mono函数了,这是训练的核心功能。
def train_mono(model,
dataset_train,
dataset_val,
cfg,
distributed=False,
validate=False,
logger=None):
if logger is None:
logger = get_root_logger(cfg.log_level)
# start training
if distributed:
_dist_train(model, dataset_train, dataset_val, cfg, validate=validate)
else:
_non_dist_train(model, dataset_train, dataset_val, cfg, validate=validate)
可以看出训练分为分布式和非分布式。非分布式核心代码如下:
def _non_dist_train(model, dataset_train, dataset_val, cfg, validate=False):
# prepare data loaders
data_loaders = [
build_dataloader(dataset_train,
cfg.imgs_per_gpu,
cfg.workers_per_gpu,
cfg.gpus.__len__(),
dist=False)
]
# put model on gpus
model = MMDataParallel(model, device_ids=cfg.gpus).cuda()
# build runner
optimizer = build_optimizer(model,
cfg.optimizer)
runner = Runner(model, batch_processor,
optimizer,
cfg.work_dir,
cfg.log_level)
runner.register_training_hooks(cfg.lr_config,
cfg.optimizer_config,
cfg.checkpoint_config,
cfg.log_config)
分布式核心代码如下:
def _dist_train(model, dataset_train, dataset_val, cfg, validate=False):
# prepare data loaders
data_loaders = [build_dataloader(dataset_train,
cfg.imgs_per_gpu,
cfg.workers_per_gpu,
dist=True)
]
# put model on gpus
model = MMDistributedDataParallel(model.cuda())
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
print('cfg work dir is ', cfg.work_dir)
runner = Runner(model,
batch_processor,
optimizer,
cfg.work_dir,
cfg.log_level)
# register hooks
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
runner.register_training_hooks(cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config)
runner.register_hook(DistSamplerSeedHook())
这几乎就是用了mmcv的模板,从数据读取build_dataloader,到模型包装MMDataParallel/MMDistributedDataParallel、优化器build_optimizer、训练工作流Runner,以及各种hook注册。
这里详细解释一下各个步骤。
1.data_loader部分
一般来说data_loader可以使用pytorch原生的torch.utils.data.DataLoader,但是其中的sampler参数是可以自定义的。pytorch原生且默认的dataloader有两种:RandomSampler和SequentialSampler,以及以batch为单位的BatchSampler,分布式工具torch.utils.data.distributed中提供的是DistributedSampler,但原生的有几个缺点:
一是全都缺少分组功能,假如你的数据集是分为几类的,一个batch输入的图片必须属于同一类,就需要扩展
二是DistributedSampler缺少shuffle功能,假如想随机输入,也需要扩展
三是DistributedSampler只提供多卡数据补全功能,也就是说保证你的图片总数可以被gpu数整除,确保每个gpu有同样数量的图片,但缺少batchsize补全功能,也就是保证每个gpu的图片数量可以被batchsize整除(单卡训练的RandomSampler和SequentialSampler也没有),这时候只能在data_loader初始化时设置drop_last=True,通过去掉最后一个batch来保证每个batch的大小一样(这在有些模型中比较重要,因为可能网络的输入大小限制为batch_size,出现不能整除的情况最后一个batch会报错),这样会导致浪费部分图片,也需要扩展。
作者在这里用了mmdetection中写的三种sampler扩展类源码,其实并没有用到分组功能,只用到了shuffle和batch_size补全功能,分别是:
单卡训练采用GroupSampler,实现分组功能
多卡训练使用了DistributedGroupSampler和加入shuffle功能的DistributedSampler,实现分组和数据补全功能。这个源码就不贴了,看起来比较费劲,在notebook上构造数据跑了一下才理解。
2.模型并行化部分
这里单卡情况使用了MMDataParallel,多卡情况使用了MMDistributedDataParallel。
MMDataParallel的注释中描述了和pytorch的DataParallel的区别:
1.支持一个定制化类型:DataContainer,可以允许对输入数据更加灵活的控制。DataContainer的解释是可以解决原生DataParallel对数据大小必须一致、类型必须一致的限制。
2.支持两个API:train_step()和val_step(),这是mmcv中的工作流控制工具Runner中需要的方法。但MMDataParallel只支持单卡训练,多卡训练要使用MMDistributedDataParallel。MMDistributedDataParallel对pytorch原生的DistributedDataParallel的扩展和MMDataParallel一致。
3.优化器部分
主要通过build_optimizer函数读取配置文件中的优化器配置,从torch.optim中寻找相应的优化器类并实例化。这部分是featdepth作者自己写的简化版,直接用了pytorch自带的,在mmcv源码中支持自定义优化器,也通过Registry来注册和实例化。
4.工作流Runner部分
Runner是mmcv训练部分的引擎,整个训练的流程都由它来控制。详细解释可以参照OpenMMLab:MMCV 核心组件分析(七): Runner
Runner根据Epoch 和 Iter 模式又分为EpochBasedRunner和IterBasedRunner,默认Runner是EpochBasedRunner。需要以下参数进行初始化:
def __init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None,
max_iters=None,
max_epochs=None):
实际使用的时候调用Runner.run():
def run(self,
data_loaders, # dataloader 列表
workflow, # 工作流列表,长度需要和 data_loaders 一致
max_epochs=None,
**kwargs):
其中workflow参数决定了工作流的顺序,例如workflow = [(‘train’, 1),(‘val’, 1)],代表一个train一个val流程。run是Runner的入口,会从workflow中读取工作流名称,用getattr()去调用对应的方法。例如调用train()方法:
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
其中run_iter()函数用来走train流程中的前向传播部分,主要包括计算输出、计算损失函数。
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
解释:如果自定义了batch_processor方法,则调用batch_processor中的流程,否则调用model中的train_step()(此时model已被MMDataParallel或者MMDistributedDataParallel包装)。在train_step函数内部又进一步调用了model本身的train_step,如果不定义batch_processor,需要在模型中定义这个函数。例如:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def train_step(self, data, optimizer):
images, labels = data
predicts = self(images) # -> self.__call__() -> self.forward()
loss = self.loss_fn(predicts, labels)
return {'loss': loss}
5.HOOK部分
HOOK部分具体介绍也可以参照博客:OpenMMLab:MMCV 核心组件分析(六): Hook
mmcv框架大量用到了HOOK,即除了主流程之外,其他所有的功能都通过HOOK去调用:
self.call_hook(‘after_run’)
框架支持默认HOOK、定制HOOK和自定义HOOK,默认HOOK可以使用runner.register_training_hooks()直接注册,定制HOOK可以从runner中导入之后用runner.registerhook()进行注册,如featdepth中使用的DistSamplerSeedHook,再有一种自定义HOOK,可以由用户自定义编写HOOK内容,就需要继承runner中的HOOK类,定义代表位置的函数和优先级,如before_train_epoch()、beforerun()等,在上层函数调用call_hook()的时候就会按照优先级去依次调用已注册的HOOK中的相应函数。例如:
class DistEvalHook(Hook):
def __init__(self, dataset, interval=1, cfg=None):
assert isinstance(dataset, Dataset)
self.dataset = dataset
self.interval = interval
self.cfg = cfg
def after_train_epoch(self, runner):
print('evaluation..............................................')
if not self.every_n_epochs(runner, self.interval):
return
runner.model.eval()
results = [None for _ in range(len(self.dataset))]
if runner.rank == 0:
prog_bar = mmcv.ProgressBar(len(self.dataset))
框架部分就介绍到这里,水平有限,欢迎指正
要了解模型的原理和核心代码,请继续阅读:
苹果姐:单目深度估计自监督模型Featdepth解读(上)——论文理解和核心源码分析