1.使用pytorch模型转onnx
首先要将pth模型转为pt模型,使用torch.jit.trace
import torchvision.models as models
from mmpose.apis import init_pose_model
from mmcv.runner import load_checkpoint
import torch
def _convert_batchnorm(module):
"""Convert the syncBNs into normal BN3ds."""
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm3d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def export_pytorch_model():
pose_config = "demo/hrnet_w32_coco_256x192.py"
pose_checkpoint = "checkpoints/hrnet_w32_coco_256x192-c78dce93_20200708.pth"
model = init_pose_model(pose_config, pose_checkpoint,
'cpu')#构建完添加了model.cfg的属性
model = _convert_batchnorm(model)
# onnx.export does not support kwargs
if hasattr(model, 'forward_dummy'):
from functools import partial
# model.forward = partial(model.forward_dummy, softmax=args.softmax)
model.forward = model.forward_dummy
elif hasattr(model, '_forward') and args.is_localizer:
model.forward = model._forward
else:
raise NotImplementedError(
'Please implement the forward method for exporting.')
trace_model = torch.jit.trace(model, torch.Tensor(1, 3, 256, 192))
trace_model.save('./hrnet.pt')
if __name__ == '__main__':
export_pytorch_model()
2.pt模型转rknn
参考文献
https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc