Tensorflow 模型的保存和读取

tensorflow 保存模型主要有两种途径,一个是通过tf.train.Saver,另一种是通过tf.python.saved_model.builder.SavedModelBuilder

可以保存的内容主要包括Variable/Constant/Placeholder, GraphDef,metaGraph

Graph & GraphDef & MetaGraph

参考博客

Graph

Graph在tensorflow中可以定义为“Tensor和Operation的集合”,默认状态下,所有的tensor都被加载到default_graph中。

但是在真是情况下,计算往往会被放到GPU或者CPU上进行高速运算,因此使用python肯定无法解决这一问题。实际上呢,tensorflow会将graph序列化(serialized)成Protocaol Buffer(google 自用的序列化格式,可以直接编译出C++代码读取二进制的序列化文件,并进行操作(即计算)),然后再通过C++/CUDA读取Protocol Buffer。并运行计算

具体Protocol Buffer的信息可以参考这里

GraphDef

从 python Graph中序列化出来的图就叫做 GraphDef(不太严格,先这么理解)。GraphDef实际上是由多个叫NodeDef的ProtocolBuffer组成的。在概念上 NodeDef 与 (Python Graph 中的)Operation 相对应,保存placeholder,constant和operation。如下就是 GraphDef 的 ProtoBuf,由许多node组成的图表示。

假设对于如下的计算图:

tensorflow 保存为pb文件 tensorflow保存模型_Saver

node {
  name: "Placeholder"     # 注释:这是一个叫做 "Placeholder" 的node
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "Placeholder_1"     # 注释:这是一个叫做 "Placeholder_1" 的node
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "mul"                 # 注释:一个 Mul(乘法)操作
  op: "Mul"
  input: "Placeholder"        # 使用上面的node(即Placeholder和Placeholder_1)
  input: "Placeholder_1"      # 作为这个Node的输入
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}

以上三个NodeDef定义了两个placeholder和一个multiply,Placeholder 通过 attr(attribute的缩写)来定义数据类型和 Tensor 的形状。Multiply通过 input 属性定义了两个placeholder作为其输入。无论是 Placeholder 还是 Multiply 都没有关于输出(output)的信息。其实 Tensorflow 中都是通过 Input 来定义 Node 之间的连接信息。

注意:Graphdef中不保存任何 Variable 的信息,所以如果我们从 graph_def 来构建图并恢复训练的话,是不能成功的,比如:

with tf.Graph().as_default() as graph:
  tf.import_graph_def("graph_def_path")
  saver= tf.train.Saver()
  with tf.Session() as sess:
    tf.trainable_variables()

其中 tf.trainable_variables() 只会返回一个空的list。Tf.train.Saver() 也会报告 no variables to save。

但是可以恢复GraphDef来进行Inference,因为可以吧weight存成Constant存入到GraphDef中,然后用这些constant的weight和placeholder的input来进行inference。

PS:tensorflow 1.3.0 版本也提供了一套叫做 freeze_graph 的工具来自动的将图中的 Variable 替换成 constant 存储在 GraphDef 里面,并将该图导出为 Proto。

MetaGraph

和GraphDef不同,MetaGraph可以保存Variable的信息

Meta Graph在具体实现上就是一个MetaGraphDef (同样是由 Protocol Buffer来定义的)。其包含了四种主要的信息
1. MetaInfoDef,存一些元信息(比如版本和其他用户信息)
2. GraphDef, MetaGraph的核心内容之一,我们刚刚介绍过
3. SaverDef,图的Saver信息(比如最多同时保存的check-point数量,需保存的Tensor名字等,但并不保存Tensor中的实际内容)
4. CollectionDef 任何需要特殊注意的 Python 对象,需要特殊的标注以方便import_meta_graph 后取回。(比如“train_op”,”prediction”
等等)

Collection: Graph中为方便用户管理变量二加入的集合,通过key(string)来对一组对象进行命名,当然这个key可以使自定义的也可以是TF定义的key(在tf.GraphKeys中定义)

比如可以吧python train_op = tf.train.AdamOptimizer(lr).minimize(loss)python tf.add_to_collection("training_collection", train_op)加入到collection中,然后保存到MetaGraph里,然后再import_meta_graph后,可以直接用python tf.get_collection("training_collection"),然后直接python sess.run(train_op)就可以直接开始进行训练。

注意,从MetaGraph中恢复的图可以进行训练,但所有变量都会从随机初始化的值开始。训练中Variable的实际值都保存在check-point中,如果要从之前训练的状态继续恢复训练,就要从check-point中restore

Saver & SavedModelBuilder

Saver类

构造函数:

__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    restore_sequentially=False,
    save_relative_paths=False,
)

–保存模型时–

var_list:特殊需要保存和恢复的变量和可保存对象列表或字典,默认为空,将会保存所有的可保存对象;

max_to_keep:保存多少个最新的checkpoint文件,默认为5,即保存最近五个checkpoint文件;

keep_checkpoint_every_n_hours:多久保存checkpoint文件,默认为10000小时,相当于禁用了这个功能;

save_relative_paths:为True时,checkpoint文件将不会记录完整的模型路径,而只会仅仅记录模型名字,这方便于将保存下来的模型复制到其他目录并使用的情况;

–恢复模型时–

reshape:为True时,允许从已保存checkpoint文件里恢复并重新设定形状不一样的张量,默认为false;

sharded:碎片化checkpoint文件到每一个设备,默认false;

restore_sequentially:为True时,会在每个设备中顺序地恢复不同的变量,同时可以在恢复比较大的模型时节省内存;

注意: 所有在初始化saver之后的变量都不会被保存

1. Saver.save()

API:

save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True
)

因为saver保存的是Variable的值,但是只有在session中variable才有确实的值,因此Saver.save需要一个会话,并且所有变量都已经被初始化

sess:一个建好图的会话,用以运行保存操作;

save_path:包含模型名字的绝对路径,最终会自动在模型名字添加相应后缀

global_step:该参数会自动添加到save_path名字用以区别不同步骤保存的模型;

latest_filename:生成检查点文件的名字,默认是“checkpoint”;

meta_graph_suffix:MetaGraphDef元图后缀,默认为“meta”;

write_meta_graph:指明是否要保存元图数据,默认为True;

write_state:指明是否要写CheckpointStateProto,默认为True

2.Saver.restore

该函数恢复一个已保存的模型,它需要一个已建好图结构的会话,恢复模型得到的变量无需初始化,在恢复过程中已有对保存变量做了初始化操作。

saver.restore(sess, save_path)

sess:用以恢复参数模型的会话;
save_path:已保存模型的路径,通常包含模型名字;

注意:saver.restore不会再会话中恢复计算图,saver.save默认会将MetaGraph保存到.meta中,所以呢,需要先import_meta_graph到一个saver中,然后再restore恢复变量的具体值

...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_op', train_op)
sess = tf.Session()
for step in xrange(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # Saves checkpoint, which by default also exports a meta_graph
        # named 'my-model-global_step.meta'.
        saver.save(sess, 'my-model', global_step=step)


with tf.Session() as sess:
  # 将meta-graph读取到当前的graph中,并返回saver
  new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  # tf.get_collection() returns a list. In this example we only want the
  # first one.
  train_op = tf.get_collection('train_op')[0]
  for step in xrange(1000000):
    sess.run(train_op)
3. 其他

======================================================

def write_graph(graph_or_graph_def, logdir, name, as_text=True):

该函数存储一个tensorflow图原型到文件里,其参数含义如下:

graph_or_graph_def:tensorflow Graph或GraphDef;

logdir:保存图或图原型的目录;

as_text:默认为True,即以ASCII方式写到文件里

return:返回图或图原型保存的路径

======================================================

def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None):

该函数可加载已存储的”graph_def”到当前默认图里,并从系列化的tensorflow [GraphDef]协议缓冲里提取所有的tf.Tensor和tf.Operation到当前图里,其参数如下:

graph_def:一个包含图操作OP且要导入GraphDef的默认图;

input_map:字典关键字映射,用以从已保存图里恢复出对应的张量值;

return_elements:从已保存模型恢复的Ops或Tensor对象;

return:从已保存模型恢复后的Ops和Tensorflow列表,其名字位于return_elements;

======================================================

MetaGraph导出方法:

def export_meta_graph(filename=None, collection_list=None, as_text=False, export_scope=None, clear_devices=False, clear_extraneous_savers=False):

该函数可以导出tensorflow元图及其所需的数据,其参数如下:

filename:保存路径及其文件名;

collection_list:要收集的字符串键的列表;

as_text:为True时导出的文本格式为ASCII编码;

export_scope:导出的名字空间,用以删除;

clear_devices:导出时将与设备相关的信息去掉,即导出文件不与特定设备环境关联;

clear_extraneous_savers:从图中删除与此导出操作无关的任何saver相关信息(保存/恢复操作和SaverDefs)。

return:MetaGraphDef proto;

======================================================

def import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs):

该函数以“MetaGraphDef”协议缓冲区作为输入,如果其参数是一个包含“MetaGraphDef”协议缓冲区的文件,它将以文件内容构造一个协议缓冲区,然后将“graph_def”字段中的所有节点添加到当前Graph,并重新创建所有由collection_list收集的列表内容,最后返回由“saver_def”字段构造的saver以供使用,其参数如下:

meta_graph_or_file:MetaGraphDef协议缓冲区或者包含MetaGraphDef且带有路径的文件名;

clear_devices:导入时将与设备相关的信息去掉,即不与导出时的图设备环境关联,可兼容当前设备环境;

import_scope:导入名字空间,用以删除;

**kwargs:可选的参数;

return:在“MetaGraphDef”中由“saver_def”构造的存储模型,如果MetaGraphDef没有保存的变量则会直接返回None;

SavedModelBuilder & SavedModelLoader

SavedModelBuilder

初始化:

__init__(export_dir)

使用方法
1. 创建一个builder对象

builder = tf.saved_model.builder.SavedModelBuilder("dir")
  1. 训练完后,添加meta-graph和variable的值
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None)

builder.save()

add_meta_graph_and_variables

def add_meta_graph_and_variables(sess,tags,signature_def_map=None,assets_collection=None,legacy_init_op=None,clear_devices=False,main_op=None):

该函数可以将当前元图添加到SavedModel并保存变量,其参数如下:

sess:用于执行添加元图和变量功能的会话;

tags:用于保存元图的标签;(因为builde可以同时存多个图)

signature_def_map:用于保存元图的签名;

assets_collection:使用SavedModel保存的资源集合;

legacy_init_op:在恢复模型操作后,对Op和Ops组的遗留支持;

clear_devices:如果默认图形上的设备信息应该被清除,则应该设置为true;

main_op:在加载图时执行Op或Ops组的操作。请注意,当main_op被指定时,它将在加载恢复op后运行;

return:无返回

signature_def: 用来identify function输入输出的,它的数据结构是一个key-value形式的,value是input和output的map,常用在tensorflow的部署中,验证输入输出.。参考这里

形式是

signature_def: {
  key  : "my_classification_signature"
  value: {
    inputs: {
      key  : "inputs"
      value: {
        name: "tf_example:0"
        dtype: DT_STRING
        tensor_shape: ...
      }
    }
    outputs: {
      key  : "classes"
      value: {
        name: "index_to_string:0"
        dtype: DT_STRING
        tensor_shape: ...
      }
    }
    outputs: {
      key  : "scores"
      value: {
        name: "TopKV2:0"
        dtype: DT_FLOAT
        tensor_shape: ...
      }
    }
    method_name: "tensorflow/serving/classify"
  }
}

可以直接通过input和output的tensor生成signature_mao

tf.saved_model.signature_def_utils.build_signature_def(
    inputs=None,
    outputs=None,
    method_name=None
)

SavedModelLoader

我们主要使用load(…)来恢复模型:

def load(sess, tags, export_dir, **saver_kwargs):

该函数可以从标签指定的SavedModel加载模型,其参数如下:

sess:恢复模型的会话;

tags:用于恢复元图的标签,需与保存时的一致,用于区别不同的模型;

export_dir:存储SavedModel协议缓冲区和要加载的变量的目录;

**saver_kwargs:可选的关键字参数传递给saver;

return:在提供的会话中加载的“MetaGraphDef”协议缓冲区,这可以用于进一步提取signature-defs, collection-defs等;

with tf.Session() as sess:  
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], "dir")

保存目录只有一个pb文件,以及一个variables文件夹,里面存放的是variables.data-00000-of-00001和
variables.index,与save/restore方法比,没有checkpoint检查点文件以及以“.meta”为后缀的元数据文件,但是多了一个pb文件,这是这两种tensorflow保存和恢复模型方法的区别!

参考1

参考2