在 tensorflow2 中,保存模型有多种方法,方法都在tf.keras.models下面:

  1. 保存模型:save_model 或 model.save,读取模型:load_model。这里可保存成 TensorFlow格式 或者 HDF5 文件。
  2. 只保存模型权重:model.save_weights,读取模型权重:model.load_weights。这里同样可以保存成 TensorFlow格式 或者 HDF5 文件。
  3. 从配置读取模型结构:model.from_config(model.get_config())
  4. 从JSON读取模型结构:model_from_json(model.to_json())
  5. 从YAML读取模型结构:model_from_yaml(model.to_yaml())
  6. 使用tf.keras.callbacks.ModelCheckpoint,保存训练过程中的模型或权重。

下面主要介绍save_model 或 model.saveload_modelmodel.save_weightsmodel.load_weights 和 tf.keras.callbacks.ModelCheckpoint的定义及用法。


一、save_model 或 model.save 保存模型

model.save(
    filepath, overwrite=True, include_optimizer=True, save_format=None,
    signatures=None, options=None, save_traces=True
)

参数:

  • filepath:权重文件路径,如果是h5文件直接是完整文件路径即可。如果是tensorflow模型,可以是模型名称或"文件夹/模型名称",最好是有文件夹,因为保存时会生成多个文件,包括权重与索引文件。
  • overwrite:是否以静默方式覆盖目标位置上的任何现有文件,或向用户提供手动提示。
  • include_optimizer:如果为True,则将优化器的状态保存在一起。
  • save_format:'tf'或'h5',指示是将模型保存到Tensorflow SavedModel还是HDF5。在TF 2.X中默认为'tf',在TF 1.X中默认为'h5'。
  • signatures:要使用SavedModel保存的签名。仅适用于“ tf”格式。有关详细信息,请参见tf.saved_model.save中的signatures参数。
  • options:(仅适用于SavedModel格式)tf.saved_model.SaveOptions对象,该对象指定用于保存到SavedModel的选项。
  • save_traces:(仅适用于SavedModel格式)启用后,SavedModel将存储每个图层的功能跟踪。可以禁用此功能,以便仅存储每个层的配置。默认为True。禁用此功能将减少序列化时间并减小文件大小,但是需要所有自定义图层/模型都实现get_config()方法。

 二、load_model 加载模型

tf.keras.models.load_model(
    filepath, custom_objects=None, compile=True, options=None
)

参数: 

  • filepath:字符串或pathlib.Path对象、已保存模型的路径 或 已加载模型的h5py.File对象。如果是tensorflow格式,那么需要先用 tf.train.latest_checkpoint 获取最后的模型文件路径。
  • custom_objects:反序列化过程中要考虑的可选字典映射名称(字符串)到自定义类或函数。
  • compile:布尔值,是否在加载后编译模型。
  • options:可选的tf.saved_model.LoadOptions对象,用于指定从SavedModel加载的选项。

三、tf.keras.callbacks.ModelCheckpoint 在训练过程中保存模型

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch',
    options=None, **kwargs
)

参数:

filepath:模型路径

monitor:根据该字段,结合mode与save_best_only,保存最优模型

verbose:1.显示详细信息

save_best_only:结合mode与monitor字段,只保存最优模型

save_weights_only:只保存权重,调用的是model.save_weights

mode:'auto'、'min'、'max',最优时是按最大还是最小monitor来判断

save_freq:'epoch'或数字,数字代码步数,而不是epoch


完整Demo,包括5种方法的使用:

import os

import tensorflow as tf

# 训练数据
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

# 定义模型
def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])

  return model

# 创建模型
model = create_model()

# 显示模型结构
model.summary()

checkpoint_path = "./data/model_save_and_load/ep{epoch:04d}-loss{loss:.3f}-val_loss{val_loss:.3f}-acc{val_sparse_categorical_accuracy:.2f}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

last_model_path = tf.train.latest_checkpoint(checkpoint_dir)
if os.path.exists(last_model_path):
  model.load_weights(last_model_path) # 只加载权重,对应model.save_weights保存的文件
  model = tf.keras.models.load_model(last_model_path) # 加载完整模型,对应model.save保存的文件
  print('加载模型:{}'.format(last_model_path))

# 创建callback,用于保存训练时的模型权重
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True, # 只保存权重
                                                 save_best_only=True, # 只保存最佳模型
                                                 monitor='val_sparse_categorical_accuracy',
                                                 mode='max',
                                                 verbose=1)

# 训练模型
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # 通过callback保存训练时的模型

# 只保存权重
model.save_weights(os.path.join(checkpoint_dir,'last_weithts.ckpt'))
# 保存完整模型
model.save(os.path.join(checkpoint_dir,'last_model.ckpt'))

这节内容到此结束,感谢关注,欢迎点赞,哈哈哈