最近因为项目要求,需要把模型的训练和测试过程分开,这里主要涉及两个过程:训练图的存取和参数的存取。
以下所有/home/yy/xiajbxie/model是我的模型的存储路径,将其换成你自己的即可。

tf.train.Saver()

Saver的作用中文社区已经讲得相当清楚。tf.train.Saver()类的基本操作时save()和restore()函数,分别负责模型参数的保存和恢复。参数保存示例如下:

import tensorflow as tf

# Create some variables.
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")

# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# initialize the variables, save the variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  v1, v2 = sess.run([v1, v2])
  print(v1)
  print(v2)
  # Do some work with the model.
  # Save the variables to disk.
  save_path = saver.save(sess, "/home/yy/xiajbxie/model")
  print "Model saved in file: ", save_path

运行结果:

[[-0.0493206   0.12752049]]
[[ 1.9456626   0.6319563  -0.1296857 ]
 [-0.7834143   0.33656874 -0.96077037]]
Model saved in file:  /home/yy/xiajbxie/model

参数恢复示例如下:

import tensorflow as tf

# Create some variables.
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")


# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/home/yy/xiajbxie/model")
  print "Model restored."
  print(sess.run([v1, v2]))

运行结果:

Model restored.
[array([[-0.0493206 ,  0.12752049]], dtype=float32), array([[ 1.9456626 ,  0.6319563 , -0.1296857 ],
       [-0.7834143 ,  0.33656874, -0.96077037]], dtype=float32)]

saver.save()函数的参数为需保存的会话,以及模型的存储路径。保存后我们进入模型的存储路径会看到4个新增文件,4个文件根据tensorflow版本不同名字不同,以上例为例,1.2版本4个文件如下:
1. checkpoint:其中存储模型所在的路径
2. model.meta:包含计算图的完整信息
3. model.index:与下面的文件一起保存所有的变量值
4. model.data-00000-of-00001

可以看到,在模型参数恢复前需事先定义要恢复的变量,并且变量名需要与模型中存储的变量名保持一致。
官方文档的说法是无需在参数恢复前对其进行初始化,但实际操作的时候有出现过报错“FailedPreconditionError (see above for traceback): Attempting to use uninitialized value”的情况,此时利用tf.global_variables_initializer()初始化变量可解决问题。

tf.train.import_meta_graph()

模型参数恢复之前需要先定义模型中保存的变量,如果不想这样做可以把模型的计算图也恢复出来。tf.train.import_meta_graph()函数就用于恢复模型,它的输入参数为模型路径,返回一个Saver类实例,再调用这个实例的restore()函数就可以恢复其参数了。示例如下:

import tensorflow as tf

sess = tf.Session()
new_saver = tf.train.import_meta_graph('/home/yy/xiajbxie/model.meta')

with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('/home/yy/xiajbxie')
    if ckpt and ckpt.model_checkpoint_path:
        print ckpt.model_checkpoint_path
        new_saver.restore(sess, ckpt.model_checkpoint_path)

    v1 = tf.get_default_graph().get_tensor_by_name('v1:0')
    v2 = tf.get_default_graph().get_tensor_by_name('v2:0')
    print(sess.run([v1, v2]))

执行结果:

/home/yy/xiajbxie/model
[array([[-0.0493206 ,  0.12752049]], dtype=float32), array([[ 1.9456626 ,  0.6319563 , -0.1296857 ],
       [-0.7834143 ,  0.33656874, -0.96077037]], dtype=float32)]

其中get_checkpoint_state()用于在传入的路径中寻找tensorflow检查点。

tips

  • 在不知道要重载的tensor叫什么名字时可以在训练阶段打印变量名来观察。
  • 不能在与训练数据相同的计算图下载入以前保存的计算图,如果实在要这样做也要保证两个计算图中不包含名字相同的变量。
  • 利用tf.Graph()来生成新的计算图,利用tf.Graph().as_default()来将新生成的计算图设置为默认。