最近花了两天时间终于走通的整个流程,记录一下

1. estimator模型增加placeholder,方便java预测。因为有embeding字段,所以需要特殊处理一下
    print("exporting model ...")
    inputs = {}
    for feat in my_feature_columns:
        atype = ""
        aname = ""
        if feat.name.endswith('_embedding'):
            atype = tf.int64
            aname = feat.name[0:-10]
        else:
            atype = feat.dtype
            aname = feat.name
        inputs[aname] = tf.placeholder(shape=[None], dtype=atype, name=aname)
    serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(inputs)
    classifier.export_saved_model(FLAGS.output_model, serving_input_receiver_fn)

 

2. 校验saved_model正确性
saved_model_cli show --dir 1628576905 --tag_set serve --signature_def "prediction"   
其中prediction 的值取决于 export_outputs 的key,   1628576905  是训练出的模型目录 根据上面的结果用下面的命令feed参数
saved_model_cli run --dir 1628576905 --tag_set serve --signature_def "prediction" --input_expr 'airQuality=[1];'  多行的格式是[1,...]

3. 导出pd模型
freeze_graph  --input_saved_model_dir=1628576905  --output_node_names=CTCVR  --output_graph=model.pb

4. java预测
SavedModelBundle model = SavedModelBundle.load("path", "serve"); //使用saved_model模型格式 则不用走第三步
Tensor<?> airQuality = Tensor.create(new long[] { 1,1 });
Tensor<?> output = model
                        .session()
                        .runner() 
                        .feed("airQuality", airQuality)
                        .fetch("CTCVR") // 该值取决步骤3的output的名称
                        .run()
                        .get(0);
        
float[][] finalRlt = new float[2][1];
output.copyTo(finalRlt);
System.out.println(JSON.toJSONString(finalRlt));

 

如果使用pd模型

        Graph graph = new Graph();
        byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("a.pb"));
        graph.importGraphDef(graphBytes);
        Session  sess = new Session(graph);

 

注意事项:

1 tf版本:1.15

2 tf的int64 对应java中的long型

3 feed的long型数据时 请勿使用Long类,这绝对是一个大坑,花了很长时间才爬出来