本文参考于博客,但是增加了自己的一些见解和修改,主要实现的是tensorflow保存模型、加载模型、修改模型、保存修改后的模型、使用修改后的模型做推理、模型转pb和使用pb做推理。
1.创建和保存模型:

import tensorflow as tf

w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1 = tf.Variable(2.0, name="bias")

# 定义一个op,用于后面恢复
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # 创建一个Saver对象,用于保存所有变量
    saver = tf.train.Saver()

    # 通过传入数据,执行op
    print(sess.run(w4, feed_dict={w1: 4, w2: 8}))
    # 打印 24.0 ==>(w1+w2)*b1

    # 现在保存模型
    saver.save(sess, './checkpoint_dir1/MyModel', global_step=1000)

输出结果:

24.0

2.加载模型并增加自己的op,然后保存修改后模型(注意看程序注解)

import tensorflow as tf
with tf.Session() as sess:
    # 先加载图和参数变量
    saver = tf.train.import_meta_graph('./checkpoint_dir1/MyModel-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./checkpoint_dir1'))
    # 访问placeholders变量,并且创建feed-dict来作为placeholders的新值
    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")

    # 接下来,访问你想要执行的op
    op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

    # 在当前图中能够加入op
    add_on_op = tf.multiply(op_to_restore, 2,name='add_on_op')
    
    #我自己为了验证增加的操作,没有指定名字的变量,tensorflow会自动生成名字,如下add_on_op_10和add_on_op_10_10会自动生成"add_on_op_10/add:0"和"add_on_op_10/add_1:0"
    with tf.name_scope('add_on_op_10'):
        add_on_op_10 = add_on_op + 10
        add_on_op_10_10 = add_on_op_10 + 10
    print("add_on_op_10_10",add_on_op_10_10)      #可以打印变量,获得后面执行该op时的名字  打印输出:add_on_op_10_10 Tensor("add_on_op_10/add_1:0", dtype=float32),其中"add_on_op_10/add_1:0"就是op的名字

    sess.run(tf.initialize_all_variables())
    saver.save(sess,'./checkpoint_dir2/MyModel', global_step=2000)

    print(sess.run(add_on_op_10_10,feed_dict={w1:13.0,w2:17.0}))

输出结果:

140.0

3.使用修改后的模型做推理

import tensorflow as tf
with tf.Session() as sess:
    #先加载图和参数变量
    saver = tf.train.import_meta_graph('./checkpoint_dir2/MyModel-2000.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir2'))

    # 访问placeholders变量
    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")

    # 接下来,访问你想要执行的op
    add_on_op = graph.get_tensor_by_name("add_on_op:0")
    add_on_op_10_10 = graph.get_tensor_by_name("add_on_op_10/add_1:0")
    
    #运行图,获得结果
    print(sess.run(add_on_op_10_10, feed_dict={w1: 13.0, w2: 17.0}))

输出结果:

140.0

注意:在访问要执行的op时,要先知道要执行的那个op中的名字,可以通过2获取,然后指定名字输出,如果这句add_on_op_10_10 = graph.get_tensor_by_name("add_on_op_10/add_1:0")"add_on_op_10/add_1:0"改为"add_on_op_10/add:0",结果将会输出130.0,因为它输出的是2中add_on_op_10 = add_on_op + 10计算的结果,所以说对节点名字的获知很重要

4.模型转pb

import tensorflow as tf

with tf.Session() as sess:
    #下面3行是加载之前保存的模型
    saver = tf.train.import_meta_graph('./checkpoint_dir2/MyModel-2000.meta')
    graph = tf.get_default_graph()
    saver.restore(sess,"./checkpoint_dir2/MyModel-2000")

    #下面是模型固化的过程,需要给出输入('w1', 'w2')和输出('add_on_op_10/add_1')名称,否则在使用时将无法
    #获得输出(测试发现,输入可以不指定名字,在使用时也可以获取,但是输出必须要指定)
    output_name = ['w1', 'w2', 'add_on_op_10/add_1']
    output_graph = './mymodel.pb'
    out_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                 input_graph_def=sess.graph.as_graph_def(),
                                                                 output_node_names=output_name)
    with tf.gfile.GFile(output_graph,"wb") as f:
        f.write(out_graph_def.SerializeToString())

5.使用pb模型

import tensorflow as tf
frozen_graph = './mymodel.pb'
#返回的节点名字,跟固化的时候相对应
return_elements = ['w1','w2','add_on_op_10/add_1']
#下面5行就是加载pb模型
with tf.gfile.GFile(frozen_graph,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
    return_elements = tf.import_graph_def(graph_def, return_elements=return_elements)
print(return_elements)

#获取输入节点
#说明:不知道为什么载入的时候前面多了import/,可以通过上面一句打印查看它的名字:
# [<tf.Operation 'import/w1' type=Placeholder>, <tf.Operation 'import/w2' type=Placeholder>, <tf.Operation 'import/add_on_op_10/add_1' type=Add>],
# 在引用该节点时后面需要加:0
w1 = graph.get_tensor_by_name("import/w1:0")
w2 = graph.get_tensor_by_name("import/w2:0")
#获取输出节点
add_on_op = graph.get_tensor_by_name("import/add_on_op_10/add_1:0")

#新建sess,在sess中进行run操作,注意tf.Session(graph=graph)中必须要有graph=graph这个参数,否则报错。
with tf.Session(graph=graph) as sess:
    print(sess.run(add_on_op, feed_dict={w1: 13.0, w2: 17.0}))

输出结果:

140.0