keras的get_value运行越来越慢的解决方案

 更新时间:2021年05月17日 12:09:57   作者:头狼586  
这篇文章主要介绍了keras的get_value运行越来越慢的解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

keras 深度学习框架中get_value函数运行越来越慢,内存消耗越来越大问题

问题描述

get_value问题截图

如上图所示,经过时间和内存消耗跟踪测试,发现是keras.backend.get_value() 函数导致的程序越来越慢,而且严重的造成内存泄露;

查看该函数内部实现,发现一个主要核心是x.eval(session=get_session()),该语句可能是导致内存泄露和运行慢的核心语句; 根据查看一些博文得到了运行得越来越慢的

原因该x.eval函数会添加新的节点到tf的图中;而这也导致了tf的图越来越大,内存泄露;

解决方法

import tensorflow.keras.backend as K

def get_my_session(gpu_fraction=0.1):
    '''Assume that you have 6GB of GPU memory and want to allocate ~2GB'''

    num_threads = os.environ.get('OMP_NUM_THREADS')
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)

    if num_threads:
        return tf.Session(config=tf.ConfigProto(
            gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
    else:
        return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

K.set_session(get_my_session())

如上图所示, 我在使用tensorflow之前(也就是该工程文件前面),对session进行自定义,然后用自定义的session设定keras.backend.set_session();

然后删除get_value() 函数,直接用get_value()中所使用的执行语句x.eval(session=get_my_session());这样这个添加节点导致内存泄露的核心语句x.eval()就使用的是该工程统一自定义session,然后用tf.reset_default_graph() 对图重置就可以了

即上图问题代码修改为:

output = ctc_decode(y_pred,input_length=input_length,)
output = output[0][0]
out = output.eval(session=get_my_session())
# 删除 K.get_value(out[0][0])
tf.reset_default_graph() # 然后重置tf图,这句很关键

这样就解决了get_value()导致的越来越慢的问题;

个人认为:这样可能就不会总是添加新的节点,导致tf图不断地无限变大;而是重复使用这一个自定义的节点。

补充:tensorflow与keras之间版本问题引起get_session问题解决办法

1.产生报错原因

import tensorflow.keras.backend as K
def __init__(self, **kwargs):
    self.__dict__.update(self._defaults) # set up default values
    self.__dict__.update(kwargs) # and update with user overrides
    self.class_names = self._get_class()
    self.anchors = self._get_anchors()
    self.sess = K.get_session()

报错如下:

get_session is not available when using TensorFlow 2.0.

意思是 tf2.0 没有 get_session

2.解决方案1

import tensorflow.python.keras.backend as K
sess = K.get_session()

3. 解决方案2

import tensorflow as tf
sess = tf.compat.v1.keras.backend.get_session()

之前一直采用方案1 解决,感觉比较方便;但是解决方案1 有其它属性会丢失问题

比如AttributeError: module ‘keras.backend' has no attribute image_dim_ordering

所以建议大家采用方案2

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python利用文件读写编写一个博客

    python利用文件读写编写一个博客

    这篇文章主要为大家详细介绍了python利用文件读写编写一个博客,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-08-08
  • 基于Python自制一个文件解压缩小工具

    基于Python自制一个文件解压缩小工具

    经常在办公的过程中会遇到各种各样的压缩文件处理,但是呢每个压缩软件支持的格式又是不同的。本文就来用Python自制一个文件解压缩小工具,可以支持7z/zip/rar三种格式,希望对大家有所帮助
    2023-02-02
  • Python程序员面试题 你必须提前准备!

    Python程序员面试题 你必须提前准备!

    Python程序员面试,这些问题你必须提前准备!供广大Python程序员参考,预祝大家顺利通过面试。
    2018-01-01
  • pandas读取excel,txt,csv,pkl文件等命令的操作

    pandas读取excel,txt,csv,pkl文件等命令的操作

    这篇文章主要介绍了pandas读取excel,txt,csv,pkl文件等命令的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • python实现log日志的示例代码

    python实现log日志的示例代码

    下面小编就为大家分享一篇python实现log日志的示例代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • python实现textrank关键词提取

    python实现textrank关键词提取

    这篇文章主要为大家详细介绍了python实现textrank关键词提取,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-06-06
  • Python中实现插值法的示例详解

    Python中实现插值法的示例详解

    在数据处理和分析中,插值法是一种常用的数值分析技术,用于估计在已知数据点之间的值,本文将详细介绍Python中插值法的实现方法,需要的可以参考下
    2024-02-02
  • 浅谈Python批处理文件夹中的txt文件

    浅谈Python批处理文件夹中的txt文件

    这篇文章主要介绍了Python批处理文件夹中的txt文件,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • python简单验证码识别的实现过程

    python简单验证码识别的实现过程

    很多网站登录都需要输入验证码,如果要实现自动登录就不可避免的要识别验证码,这篇文章主要给大家介绍了关于python简单验证码识别的实现过程,需要的朋友可以参考下
    2021-06-06
  • Python爬取动态网页中图片的完整实例

    Python爬取动态网页中图片的完整实例

    这篇文章主要给大家介绍了关于Python爬取动态网页中图片的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-03-03

最新评论