tensorflow保存模型和加载模型的方法(Python和Android)
一、tensorflow保存模型的几种方法:
(1) tf.train.saver()保存模型
使用 tf.train.saver()保存模型,该方法保存模型文件的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方法是在TensorFlow中是最常用的保存方式。
例如:
运行后,会在save目录下保存了四个文件:
其中
- checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
- model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
- ckpt.data : 保存模型中每个变量的取值
(2)tf.train.write_graph()
使用 tf.train.write_graph()保存模型,该方法只是保存了模型的结构,并不保存训练完毕的参数值。
(3)convert_variables_to_constants固化模型结构
很多时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。 我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。
TensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存。而且保存的模型可以移植到Android平台。
参考资料:
【2】这里主要实现第三种方法,因为该方法保存的模型可以移植到Android平台运行。以下Python代码,都共享在
Github:https://github.com/PanJinquan/tensorflow-learning-tutorials/tree/master/MNIST-Demo;
【3】移植Android的详细过程可参考本人的另一篇博客资料《将tensorflow MNIST训练模型移植到Android》:
javascript:void(0)
二、训练和保存模型
以MNIST手写数字识别为例,这里首先使用Python版的TensorFlow实现SoftMax Regression分类器,并将训练好的模型的网络拓扑结构和参数保存为pb文件,其中convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存
三、加载和测试
批量测试:
单个样本测试:
读取图片进行测试:
源码Github:https://github.com/PanJinquan/MNIST-TensorFlow-Python
上面TensorFlow保存训练好的模型,可以移植到Android,详细过程可以参考另一篇博客资料《将tensorflow MNIST训练模型移植到Android》