就是通过如下几个函数实现的tensorflow模型保存的模型,是ckpt格式的模型。
saver = tf.train.Saver() ... saver.save(sess, saveFile)就可以保存出如下文件:
checkpoint model-450.data-00000-of-00001 model-450.index model-450.meta具体说明: checkpoint Checkpoint 文件会记录最近一次的断点文件(Checkpoint File) 的前缀,根据前缀可以找对对应的索引和数据文件。当调用tf.train.latest_checkpoint,可以快速找到最近一次的断点文件。 ckp.data-00000-of-00001 数据(data) 文件记录了所有变量(Variable) 的值。当restore 某个变量时,首先从索引文件中找到相应变量在哪个数据文件,然后根据索引直接获取变量的值,从而实现变量数据的恢复。 ckp.index 索引(index)文件,保存了一个不可变表的数据。其中,关键字为Tensor 的名称,其值描述该Tensor 的元数据信息,包括该Tensor 存储在哪个数据(data) 文件中,及其在该数据文件中的偏移,及其校验和等信息。 ckp.meta 元文件(meta) 中保存了MetaGraphDef 的持久化数据,即模型数据,它包括GraphDef, SaverDef 等元数据。
通俗的讲,这种模型是全部保存的,即模型的框架,参数及其他的信息。这种模型重载之后,是可以继续训练的,即可以pre-train或fine-tune。 个人理解,这种形式适合模型迭代需要,但不会应用于生产或者应用。
这种模型在加载的时候直接使用**saver.restore(sess, ckpt_file)**就可以了。具体的代码网上很多,这里就不赘述了。
这种保存形式的常规代码形式如下:
builder = tf.saved_model.builder.SavedModelBuilder("./model") signature = predict_signature_def(inputs={'myInput': x}, outputs={'myOutput': y}) builder.add_meta_graph_and_variables(sess=sess, tags=[tag_constants.SERVING], signature_def_map={'predict'}) builder.save()简单的保存形式如下:
tf.saved_model.simple_save(sess, "./model", inputs={"myInput": x}, outputs={"myOutput": y})具体代码大家可以网上细研究哈。 这种形式保存的文件是什么样的呢? 类似这个样子:
variables/ variables.data-*****-of-***** variables.index model.pb其中:model.pb是二进制模型文件,也就是图,variables路径下的是变量参数等。 这种模型的加载方式,大概如下:
with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, ["test"], "./model") graph = tf.get_default_graph() input = ... x = sess.graph.get_tensor_by_name('input:0') y = sess.graph.get_tensor_by_name('output:0') result = sess.run(y, feed_dict={x: input})具体代码和使用大家可以看tensorflow的手册或者源码。
这种模型适合怎么应用?如果是部署在线服务(Serving)时,官方推荐使用 SavedModel 格式。
具体保存代码形式:
frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=["predict"]) with open('./model2/model_' + self.timestamp + '.pb', 'wb') as f: f.write(graph_def.SerializeToString())或者
frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=["predict"]) with tf.gfile.FastGFile('./model2/_model_' + self.timestamp + '.pb', mode='wb') as f: f.write(graph_def.SerializeToString())二者差不多哈,就是保存时候有点差异。这种方式保存的模型文件是什么?
model.pbpython上的模型加载形式如下:
output_graph_path = './model/model_1572338162.pb' with tf.Session() as sess: with gfile.FastGFile(output_graph_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') sess.run(tf.global_variables_initializer()) input_x = sess.graph.get_tensor_by_name("word_ids:0") sequence_lengths = sess.graph.get_tensor_by_name("sequence_lengths:0") dropout = sess.graph.get_tensor_by_name("dropout:0") output = sess.graph.get_tensor_by_name("proj/predict:0") logit = sess.run(output, feed_dict={input_x: sent_, sequence_lengths:[len(sent_[0])], dropout:[1]})对的,就这一个文件。这种方式保存的模型是序列化的模型,二进制文件。这个模型只保留必要的从输入到输出的一条路径的图,其他的不需要都不会保存,所以这个模型不能pre_train了。 但是应用场景,比较适合手机端等app调用,模型比较小。当然如果还想更小?就需要研究模型压缩方法了。此处不讨论。
以上python上的演示代码都是从不同项目上粘贴来的,不一致,大家可以自己体会哈。摘要性的介绍完毕。
下一篇介绍一下,java调用tensorflow模型进行使用,及在分布式上调用的问题。