一、模型保存
为了更好地保存和加载我们已经训练好的模型,TensorFlow使用tf.train.Saver类和checkpoint的机制去实现这一过程,
什么是checkpoints?
是用于存储变量的二进制文件,在其内部使用“ 字典结构 ”存储变量,键 即变量的名字,值 为变量的tensor值。
其中Saver类的定义如下所示:
class Saver(object):
def __init__(self,
var_list=None, #要保存的变量列表
reshape=False, #加载时是否恢复变量形状
sharded=False,
max_to_keep=5, #最大保留几个checkpoint点
keep_checkpoint_every_n_hours=10000.0, #隔多长时间保留一个checkpoint
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None)
1、保存的操作步骤:
saver=tf.train.Saver() #创建saver对象
save_path=saver.save(sess,"titanic_model_saver/titanic_model.ckpt")
print(f"模型的保存路径为{save_path}")
注意事项:
(1)Saver对象在初始化的时候如果没有指定需要存储的变量列表,默认只会自动收集saver定义之前的所有变量,在saver初始化后面的相关变量则不会保存下来。也可以指定保存一些指定的变量:
saver = tf.train.Saver([w1,w2]) #只保存 w1 w2这两个变量
(2)模型保存之后,会在相应的文件夹之下生成4个文件。
.ckpt文件:该文件是真实存储变量及其变量值的文件
.ckpt.meta文件:它是一个描述文件,在这个文件存储的是MetaGraphDef结构的对象经过二进制序列化之后的内容。 MetaGraphDef结构是由Protocol buffer定义的,其中包含了整个计算图的描述、各个变量的定义和声明、输入管道的形式、以及其他的一些信息
.ckpt.index文件:存储变量在checkpoints文件中的位置索引
checkpoint文件:最后还有一个名称为checkpoint的文件,这个是文件中存储了最新存档的文件路径。
2、相关的参数设置
前面的Saver没有添加任何参数,这样的模型存储,只会讲模型最终训练的数据存储起来,即存储最终的“ 稳定 ”的模型。除此之外,还可以引入“ 迭代计数器 ”的方式,即按照训练迭代轮次进行存储。即如下所示:
saver.save(sess,'my_model.ckpt', global_step=step)
这里的global_step和记录日志里面的 writer.add_summary(summary,global_step=step)中的是一个意思,然后会在自动生成的带有测试的轮次和版本号的checkpoint文件。
但是因为每一次迭代记录都会生成一组checkpoint,那么迭代成千上万次之后的训练后会占用大量的磁盘空间,为了防范这种情况,Saver类中的构造函数会有两个参数进行设置。
参数一:max_to_keep :此参数指定存储操作以更迭的方式只会保留最后的5个版本的checkpoint
参数二:keep_checkpoint_every_n_hour: 这种方式以时间为单位,每n个小时存储一个checkpoint,该参数的默认值是10000, 即10000个小时记录一个checkpoint。
当然我们可以在自己定义Saver对象的时候修改这两个参数的值,但是我们一般不推荐这样去做。
二、模型的恢复和加载
1、模型加载的方式
with tf.Session() as sess: #第一步:构造会话对象
#第二步:导入模型的图结构
#第三步:将这个会话绑定到导入的图中
#第三步也可以是这样操作,因为会从mymodel文件夹中获取checkpoint,而checkpoint中存储了最新存档的文件路径
2、获取模型内保存的变量以及相关的参数
(1)直接获取
w1=sess.graph.get_xxxxx() #直接获取,因为sess已经和加载的图进行了绑定
(2)先创建所获取的图对象
graph=tf.get_default_graph() #获取该session所绑定的默认图
w1=graph.get_xxxxxxx()
3、获取模型中tensor的值
(1)方法一
w2=sess.run('w2:0') #"name:index"的形式
print(w2)
(2)方法二
w2=sess.graph.get_operation_by_name('w2:0')
print(sess.run(w2)
由此可见,第一种方法更加简单,获取张量的时候,一定要写成“ name:index ”的形式。