背景
人工智能火了,tensorflow 也火了,Google推出移动版的TensorFlow Lite,作为一个Android开发应该熟悉一下。今天的目标就是能够在移动端也能进行部署深度学习框架,既然Android也能运行TensorFlow 为何不尝试一下,这是程序员们的通病,干就完了。
本次开发环境为TensorFlow 2.1+python 3.7+Android studio 3.6.1
windows10下搭建TensorFlow
这里我只是简单的说一下,毕竟今天的目标不是搭建环境,而是如何在Android上部署TensorFlow。你直接下载anaconda版的python,用命令安装,有坑,不要用pip命令,有两方面有原因,
- 第一、太慢,我无数次下载都是超时,翻墙也不好使;
- 第二、报错(没找到合适的解决办法,统一方法就是降级到1.x,TensorFlow2.0有这个问题 )。
解决办法:采用conda install python 这个命令。如何你对配置环境确实毫无头绪,你可以参考TensorFlow官网,或者百度一大堆教你如何搭建环境的。
编写python代码
今天我用了一个很简单的例子,用TensorFlow 拟合一个函数,y=ax+b,给出x,y,通过TensorFlow 算出a,b;看一下我拟合效果
其实很简单,代码如下:
import tensorflow as tf
# 创建一个简单的 Keras 模型。
x = [-1, 0, 1, 2, 3, 4]
y = [-3, -1, 1, 3, 5, 7]
model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=500)
print(model.predict([1, 3, 7]))
记住([1, 3, 7])得到的结果,我们会在Android也运行这组数据,看结果是否一样。
模型转换
上面是用Keras写例子,既然python 的TensorFlow 已经写好了,那如何在Android上用呢,就需要用到转换,将python代码转成Android 可以用的。
export_dir = 'saved_model/test'
tf.saved_model.save(model, export_dir)
#转换模型。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tflite_model_file = pathlib.Path('saved_model/model.tflite')
tflite_model_file.write_bytes(tflite_model)
这个tf.saved_model.save(model, export_dir) 可能会报错,无法创建目录,你可以手动来创建目录。
在对应的目录下,找到一个.tflite的文件,这个文件就是我们在Android要调用。
新建一个Android 项目
Android 环境需要哪些依赖?
TensorFlow lite文件放在哪?
如何调用?
数据如何输入?
结果在哪?
下面我都会一一说明。
1.依赖
- 在build.gradle中依赖 implementation ‘org.tensorflow:tensorflow-lite:0.0.0-nightly’,这个必不可少的。
- 添加arm支持
defaultConfig {
……
ndk {
abiFilters 'armeabi-v7a', 'arm64-v8a'
}
}
- 防止aapt对模型进行压缩,不然读取模型时会报错。
android{
……
aaptOptions {
noCompress "tflite" //表示不让aapt压缩的文件后缀
}
}
2.如何读取文件并使用
- 将**.tflite**文件放在assets文件下面
- 知道文件放在哪,接下来就要了解如何获取这个模型,并在Android 中使用。
public class TFLiteLoader {
private static Context mContext;
Interpreter mInterpreter;
private static TFLiteLoader instance;
public static TFLiteLoader newInstance(Context context) {
mContext = context;
if (instance == null) {
instance = new TFLiteLoader();
}
return instance;
}
Interpreter get() {
try {
if (Objects.isNull(mInterpreter))
mInterpreter = new Interpreter(loadModelFile(mContext));
} catch (IOException e) {
e.printStackTrace();
}
return mInterpreter;
}
// 获取文件
private MappedByteBuffer loadModelFile(Context context) throws IOException {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd("model.tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
输入输入和输出:
float[][] input = new float[][]{{1, 3, 7}};
float[][] output = new float[3][1];
TFLiteLoader.newInstance(getApplicationContext()).get().run(input, output);
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 1; j++) {
Log.i(TAG, output[i][j] + "");
}
}
运行结果如下: