在Tensorflow中,有两种保存模型的方法:一种是Checkpoint,另一种是Protobuf,也就是PB格式;
一. Checkpoint方法:
1.保存时使用方法:
tf.train.Saver()
生成四个文件:
checkpoint 检查点文件
model.ckpt.data-xxx 参数值
model.ckpt.index 各个参数
model.ckpt.meta 图的结构
2.恢复时使用方法:
saver.restore() :模型文件依赖Tensorflow,只能在其框架下使用,恢复模型之前需要定义下网络结构
saver=tf.train.import_meta_graph('./ckpt/mode..ckpt.meta') :直接加载网络结构,不需要重新定义网络
二. PB方法:
1. 保存模型为PB文件(谷歌推荐),具有语言独立性,可独立运行,序列化的格式,任何语言可解析它,允许其他语言和框架读取,训练和迁移;模型变量是固定的,模型大小会大大减少,适合在手机端运行;
2. 实现创建模型与使用模型的解耦,使得前向推导Inference代码统一;
3. PB文件表示MetaGraph的protocol buffer格式的文件;
4. GraphDef 不保存任何Variable信息,不能从graph_def 来构建图并恢复训练.
一般情况下,PB可直接生成;
当然也可以从checkpoint文件中生成,代码如下:


1 output_graph = os.path.join('./checkpoint/','frozen_graph.pb')
2 input_checkpoint = os.path.join('./checkpoint/','model.ckpt-xxxxx') #[xxxxxx为训练生成的step号]
3 saver = tf.train.import_meta_graph(input_checkpoint+'.meta',clear_devices=True)
4 graph = tf.get_default_graph()
5 input_graph_def = graph.as_graph_def
6
7 for op in graph.get_operations():
8 print("checkpoint2pb",,op.values())
9
10 variable_names = [ for v in tf.trainable_variables()]
11 pirnt("trainalbe_variables:",variable_names)
12
13 output_node_name=['fc2/add'] #fc2/add 上面的列表里需要存在该操作
14
15 with tf.Session() as sess:
16 saver.restore(sess,input_checkpoint)
17
18 output_graph_def = graph_util.convert_variables_to_constants(sess=sess,
19 input_graph_def = input_graph_def,
20 output_node_names = output_node_name)
21
22 with tf.gfile.GFile(output_graph,"wb") as f:
23 f.write(output_graph_def.SerializeToString())
24
25
26
27
人生,从没有一劳永逸 想要变强,只有不停奔跑
















