本问题已经有最佳答案,请猛点这里访问。
我对机器学习和Tensorflow框架比较陌生。我试图使用MNIST手写数字数据集,对受此处介绍的代码影响很大的受过训练的模型,并对我创建的测试示例进行推断。但是,我正在使用GPU的远程计算机上进行培训,并试图将数据保存到目录中,以便可以在本地计算机上传输数据和推断
看来我可以用tf.saved_model.simple_save保存一些模型,但是,我不确定如何使用保存的数据进行推断以及如何使用给定新图像的数据进行预测。似乎有多种方法可以保存模型,但是我不确定使用Tensorflow框架进行哪种惯例或"正确方式"。
到目前为止,这是我认为我需要的行,但是不确定是否正确。
tf.saved_model.simple_save(sess, 'mnist_model',
inputs={'x': self.x},
outputs={'y_': self.y_, 'y_conv':self.y_conv})
如果有人可以指出如何正确保存经过训练的模型以及可以使用哪些变量来推断已保存的模型的方向,我将非常感激。
一种执行此操作的方法是在图形定义中创建一个tf.train.Saver()对象,然后使用该对象将网络保存到指定目录。 然后可以将该目录中的权重从远程计算机下载到本地,然后在本地还原。 这是一个小的示例网络:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# >>>> Config. Vars <<<<
TRAIN_STEPS = 1000
SAVE_EVERY = 100
# >>>> Network <<<<
inputs = tf.placeholder(tf.float32, shape=[None, 784])
labels = tf.placeholder(tf.float32, shape=[None, 10])
h1 = tf.layers.dense(inputs, 256, activation=tf.nn.relu, use_bias=True)
logits = tf.layers.dense(h1, 10, use_bias=True)
predictions = tf.nn.softmax(logits)
prediction_ids = tf.argmax(predictions, axis=1)
# >>>> Loss & Optimisation <<<<
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)
opt = tf.train.AdamOptimizer().minimize(loss)
# >>>> Utilities <<<<
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
# >>>> Training - run on remote, comment out locally <<<<
for i in range(TRAIN_STEPS):
print("Train step {}".format(i), end="
")
batch_data, batch_labels = mnist.train.next_batch(batch_size=128)
feed_dict = {
inputs: batch_data,
labels: batch_labels
}
l, _ = sess.run([loss, opt], feed_dict=feed_dict)
if i % SAVE_EVERY == 0:
saver.save(sess,"saved_model/network_weights.ckpt")
# >>>> Using the network - run locally to use the network <<<
saver.restore(sess,"saved_model/network_weights.ckpt")
test_data, test_labels = mnist.test.images, mnist.test.labels
feed_dict = {
inputs: test_data,
labels: test_labels
}
preds = sess.run(prediction_ids, feed_dict=feed_dict)
print(preds)
因此,一旦您在网络中定义了保护程序,就可以使用它来将权重保存到指定的目录-在这种情况下,将保存在目录" saved_models"中,在运行此特定代码之前需要先创建该目录。
还原模型就像调用saver.restore()然后将会话和权重存储路径传递给它一样简单。 因此,您可以在远程计算机上运行此代码,将" saved_models"目录下载到本地计算机,然后在注释掉训练部分以实际使用模型的情况下运行此代码。
哦,我知道了,我想我没有意识到还原图时需要重建模型并重新定义变量。 谢谢!