一、什么叫做断点续训

断点续训的意思是因为某些原因还没有训练完成就被中断,下一次训练可以在上一次的训练基础上继续进行。这种方式对于需要长时间训练的模型而言非常友好

二、模型文件解析

tensorflow训练超时 tensorflow暂停训练_机器学习


checkpoint文件会记录保存信息,通过它可以定位最新保存的模型;

.meta文件保存了当前NN的网络结构:tf.train.import_meta_graph(‘MODEL_NAME.ckpt-1174.meta’)

.data文件保存了当前参数名和值,网络权重、偏置、操作等

.index文件保存了辅助索引信息,是一个不可变的字符串表

至于文件名后面的数字1174表示的是模型训练的不同批次,我们一般只需要最新的那个;

三、如何实现断点续训

满足两个必备条件即可:

(1):本地保存了模型训练器中的快照(断点数据保存)
(2):可以通过读取快照恢复模型训练的现场环境(断点数据恢复)
其中,这两个操作都需要用的tensorflow中的train.Saver类。官方说明文档地址

1. 创建tensorflow.train.Saver类

saver = tf.train.Saver(max_to_keep=1) #允许保存训练最新的模型个数,默认值为5,若只保存最新的模型赋值为1

2. 使用Saver对象的save方法保存模型

saver.save(sess,os.path.join(MODEL_SVAE_PATH , MODEL_NAME ),global_step=epoch)
#sess为需要保存的会话,MODEL_SVAE_PATH为模型保存路径,MODEL_NAME 为模型名字, global_step模型训练次数

3. 断点数据的恢复

3.1 只加载模型不加载图

saver.restore(sess, ckpt.model_checkpoint_path) #恢复当前会话,将ckpt中的值赋给w和b
一般断点续训会选用此方法,当网络模型结构并不是很复杂的时候,重新构建会话图也只是毫秒级别。此外,图结构只用加载一回,因为在整个训练过程中,网络结构并不会发生改变。

3.2 图结构与参数都加载

saver = tf.train.import_meta_graph(ckpt+".meta") #加载图结构,也就是神经网络的结构
saver.restore(sess, ckpt.model_checkpoint_path) #恢复当前会话,将ckpt中的值赋给w和b

四、代码详解

在这一部分,我将附上一份完整的代码给首次接触断点续训,想要在程序中实实在在的运用起来的大伙们。代码分为四个部分:加载需要的package,加载数据集,构建网络模型,训练模型。前三部分的代码与断点续训的关联性不大,已经有部分基础的同学可以直接看最后部分。完整代码训练需要有tensorflow>=2.0和cifar10数据集。

温馨提示:

如果只进行模型参数保存和加载的步骤,模型每一次重头训练确实会加载上一次的模型参数,训练的效果会在上一次的基础上继续优化,但是模型的训练次数又重头再次开始,在这种情况下,实际的模型训练次数是无法确定的。但是在某些情况下,我们需要一个确切的模型训练参数与其他基准模型比较才有意义。所以我们需要保留上次训练的epoch,在重新开始新一轮训练时减去,这样才可以保证模型的训练次数一致。

4.1 导入需要的package

import tensorflow.compat.v1 as tf  
tf.disable_v2_behavior()  #tf2.0中placehold函数已经废弃,为了沿用保留tf1.0中的语法
import os                 #方便创建模型保存checkpoint文件夹 
import pickle
import numpy as np
import os
import re               #为了获得字符串中的数字,引入正则表达式

MODEL_SAVE_PATH = './ckpt/'  #saver.save会在该文件夹下自动保存checkpoint文件
MODEL_NAME = 'vgg model'
batch_size = 20
train_steps = 10000
test_steps = 100

CIFAR_DIR = "D:\Dataset\cifar-10-python\cifar-10-batches-py"  #更换为自己cifar10数据集的目录
print(os.listdir(CIFAR_DIR))

4.2 加载数据(此部分与断点续训无关)

def load_data(filename):
    """read data from data file."""
    with open(filename, 'rb') as f:
        data = pickle.load(f, encoding='bytes')
        return data[b'data'], data[b'labels']

# tensorflow.Dataset.
class CifarData:
    def __init__(self, filenames, need_shuffle):
        all_data = []
        all_labels = []
        for filename in filenames:
            data, labels = load_data(filename)
            all_data.append(data)
            all_labels.append(labels)
        self._data = np.vstack(all_data)
        self._data = self._data / 127.5 - 1
        self._labels = np.hstack(all_labels)
        print(self._data.shape)
        print(self._labels.shape)
        
        self._num_examples = self._data.shape[0]
        self._need_shuffle = need_shuffle
        self._indicator = 0
        if self._need_shuffle:
            self._shuffle_data()
            
    def _shuffle_data(self):
        # [0,1,2,3,4,5] -> [5,3,2,4,0,1]
        p = np.random.permutation(self._num_examples)
        self._data = self._data[p]
        self._labels = self._labels[p]
    
    def next_batch(self, batch_size):
        """return batch_size examples as a batch."""
        end_indicator = self._indicator + batch_size
        if end_indicator > self._num_examples:
            if self._need_shuffle:
                self._shuffle_data()
                self._indicator = 0
                end_indicator = batch_size
            else:
                raise Exception("have no more examples")
        if end_indicator > self._num_examples:
            raise Exception("batch size is larger than all examples")
        batch_data = self._data[self._indicator: end_indicator]
        batch_labels = self._labels[self._indicator: end_indicator]
        self._indicator = end_indicator
        return batch_data, batch_labels

train_filenames = [os.path.join(CIFAR_DIR, 'data_batch_%d' % i) for i in range(1, 6)]
test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]

train_data = CifarData(train_filenames, True)
test_data = CifarData(test_filenames, False)

4.3 构建网络模型(与断点续训无关)

x = tf.placeholder(tf.float32, [None, 3072])
y = tf.placeholder(tf.int64, [None])
# [None], eg: [0,5,6,3]
x_image = tf.reshape(x, [-1, 3, 32, 32])
# 32*32
x_image = tf.transpose(x_image, perm=[0, 2, 3, 1])

# conv1: 神经元图, feature_map, 输出图像
conv1_1 = tf.layers.conv2d(x_image,
                           32, # output channel number
                           (3,3), # kernel size
                           padding = 'same',
                           activation = tf.nn.relu,
                           name = 'conv1_1')
conv1_2 = tf.layers.conv2d(conv1_1,
                           32, # output channel number
                           (3,3), # kernel size
                           padding = 'same',
                           activation = tf.nn.relu,
                           name = 'conv1_2')

# 16 * 16
pooling1 = tf.layers.max_pooling2d(conv1_2,
                                   (2, 2), # kernel size
                                   (2, 2), # stride
                                   name = 'pool1')


conv2_1 = tf.layers.conv2d(pooling1,
                           32, # output channel number
                           (3,3), # kernel size
                           padding = 'same',
                           activation = tf.nn.relu,
                           name = 'conv2_1')
conv2_2 = tf.layers.conv2d(conv2_1,
                           32, # output channel number
                           (3,3), # kernel size
                           padding = 'same',
                           activation = tf.nn.relu,
                           name = 'conv2_2')
# 8 * 8
pooling2 = tf.layers.max_pooling2d(conv2_2,
                                   (2, 2), # kernel size
                                   (2, 2), # stride
                                   name = 'pool2')

conv3_1 = tf.layers.conv2d(pooling2,
                           32, # output channel number
                           (3,3), # kernel size
                           padding = 'same',
                           activation = tf.nn.relu,
                           name = 'conv3_1')
conv3_2 = tf.layers.conv2d(conv3_1,
                           32, # output channel number
                           (3,3), # kernel size
                           padding = 'same',
                           activation = tf.nn.relu,
                           name = 'conv3_2')
# 4 * 4 * 32
pooling3 = tf.layers.max_pooling2d(conv3_2,
                                   (2, 2), # kernel size
                                   (2, 2), # stride
                                   name = 'pool3')
# [None, 4 * 4 * 32]
flatten = tf.layers.flatten(pooling3)
y_ = tf.layers.dense(flatten, 10)

loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=y_)
# y_ -> sofmax
# y -> one_hot
# loss = ylogy_

# indices
predict = tf.argmax(y_, 1)
# [1,0,1,1,1,0,0,0]
correct_prediction = tf.equal(predict, y)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))

with tf.name_scope('train_op'):
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

4.4 模型训练及断点续训(关键)

init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=1)   #Saver类提供了保存和回复模型的方法   

# train 10k: 73.4%
with tf.Session() as sess:
    sess.run(init)
    ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)  
    #与上一句等效:ckpt = tf.train.latest_checkpoint("./ckpt/")
    print(ckpt)  #最新保存模型的name

    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)#恢复当前会话,将ckpt中的值赋给w和b
        steped = re.findall(r"\d+\.?\d*", str(ckpt))   #提取字符串中的数字
        print('the step finished last time is ' + steped[0])  #输出上一次已经进行到的训练次数
        steped = int(steped[0])                #保证总的训练次数一定
        print('Model restored...')
    else:
        steped = 0
        print('No model')      
    for step in range(train_steps-steped):
        batch_data, batch_labels = train_data.next_batch(batch_size)
        loss_val, acc_val, _ = sess.run(
            [loss, accuracy, train_op],
            feed_dict={
                x: batch_data,
                y: batch_labels})
        saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=step+steped+1)
        if (step+steped+1) % 100 == 0:      #从上次的训练次数开始计算输出
            print('[Train] Step: %d, loss: %4.5f, acc: %4.5f' 
                  % (step+steped+1, loss_val, acc_val))
        if (step+steped+1) % 1000 == 0:
            test_data = CifarData(test_filenames, False)
            all_test_acc_val = []
            for j in range(test_steps):
                test_batch_data, test_batch_labels \
                    = test_data.next_batch(batch_size)
                test_acc_val = sess.run(
                    [accuracy],
                    feed_dict = {
                        x: test_batch_data, 
                        y: test_batch_labels
                    })
                all_test_acc_val.append(test_acc_val)
            test_acc = np.mean(all_test_acc_val)
            print('[Test ] Step: %d, acc: %4.5f' % (step+steped+1, test_acc))

4.5 断点续训的训练效果展示

tensorflow训练超时 tensorflow暂停训练_pytorch_02