当训练好一个CNN模型之后,可能要集成到项目工程中,或者移植到到不同的开发平台(比如Android, IOS), 一般项目工程或者App大多数采用C/C++, Java等语言,但是采用pytroch训练的模型用的是python语言,这样就存在一个问题,如何使用C/C++调用预训练好的模型, 如果解决了这个问题,那么训练好的模型才可以在App中得到广泛应用。
PyTorch模型从Python到C++的转换由Torch Script实现。Torch Script是PyTorch模型的一种表示,可由Torch Script编译器理解,编译和序列化。
1、将pytorch模型转换为Torch Script
将PyTorch模型转换为Torch Script有两种方法。
第一种方法是Tracing。该方法通过将样本输入到模型中一次来对该过程进行评估从而捕获模型结构,并记录该样本在模型中的flow。该方法适用于模型中很少使用控制flow的模型。
第二个方法就是向模型添加显式注释(Annotation),通知Torch Script编译器它可以直接解析和编译模型代码,受Torch Script语言强加的约束。
2、windows平台安装pytorch
windows平台的pytorch库为libtorch。
1)在pytorch官网下载libtroch, 官网提供了win/linux/Mac系统编译好的库,省去了编译库的过程。
官方网址:https://pytorch.org/
按照下列方式选择:【仅做推断,选择非CUDA版本即可】
下载release版本:点击红色部分下载网址,即可下载。
2)根据cmakelists.txt创建vs工程。
【也可以直接创建工程,将libtorch的头文件路径和库文件路径按照其他库的方式配置进工程即可,
头文件路径:D:\thirdLib\pytorch\libtorch\include
库文件路径:D:\thirdLib\pytorch\libtorch\lib】
如有以下的CmakeLists.txt和cpp测试文件,通过cmake创建vs工程,同时配置libtorch:
CmakeLists.txt文件内容:
cmake_minimum_required(VERSION 3.0.0 FATAL_ERROR)
project(simnet)find_package(Torch REQUIRED)
message(STATUS "Pytorch status:")
message(STATUS " libraries: ${TORCH_LIBRARIES}")add_executable(simnet simnetTest.cpp)
target_link_libraries(simnet ${TORCH_LIBRARIES})
set_property(TARGET simnet PROPERTY CXX_STANDARD 11)
simnetTest.cpp文件内容:
#include "torch/script.h"
#include "torch/torch.h"#include <iostream>
#include <memory>using namespace std;
int main(int argc, const char* argv[])
{
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
module->to(at::kCUDA); assert(module != nullptr);
std::cout << "ok\n";
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}).to(at::kCUDA));
at::Tensor output = module->forward(inputs).toTensor(); std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
}
打开cmake,选择编译器,添加路径,点击“Configure”,发现报错找不到libtorch库,然后
配置Torch_DIR路径:D:\thirdLib\pytorch\libtorch\share\cmake\Torch
再次点击“Configure”,成功,点击“Generate”即可生成vs工程。打开工程进行编译、运行即可。
注意在环境变量中添加.dll文件的路径:D:\thirdLib\libtorch\libtorch\lib。