Keras模型转成tensorflow的.pb操作

 更新时间:2020年07月06日 15:33:10   作者:VickyD1023  
这篇文章主要介绍了Keras模型转成tensorflow的.pb操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

Keras的.h5模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
 
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
 
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')
 
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
  return frozen_graph
 
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
 
print('input is :', model.input.name)
print ('output is:', model.output.name)
 
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
 
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))

补充知识:keras h5 model 转换为tflite

在移动端的模型,若选择tensorflow或者keras最基本的就是生成tflite文件,以本文记录一次转换过程。

环境

tensorflow 1.12.0

python 3.6.5

h5 model saved by `model.save('tf.h5')`

直接转换

`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`

先转成pb再转tflite

```

git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb
tflite_convert \

 --output_file=tf.tflite \
 --graph_def_file=tf.pb \
 --input_arrays=convolution2d_1_input \
 --output_arrays=dense_3/BiasAdd \
 --input_shape=1,3,448,448
```

参数说明,input_arrays和output_arrays是model的起始输入变量名和结束变量名,input_shape是和input_arrays对应

官网是说需要用到tenorboard来查看,一个比较trick的方法

先执行上面的命令,会报convolution2d_1_input找不到,在堆栈里面有convert_saved_model.py文件,get_tensors_from_tensor_names()这个方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 这个变量下面,再执行一遍,会打印出所有tensor的名字,再根据自己的模型很容易就能判断出实际的name。

以上这篇Keras模型转成tensorflow的.pb操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python子类如何继承父类的实例变量

    python子类如何继承父类的实例变量

    这篇文章主要介绍了python子类如何继承父类的实例变量,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-03-03
  • Python values()与itervalues()的用法详解

    Python values()与itervalues()的用法详解

    今天小编就为大家分享一篇Python values()与itervalues()的用法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • 如何用python爬取微博热搜数据并保存

    如何用python爬取微博热搜数据并保存

    这篇文章主要介绍了如何用python爬取微博热搜数据并保存,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-02-02
  • pytorch中的numel函数用法说明

    pytorch中的numel函数用法说明

    这篇文章主要介绍了pytorch中的numel函数用法说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-05-05
  • 在pycharm中输入import torch报错如何解决

    在pycharm中输入import torch报错如何解决

    这篇文章主要介绍了在pycharm中输入import torch报错如何解决问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-01-01
  • 解决py2exe打包后,总是多显示一个DOS黑色窗口的问题

    解决py2exe打包后,总是多显示一个DOS黑色窗口的问题

    今天小编就为大家分享一篇解决py2exe打包后,总是多显示一个DOS黑色窗口的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • Python第三方包PrettyTable安装及用法解析

    Python第三方包PrettyTable安装及用法解析

    这篇文章主要介绍了Python第三方包PrettyTable安装及用法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-07-07
  • Python中的 dir() 函数示例详解

    Python中的 dir() 函数示例详解

    dir()函数是Python 中一个非常有用的工具,可以用于查找对象的所有属性和方法,如获取当前作用域的变量和方法、查找模块中的导出内容、动态查找对象属性等,通过本文的介绍和示例代码,大家可以更全面地了解 dir() 函数的用法和注意事项,需要的朋友参考下吧
    2022-03-03
  • Python version 2.7 required, which was not found in the registry

    Python version 2.7 required, which was not found in the regi

    这篇文章主要介绍了安装PIL库时提示错误Python version 2.7 required, which was not found in the registry问题的解决方法,需要的朋友可以参考下
    2014-08-08
  • Python爬取视频时长场景实践示例

    Python爬取视频时长场景实践示例

    这篇文章主要为大家介绍了Python获取视频时长场景实践示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-07-07

最新评论