TensorFlow模型的存储与恢复
最简单的保存和恢复模型的方法是使用tf.train.Saver对象.
模型的存储
用tf.train.Saver创建一个Saver来存储模型中的所有变量.
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
# 定义两个常量Variable
v1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")
# 变量初始化
init_op = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
save_path = saver.save(sess, "model/model.ckpt")
print "Model saved in file:", save_path
输出:
Model saved in file: model/model.ckpt
可以在model目录下看到:
变量存储在二进制文件里,主要包含从变量到tensor值的映射关系.
模型的恢复
用同一个Saver对象来恢复变量.
当从文件中恢复变量时,不需要事先对变量进行初始化.
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
v1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")
# 当从文件中恢复变量时,不需要事先初始化
# init_op = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
# sess.run(init_op)
saver.restore(sess, "model/model.ckpt")
print "Model:"
print v1.eval()
print v2.eval()
输出:
Model:
[ 1. 1. 1.]
[ 2. 2. 2. 2. 2.]
指定变量存储与恢复
如果不给tf.train.Saver()传入任何参数,则saver将处理graph中的所有变量.
通过给tf.train.Saver()传入python字典或列表,来保持变量及其对应的名称:键对应使用的名称,值对应被管理的变量.
传入字典
存储
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")
init_op = tf.initialize_all_variables()
# 如果不给tf.train.Saver()传入任何参数,则saver将处理graph中的所有变量
saver = tf.train.Saver({"variable_v1":v1})
with tf.Session() as sess:
sess.run(init_op)
save_path = saver.save(sess, "model/model_v1.ckpt")
print "Model saved in file:", save_path
输出:
Model saved in file: model/model_v1.ckpt
可以在model目录下看到:
恢复
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
v1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")
saver = tf.train.Saver({"variable_v1":v1})
with tf.Session() as sess:
# sess.run(init_op)
saver.restore(sess, "model/model_v1.ckpt")
print "Model v1:"
print v1.eval()
# 或使用sess.run(v1)
# print sess.run(v1)
输出:
Model v1:
[ 1. 1. 1.]
传入列表
存储
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")
init_op = tf.initialize_all_variables()
saver = tf.train.Saver([v1, v2])
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, "model/model_v1v2.ckpt")
恢复
import tensorflow as tf
v1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")
saver = tf.train.Saver([v1])
with tf.Session() as sess:
saver.restore(sess, "model/model_v1v2.ckpt")
print sess.run(v1)
输出:
[ 1. 1. 1.]
创建多个saver对象
需要保存和恢复变量的不同子集时可以创建任意多个saver对象.
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")
init_op = tf.initialize_all_variables()
saver1 = tf.train.Saver({"variable_v1":v1})
saver2 = tf.train.Saver({"variable_v2":v2})
with tf.Session() as sess:
sess.run(init_op)
saver1.save(sess, "model/model_v1.ckpt")
saver2.save(sess, "model/model_v2.ckpt")
可以在model目录下看到:
同一个变量也可被列入多个saver对象中,只有saver的restore()函数被运行时它的值才会被改变.
完整示例
模型存储
创建一个简单的TensorFlow模型用于二维数据的线性回归.定义一个Saver对象,并且在train_graph()方法中,通过100次迭代来最小化损失函数.然后,模型在每次迭代中以及优化完成后保存到本地磁盘.每次保存都会在磁盘上创建名为“checkpoint”的二进制文件.
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.reset_default_graph()
# 为x和y点创建占位符
X = tf.placeholder("float")
Y = tf.placeholder("float")
# 初始化需要学习的两个参数
h_est = tf.Variable(0.0, name='hor_estimate')
v_est = tf.Variable(0.0, name='ver_estimate')
# y_est保存y轴上的估计值
y_est = tf.square(X - h_est) + v_est
# 将损失函数定义为Y和y_est之间的平方距离
cost = (tf.pow(Y - y_est, 2))
# 最小化损失函数,学习率为0.001
trainop = tf.train.GradientDescentOptimizer(0.001).minimize(cost)
# 水平和垂直方向进行移动
h = 1
v = -2
# 在训练数据中添加噪音
x_train = np.linspace(-2,4,201)
noise = np.random.randn(*x_train.shape) * 0.4
y_train = (x_train - h) ** 2 + v + noise
# 创建一个Saver对象
saver = tf.train.Saver()
init = tf.global_variables_initializer()
# 迭代100次
def train_graph():
with tf.Session() as sess:
sess.run(init)
for i in range(100):
for (x, y) in zip(x_train, y_train):
# 将实际数据传入
sess.run(trainop, feed_dict={X: x, Y: y})
# print(x,y)
# 在每次迭代中创建一个检查点
saver.save(sess, './model_iter', global_step=i)
# 保存最终模型
saver.save(sess, './model_final')
h_ = sess.run(h_est)
v_ = sess.run(v_est)
return h_, v_
if __name__=="__main__":
result = train_graph()
print("h_est = %.2f, v_est = %.2f" % result)
# 可视化数据
plt.rcParams['figure.figsize'] = (10, 6)
plt.scatter(x_train, y_train)
plt.xlabel('x_train')
plt.ylabel('y_train')
plt.show()
保存模型时,有4种类型的文件来保存它:
- “.meta”文件:包含图结构.
- “.data”文件:包含变量的值.
- “.index”文件:标识检查点.
- “checkpoint”文件:包含最近检查点列表的protocol buffer.
Saver构造函数的其他一些参数可以控制整个过程:
max_to_keep:保留的最大检查点数量;
keep_checkpoint_every_n_hours:保存检查点的时间间隔.
模型恢复
在下面的例子中加载模型,并打印出两个系数的数值h_est和v_est:
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.reset_default_graph()
# 为x和y点创建占位符
X = tf.placeholder("float")
Y = tf.placeholder("float")
# 初始化需要学习的两个参数
h_est = tf.Variable(0.0, name='hor_estimate')
v_est = tf.Variable(0.0, name='ver_estimate')
# y_est保存y轴上的估计值
y_est = tf.square(X - h_est) + v_est
# 将损失函数定义为Y和y_est之间的平方距离
cost = (tf.pow(Y - y_est, 2))
# 最小化损失函数,学习率为0.001
trainop = tf.train.GradientDescentOptimizer(0.001).minimize(cost)
# 水平和垂直方向进行移动
h = 1
v = -2
# 在训练数据中添加噪声
x_train = np.linspace(-2,4,201)
noise = np.random.randn(*x_train.shape) * 0.4
y_train = (x_train - h) ** 2 + v + noise
tf.reset_default_graph()
imported_meta = tf.train.import_meta_graph("./model_final.meta")
with tf.Session() as sess:
# 恢复
imported_meta.restore(sess, tf.train.latest_checkpoint('./'))
h_est2 = sess.run('hor_estimate:0')
v_est2 = sess.run('ver_estimate:0')
print("h_est: %.2f, v_est: %.2f" % (h_est2, v_est2))
plt.scatter(x_train, y_train, label='train data')
plt.plot(x_train, (x_train - h_est2) ** 2 + v_est2, color='red', label='model')
plt.xlabel('x_train')
plt.ylabel('y_train')
plt.legend()
plt.show()
参考:
TensorFlow: Save and Restore Models