Tensorflow与Keras自适应使用显存方式

 更新时间:2020年06月22日 10:19:44   作者:一呆飞仙  
这篇文章主要介绍了Tensorflow与Keras自适应使用显存方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

Tensorflow支持基于cuda内核与cudnn的GPU加速,Keras出现较晚,为Tensorflow的高层框架,由于Keras使用的方便性与很好的延展性,之后更是作为Tensorflow的官方指定第三方支持开源框架。

但两者在使用GPU时都有一个特点,就是默认为全占满模式。在训练的情况下,特别是分步训练时会导致显存溢出,导致程序崩溃。

可以使用自适应配置来调整显存的使用情况。

一、Tensorflow

1、指定显卡

代码中加入

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

或者在运行代码前,在终端

export CUDA_VISIBLE_DEVICES=0

2、为显存分配使用比例

在建立tf.Session加入设置数据(显存使用比例为1/3),但有时你虽然设置了使用上限,在程序需要更高显存时还是会越过该限制

gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

3、自适应分配

会自适应分配显存,不会将显存全部分配导致资源浪费

config = tf.ConfigProto() 
config.gpu_options.allow_growth=True 
sess = tf.Session(config=config) 

二、Keras

与tensorflow大差不差,就是将tf.Session配置转置Keras配置

1、指定显卡

代码中加入

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

或者在运行代码前,在终端

export CUDA_VISIBLE_DEVICES=0

2、为显存分配使用比例

import tensorflow as tf
import keras.backend.tensorflow_backend as KTF

config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.333
session = tf.Session(config=config)
KTF.set_session(session)

3、自适应分配

import keras.backend.tensorflow_backend as KTF

config = tf.ConfigProto() 
config.gpu_options.allow_growth=True 
session = tf.Session(config=config)
KTF.set_session(session)

4、如有设置fit_generator

将多线程关闭

#可将
use_multiprocessing=True
#改为
use_multiprocessing=False

补充知识:Keras 自动分配显存,不占用所有显存

自动分配显存,不占用所有显存

import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
config = tf.ConfigProto()
config.gpu_options.allow_growth=True #不全部占满显存, 按需分配
sess = tf.Session(config=config)
KTF.set_session(sess)

以上这篇Tensorflow与Keras自适应使用显存方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • drf-router和authenticate认证源码分析

    drf-router和authenticate认证源码分析

    在 Rest Framework 中提供了两个 router , 可以帮助我们快速的实现路由的自动生成,本文通过实例代码给大家介绍drf-router和authenticate认证源码分析,感兴趣的朋友跟随小编一起看看吧
    2021-07-07
  • 用Cython加速Python到“起飞”(推荐)

    用Cython加速Python到“起飞”(推荐)

    这篇文章主要介绍了用Cython加速Python到“起飞”,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-08-08
  • 用python实现监控视频人数统计

    用python实现监控视频人数统计

    今天教各位小伙伴学习怎么用python实现监控视频人数统计,文中有非常详细的代码示例,对正在学习python的小伙伴有很大的帮助,需要的朋友可以参考下
    2021-05-05
  • Python使用微信SDK实现的微信支付功能示例

    Python使用微信SDK实现的微信支付功能示例

    这篇文章主要介绍了Python使用微信SDK实现的微信支付功能,结合实例形式分析了Python调用微信SDK接口实现微信支付功能的具体步骤与相关操作技巧,需要的朋友可以参考下
    2017-06-06
  • Python中的yield全方位解读

    Python中的yield全方位解读

    这篇文章主要介绍了Python中的yield全方位解读,在 Python 中,使用了 yield 的函数被称为生成器,跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器,需要的朋友可以参考下
    2023-08-08
  • Python多线程处理实例详解【单进程/多进程】

    Python多线程处理实例详解【单进程/多进程】

    这篇文章主要介绍了Python多线程处理,结合实例形式总结分析了Python单进程、多进程、多线程等相关操作技巧与注意事项,需要的朋友可以参考下
    2019-01-01
  • Python连接和操作PostgreSQL数据库的流程步骤

    Python连接和操作PostgreSQL数据库的流程步骤

    PostgreSQL 是一种开源的对象关系型数据库管理系统(ORDBMS),以其强大的功能和稳定性而广受欢迎,本文将详细介绍如何使用 Python 连接和操作 PostgreSQL 数据库,需要的朋友可以参考下
    2024-10-10
  • Pytest使用logging模块写日志的实例详解

    Pytest使用logging模块写日志的实例详解

    logging是python语言中的一个日志模块,专门用来写日志的,日志级别通常分为debug、info、warning、error、critical几个级别,一般情况下,默认的日志级别为warning,在调试或者测试阶段,下面就快速体验一下logging模块写日志的用法,感兴趣的朋友跟随小编一起看看吧
    2022-12-12
  • 解决pycharm下pyuic工具使用的问题

    解决pycharm下pyuic工具使用的问题

    这篇文章主要介绍了解决pycharm下pyuic工具使用的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • Python利用三层神经网络实现手写数字分类详解

    Python利用三层神经网络实现手写数字分类详解

    这篇文章主要介绍了如何设计一个三层神经网络模型来实现手写数字分类。本文给大家介绍的非常详细,感兴趣的小伙伴快来跟小编一起学习一下
    2021-11-11

最新评论