文章目录

一、CKPT 转换成 pb 格式

1 . checkpoint 文件介绍

使用 tf.train.saver() 保存模型时会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方法是在TensorFlow中是最常用的保存方式。

model.data-00000-of-00001 保存模型中每个变量的取值
model.index
model.meta 文件保存了TensorFlow计算图的网络结构,使用tf.train.import_meta_graph 加载到当前默认的图来使用。

2 . 转换流程

通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化

详情参考:​​使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)​

查看图的节点

# 默认图的所有节点名称
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

# 图graph的所有节点名称
tensor_name_list = [tensor.name for tensor in graph.as_graph_def().node]

# 可训练的节点名称
variable_names = [v.name for v in tf.trainable_variables()]

若是有网络network架构文件,可以直接看网络结构,查看输入输出节点。

方式一:单一网络结构

checkpoint2pb.py

import tensorflow as tf
from tensorflow.python.framework import graph_util


def freeze_graph(input_checkpoint, output_graph):
'''
:param input_checkpoint:
:param output_graph:
:return:
'''
output_node_names = "level_16/active" # 输出的节点名称
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()

with tf.Session() as sess:
#===================================
#file = open('./nodes.txt', 'a+')
#tensor_name_list = [tensor.name for tensor in graph.as_graph_def().node]
#for n in tensor_name_list:
# file.write(n + '\n')
#file.close()
#===================================
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=input_graph_def,
output_node_names=output_node_names.split(","))

with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == '__main__':
input_checkpoint = '/home/xxx/net_101/model'
out_pb_path = "/home/xxx/net_101/net_101_frozen_model.pb"
freeze_graph(input_checkpoint,out_pb_path)

1、函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。因为网络其实是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

2、在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。注意节点名称与张量的名称的区别。

张量的名称:​​"input:0"​​​ 节点的名称:​​"output"​
3、源码中通过graph = tf.get_default_graph()获得默认的图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + ‘.meta’, clear_devices=True)恢复的图,因此必须先执行tf.train.import_meta_graph,再执行tf.get_default_graph() 。

查看上述代码保存的​​nodes.txt​​文件

global_step/initial_value
global_step
global_step/Assign
global_step/read
corr_alpha/initial_value
corr_alpha
corr_alpha/Assign
corr_alpha/read
corr_beta/initial_value
corr_beta
corr_beta/Assign
corr_beta/read
x_data # <===== 输入节点
learnin_rate
bnorm_decay
flag_train
level_0/conv/weights/Initializer/random_normal/shape
level_0/conv/weights/Initializer/random_normal/mean
level_0/conv/weights/Initializer/random_normal/stddev
...
level_16/conv/Conv2D
level_16/bias/beta/Initializer/Const
level_16/bias/beta
level_16/bias/beta/Assign
level_16/bias/beta/read
level_16/bias/add
level_16/active # <===== 输出节点
Tensordot/a
Tensordot/transpose/perm
Tensordot/transpose
Tensordot/Reshape/shape

方式二:全卷积网络结构(运行时网络)

该方式 加载运行时网络,​​saver = tf.train.Saver(net.variables_list)​​ ,确定需要的网络节点

import tensorflow as tf
from tensorflow.python.framework import graph_util
from network import FullConvNet




def freeze_graph(input_checkpoint, output_graph):
tf.reset_default_graph()
output_node_names = ["level_16/active"]

x_data = tf.placeholder(tf.float32, [1, None, None, 1], name="x_data")
net = FullConvNet(x_data, 0.9, tf.constant(False), num_levels=17)
saver = tf.train.Saver(net.variables_list)

with tf.Session(graph=tf.get_default_graph()) as sess:

file2 = open('./read_static_nodes.txt', 'w')
for xx in sess.graph.as_graph_def().node:
file2.write(xx.name+ '\n')
file2.close()

input_graph_def = sess.graph.as_graph_def()
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=input_graph_def,
output_node_names=output_node_names)

with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
print("============================")
for op in sess.graph.get_operations():
print(op.name, op.values())


if __name__ == '__main__':
input_checkpoint = '/home/xxx/net_jpg{}/model'.format(name)
out_pb_path = "/home/xxx/net_jpg{0}_frozen_model.pb".format(name)
freeze_graph(input_checkpoint, out_pb_path)

方式三:(分段)多网络结构

import tensorflow as tf
from tensorflow.python.framework import graph_util
from Network import MISLNet,CompareNet


"""
第一次输入网络
input_data (<tf.Tensor 'input_data:0' shape=(?, 256, 256, 3) dtype=float32>,)
phase (<tf.Tensor 'phase:0' shape=<unknown> dtype=bool>,)
输出节点:MISLNet/dense2_out

第二次输入网络
feature1 (<tf.Tensor 'feature1:0' shape=(?, 200) dtype=float32>,)
feature2 (<tf.Tensor 'feature2:0' shape=(?, 200) dtype=float32>,)
输出节点:CompareNet/add_3
"""

def freeze_graph(input_checkpoint, output_graph):
tf.reset_default_graph()
output_node_names = ["MISLNet/dense2_out","CompareNet/add_3"]
saver = tf.train.Saver()
with tf.Session(graph=tf.get_default_graph()) as sess:

file2 = open('./read_static_nodes.txt', 'w')
for xx in sess.graph.as_graph_def().node:
file2.write(xx.name+ '\n')
file2.close()

input_graph_def = sess.graph.as_graph_def()
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=input_graph_def,
output_node_names=output_node_names)

with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
print("============================")
for op in sess.graph.get_operations():
print(op.name, op.values())



if __name__ == '__main__':
input_checkpoint = "/home/xxx/models/test"
out_pb_path = "/home/xxx/models/freeze_model.pb"
print(input_checkpoint)
freeze_graph(input_checkpoint, out_pb_path)

二、bp2onnx

官方示例:​​https://github.com/onnx/tutorials/blob/master/tutorials/TensorflowToOnnx-1.ipynb​​​​https://github.com/onnx/tensorflow-onnx​

单输入输出

python -m tf2onnx.convert \
--input /home/xxx/net_101_frozen_model.pb \
--inputs x_data:0 \
--outputs level_16/active:0 \
--output /home/xxx/net_101_frozen_model.onnx \

多输入输出(指定节点名称:节点id)

python -m tf2onnx.convert \
--input ./image_server/models/freeze_model.pb \
--inputs input_data:0,phase:0,feature1:0,feature2:0 \
--outputs MISLNet/dense2_out:0,CompareNet/add_3:0 \
--output ./image_server/models/freeze_model.onnx \
--verbose \
--opset 12

检查网络(检查IR是否形成良好)查看网络:​​https://netron.app/​

import onnx


model = onnx.load("./onnx/entry_202008120940_s3_79779.onnx")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))

三、h5 转 onnx 推理

使用onnxruntime 与 opencv 进行推理

# pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tf2onnx
# pip install -i https://pypi.tuna.tsinghua.edu.cn/simple onnxmltools
# pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tensorflow>=2.2.0

import onnxmltools
import onnx
import numpy as np
import onnxruntime
import cv2


def h5_to_onnx(input_h5, output_onnx):
model = load_model(input_h5)
onnx_model = onnxmltools.convert_keras(model, model.name)
onnx.save_model(onnx_model, output_onnx)


def onnx_autoIAFIS(output_onnx):
x_input = np.load(r"../ocr_server/weight/matrix_autoIAFIS.npy")
print("x_input", x_input.shape)
tensorflow_icbc = onnxruntime.InferenceSession(output_onnx)
ort_inputs = {tensorflow_icbc.get_inputs()[0].name: x_input.astype(np.float32)}
ort_outs = tensorflow_icbc.run(None, ort_inputs)
print("ort_outs", type(ort_outs), len(ort_outs), ort_outs[0].shape)
test_y = np.argmax(np.array(ort_outs), axis=2)[:, 0]
characters = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
result = ''.join([characters[x] for x in test_y])
print(result)


def cv_autoIAFIS(output_onnx):
img = cv2.imread("../test_image/9LN8_autoAIFIS.jpg")
imgG = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, th = cv2.threshold(imgG, 150, 255, cv2.THRESH_BINARY)
dst = np.expand_dims(th, axis=-1)
dst = np.expand_dims(dst, axis=0)
blob = dst / 255.0
print(blob.shape) # NHWC

net = cv2.dnn.readNetFromONNX(output_onnx)
net.setInput(blob)
result_out = net.forward(net.getUnconnectedOutLayersNames())
test_y = np.argmax(np.array(result_out), axis=2)[:, 0][::-1]
characters = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
result = ''.join([characters[x] for x in test_y])
print(result)



if __name__ == '__main__':
input_h5 = '../ocr_server/weight/tensorflow_autoIAFIS.h5'
output_onnx = '../ocr_server/weight/tensorflow_autoIAFIS.onnx'

h5_to_onnx(input_h5, output_onnx)
onnx_autoIAFIS(output_onnx)
cv_autoIAFIS(output_onnx)