前言

  本文主要介绍detectron2如何构建模型。本文将首先介绍注册器,然后介绍如何利用注册器注册模型,最后介绍下构建流程即可。感兴趣可以看下mmdetection中注册器,你会发现这两种优秀框架所用到的设计思想一致。

1、Registry介绍

  Registry你可以理解成一个能够存储类的字典。比如{‘BackBone’, resnet}。看下源码:

class Registry(Iterable[Tuple[str, Any]]):
    """
    .. code-block:: python

        BACKBONE_REGISTRY = Registry('BACKBONE')

    To register an object:

    .. code-block:: python

        @BACKBONE_REGISTRY.register()
        class MyBackbone():
            ...

    Or:

    .. code-block:: python

        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name: str) -> None:
        """
        Args:
            name (str): the name of this registry
        """
        self._name: str = name
        self._obj_map: Dict[str, Any] = {}

    def _do_register(self, name: str, obj: Any) -> None:
        assert (
            name not in self._obj_map
        ), "An object named '{}' was already registered in '{}' registry!".format(
            name, self._name
        )
        self._obj_map[name] = obj

    def register(self, obj: Any = None) -> Any:
        """
        Register the given object under the the name `obj.__name__`.
        Can be used as either a decorator or not. See docstring of this class for usage.
        """
        if obj is None:
            # used as a decorator
            def deco(func_or_class: Any) -> Any:
                name = func_or_class.__name__
                self._do_register(name, func_or_class)
                return func_or_class

            return deco

        # used as a function call
        name = obj.__name__
        self._do_register(name, obj)

    def get(self, name: str) -> Any:
        ret = self._obj_map.get(name)
        if ret is None:
            raise KeyError(
                "No object named '{}' found in '{}' registry!".format(name, self._name)
            )
        return ret

  注释介绍了两种注册新模型的方法。register函数后来会作为装饰器来注册模型,而get会在构建模型时根据key来索引对应的类。可能听不懂,请跳转:mmdetection之Registry介绍

2、构建ResNet50为例

2.1. 构造ResNet类

 代码来自detectron2/modeling/backbone文件夹内。
detectron2源码阅读4--注册器构建模型_python
 首先看下backbone.py,里面定义了一个父类,是所有backbone的父类,此处没什么可讲的。

class Backbone(nn.Module, metaclass=ABCMeta):
    """
    Abstract base class for network backbones.
    """

    def __init__(self):
        super().__init__()

    @abstractmethod
    def forward(self):
        pass

    @property
    def size_divisibility(self) -> int:
        return 0

    def output_shape(self):
        # this is a backward-compatible default
        return {
            name: ShapeSpec(
                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
            )
            for name in self._out_features
        }

 在来看下resnet.py。主要就是继承自上述父类,构建ResNet。

class ResNet(Backbone):
    """
    Implement :paper:`ResNet`.
    """

    def __init__(self, stem, stages, num_classes=None, out_features=None):

2.2 利用Registry注册ResNet

  注册器的使用方式如下:

BACKBONE_REGISTRY = Registry("BACKBONE")    # 实例化一个注册器
@BACKBONE_REGISTRY.register()               # 用注册器的register来装饰ResNet类
def build_resnet_backbone(cfg, input_shape):

通过上述方式,在程序运行后,在BACKBONE_REGISTRY注册器中就包含了ResNet类。然后再通过build.py中的build_backbone接口函数,就能实例化一个ResNet。

def build_backbone(cfg, input_shape=None):
    """
    Build a backbone from `cfg.MODEL.BACKBONE.NAME`.

    Returns:
        an instance of :class:`Backbone`
    """
    if input_shape is None:
        input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))

    backbone_name = cfg.MODEL.BACKBONE.NAME
    backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape)  # .get获取类,并cfg实例化类
    assert isinstance(backbone, Backbone)
    return backbone                                    # 构建出一个resnet
3、SparseRCNN

  现在若想构建自己的模型,该如何使用注册器呢?这里我以开源项目SparseRCNN为例,源码使用的detectron2进行构建。新建一个project目录,然后创建一个detector.py文件。该函数内容为:

from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess
__all__ = ["SparseRCNN"]

@META_ARCH_REGISTRY.register()
class SparseRCNN(nn.Module):

 跟ResNet构建方式一致,首先导入modeling中的META_ARCH_REGISTRY,然后将类SparseRCNN注册进去。之后在detectron2/engine/defaults.py中调用build_model就实例化一个网络。

总结

 detectron2通过借助注册器能够很方便搭建网络。