TorchScript 简介

TorchScript是PyTorch模型(nn.Module的子类)的中间表示,可以在高性能环境(例如C ++)中运行。

转换模块

跟踪(Tracing)模块

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

x, h = torch.rand(3, 4), torch.rand(3, 4)
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

输出结果:

if x.sum() > 0:
MyCell(
  original_name=MyCell
  (dg): MyDecisionGate(original_name=MyDecisionGate)
  (linear): Linear(original_name=Linear)
)

TorchScript将其定义记录在中间表示(或IR)中,在深度学习中通常称为图形。我们可以检查带有.graph属性的图:

print(traced_cell.graph)

输出结果:

graph(%self.1 : __torch__.___torch_mangle_135.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.___torch_mangle_134.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /tmp/ipykernel_10930/2555227347.py:7:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /tmp/ipykernel_10930/2555227347.py:7:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /tmp/ipykernel_10930/2555227347.py:7:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)

但是,这是一个非常低级的表示形式,图中包含的大多数信息对最终用户没有用。相反,我们可以使用.code属性来给出代码的Python语法解释:

print(traced_cell.code)

输出结果:

if x.sum() > 0:
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  _1 = torch.tanh(_0)
  return (_1, _1)

那么为什么我们要做所有这些呢? 有以下几个原因:

  • TorchScript代码可以在其自己的解释器中调用,该解释器基本上是受限制的Python解释器。该解释器不被全局解释器锁定,因此可以在同 一实例上同时处理许多请求。
  • 这种格式使我们可以将整个模型保存到磁盘上,并将其加载到另一个环境中,例如在以Python以外的语言编写的服务器中
  • TorchScript为我们提供了一种表示形式,其中我们可以对代码进行编译器优化以提供更有效的执行
  • TorchScript允许我们与许多后端/设备运行时进行接口,这些运行时比单个操作员需要更广泛的程序视图。

我们可以看到,调用traced_cell产生的结果与Python模块相同:

print(my_cell(x, h))
print(traced_cell(x, h))

输出结果:

if x.sum() > 0:
(tensor([[ 0.1474,  0.9080, -0.3856,  0.8971],
        [-0.2458,  0.7805,  0.4175,  0.7327],
        [-0.4040,  0.7462,  0.4431,  0.7378]], grad_fn=<TanhBackward0>), tensor([[ 0.1474,  0.9080, -0.3856,  0.8971],
        [-0.2458,  0.7805,  0.4175,  0.7327],
        [-0.4040,  0.7462,  0.4431,  0.7378]], grad_fn=<TanhBackward0>))
(tensor([[ 0.1474,  0.9080, -0.3856,  0.8971],
        [-0.2458,  0.7805,  0.4175,  0.7327],
        [-0.4040,  0.7462,  0.4431,  0.7378]], grad_fn=<TanhBackward0>), tensor([[ 0.1474,  0.9080, -0.3856,  0.8971],
        [-0.2458,  0.7805,  0.4175,  0.7327],
        [-0.4040,  0.7462,  0.4431,  0.7378]], grad_fn=<TanhBackward0>))

脚本转换模块

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)

输出:

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)

torch.jit.script跟踪完全按照我们所说的去做:运行代码,记录发生的操作,并构 造一个可以做到这一点的ScriptModule。不幸的是,诸如控制流之类的东西被抹去了。

保存和加载模型

一旦有了ScriptModule(通过跟踪或注释PyTorch模型),您就可以将其序列化为文件了。torch提供API,以存档格式将TorchScript模块保存到磁盘或从磁盘加载TorchScript模块。这种格式包括代码,参数,属性和调试信息,这意味着归档文件是模型的独立表示形式,可以在完全独立的过程中加载,稍后,您将可以使用C ++从此文件加载模块并执行它。

traced.save('wrapped_rnn.zip')

loaded = torch.jit.load('wrapped_rnn.zip')

print(loaded)
print(loaded.code)

C++加载执行TorchScript模型

要在C ++中加载序列化的PyTorch模型,您的应用程序必须依赖于PyTorch C ++ API(也称为LibTorch)。LibTorch发行版包含共享库,头文件 和CMake构建配置文件的集合。虽然CMake不是依赖LibTorch的要求,但它是推荐的方法,并且将来会得到很好的支持。 对于本教程,我们将 使用CMake和LibTorch构建一个最小的C ++应用程序,该应用程序简单地加载并执行序列化的PyTorch模型。

Libtorch下载:https://pytorch.org/

// // 使用以下命令从文件中反序列化脚本模块: torch::jit::load().
torch::jit::script::Module module = torch::jit::load("保存的TorchScript模型");

CMakeLists.txt 加入以下几句,链接LibTorch

find_package(Torch REQUIRED)

target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}")
set_property(TARGET ${PROJECT_NAME}  PROPERTY CXX_STANDARD 11)

成功用C ++加载了序列化的TorchScript模型之后,让我们将这些行添加到C ++应用程序的main()函数中:

// 创建输入向量
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));

// 执行模型并将输出转化为张量
// 通过调用toTensor()将其转换为张量
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';