目前,Tensorflow的Java版本支持Windows、Mac OS、Linux、Android这几个操作系统。本次主要以Windows操作系统为列来介绍。**
在Windows操作系统中,如果要在Java语言中调用TensorFlow的模型,需要到TensorFlow官网的安装页面中下载一个TensorFlow的工具类包libtensorflow-1.5.0.jar,还有一个包含JNI接口的动态链接库文件压缩包libtensorflow_jni-cpu-windows-x86_64-1.5.0.zip,该压缩包展开后会得到TensorFlow_jni.dll动态链接库文件、注意,文件名中的版本号部分可能随着TensorFlow的升级而有所变化,在使用java程序调用神经网络模型的时候,这里文件都会用到。
下面是调用代码 保存的模型文件来进行预测的示例代码。

TestTF.java
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.SaveModelBundle;
import java.nio.FloatBuffer;
import java.util.Arrays;

public class TestTf {

public static void main(String[] args) {
SaveModelBudle smb =SaveModelBundle.load("export", "tag");
Session s=smb.session();

float[][] matrix={{1.0F,2.0F,3.0F,4.0F}};
System.out.println(Arrays.deepToString(matrix));

Tensor xFeed=Tensor.create(matrix);
Tensor result=s.runner.feed("x",xFeed).fetch("y").run().get(0);
FloatBuffer buf =FloatBuffer.allocate(2);
result.writeTo(buf);
System.out.println(result.toString());
System.out.println(buf.get(0));
System.out.println(buf.get(1));
   }
}

主要说明的是一下几点:

  • 模型的载入是通过SaveModelBundle.load(“export”,“tag”);这句语句来实现的。其中第一个参数指定了读取模型的位置,我们需要把代码生成的export文件夹复制到运行Java程序的目录下;第二个参数要与保存模型时指定的标记名一致才能正确读取。
  • 在Java中也要创建一个会话对象,程序中是用Session s=smb.session();这条语句来实现的。
  • matrix变量是我们准备进行预测的数据,程序中float[][]{{1.0F,2.0F,3.0F,4.0F}};是代表一个二维数组,相当于Python中的[[1,2,3,4]],数字后面加“F”在Java中表示该数字是浮点数。
  • Tensor xFeed=Tensor.create(matrix);这条语句调用了Tensor对象的create函数来根据matrix生成一个Tensor类型的变量xFeed,准备作为输入数据。只有Tensor类型的变量才能作为神经网络的输入。
  • Tensor result=s.runner().feed(“x”,xFeed).fetch(“y”).run().get(0);这条语句是调节网络进行计算的最主要函数,其中feed函数中对命名过的张量x用xFeed作为输入数据“喂”了进去;fetch函数则把命名过的张量y取出来。
  • 由于该神经网络的输出是一个浮点数类型的二维数组,我们还需要把它写到一个浮点数缓冲区内,这是由FloatBuffer buf =FloatBuffer.allocate(2);和result.writeTo(buf);这两条语句实现的。
  • 最后,输出result.toString()可以看出神经网络输出张量y的类型,输出buf.get(0)和buf.get(1)可以看出y中两个浮点数的计算值。