Tensorflow 模型的保存和读取
tensorflow 保存模型主要有两种途径,一个是通过tf.train.Saver,另一种是通过tf.python.saved_model.builder.SavedModelBuilder
可以保存的内容主要包括Variable/Constant/Placeholder, GraphDef,metaGraph
Graph & GraphDef & MetaGraph
Graph
Graph在tensorflow中可以定义为“Tensor和Operation的集合”,默认状态下,所有的tensor都被加载到default_graph中。
但是在真是情况下,计算往往会被放到GPU或者CPU上进行高速运算,因此使用python肯定无法解决这一问题。实际上呢,tensorflow会将graph序列化(serialized)成Protocaol Buffer(google 自用的序列化格式,可以直接编译出C++代码读取二进制的序列化文件,并进行操作(即计算)),然后再通过C++/CUDA读取Protocol Buffer。并运行计算
GraphDef
从 python Graph中序列化出来的图就叫做 GraphDef(不太严格,先这么理解)。GraphDef实际上是由多个叫NodeDef的ProtocolBuffer组成的。在概念上 NodeDef 与 (Python Graph 中的)Operation 相对应,保存placeholder,constant和operation。如下就是 GraphDef 的 ProtoBuf,由许多node组成的图表示。
假设对于如下的计算图:
node {
name: "Placeholder" # 注释:这是一个叫做 "Placeholder" 的node
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "Placeholder_1" # 注释:这是一个叫做 "Placeholder_1" 的node
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "mul" # 注释:一个 Mul(乘法)操作
op: "Mul"
input: "Placeholder" # 使用上面的node(即Placeholder和Placeholder_1)
input: "Placeholder_1" # 作为这个Node的输入
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
以上三个NodeDef定义了两个placeholder和一个multiply,Placeholder 通过 attr(attribute的缩写)来定义数据类型和 Tensor 的形状。Multiply通过 input 属性定义了两个placeholder作为其输入。无论是 Placeholder 还是 Multiply 都没有关于输出(output)的信息。其实 Tensorflow 中都是通过 Input 来定义 Node 之间的连接信息。
注意:Graphdef中不保存任何 Variable 的信息,所以如果我们从 graph_def 来构建图并恢复训练的话,是不能成功的,比如:
with tf.Graph().as_default() as graph:
tf.import_graph_def("graph_def_path")
saver= tf.train.Saver()
with tf.Session() as sess:
tf.trainable_variables()
其中 tf.trainable_variables() 只会返回一个空的list。Tf.train.Saver() 也会报告 no variables to save。
但是可以恢复GraphDef来进行Inference,因为可以吧weight存成Constant存入到GraphDef中,然后用这些constant的weight和placeholder的input来进行inference。
PS:tensorflow 1.3.0 版本也提供了一套叫做 freeze_graph 的工具来自动的将图中的 Variable 替换成 constant 存储在 GraphDef 里面,并将该图导出为 Proto。
MetaGraph
和GraphDef不同,MetaGraph可以保存Variable的信息
Meta Graph在具体实现上就是一个MetaGraphDef (同样是由 Protocol Buffer来定义的)。其包含了四种主要的信息
1. MetaInfoDef,存一些元信息(比如版本和其他用户信息)
2. GraphDef, MetaGraph的核心内容之一,我们刚刚介绍过
3. SaverDef,图的Saver信息(比如最多同时保存的check-point数量,需保存的Tensor名字等,但并不保存Tensor中的实际内容)
4. CollectionDef 任何需要特殊注意的 Python 对象,需要特殊的标注以方便import_meta_graph 后取回。(比如“train_op”,”prediction”
等等)
Collection: Graph中为方便用户管理变量二加入的集合,通过key(string)来对一组对象进行命名,当然这个key可以使自定义的也可以是TF定义的key(在tf.GraphKeys中定义)
比如可以吧python train_op = tf.train.AdamOptimizer(lr).minimize(loss)
和 python tf.add_to_collection("training_collection", train_op)
加入到collection中,然后保存到MetaGraph里,然后再import_meta_graph后,可以直接用python tf.get_collection("training_collection")
,然后直接python sess.run(train_op)
就可以直接开始进行训练。
注意,从MetaGraph中恢复的图可以进行训练,但所有变量都会从随机初始化的值开始。训练中Variable的实际值都保存在check-point中,如果要从之前训练的状态继续恢复训练,就要从check-point中restore
Saver & SavedModelBuilder
Saver类
构造函数:
__init__(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
restore_sequentially=False,
save_relative_paths=False,
)
–保存模型时–:
var_list:特殊需要保存和恢复的变量和可保存对象列表或字典,默认为空,将会保存所有的可保存对象;
max_to_keep:保存多少个最新的checkpoint文件,默认为5,即保存最近五个checkpoint文件;
keep_checkpoint_every_n_hours:多久保存checkpoint文件,默认为10000小时,相当于禁用了这个功能;
save_relative_paths:为True时,checkpoint文件将不会记录完整的模型路径,而只会仅仅记录模型名字,这方便于将保存下来的模型复制到其他目录并使用的情况;
–恢复模型时–:
reshape:为True时,允许从已保存checkpoint文件里恢复并重新设定形状不一样的张量,默认为false;
sharded:碎片化checkpoint文件到每一个设备,默认false;
restore_sequentially:为True时,会在每个设备中顺序地恢复不同的变量,同时可以在恢复比较大的模型时节省内存;
注意: 所有在初始化saver之后的变量都不会被保存
1. Saver.save()
API:
save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True
)
因为saver保存的是Variable的值,但是只有在session中variable才有确实的值,因此Saver.save需要一个会话,并且所有变量都已经被初始化
sess:一个建好图的会话,用以运行保存操作;
save_path:包含模型名字的绝对路径,最终会自动在模型名字添加相应后缀
global_step:该参数会自动添加到save_path名字用以区别不同步骤保存的模型;
latest_filename:生成检查点文件的名字,默认是“checkpoint”;
meta_graph_suffix:MetaGraphDef元图后缀,默认为“meta”;
write_meta_graph:指明是否要保存元图数据,默认为True;
write_state:指明是否要写CheckpointStateProto,默认为True
2.Saver.restore
该函数恢复一个已保存的模型,它需要一个已建好图结构的会话,恢复模型得到的变量无需初始化,在恢复过程中已有对保存变量做了初始化操作。
saver.restore(sess, save_path)
sess:用以恢复参数模型的会话;
save_path:已保存模型的路径,通常包含模型名字;
注意:saver.restore不会再会话中恢复计算图,saver.save默认会将MetaGraph保存到.meta中,所以呢,需要先import_meta_graph到一个saver中,然后再restore恢复变量的具体值
...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_op', train_op)
sess = tf.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:
# Saves checkpoint, which by default also exports a meta_graph
# named 'my-model-global_step.meta'.
saver.save(sess, 'my-model', global_step=step)
with tf.Session() as sess:
# 将meta-graph读取到当前的graph中,并返回saver
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
# tf.get_collection() returns a list. In this example we only want the
# first one.
train_op = tf.get_collection('train_op')[0]
for step in xrange(1000000):
sess.run(train_op)
3. 其他
======================================================
def write_graph(graph_or_graph_def, logdir, name, as_text=True):
该函数存储一个tensorflow图原型到文件里,其参数含义如下:
graph_or_graph_def:tensorflow Graph或GraphDef;
logdir:保存图或图原型的目录;
as_text:默认为True,即以ASCII方式写到文件里
return:返回图或图原型保存的路径
======================================================
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None):
该函数可加载已存储的”graph_def”到当前默认图里,并从系列化的tensorflow [GraphDef
]协议缓冲里提取所有的tf.Tensor和tf.Operation到当前图里,其参数如下:
graph_def:一个包含图操作OP且要导入GraphDef的默认图;
input_map:字典关键字映射,用以从已保存图里恢复出对应的张量值;
return_elements:从已保存模型恢复的Ops或Tensor对象;
return:从已保存模型恢复后的Ops和Tensorflow列表,其名字位于return_elements;
======================================================
MetaGraph导出方法:
def export_meta_graph(filename=None, collection_list=None, as_text=False, export_scope=None, clear_devices=False, clear_extraneous_savers=False):
该函数可以导出tensorflow元图及其所需的数据,其参数如下:
filename:保存路径及其文件名;
collection_list:要收集的字符串键的列表;
as_text:为True时导出的文本格式为ASCII编码;
export_scope:导出的名字空间,用以删除;
clear_devices:导出时将与设备相关的信息去掉,即导出文件不与特定设备环境关联;
clear_extraneous_savers:从图中删除与此导出操作无关的任何saver相关信息(保存/恢复操作和SaverDefs)。
return:MetaGraphDef proto;
======================================================
def import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs):
该函数以“MetaGraphDef”协议缓冲区作为输入,如果其参数是一个包含“MetaGraphDef”协议缓冲区的文件,它将以文件内容构造一个协议缓冲区,然后将“graph_def”字段中的所有节点添加到当前Graph,并重新创建所有由collection_list收集的列表内容,最后返回由“saver_def”字段构造的saver以供使用,其参数如下:
meta_graph_or_file:MetaGraphDef
协议缓冲区或者包含MetaGraphDef且带有路径的文件名;
clear_devices:导入时将与设备相关的信息去掉,即不与导出时的图设备环境关联,可兼容当前设备环境;
import_scope:导入名字空间,用以删除;
**kwargs:可选的参数;
return:在“MetaGraphDef”中由“saver_def”构造的存储模型,如果MetaGraphDef没有保存的变量则会直接返回None;
SavedModelBuilder & SavedModelLoader
SavedModelBuilder
初始化:
__init__(export_dir)
使用方法
1. 创建一个builder对象
builder = tf.saved_model.builder.SavedModelBuilder("dir")
- 训练完后,添加meta-graph和variable的值
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None)
builder.save()
add_meta_graph_and_variables
def add_meta_graph_and_variables(sess,tags,signature_def_map=None,assets_collection=None,legacy_init_op=None,clear_devices=False,main_op=None):
该函数可以将当前元图添加到SavedModel并保存变量,其参数如下:
sess:用于执行添加元图和变量功能的会话;
tags:用于保存元图的标签;(因为builde可以同时存多个图)
signature_def_map:用于保存元图的签名;
assets_collection:使用SavedModel保存的资源集合;
legacy_init_op:在恢复模型操作后,对Op和Ops组的遗留支持;
clear_devices:如果默认图形上的设备信息应该被清除,则应该设置为true;
main_op:在加载图时执行Op或Ops组的操作。请注意,当main_op被指定时,它将在加载恢复op后运行;
return:无返回
signature_def: 用来identify function输入输出的,它的数据结构是一个key-value形式的,value是input和output的map,常用在tensorflow的部署中,验证输入输出.。参考这里
形式是
signature_def: {
key : "my_classification_signature"
value: {
inputs: {
key : "inputs"
value: {
name: "tf_example:0"
dtype: DT_STRING
tensor_shape: ...
}
}
outputs: {
key : "classes"
value: {
name: "index_to_string:0"
dtype: DT_STRING
tensor_shape: ...
}
}
outputs: {
key : "scores"
value: {
name: "TopKV2:0"
dtype: DT_FLOAT
tensor_shape: ...
}
}
method_name: "tensorflow/serving/classify"
}
}
可以直接通过input和output的tensor生成signature_mao
tf.saved_model.signature_def_utils.build_signature_def(
inputs=None,
outputs=None,
method_name=None
)
SavedModelLoader
我们主要使用load(…)来恢复模型:
def load(sess, tags, export_dir, **saver_kwargs):
该函数可以从标签指定的SavedModel加载模型,其参数如下:
sess:恢复模型的会话;
tags:用于恢复元图的标签,需与保存时的一致,用于区别不同的模型;
export_dir:存储SavedModel协议缓冲区和要加载的变量的目录;
**saver_kwargs:可选的关键字参数传递给saver;
return:在提供的会话中加载的“MetaGraphDef”协议缓冲区,这可以用于进一步提取signature-defs, collection-defs等;
with tf.Session() as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], "dir")
保存目录只有一个pb文件,以及一个variables文件夹,里面存放的是variables.data-00000-of-00001和
variables.index,与save/restore方法比,没有checkpoint检查点文件以及以“.meta”为后缀的元数据文件,但是多了一个pb文件,这是这两种tensorflow保存和恢复模型方法的区别!