我们常用python进行tensorflow深度模型训练,然后,训练后的模型需要应用到web端调用,或者app应用调用,甚至分布式任务使用等等。在这些应用中java代码调用,是避免不了的。本文就介绍一下,java加载tensorflow模型的方式,分别单机调用和分布式调用。
首先需要加载一些jar包,如果是maven项目在pom.xml中添加以下依赖。至于版本信息,切记一定要和你python代码训练的时候的tensorflow版本一致。
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>0.13</version> </dependency>单价加载模型的代码简要如下,做了简化哈,但是核心的都在这里。offLineInit函数就是单机或者叫离线的模型加载方式。这里的模型一般是pb模型。可以是序列化的也可以是savemodel形式的。
import org.tensorflow.Session; import org.tensorflow.Graph; import org.tensorflow.Tensor; import java.io.File; import java.io.BufferedInputStream; import java.io.FileInputStream; public class SplitByDeepModel implements Serializable{ public Session session; public boolean offLineInit(String path) throws Exception{ /* * 单机加载tensorflow模型,初始化环境 */ if (!this.tensorflow_factory.InitModelsGeneratorByBasePath(path)) throw new UDFArgumentException("tensorflow model Generator Init failed"); byte[] graphDef = readAllBytes(path+"/models/model_test1.pb"); Graph graph = new Graph(); graph.importGraphDef(graphDef); session = new Session(graph); return true; } }分布式加载方式如下,代码如下:
public boolean InitModelByBasePath(String tensorflow_model_path){ try { // model label modelTag = "Mytest"; modelPath = tensorflow_model_path; System.err.println("tensorflow models Path:" + modelPath); Configuration conf = new Configuration(); Graph graph = new Graph(); FileSystem fs = FileSystem.get(conf); ByteArrayOutputStream contents = new ByteArrayOutputStream(); IOUtils.copy(fs.open(new Path(modelPath)), contents); graph.importGraphDef(contents.toByteArray()); this.session = new Session(graph); }catch(Exception e) { System.err.println("Init tensorflow model from " + modelPath + " error: " + e.toString()); e.printStackTrace(); } return false; }其中,tensorflow_model_path,hdfs模型路径。这种方式需要将模型文件加载到hdfs系统中,然后应用在每个executor里面进行加载。
3.分布式加载和单机加载的差异
差异主要在文件读取的方式上,其他都是一样的,都是sess加载模型图,然后初始化。 这里有一点需要强调一下。无论是单机模式还是分布式模式加载,都是使用的pb模型文件,加载之后相当于一个函数。但是这个函数是非序列化的,不能跨节点传参。这要注意。。
4.模型加载后使用 代码如下:
public float[][] lstmPredict(int[][] input) { Tensor input_X = Tensor.create(input); int[] seqs = new int[1]; seqs[0] = input[0].length; Tensor seq_len = Tensor.create(seqs); Tensor dropout = Tensor.create(1.0f); Tensor out= this.session.runner().feed("word_ids",input_X).feed("sequence_lengths", seq_len). feed("dropout", dropout).fetch("proj/predict").run().get(0); float [][][] ans = new float[1][input[0].length][3]; out.copyTo(ans); // tensor对象必须close对象,来清理内存,因为,tensorflow的底层代码都是c++写的 // c++的内存释放不受java管理,所以必须主动释放,否做很容易出现内存溢出问题 input_X.close(); out.close(); seq_len.close(); dropout.close(); return ans[0]; }其中Tensor变量都是,预测时候必须输入的变量以及输出的变量。 这个代码是一个lstm的case。其中:seq_len,dropout,input_X都是必须的输出;out是输出形式。
这里特别强调下,最后的close部分,一定要每次预测之后,所有参与计算的tensor都要close。 为什么?因为tensorflow的底层代码都是c++写的,它的内存管理不受java控制,所以必须手工close,否则很容易出现内存溢出的问题,如果是分布式跑,很可能服务器都宕机了。
我在做这件事的时候,也是经历了各种趟坑,然后在一位大神的指导下,实现了分布式跑预测模型。这也是小白必须经历的过程。