前言

当我们把使用Python训练的模型固化成PB文件之后,再进行相应的模型压缩之后可以考虑往Mobile端移植了,本文主要讲解TensorFlow Model移植到Android端。

TensorFlow1.0之后推出了Java版本,所以间接为Android开发TensorFlow程序带来便利,以前我们需要用JNI去编写,可是JNI难于调试,C++代码对于普通Android开发者来讲还是比Java繁琐,所以本文以Java API讲述开发过程。

正文

下面就正式开始一直TensorFlow model到Android中啦。

  • 引入依赖

在TensorFlow更新到1.2.0版本之后,TensorFlow为广大开发者提供了gradle依赖,现在我们想要引入TensorFlow只需要在gradle中加入

compile 'org.tensorflow:tensorflow-android:1.2.0-rc0'

即可引入TensorFlow的库。

  • 复制PB文件

快速开发的话直接把PB文件放在assets文件夹里就行,如果正式上线的时候觉得PB文件一起打包较大的话可以放在服务器,打开APP的时候提示下载再复制进去就好。

  • 创建TensorFlowInterface类

这个类指的是我们读取、识别等一系列方法存放的类,名字随你取。

  • 载入TensorFlow

在类的第一行加入这句话,会在加载类的时候首先加载TensorFlow

{
        System.loadLibrary("tensorflow_inference");
    }
  • 定义常量

在这一步,我们先定义一些常量,比如输入节点名、输出节点名、输出图像的尺寸、通道、输入节点数据类型、输出节点数据类型。代码如下

private static final String input_layer = "inputs/X";
    private static final String output_layer = "output/predict";

    private Context context;
    private static final int HEIGHT = 64;
    private static final int WIDTH = 256;
    private static final int CHANNEL = 1;

    private float[] inputs = new float[HEIGHT*WIDTH*CHANNEL];
    private long[] outputs = new long[11];
  • 初始化模型

这一步TensorFlow的模型会载入到内存中,传入assets和PB文件名

TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(),"rounded_graph.pb");
  • 喂数据给输入节点

这里的参数是输入节点名,输入数据,输入数据的shape

inferenceInterface.feed(input_layer,inputs,1,16384);
  • run session
inferenceInterface.run(new String[] { output_layer }, false);
  • 获取输出数据

根据你在Python定义的输出格式,new一个接收输出数据的变量,从输出节点获取数据

byte[] outPuts = new byte[88];
inferenceInterface.fetch(output_layer,outPuts);
  • 数据变换

从输出节点获取到数据之后就需要你对自己的输出数据进行操作,比如我在我们model里最终输出的结果进行了Argmax的操作,Argmax返回的值类型是Int64的,在Android里只有long对应,但fetch方法的接受变量的参数类型只有double、float、int、byte,所以这里需要使用byte获取,再进行转换。这里跟传统的byte[8]转long有些不同,具体处理方式要看你定义的数据格式,我这里的byte[8]用网上的方法转long发现数值非常大,于是遍历一遍byte[8],发现每个子元素都是相同的数值,所以这里只取第一个元素,组成一个新的数组,再对这个数组进行解析。

long[] tOutputs=new long[11];
for (int i=0;i<11;i++)
{
    int k=i*8;
    tOutputs[i]=outPuts[k];
    Log.i("output",tOutputs[i]+"");
}
String outputStr="";
for(int i=0;i<11;i++){
    long char_idx=tOutputs[i];
    long char_code = 0;
    if (char_idx<10){
        char_code = char_idx + (int)('0');
    }
    else if (char_idx<36){
        char_code = char_idx-10 + (int)('A');
    }
    else if (char_idx<62){
        char_code = char_idx + (int)('a');
    }
    outputStr+= (char)char_code;
}

后记

有Java API确实相比C++来的更直观方便,而且native debug也比JNI好操作,等TensorFlowLite出来的时候,Android TensorFlow应用会更加广泛吧。