上一篇介绍了onnx模型在tvm中优化的总体流程。

在这一篇中,介绍onnx模型到relay模型的转换流程,主要涉及了以下几个方面:

  • onnx算子到relay算子转换
  • relay算子实现

上一篇中,onnx模型到relay转换,主要是下面的语句,输入是onnx模型以及输入shape信息。
输出为relay IR模型和模型参数。

onnx算子到relay算子转换

# onnx -> relay
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

这部分实现是在python/tvm/relay/frontend/onnx.py中。实现转换的核心在于GraphProto这个类。这个类中实现了读取onnx模型各个节点、输入输出,映射onnx算子到relay IR。对外接口为from_onnx这个函数。其伪代码可以大致表示为

def from_onnx(self, graph, opset, get_output_expr=False):
    inputs, params = read_model_inputs(graph) # 模型参数
    nodes = read_model_node(graph) # 模型节点、算子信息
    convert_map = _get_convert_map(opset) # 模型转换map
    check_op_support(nodes)
    for node in nodes:
        op = self._convert_operator(op_name, inputs, attr, opset)
    return

这下就可以知道ONNX前端的每个算子转化与_get_convert_map有关。
_convert_operator完成了算子转换过程。具体的convert_map包含了所有支持算子的转换函数。

def _convert_operator(self, op_name, inputs, attrs, opset):
    convert_map = _get_convert_map(opset)
    if op_name in _identity_list: #这里是空的
        sym = get_relay_op(op_name)(*inputs, **attrs)
    elif op_name in convert_map:
        sym = convert_map[op_name](inputs, attrs, self._params)
    else:
        raise NotImplementedError("Operator {} not implemented.".format(op_name))
    return sym

以卷积算子为例,介绍具体的转换过程:

"Conv": Conv.get_converter(opset)

实际转换操作

class Conv(OnnxOpConverter):
    """Operator converter for Conv."""
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
        # Use shape of input to determine convolution type.
        data = inputs[0]
        input_shape = infer_shape(data)
        ndim = len(input_shape)
        # auto_pad ...

        # construct op from attrs
        out = AttrCvt(
            op_name=dimension_picker("conv"),
            transforms={
                "kernel_shape": "kernel_size",
                "dilations": ("dilation", 1),
                "pads": ("padding", 0),
                "group": ("groups", 1),
            },
            custom_check=dimension_constraint(),
        )([data, inputs[1]], attr, params)

        use_bias = len(inputs) == 3
        if use_bias:
            out = _op.nn.bias_add(out, inputs[2])
        return out

最终在AttrCvt中构建相应的relay算子,python/tvm/relay/frontend/common.py

get_relay_op(op_name)(*inputs, **new_attrs)

relay 算子实现

继续以conv卷积算子为例介绍。上文所述的转换算子中,有下面的语句

for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib):
    op = getattr(candidate, op_name, None)
    if op is not None:
        break

跟踪conv2d算子,在_op.nn中,找到conv2d实现。

def conv2d(
    data,
    weight,
    strides=(1, 1),
    padding=(0, 0),
    dilation=(1, 1),
    groups=1,
    channels=None,
    kernel_size=None,
    data_layout="NCHW",
    kernel_layout="OIHW",
    out_layout="",
    out_dtype="",
):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if isinstance(strides, int):
        strides = (strides, strides)
    if isinstance(dilation, int):
        dilation = (dilation, dilation)
    padding = get_pad_tuple2d(padding)
    return _make.conv2d( data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
    )

这里的_make.conv2d是通过下面的PackFunc注册得到的

tvm._ffi._init_api("relay.op.nn._make", __name__)

src/relay/op/nn/convolution.cc找到conv2d的注册函数

TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d")
    .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
                       Array<IndexExpr> dilation, int groups, IndexExpr channels,
                       Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,
                       String out_layout, DataType out_dtype) {
      return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,
                                   kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
                                   "nn.conv2d");
    });

MakeConv 是对所有卷积的模板,根据参数实例化相应的函数

template <typename T>
inline Expr MakeConv(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
                     Array<IndexExpr> dilation, int groups, IndexExpr channels,
                     Array<IndexExpr> kernel_size, std::string data_layout,
                     std::string kernel_layout, std::string out_layout, DataType out_dtype,
                     std::string op_name) {
  auto attrs = make_object<T>();
  attrs->strides = std::move(strides);
  attrs->padding = std::move(padding);
  attrs->dilation = std::move(dilation);
  attrs->groups = groups;
  attrs->channels = std::move(channels);
  attrs->kernel_size = std::move(kernel_size);
  attrs->data_layout = std::move(data_layout);
  attrs->kernel_layout = std::move(kernel_layout);
  attrs->out_layout = std::move(out_layout);
  attrs->out_dtype = std::move(out_dtype);
  const Op& op = Op::Get(op_name);
  return Call(op, {data, weight}, Attrs(attrs), {});
}