TensorFlow模型的保存及模型的应用

mac2024-03-25  29

tensorflow的模型保存形式?

1.ckpt格式

就是通过如下几个函数实现的tensorflow模型保存的模型,是ckpt格式的模型。

saver = tf.train.Saver() ... saver.save(sess, saveFile)

就可以保存出如下文件:

checkpoint model-450.data-00000-of-00001 model-450.index model-450.meta

具体说明: checkpoint Checkpoint 文件会记录最近一次的断点文件(Checkpoint File) 的前缀,根据前缀可以找对对应的索引和数据文件。当调用tf.train.latest_checkpoint,可以快速找到最近一次的断点文件。 ckp.data-00000-of-00001 数据(data) 文件记录了所有变量(Variable) 的值。当restore 某个变量时,首先从索引文件中找到相应变量在哪个数据文件,然后根据索引直接获取变量的值,从而实现变量数据的恢复。 ckp.index 索引(index)文件,保存了一个不可变表的数据。其中,关键字为Tensor 的名称,其值描述该Tensor 的元数据信息,包括该Tensor 存储在哪个数据(data) 文件中,及其在该数据文件中的偏移,及其校验和等信息。 ckp.meta 元文件(meta) 中保存了MetaGraphDef 的持久化数据,即模型数据,它包括GraphDef, SaverDef 等元数据。

通俗的讲,这种模型是全部保存的,即模型的框架,参数及其他的信息。这种模型重载之后,是可以继续训练的,即可以pre-train或fine-tune。 个人理解,这种形式适合模型迭代需要,但不会应用于生产或者应用。

这种模型在加载的时候直接使用**saver.restore(sess, ckpt_file)**就可以了。具体的代码网上很多,这里就不赘述了。

2.SavedModel 格式

这种保存形式的常规代码形式如下:

builder = tf.saved_model.builder.SavedModelBuilder("./model") signature = predict_signature_def(inputs={'myInput': x}, outputs={'myOutput': y}) builder.add_meta_graph_and_variables(sess=sess, tags=[tag_constants.SERVING], signature_def_map={'predict'}) builder.save()

简单的保存形式如下:

tf.saved_model.simple_save(sess, "./model", inputs={"myInput": x}, outputs={"myOutput": y})

具体代码大家可以网上细研究哈。 这种形式保存的文件是什么样的呢? 类似这个样子:

variables/ variables.data-*****-of-***** variables.index model.pb

其中:model.pb是二进制模型文件,也就是图,variables路径下的是变量参数等。 这种模型的加载方式,大概如下:

with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, ["test"], "./model") graph = tf.get_default_graph() input = ... x = sess.graph.get_tensor_by_name('input:0') y = sess.graph.get_tensor_by_name('output:0') result = sess.run(y, feed_dict={x: input})

具体代码和使用大家可以看tensorflow的手册或者源码。

这种模型适合怎么应用?如果是部署在线服务(Serving)时,官方推荐使用 SavedModel 格式。

3.FrozenGraphDef 格式

具体保存代码形式:

frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=["predict"]) with open('./model2/model_' + self.timestamp + '.pb', 'wb') as f: f.write(graph_def.SerializeToString())

或者

frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=["predict"]) with tf.gfile.FastGFile('./model2/_model_' + self.timestamp + '.pb', mode='wb') as f: f.write(graph_def.SerializeToString())

二者差不多哈,就是保存时候有点差异。这种方式保存的模型文件是什么?

model.pb

python上的模型加载形式如下:

output_graph_path = './model/model_1572338162.pb' with tf.Session() as sess: with gfile.FastGFile(output_graph_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') sess.run(tf.global_variables_initializer()) input_x = sess.graph.get_tensor_by_name("word_ids:0") sequence_lengths = sess.graph.get_tensor_by_name("sequence_lengths:0") dropout = sess.graph.get_tensor_by_name("dropout:0") output = sess.graph.get_tensor_by_name("proj/predict:0") logit = sess.run(output, feed_dict={input_x: sent_, sequence_lengths:[len(sent_[0])], dropout:[1]})

对的,就这一个文件。这种方式保存的模型是序列化的模型,二进制文件。这个模型只保留必要的从输入到输出的一条路径的图,其他的不需要都不会保存,所以这个模型不能pre_train了。 但是应用场景,比较适合手机端等app调用,模型比较小。当然如果还想更小?就需要研究模型压缩方法了。此处不讨论。

以上python上的演示代码都是从不同项目上粘贴来的,不一致,大家可以自己体会哈。摘要性的介绍完毕。

下一篇介绍一下,java调用tensorflow模型进行使用,及在分布式上调用的问题。

最新回复(0)