PB模型保存和模型查看方法
Tensorflow版本:2.1
一.模型的保存方法
2.1 用Tensorflow自带的Keras保存模型
(1)使用model.save()
方法
该方法一般只使用一个参数,方法中的参数形式不同,则保存的模型的格式也不同。
函数原型为:
def save(self,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None)
model.save("../model/idefi_model")
若参数为路径文件夹则保存的模型为pb格式文件。会生成asserts、variables和XXX.pb文件。model.save("../models/idefi_model.h5")
若参数为h5后缀,则保存为XX.h5的模型文件
model.save('XXX.h5') #传入模型名称,生成h5格式的模型文件
model.save('./XXX/xxx') #传入文件夹路径,生成pb格式的模型文件
(2)使用tf.keras.models.save_model()
方法
该方法与model.save()方法类型,函数原型为:
tf.keras.models.save_model(
model, filepath, overwrite=True, include_optimizer=True, save_format=None,
signatures=None, options=None
)
一般与也是指定不同的路径类型便可得到不同的模型文件。如:
tf.keras.models.save_model(model,'./saved_models/wei_less_2') #模型保存为pb格式
tf.keras.models.save_model(model,'./saved_models/wei_less_2.h5') #模型保存为h5格式
可使用tf.keras.models.load_model
方法加载这两种方法保存的模型,返回的是Keras的模型对象
2.2 用Tensorflow接口保存模型
使用tf.saved_model.save
方法保存模型,该方法只能将模型保存为pb格式,函数原型如下:
tf.saved_model.save(
obj, export_dir, signatures=None, options=None
)
保存的模型可使用tf.saved_model.load
方法加载,返回的不是Keras模型对象,不能使用model.predict()
和model.fit()
方法。
二、模型查看工具以及用法
saved_model_cli
提供了一种通过命令行检查并恢复模型的机制,如果你的TensorFlow是通过pip安装的,那么saved_model_cli
应该已经被一同安装,saved_model_cli
主要有两个命令,一个是show
,一个是run
,我们可以通过如下方式检查该模型的相关信息:
saved_model_cli show --all --dir=model Path
输出的信息如下:
# 这里显示了标签的信息,标签的名称在Go语言加载模型时需要用到
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
# 下面显示了一些层的名称,在C++加载pb模型时,需要根据输入输出层层的名称来获取这些层
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 40, 100, 1)
name: serving_default_inputs:0
The given SavedModel SignatureDef contains the following output(s):
outputs['outputs1'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 36)
name: StatefulPartitionedCall:0
outputs['outputs2'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 36)
name: StatefulPartitionedCall:1
outputs['outputs3'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 36)
name: StatefulPartitionedCall:2
outputs['outputs4'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 36)
name: StatefulPartitionedCall:3
outputs['outputs5'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 36)
name: StatefulPartitionedCall:4
Method name is: tensorflow/serving/predict
在Go中加载模型时,需要获取输入输出的op操作,通过op.Ouput()
函数获取输出,所以要找到模型op操作的名字,用一下代码可以查看模型op的名称:
model, err := tf.LoadSavedModel(modelPath, modelsNames, nil) // 载入模型
for _, op := range model.Graph.Operations() {
//log.Printf("Op name: %v, on device: %v", op.Name(), op.Device())
log.Printf("Op name: %v", op.Name())
}
找到输入输出名称后,用graph.Operation("XXX").Output()
便可得到输入输出
output, err := session.Run(
map[tf.Output]*tf.Tensor{
graph.Operation("serving_default_inputs").Output(0): img_test_tensor,
},
[]tf.Output{
graph.Operation("StatefulPartitionedCall").Output(0),
graph.Operation("StatefulPartitionedCall").Output(1),
graph.Operation("StatefulPartitionedCall").Output(2),
graph.Operation("StatefulPartitionedCall").Output(3),
graph.Operation("StatefulPartitionedCall").Output(4),
},nil)