创建Tensorflow的模型
在Android平台受到设备的限制,本身并不能训练模型,因此需要使用已有的模型。
在本文中将介绍如何将Tensorflow的模型转换成tflite模型,为Android设备可以使用。
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
import matplotlib.pyplot as plt
# 读取训练用的输入特征和标签
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train =x_train.reshape(x_train.shape[0],28,28,1).astype("float32")
x_test = x_test.reshape(x_test.shape[0],28,28,1).astype("float32")
# 输入特征归一化,减小计算量,将图片默认为0~255之间的数字,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0
class_name = ["T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子"]
class FashionModel_CNN(tf.keras.Model):
"""
定义CNN网络结构
"""
def __init__(self):
super().__init__()
self.conv1 = tf.keras.layers.Conv2D(filters=32, kernel_size=[5,5], padding='valid',input_shape=(28,28,1),activation=tf.nn.relu)
self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2,2],strides = 2)
self.conv2 = tf.keras.layers.Conv2D(filters=64,kernel_size=[5,5],padding="same",activation = tf.nn.relu)
self.pool2 = tf.keras.layers.MaxPool2D(pool_size = [2,2],strides = 2)
self.flatten = Flatten() #tf.keras.layers.Reshape(target_shape=(28*28*64,))
self.dense1 = tf.keras.layers.Dense(units = 128,activation = tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(units=10,activation = "softmax")
def call(self,inputs):
x = self.conv1(inputs)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.dense1(x)
x = self.dense2(x)
output = tf.nn.softmax(x)
return output
def generate_nn(x_train,y_train,x_test,y_test):
# 声明神经网络对象
model = FashionModel_CNN()
# 配置训练方法(优化器,损失函数,评测指标)
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[tf.keras.metrics.sparse_categorical_accuracy])
# 执行训练过程
model.fit(x_train, y_train,
batch_size=32, epochs=10,
validation_data=(x_test, y_test),
validation_freq=1)
# 打印网络结构和参数
model.summary()
return model
model = generate_nn(x_train,y_train,x_test,y_test)
tf.saved_model.save(model, "./forAndroid")
在Python应用tensorflow的模型并测试
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
from tensorflow.python.keras.preprocessing.image import load_img
import matplotlib.pyplot as plt
#加载从指定的目录中加载模型
model = tf.saved_model.load("./forAndroid")
# 读取训练用的输入特征和标签
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train = x_train.reshape(x_train.shape[0],28,28,1).astype("float32")
x_test = x_test.reshape(x_test.shape[0],28,28,1).astype("float32")
# 输入特征归一化,减小计算量,将图片默认为0~255之间的数字,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0
class_name = ["T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子"]
#测试数据集的路径
path = "./test_data/exam_fashion/exam_fashion/"
images = ["%d.jpeg"%i for i in range(0,10)]
matrix = np.full(784,255.0).reshape(28,28)
images_data=[]
for imgfile in images:
img_name = "%s%s"%(path,imgfile)
print(img_name)
img = tf.keras.preprocessing.image.load_img(img_name,color_mode="grayscale",target_size=(28,28))
img_data = tf.keras.preprocessing.image.img_to_array(img)
img_data = matrix-img_data.reshape(28,28)
#将numpy数组的数据从float64转换成float32
images_data.append(img_data.astype('float32'))
#测试样本
images_data = np.array(images_data).reshape(10,28,28,1)
print(x_test.dtype,images_data.dtype)
y_pred = model(images_data)
index_list = np.argmax(y_pred,axis=1)
img_id = 0
index = 1
for i in index_list:
img_id +=1
print(index,i,class_name[i])
index = index+1
利用tflite_convert将tensorflow模型导出tflite
在控制台中执行:
tflite_convert __saved_model_dir=forAndroid __output_file=model.tflite
执行出现错误,错误如下:
2021-11-25 09:19:56.975579: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
2021-11-25 09:19:56.979582: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: DESKTOP-4CLUK38
2021-11-25 09:19:56.979752: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: DESKTOP-4CLUK38
2021-11-25 09:19:56.991993: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Traceback (most recent call last):
File "d:\anaconda3\envs\tensorflow\lib\runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "d:\anaconda3\envs\tensorflow\lib\runpy.py", line 87, in _run_code
exec(code, run_globals)
File "D:\anaconda3\envs\tensorflow\Scripts\tflite_convert.exe\__main__.py", line 7, in <module>
File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\lite\python\tflite_convert.py", line 697, in main
app.run(main=run_main, argv=sys.argv[:1])
File "d:\anaconda3\envs\tensorflow\lib\site-packages\absl\app.py", line 303, in run
_run_main(main, args)
File "d:\anaconda3\envs\tensorflow\lib\site-packages\absl\app.py", line 251, in _run_main
sys.exit(main(argv))
File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\lite\python\tflite_convert.py", line 680, in run_main
_convert_tf2_model(tflite_flags)
File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\lite\python\tflite_convert.py", line 281, in _convert_tf2_model
converter = lite.TFLiteConverterV2.from_saved_model(
File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\lite\python\lite.py", line 1348, in from_saved_model
saved_model = _load(saved_model_dir, tags)
File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 864, in load
result = load_internal(export_dir, tags, options)["root"]
File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 902, in load_internal
loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 162, in __init__
self._load_all()
File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 259, in _load_all
self._load_nodes()
File "d:\anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\load.py", line 448, in _load_nodes
slot_variable = optimizer_object.add_slot(
AttributeError: '_UserObject' object has no attribute 'add_slot'
根据错误提示要求rebuild Tensorflow
在官方提供的解决方法是使用bazel对tflite_convert编译
形如:
bazel run tflite_convert saved_model_dir=目录 __output_file=目标.tflite
因为在windows10安装bazel的代价太大
因此调整解决思路
解决的方法
于是转而使用运行代码来实现转换,希望得到更多的信息
编辑如下代码:
import tensorflow as tf
#指定保存模型的目录
saved_model_dir = "forAndroid"
# 转换模型
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
#转换模型
tflite_model = converter.convert()
# 将转换的模型保持到指定的文件model.tflite
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
出现的问题
运行转换程序出现错误,错误的内容如下:
提示信息出现错误的日志
2021-11-25 09:28:46.532142: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
2021-11-25 09:28:46.535042: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: DESKTOP-4CLUK38
2021-11-25 09:28:46.535137: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: DESKTOP-4CLUK38
2021-11-25 09:28:46.535300: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
......
日志的错误信息量很大,增加如下内容
import tensorflow as tf
import os
#指定日志的级别,设置为错误和警告信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
#指定保存模型的目录
saved_model_dir = "forAndroid"
# 转换模型
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
#转换模型
tflite_model = converter.convert()
# 将转换的模型保持到指定的文件model.tflite
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
忽视掉基本提示的警告,发现已经成功生成tflite文件。