PyTorch的学习和使用(八)
Mon 22 Mon 29 Mon 05 settrace grad_fn pyTorch1.0 pyTorchToCaffe 完成进度表
一、pyTorch to Caffe
静态图: 网络在输入数据前就预先将网络定义好,与数据无关,将所有的操作过程定义好,运行时填入数据。比如caffe,tensorFlow等框架。
动态图: 数据在网络传输中动态的构建网络。比如pyTorch等。
两者各有优点,静态图由于提前将网络结构确定了,部署十分方便,但是由于数据在网络中的传递过程往往不可知,因此调试较为困难;动态图是根据数据的流动动态的构建网络图,因此数据在网络中的状态都是已知的,调试十分便捷。目前,两者都在吸取对方的优点,tensorFlow也在也如动态图的机制,pyTorch与Caffe2结合,在配合ONNX实现高效的部署。
pyTorch模型转换到caffe模型可以看为动态图到静态图之间的转换,主要需要进动态图到静态图之间的转换,即构建出动态图然后将其映射到静态图,并且将网络参数也进行转换。
二、python trace
机制捕获动态图
pyTroch框架采用python构建,通过使用python的trace
机制可以获取到网络在传递过程中所经过的结构,从而映射到静态图。主要步骤如下:
- 启动python的
trace
功能,并定义其回调函数。 - 在回调函数中捕获网络所调用的原子操作。
- 将对应的操作使用caffe的python接口进行映射。
- 将相应pyTorch的网络参数映射到caffe模型。
- 保存caffe模型,关闭python的
trace
功能。
代码框架如下:
import sys
import torch
from caffe import layers as L, params as P, to_proto
def tracea_fun(frame, event, arg):
//通过当前的frame栈得到每次调用的函数,并将其转换为相应的caffe调用
def main(model, input):
sys.settrace(trace_fun)
output = model(input)
sys.settrace(None)
if __name__ == "__main__":
input = DataLoder()
model = Net()
main(model, input)
2.1 sys.settrace
操作捕获
python的sys.settrace
定义docs.python.org:
Set the system’s trace function, which allows you to implement a Python source code debugger in Python.
Trace functions should have three arguments: frame, event, and arg. frame is the current stack frame. event is a string: ‘call’, ‘line’, ‘return’ or ‘exception’. arg depends on the event type.
因此通过sys.settrace
的回调函数中的frame栈可以捕获当前的操作,其中frame
为frame objects,定义见The standard type hierarchy,常用属性有:
- f_code: The code object being executed in this frame
- co_name: Function name
- co_varnames: A tuple containing the names of the local variables
- f_locals: The dictionary used to look up local variables
- f_back: The previous stack frame
则通过frame.f_code.co_name
和frmae.f_locals
可以获得网络传递过程中的函数名和参数。
2.2 pyTorch原子操作捕获
实现该方法的难点在于如何找到网络中数据的流向,比如进行的view
操作和resNet网络中何时进行add
操作,这些操作在pyTorch0.2中都封装成了相应的原子操作,只需要找到对应的调用函数即可(但是在pyToch0.3以上中直接调用C的接口,暂时不会怎么使用settrace
进行捕捉)。
以卷基层为例,在trace_fun
中的conv2d代码如下:
def trace_fun(frame, event, arg):
if frame.f_code.co_name == "conv2d":
groups = frame.f_locals["groups"]
pad_h = frame.f_locals["padding"][0]
pad_w = frame.f_locals["padding"][1]
stride_h = frame.f_locals["stride"][0]
stride_w = frame.f_locals["stride"][1]
dilation = frame.f_locals["dilation"]
weight = frame.f_locals["weight"]
bias = frame.f_locals["bias"]
bottom = getBottom()
name = "conv1"
top = L.Convolution(bottom, name=name,
kernel_h=kernel_h, kernel_w = kernel_w,
num_output=num_output, groups=groups,
stride_h=stride_h, stride_w=stride_w,
pad_h=pad_h, pad_w=pad_w,
dilation=dilation)
其中,getBottom()
为获取当前层的前一层,通过维护一个容器,在容器中以每层的物理地址作为该层的唯一索引进行检索,即使用id(feature)
来确定其前一层。
注意,该方法只用于第一个pyTorch0.2之前的版本,在0.3之后的版本通过直接调用C接口的方式,目前还不会将其操作栈剥离出来。
三、pyTorch grad_fn网络拓扑图构建
四、pyTorch1.0 ONNX和caffe2之间的使用