摘要:主要学习记录了ONNX转TensorRT流程、代码。末尾有完整代码。
目录:
3.1 创建TensorRT的日志记录器
3.2 创建bulider对象
3.3 设置engine参数
3.4 定义network并加载ONNX解析器
3.5 获取网络的输入输出
3.6 动态输入
3.7 检查设备是否支持FP16(半精度)推理
3.8 写入engine,并序列化model
3.9 完整代码
3.10 遇见的问题
3.1 创建TensorRT的日志记录器
log = trt.Logger()
3.2 创建bulider对象
使用日志记录器创建 TensorRT Builder 对象,并通过Builder创建network并从该网络生成engine
其中:trt.OnnxParser(network, log)需要传入两个参数。一个是已创建network,一个是日志记录器
builder = trt.Builder(log) # 使用日志记录器创建 TensorRT Builder 对象
parser = trt.OnnxParser(network, log) # 从network生成engine
3.3 设置engine参数
# 创建 Builder Config 对象
config = builder.create_builder_config()
# 设置 TensorRT 推理引擎使用的最大工作空间大小,单位为字节。指定最大可用显存
config.max_workspace_size = workspace * 1 << 30
3.4 定义network并加载ONNX解析器
通过builder创建一个空网络,什么都没有,需要将ONNX的模型结构信息写入创建的空network。
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network() # 通过Builder创建network,此时的network还是一个空网络
parser = trt.OnnxParser(network, log)
查看ONNX是否解析成功并将ONNX中的模型结构等信息写入network
# 查看是否解析成功,同时将模型结构写进了network
if not parser.parse_from_file(str(onnx)):
raise RuntimeError(f'failed to load ONNX file: {onnx}')
3.5 获取网络的输入输出
# 可能不是num_inputs,根据实际情况来。
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
3.6 动态输入
if dynamic:
im = torch.zeros(1, 3, *imgsz).to(device) # 我这儿输入是im = torch.zeros(1,3,640,640)
if im.shape[0] <= 1:
# log.warning(f"{trt} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
print('x')
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
config.add_optimization_profile(profile)
3.7 检查设备是否支持FP16(半精度)推理
其中:
- builder.platform_has_fast_fp16:用于检查当前设备是否可以进行半精度计算。
- half:自定义bool参数,用于决定是否半径都推理
- config.set_flag(trt.BuilderFlag.FP16):set_flag方法来设置config对象的标志,将FP16标志添加到flags中
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
3.8 写入engine,并序列化model
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize())
如果希望在trt模型中加入classes(其余信息类似)。
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
classes = (['person', 'car'])
add_meta_to_model(t, classes, type='trt')
t.write(engine.serialize())
3.9 完整代码
import numpy as np
import tensorrt as trt
import torch
import logging
# logger to capture errors, warnings, and other information during the build and inference phases
TRT_LOGGER = trt.Logger()
def build_engine(onnx, dynamic=True, half=True):
# f = onnx.with_suffix('.engine')
f = 'trt.engine'
# 1、创建日志记录器
log = trt.Logger()
# 2、创建builder对象
builder = trt.Builder(log)
# 3、创建 Builder Config 对象
config = builder.create_builder_config()
# 4、将workspace*1 二进制左移30位后的10进制
workspace = 1
config.max_workspace_size = workspace * 1 << 30 # 设置 TensorRT 推理引擎使用的最大工作空间大小,单位为字节
# 5、定义networko并加载ONNX解析器
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, log)
if not parser.parse_from_file(str(onnx)): # 查看是否解析成功
raise RuntimeError(f'failed to load ONNX file: {onnx}')
# 6、获得网络的输入输出
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
# 7.判断是否动态输入
if dynamic:
im = torch.zeros(1,3,640,640)
if im.shape[0] <= 1:
# log.warning(f"{trt} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
print('x')
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
config.add_optimization_profile(profile)
# 判断是否支持FP16推理
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
# build engine 文件的写入 这里的f是前面定义的engine文件
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
# 序列化model
t.write(engine.serialize())
return f, None
if __name__ == '__main__':
engine, context = build_engine(r'D:\zy\Yolo\yolov8-ZY\yolov8n.onnx')
3.10 遇见的问题
1、AttributeError: 'tensorrt.tensorrt.Builder' object has no attribute 'max_workspace_size'
原因是:tensorrt8.0以上删除了max_workspace_size属性。
- 降低tensorRT版本到7.x版本
- 或者如下
config = builder.create_builder_config() # 创建 Builder Config 对象
config.max_workspace_size = workspace * 1 << 30 # 设置 TensorRT 推理引擎使用的最大工作空间大小,单位为字节
上一篇:2、TensorRT学习笔记之PT转ONNX、可视化ONNX
下一篇:正在学习、持续更新(实战,瑞芯微RK3588部署yolov8检测模型)