keras实现theano和tensorflow训练的模型相互转换

 更新时间:2020年06月19日 11:50:22   作者:零落_World  
这篇文章主要介绍了keras实现theano和tensorflow训练的模型相互转换,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

我就废话不多说了,大家还是直接看代码吧~

</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">

# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
 
def th2tf( model):
  import tensorflow as tf
  ops = []
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      ops.append(tf.assign(layer.W, converted_w).op)
  K.get_session().run(ops)
  return model
 
def tf2th(model):
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      K.set_value(layer.W, converted_w)
  return model
 
def conv_layer_converted(tf_weights, th_weights, m = 0):
  """
  :param tf_weights:
  :param th_weights:
  :param m: 0-tf2th, 1-th2tf
  :return:
  """
  if m == 0: # tf2th
    tc = keras_text_classifier(weights_path=tf_weights)
    model = tc.loadmodel()
    model = tf2th(model)
    model.save_weights(th_weights)
  elif m == 1: # th2tf
    tc = keras_text_classifier(weights_path=th_weights)
    model = tc.loadmodel()
    model = th2tf(model)
    model.save_weights(tf_weights)
  else:
    print("0-tf2th, 1-th2tf")
    return
if __name__ == '__main__':
  if len(sys.argv) < 4:
    print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
    sys.exit(0)
  tf_weights = sys.argv[1]
  th_weights = sys.argv[2]
  m = int(sys.argv[3])
  conv_layer_converted(tf_weights, th_weights, m)

补充知识:keras学习之修改底层为TensorFlow还是theano

我们知道,keras的底层是TensorFlow或者theano

要知道我们是用的哪个为底层,只需要import keras即可显示

修改方法:

打开

修改

以上这篇keras实现theano和tensorflow训练的模型相互转换就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python3 sleep 延时秒 毫秒实例

    python3 sleep 延时秒 毫秒实例

    这篇文章主要介绍了python3 sleep 延时秒 毫秒实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • 详解Python的三种可变参数

    详解Python的三种可变参数

    这篇文章主要介绍了Python的三种可变参数,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-05-05
  • python实现对指定字符串补足固定长度倍数截断输出的方法

    python实现对指定字符串补足固定长度倍数截断输出的方法

    今天小编就为大家分享一篇python实现对指定字符串补足固定长度倍数截断输出的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • Queue队列中join()与task_done()的关系及说明

    Queue队列中join()与task_done()的关系及说明

    这篇文章主要介绍了Queue队列中join()与task_done()的关系及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-02
  • Python抓取通过Ajax加载数据的示例

    Python抓取通过Ajax加载数据的示例

    在网页上,有一些内容是通过执行Ajax请求动态加载数据渲染出来的,本文主要介绍了使用Python抓取通过Ajax加载数据,感兴趣的可以了解一下
    2023-05-05
  • 通过代码实例解析Pytest运行流程

    通过代码实例解析Pytest运行流程

    这篇文章主要介绍了通过代码实例解析Pytest运行流程,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-08-08
  • python实现大战外星人小游戏实例代码

    python实现大战外星人小游戏实例代码

    这篇文章主要介绍了python实现大战外星人小游戏,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-12-12
  • python正则实现计算器功能

    python正则实现计算器功能

    这篇文章主要为大家详细介绍了python正则实现计算器功能,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-12-12
  • 利用pandas读取中文数据集的方法

    利用pandas读取中文数据集的方法

    今天小编就为大家分享一篇利用pandas读取中文数据集的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • numpy.reshape()的函数的具体使用

    numpy.reshape()的函数的具体使用

    本文主要介绍了numpy.reshape()的函数的具体使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02

最新评论