解决Keras 中加入lambda层无法正常载入模型问题

 更新时间:2020年06月16日 16:45:36   作者:机器玄学实践者  
这篇文章主要介绍了解决Keras 中加入lambda层无法正常载入模型问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

刚刚解决了这个问题,现在记录下来

问题描述

当使用lambda层加入自定义的函数后,训练没有bug,载入保存模型则显示Nonetype has no attribute 'get'

问题解决方法:

这个问题是由于缺少config信息导致的。lambda层在载入的时候需要一个函数,当使用自定义函数时,模型无法找到这个函数,也就构建不了。

m = load_model(path,custom_objects={"reduce_mean":self.reduce_mean,"slice":self.slice})

其中,reduce_mean 和slice定义如下

  def slice(self,x, turn):
    """ Define a tensor slice function
    """
    return x[:, turn, :, :]
  def reduce_mean(self, X):
    return K.mean(X, axis=-1)

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。

保存时会报

TypeError: can't pickle _thread.RLock objects

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇解决Keras 中加入lambda层无法正常载入模型问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 对python模块中多个类的用法详解

    对python模块中多个类的用法详解

    今天小编就为大家分享一篇对python模块中多个类的用法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • Python最大连续区间和动态规划

    Python最大连续区间和动态规划

    这篇文章主要介绍了Python最大连续区间和动态规划,文章围绕Python最大连续区间和动态规划的相关资料展开内容,需要的小伙伴可以参考一下
    2022-01-01
  • Python接口自动化浅析数据驱动原理

    Python接口自动化浅析数据驱动原理

    这篇文章主要介绍了Python接口自动化浅析数据驱动原理,文中会详细描述怎样使用openpyxl模块操作excel及结合ddt来实现数据驱动,有需要的朋友可以参考下
    2021-08-08
  • Python实现的微信红包提醒功能示例

    Python实现的微信红包提醒功能示例

    这篇文章主要介绍了Python实现的微信红包提醒功能,结合实例形式分析了Python使用微信模块itchat实现微信红包提醒操作的相关实现技巧,需要的朋友可以参考下
    2019-08-08
  • 约瑟夫问题的Python和C++求解方法

    约瑟夫问题的Python和C++求解方法

    这篇文章主要介绍了约瑟夫问题的Python和C++求解方法,通过其示例我们也可以看出如今写法最简洁的编程语言和最复杂的语言之间的对比:D 需要的朋友可以参考下
    2015-08-08
  • python控制台英汉汉英电子词典

    python控制台英汉汉英电子词典

    这篇文章主要为大家详细介绍了python控制台英汉汉英电子词典,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2014-06-06
  • pandas DataFrame 删除重复的行的实现方法

    pandas DataFrame 删除重复的行的实现方法

    这篇文章主要介绍了pandas DataFrame 删除重复的行的实现方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-01-01
  • python基于socket实现的UDP及TCP通讯功能示例

    python基于socket实现的UDP及TCP通讯功能示例

    这篇文章主要介绍了python基于socket实现的UDP及TCP通讯功能,结合实例形式分析了基于Python socket模块的UDP及TCP通信相关客户端、服务器端实现技巧,需要的朋友可以参考下
    2019-11-11
  • 用python3读取python2的pickle数据方式

    用python3读取python2的pickle数据方式

    今天小编就为大家分享一篇用python3读取python2的pickle数据方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Python正则简单实例分析

    Python正则简单实例分析

    这篇文章主要介绍了Python正则简单实例,具体分析了Python针对字符串的简单正则匹配测试中遇到的问题与相关注意事项,需要的朋友可以参考下
    2017-03-03

最新评论