最近因为项目要求,需要把模型的训练和测试过程分开,这里主要涉及两个过程:训练图的存取和参数的存取。
以下所有/home/yy/xiajbxie/model是我的模型的存储路径,将其换成你自己的即可。
tf.train.Saver()
Saver的作用中文社区已经讲得相当清楚。tf.train.Saver()类的基本操作时save()和restore()函数,分别负责模型参数的保存和恢复。参数保存示例如下:
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# initialize the variables, save the variables to disk.
with tf.Session() as sess:
sess.run(init_op)
v1, v2 = sess.run([v1, v2])
print(v1)
print(v2)
# Do some work with the model.
# Save the variables to disk.
save_path = saver.save(sess, "/home/yy/xiajbxie/model")
print "Model saved in file: ", save_path
运行结果:
[[-0.0493206 0.12752049]]
[[ 1.9456626 0.6319563 -0.1296857 ]
[-0.7834143 0.33656874 -0.96077037]]
Model saved in file: /home/yy/xiajbxie/model
参数恢复示例如下:
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/home/yy/xiajbxie/model")
print "Model restored."
print(sess.run([v1, v2]))
运行结果:
Model restored.
[array([[-0.0493206 , 0.12752049]], dtype=float32), array([[ 1.9456626 , 0.6319563 , -0.1296857 ],
[-0.7834143 , 0.33656874, -0.96077037]], dtype=float32)]
saver.save()
函数的参数为需保存的会话,以及模型的存储路径。保存后我们进入模型的存储路径会看到4个新增文件,4个文件根据tensorflow版本不同名字不同,以上例为例,1.2版本4个文件如下:
1. checkpoint
:其中存储模型所在的路径
2. model.meta
:包含计算图的完整信息
3. model.index
:与下面的文件一起保存所有的变量值
4. model.data-00000-of-00001
可以看到,在模型参数恢复前需事先定义要恢复的变量,并且变量名需要与模型中存储的变量名保持一致。
官方文档的说法是无需在参数恢复前对其进行初始化,但实际操作的时候有出现过报错“FailedPreconditionError (see above for traceback): Attempting to use uninitialized value”的情况,此时利用tf.global_variables_initializer()
初始化变量可解决问题。
tf.train.import_meta_graph()
模型参数恢复之前需要先定义模型中保存的变量,如果不想这样做可以把模型的计算图也恢复出来。tf.train.import_meta_graph()
函数就用于恢复模型,它的输入参数为模型路径,返回一个Saver类实例,再调用这个实例的restore()函数就可以恢复其参数了。示例如下:
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('/home/yy/xiajbxie/model.meta')
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('/home/yy/xiajbxie')
if ckpt and ckpt.model_checkpoint_path:
print ckpt.model_checkpoint_path
new_saver.restore(sess, ckpt.model_checkpoint_path)
v1 = tf.get_default_graph().get_tensor_by_name('v1:0')
v2 = tf.get_default_graph().get_tensor_by_name('v2:0')
print(sess.run([v1, v2]))
执行结果:
/home/yy/xiajbxie/model
[array([[-0.0493206 , 0.12752049]], dtype=float32), array([[ 1.9456626 , 0.6319563 , -0.1296857 ],
[-0.7834143 , 0.33656874, -0.96077037]], dtype=float32)]
其中get_checkpoint_state()用于在传入的路径中寻找tensorflow检查点。
tips
- 在不知道要重载的tensor叫什么名字时可以在训练阶段打印变量名来观察。
- 不能在与训练数据相同的计算图下载入以前保存的计算图,如果实在要这样做也要保证两个计算图中不包含名字相同的变量。
- 利用tf.Graph()来生成新的计算图,利用tf.Graph().as_default()来将新生成的计算图设置为默认。