一、PC端搭建模型与训练,保存模型为PB文件
'''
简单起见,创建一个最简单的tensorflow模型,没有实际功能
只为了演示在android studio搭建App
运行代码就会在 你选择的本地目录找打PB文件
'''
import tensorflow as tf
# 此处的输入层定义为常量,为了简单起见,输入层名字为input,简单起见
input_constant = tf.constant([1,2],dtype=tf.float32,name="input")
#输出层名字为output,简单起见
out_data = tf.add(input_constant,input_constant,name="output")
# Session运行,这样就有了模型参数
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(out_data))
#把当前的模型保存为PB文件,PB文件会保存当前tensorflow的模型,将其他值固化为常量
# 第一个参数 sess指定为当前的Session
# 第二个参数是要保存的 图的定义,默认是当前图
# 然后是要输出的节点
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
# 这里是选择要保存的位置
with tf.gfile.FastGFile('android_tensorflow.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
二、配置Android Studio build.gradle(app)的参数,至关重要
apply plugin: 'com.android.application'
//该文件一共有删除需要改动
//第一步:在android/defaultConfig内部添加添加
// multiDexEnabled true
// ndk {
// abiFilters "armeabi-v7a"
// }
//对应于app/libs/armeabi-v7文件夹
//第二步,在android内添加 sourceSets {
// main {
// jni.srcDirs = []
// jniLibs.srcDirs = ['libs']
// }
// }
//对应app/libs文件夹
//第三步
//加入这一行代码就可以,编译之后 我们的Android项目就可以使用tensorflow接口
//不需要Bazle之类的工具,进行一系列的操作
//implementation files('libs/libandroid_tensorflow_inference_java.jar')
android {
compileSdkVersion 27
defaultConfig {
applicationId "com.example.tan.simple"
minSdkVersion 15
targetSdkVersion 27
versionCode 1
versionName "1.0"
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
//1
multiDexEnabled true
ndk {
abiFilters "armeabi-v7a"
}
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
//2
sourceSets {
main {
jni.srcDirs = []
jniLibs.srcDirs = ['libs']
}
}
}
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation 'com.android.support:appcompat-v7:27.1.1'
implementation 'com.android.support.constraint:constraint-layout:1.1.3'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'com.android.support.test:runner:1.0.2'
androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
//3
implementation files('libs/libandroid_tensorflow_inference_java.jar')
}
三、把该有的文件都添加进去
所有存储位置都很清晰地显示在图片了,配置好参数,存储好三个文件,就可以开始准备Java代码了
四、运行Java SDK代码,执行模型加载与调用
package com.example.tan.simple;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.TextView;
import android.widget.Toast;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class MainActivity extends AppCompatActivity {
static { //加载libtensorflow_inference.so库文件
System.loadLibrary("tensorflow_inference");
}
//保存要输入和输出的结果
float[] inputs = new float[]{7,8}; //随机给定值看是否达到想加的效果
float[] outputs = new float[2];
// 这里是那个PB文件的绝对路径
String filename = "android_tensorflow.pb";
TensorFlowInferenceInterface tf;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
try{// 以PB文件创建一个tensorflow的接口
tf = new TensorFlowInferenceInterface(getAssets(),filename);
setTitle("成功加载模型");
}catch (Exception e){
setTitle("加载模型失败");
}
Button btn = findViewById(R.id.btn);
btn.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
//生成随机输入数据,便于观察模型是否可以动态调用
inputs[0] = (float)Math.random();
inputs[1] = (float)Math.random();
// feed 参数, 第一个参数是 张量的名称
// 第二个是一个一维数组存放数据
// 最后指定矩阵的维度,我这里是1行2列
tf.feed("input",inputs,1,2);
//运行要输出的张量
tf.run(new String[]{"output"});
//然后将结果获取到,保存在数组中,方便我们获取
tf.fetch("output",outputs);
//或者用textView等在Android界面显示出来
String ret = "识别结果:\n"+String.valueOf(outputs[0])+"\n"+String.valueOf(outputs[1]);
//然后你可以控制台打印出来,或者用textView等在Android界面显示出来,看有没有达到我们的效果,有没有显示两个矩阵的想加
Toast.makeText(MainActivity.this, ret, Toast.LENGTH_LONG).show();
TextView txt = findViewById(R.id.txt);
txt.setText(ret);
}
});
}
}
五、生成App
此处是一个界面搭建效果图
下面是手机端的显示结果:
至此,本项目完整展示,并且获得了圆满成功。