最近花了两天时间终于走通的整个流程,记录一下
1. estimator模型增加placeholder,方便java预测。因为有embeding字段,所以需要特殊处理一下
print("exporting model ...")
inputs = {}
for feat in my_feature_columns:
atype = ""
aname = ""
if feat.name.endswith('_embedding'):
atype = tf.int64
aname = feat.name[0:-10]
else:
atype = feat.dtype
aname = feat.name
inputs[aname] = tf.placeholder(shape=[None], dtype=atype, name=aname)
serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(inputs)
classifier.export_saved_model(FLAGS.output_model, serving_input_receiver_fn)
2. 校验saved_model正确性
saved_model_cli show --dir 1628576905 --tag_set serve --signature_def "prediction"
其中prediction 的值取决于 export_outputs 的key, 1628576905 是训练出的模型目录 根据上面的结果用下面的命令feed参数
saved_model_cli run --dir 1628576905 --tag_set serve --signature_def "prediction" --input_expr 'airQuality=[1];' 多行的格式是[1,...]
3. 导出pd模型
freeze_graph --input_saved_model_dir=1628576905 --output_node_names=CTCVR --output_graph=model.pb
4. java预测
SavedModelBundle model = SavedModelBundle.load("path", "serve"); //使用saved_model模型格式 则不用走第三步
Tensor<?> airQuality = Tensor.create(new long[] { 1,1 });
Tensor<?> output = model
.session()
.runner()
.feed("airQuality", airQuality)
.fetch("CTCVR") // 该值取决步骤3的output的名称
.run()
.get(0);
float[][] finalRlt = new float[2][1];
output.copyTo(finalRlt);
System.out.println(JSON.toJSONString(finalRlt));
如果使用pd模型
Graph graph = new Graph();
byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("a.pb"));
graph.importGraphDef(graphBytes);
Session sess = new Session(graph);
注意事项:
1 tf版本:1.15
2 tf的int64 对应java中的long型
3 feed的long型数据时 请勿使用Long类,这绝对是一个大坑,花了很长时间才爬出来