前言
本文主要介绍detectron2如何构建模型。本文将首先介绍注册器,然后介绍如何利用注册器注册模型,最后介绍下构建流程即可。感兴趣可以看下mmdetection中注册器,你会发现这两种优秀框架所用到的设计思想一致。
1、Registry介绍Registry你可以理解成一个能够存储类的字典。比如{‘BackBone’, resnet}。看下源码:
class Registry(Iterable[Tuple[str, Any]]):
"""
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name: str) -> None:
"""
Args:
name (str): the name of this registry
"""
self._name: str = name
self._obj_map: Dict[str, Any] = {}
def _do_register(self, name: str, obj: Any) -> None:
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name
)
self._obj_map[name] = obj
def register(self, obj: Any = None) -> Any:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class: Any) -> Any:
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj)
def get(self, name: str) -> Any:
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(
"No object named '{}' found in '{}' registry!".format(name, self._name)
)
return ret
注释介绍了两种注册新模型的方法。register函数后来会作为装饰器来注册模型,而get会在构建模型时根据key来索引对应的类。可能听不懂,请跳转:mmdetection之Registry介绍。
2、构建ResNet50为例2.1. 构造ResNet类
代码来自detectron2/modeling/backbone文件夹内。
首先看下backbone.py,里面定义了一个父类,是所有backbone的父类,此处没什么可讲的。
class Backbone(nn.Module, metaclass=ABCMeta):
"""
Abstract base class for network backbones.
"""
def __init__(self):
super().__init__()
@abstractmethod
def forward(self):
pass
@property
def size_divisibility(self) -> int:
return 0
def output_shape(self):
# this is a backward-compatible default
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
在来看下resnet.py。主要就是继承自上述父类,构建ResNet。
class ResNet(Backbone):
"""
Implement :paper:`ResNet`.
"""
def __init__(self, stem, stages, num_classes=None, out_features=None):
2.2 利用Registry注册ResNet
注册器的使用方式如下:
BACKBONE_REGISTRY = Registry("BACKBONE") # 实例化一个注册器
@BACKBONE_REGISTRY.register() # 用注册器的register来装饰ResNet类
def build_resnet_backbone(cfg, input_shape):
通过上述方式,在程序运行后,在BACKBONE_REGISTRY注册器中就包含了ResNet类。然后再通过build.py中的build_backbone接口函数,就能实例化一个ResNet。
def build_backbone(cfg, input_shape=None):
"""
Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
Returns:
an instance of :class:`Backbone`
"""
if input_shape is None:
input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))
backbone_name = cfg.MODEL.BACKBONE.NAME
backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape) # .get获取类,并cfg实例化类
assert isinstance(backbone, Backbone)
return backbone # 构建出一个resnet
3、SparseRCNN
现在若想构建自己的模型,该如何使用注册器呢?这里我以开源项目SparseRCNN为例,源码使用的detectron2进行构建。新建一个project目录,然后创建一个detector.py文件。该函数内容为:
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess
__all__ = ["SparseRCNN"]
@META_ARCH_REGISTRY.register()
class SparseRCNN(nn.Module):
跟ResNet构建方式一致,首先导入modeling中的META_ARCH_REGISTRY,然后将类SparseRCNN注册进去。之后在detectron2/engine/defaults.py中调用build_model就实例化一个网络。
总结detectron2通过借助注册器能够很方便搭建网络。