一、模型保存

为了更好地保存和加载我们已经训练好的模型,TensorFlow使用tf.train.Saver类和checkpoint的机制去实现这一过程,

什么是checkpoints?

        是用于存储变量的二进制文件,在其内部使用“ 字典结构 ”存储变量,键 即变量的名字,值 为变量的tensor值。

其中Saver类的定义如下所示:

class Saver(object):
    def __init__(self,
               var_list=None,                          #要保存的变量列表
               reshape=False,                          #加载时是否恢复变量形状
               sharded=False,
               max_to_keep=5,                          #最大保留几个checkpoint点
               keep_checkpoint_every_n_hours=10000.0,  #隔多长时间保留一个checkpoint
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False,
               filename=None)

1、保存的操作步骤:

saver=tf.train.Saver()    #创建saver对象

save_path=saver.save(sess,"titanic_model_saver/titanic_model.ckpt")

print(f"模型的保存路径为{save_path}")

注意事项:

(1)Saver对象在初始化的时候如果没有指定需要存储的变量列表,默认只会自动收集saver定义之前的所有变量,在saver初始化后面的相关变量则不会保存下来。也可以指定保存一些指定的变量:

    saver = tf.train.Saver([w1,w2])   #只保存 w1 w2这两个变量

(2)模型保存之后,会在相应的文件夹之下生成4个文件。

.ckpt文件:该文件是真实存储变量及其变量值的文件

.ckpt.meta文件:它是一个描述文件,在这个文件存储的是MetaGraphDef结构的对象经过二进制序列化之后的内容。             MetaGraphDef结构是由Protocol buffer定义的,其中包含了整个计算图的描述、各个变量的定义和声明、输入管道的形式、以及其他的一些信息

.ckpt.index文件:存储变量在checkpoints文件中的位置索引

checkpoint文件:最后还有一个名称为checkpoint的文件,这个是文件中存储了最新存档的文件路径。

2、相关的参数设置

前面的Saver没有添加任何参数,这样的模型存储,只会讲模型最终训练的数据存储起来,即存储最终的“ 稳定 ”的模型。除此之外,还可以引入“ 迭代计数器 ”的方式,即按照训练迭代轮次进行存储。即如下所示:

      saver.save(sess,'my_model.ckpt', global_step=step)

这里的global_step和记录日志里面的  writer.add_summary(summary,global_step=step)中的是一个意思,然后会在自动生成的带有测试的轮次和版本号的checkpoint文件。

但是因为每一次迭代记录都会生成一组checkpoint,那么迭代成千上万次之后的训练后会占用大量的磁盘空间,为了防范这种情况,Saver类中的构造函数会有两个参数进行设置。

参数一:max_to_keep :此参数指定存储操作以更迭的方式只会保留最后的5个版本的checkpoint

参数二:keep_checkpoint_every_n_hour: 这种方式以时间为单位,每n个小时存储一个checkpoint,该参数的默认值是10000,                 即10000个小时记录一个checkpoint。

当然我们可以在自己定义Saver对象的时候修改这两个参数的值,但是我们一般不推荐这样去做。

二、模型的恢复和加载

1、模型加载的方式

with tf.Session() as sess:   #第一步:构造会话对象

#第二步:导入模型的图结构

#第三步:将这个会话绑定到导入的图中

#第三步也可以是这样操作,因为会从mymodel文件夹中获取checkpoint,而checkpoint中存储了最新存档的文件路径

2、获取模型内保存的变量以及相关的参数

      (1)直接获取

               w1=sess.graph.get_xxxxx()        #直接获取,因为sess已经和加载的图进行了绑定

      (2)先创建所获取的图对象

              graph=tf.get_default_graph()   #获取该session所绑定的默认图

              w1=graph.get_xxxxxxx()

3、获取模型中tensor的值

      (1)方法一

               w2=sess.run('w2:0')            #"name:index"的形式

               print(w2)

      (2)方法二

              w2=sess.graph.get_operation_by_name('w2:0')  

              print(sess.run(w2)

由此可见,第一种方法更加简单,获取张量的时候,一定要写成“ name:index ”的形式。