PB模型保存和模型查看方法

Tensorflow版本:2.1

一.模型的保存方法

2.1 用Tensorflow自带的Keras保存模型

(1)使用model.save()方法

该方法一般只使用一个参数,方法中的参数形式不同,则保存的模型的格式也不同。

函数原型为:

def save(self,
         filepath,
         overwrite=True,
         include_optimizer=True,
         save_format=None,
         signatures=None,
         options=None)
  • model.save("../model/idefi_model")若参数为路径文件夹则保存的模型为pb格式文件。会生成asserts、variables和XXX.pb文件。
  • model.save("../models/idefi_model.h5") 若参数为h5后缀,则保存为XX.h5的模型文件
model.save('XXX.h5')         #传入模型名称,生成h5格式的模型文件
model.save('./XXX/xxx')      #传入文件夹路径,生成pb格式的模型文件

(2)使用tf.keras.models.save_model()方法

该方法与model.save()方法类型,函数原型为:

tf.keras.models.save_model(
    model, filepath, overwrite=True, include_optimizer=True, save_format=None,
    signatures=None, options=None
)

一般与也是指定不同的路径类型便可得到不同的模型文件。如:

tf.keras.models.save_model(model,'./saved_models/wei_less_2')      #模型保存为pb格式
tf.keras.models.save_model(model,'./saved_models/wei_less_2.h5')   #模型保存为h5格式

可使用tf.keras.models.load_model方法加载这两种方法保存的模型,返回的是Keras的模型对象

2.2 用Tensorflow接口保存模型

使用tf.saved_model.save方法保存模型,该方法只能将模型保存为pb格式,函数原型如下:

tf.saved_model.save(
    obj, export_dir, signatures=None, options=None
)

保存的模型可使用tf.saved_model.load方法加载,返回的不是Keras模型对象,不能使用model.predict()model.fit()方法。

二、模型查看工具以及用法

saved_model_cli提供了一种通过命令行检查并恢复模型的机制,如果你的TensorFlow是通过pip安装的,那么saved_model_cli应该已经被一同安装,saved_model_cli主要有两个命令,一个是show,一个是run,我们可以通过如下方式检查该模型的相关信息:

saved_model_cli show  --all --dir=model Path

输出的信息如下:

# 这里显示了标签的信息,标签的名称在Go语言加载模型时需要用到
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:
# 下面显示了一些层的名称,在C++加载pb模型时,需要根据输入输出层层的名称来获取这些层
signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['inputs'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 40, 100, 1)
        name: serving_default_inputs:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['outputs1'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 36)
        name: StatefulPartitionedCall:0
    outputs['outputs2'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 36)
        name: StatefulPartitionedCall:1
    outputs['outputs3'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 36)
        name: StatefulPartitionedCall:2
    outputs['outputs4'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 36)
        name: StatefulPartitionedCall:3
    outputs['outputs5'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 36)
        name: StatefulPartitionedCall:4
  Method name is: tensorflow/serving/predict

在Go中加载模型时,需要获取输入输出的op操作,通过op.Ouput()函数获取输出,所以要找到模型op操作的名字,用一下代码可以查看模型op的名称:

model, err := tf.LoadSavedModel(modelPath, modelsNames, nil) // 载入模型
for _, op := range model.Graph.Operations() {
		//log.Printf("Op name: %v, on device: %v", op.Name(), op.Device())
		log.Printf("Op name: %v", op.Name())
	}

找到输入输出名称后,用graph.Operation("XXX").Output()便可得到输入输出

output, err := session.Run(
		map[tf.Output]*tf.Tensor{
			graph.Operation("serving_default_inputs").Output(0): img_test_tensor,
		},
		[]tf.Output{
			graph.Operation("StatefulPartitionedCall").Output(0),
			graph.Operation("StatefulPartitionedCall").Output(1),
			graph.Operation("StatefulPartitionedCall").Output(2),
			graph.Operation("StatefulPartitionedCall").Output(3),
			graph.Operation("StatefulPartitionedCall").Output(4),
		},nil)