通过 ONNX 模型在 C++ 中加载和推理 GRU 模型的步骤如下:

1. 导出 GRU 模型为 ONNX 格式

首先,确保您已经有一个训练好的 GRU 模型,并将其导出为 ONNX 格式。以下是一个使用 PyTorch 导出 GRU 模型为 ONNX 格式的示例代码:

import torch  
import torch.nn as nn  

class GRUModel(nn.Module):  
    def __init__(self, input_size, hidden_size, output_size):  
        super(GRUModel, self).__init__()  
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)  
        self.fc = nn.Linear(hidden_size, output_size)  

    def forward(self, x):  
        out, _ = self.gru(x)  
        return self.fc(out[:, -1, :])  # 最后一时刻的输出  

# 创建模型实例  
model = GRUModel(input_size=10, hidden_size=20, output_size=1)  
model.eval()  

# 输入样本  
dummy_input = torch.randn(1, 5, 10)  # (batch_size, sequence_length, input_size)  

# 导出为 ONNX  
torch.onnx.export(model, dummy_input, "gru_model.onnx", input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

2. 在 C++ 中加载和推理 ONNX 模型

要在 C++ 中加载和推理 ONNX 模型,您可以使用 ONNX Runtime。以下是如何加载和推理 GRU ONNX 模型的示例代码:

安装 ONNX Runtime

确保您已经安装了 ONNX Runtime C++ API。可以参考 ONNX Runtime 的官方文档 来完成安装。

C++ 示例代码

#include <onnxruntime/core/session/onnxruntime_cxx_api.h>  
#include <iostream>  
#include <vector>  

int main() {  
    // 创建 ONNX Runtime 环境和会话  
    Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXRuntime");  
    Ort::SessionOptions session_options;  
    Ort::Session session(env, "gru_model.onnx", session_options);  

    // 准备输入数据  
    std::vector<float> input_data = { /* 填充您的输入数据 */ };  
    std::vector<int64_t> input_shape = {1, 5, 10};  // (batch_size, sequence_length, input_size)  
    
    // 创建输入张量  
    Ort::Value input_tensor = Ort::Value::CreateTensor<float>(env.GetAllocator(0, OrtArenaAllocator), input_data.data(), input_data.size(), input_shape.data(), input_shape.size());  

    // 进行推理  
    std::vector<Ort::Value> input_tensors;  
    input_tensors.push_back(std::move(input_tensor));  

    // 获取输出节点名称  
    const char* output_node_names[] = {"output"};  
    
    // 执行推理  
    auto output_tensors = session.Run(Ort::RunOptions{nullptr},   
                                       &input_node_name,  // 输入节点名称  
                                       input_tensors.data(),  // 输入张量  
                                       1,  // 输入张量数量  
                                       output_node_names,      // 输出节点名称  
                                       1);  // 输出张量数量  

    // 处理输出结果  
    float* output_arr = output_tensors[0].GetTensorMutableData<float>();  
    std::cout << "Output: " << output_arr[0] << std::endl;  

    return 0;  
}

重要事项

  1. 安装 ONNX Runtime:确保您已经找到并配置好 ONNX Runtime C++ 客户端库。
  2. 输入数据:在填充输入数据时,请确保形状与模型要求匹配。
  3. 编译器设置:根据您的项目设置,可能需要配置 CMake 或 Makefile 来包含 ONNX Runtime 库和头文件。
  4. 输出处理:根据您的 GRU 模型和任务,调整输出处理逻辑