主要记录以下输入、输出参数处理过程,其他初始化百度资料很多。

背景

项目中用到鉴黄识别,从Github上找到了别人训练好的pb模型,项目地址: https://github.com/kingroc711/CVSample/tree/master/TensorFlow/inception_model

但是项目中只提供了python代码,首先对python不熟悉,并且发现tensorflow提供了对java预测模型的支持,并且项目使用的是java,所以想把tensorflow 集成到项目中,调用pb模型预测。

但通过tensorboard工具查看模型时发现输入参数为string,虽然可以跑通,但到现在也不理解入参为什么设计成string类型.

 

pb文件参数(output_graph.pb)

在调用模型之前,需要先清楚模型输入、输出参数类型。

输入名称:DecodeJpeg/contents:0    类型: string,实际传入图片文件原始数据就可以
输出名称:final_result:0        类型:  float

这个文件的输入、输出参数类型,通过CVSample项目库中python调用代码,找到输入、输出名称

也可以先用python生成日志,通过tensorboard工具分析日志,拿到模型输入、输出参数

 

推荐参考示例(LabelImage):

tensorflow 官方有一个labelImg的java示例,如果第一次使用tensorflow java api,应该会对你有用: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java

如果想运行这个示例,下载示例中提到的模型: https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

这个示例中,输入、输出参数和鉴黄识别模型参数不太一样,所以也会有一些区别。

在这个示例中,对图片进行了一些图像预处理。 

图像是否需要预处理,需要看模型,有些模型需要,有些不需要(比如这个鉴黄模型)。

 

精简代码:

tensorflow: 1.15.0

<dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.15.0</version>
        </dependency>
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
public static void main(String[] args) throws IOException {
        try (Graph g = new Graph()) {
            //pb 模型文件
            byte modelBytes[] = Files.readAllBytes(new File("/opt/work/java_work/tensorflow_demo/inception_model/output_graph.pb").toPath());
            g.importGraphDef(modelBytes);
            try (Session s = new Session(g)) {
                //生成输入参数,此处生成从 https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java 中找到的方法
                Tensor<String> tensor = (Tensor<String>) Tensor.create(Files.readAllBytes(Paths.get("/root/test.png")));
                Tensor<Float> result = s.runner()
                        //输入参数
                        .feed("DecodeJpeg/contents:0", tensor)
                        //输出参数
                        .fetch("final_result:0")
                        .run()
                        .get(0)
                        .expect(Float.class);
                //存储结果容器, 输出固定有5条数据,分别是每个分类(0:porn 1:neutral 2:hentai 3:drawings 4:sexy)的分数
                float[][] values = new float[1][5];
                result.copyTo(values);
                System.out.println(Arrays.toString(values[0]));
                //结果[0.027002065, 0.8941082, 0.02338332, 0.044249564, 0.011256761]
                //porn(色情): 0.027002065, neutral(正常): 0.8941082, hentai: 0.02338332, drawings: 0.044249564, sexy(性感): 0.011256761
            }
        }
    }

 

 

前前后后为了生成输入参数查了一周,网上资料是真的少,为了有相同问题的人可以快速解决,避免和我类似情况出现,所以此处记录以下。