pytorch, onnx
摘要:为了将自定义算子的参数,或者自己想要保存的参数序列化到onnx中。
code
import torch
import torch.nn as nn
from torch.autograd import Function
import onnx
import torch.onnx
class Requant_(Function):
@staticmethod
def forward(ctx, input, requant_scale, shift): # ctx 必须要
input = input.double() * requant_scale / 2**shift # 为了等价于c中的移位操作。会存在int32溢出
input = torch.floor(input).float()
return torch.floor(input)
@staticmethod
def symbolic(g, *inputs):
return g.op("Requant", inputs[0], scale_f=23.0, shift_i=8)
requant_ = Requant_.apply
class TinyNet(nn.Module):
def __init__(self):
super(TinyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 1, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = x.view(-1)
x = requant_(x, 5, 5)
return x
net = TinyNet().cuda()
ipt = torch.ones(2,3,12,12).cuda()
torch.onnx.export(net, (ipt,), 'tinynet.onnx', opset_version=11, enable_onnx_checker=False)
print(onnx.load('tinynet.onnx'))
关键点:
- 继承自torch.autograd
- scale_f=23.0, shift_i=8,_f表示浮点数,_i表示整形int32类型
onnx 模型
总结
这种是在pytorch中新写一个op,并序列化到onnx中,另外一个想法是:如果修改已有op的onnx序列化,比如conv2d,upsample等。得到onnx模型中,还需要对onnx模型解析,在把onnx模型转换成自己想要的表达。