浅谈keras2 predict和fit_generator的坑

 更新时间:2020年06月17日 15:02:49   作者:BYR_jiandong  
这篇文章主要介绍了浅谈keras2 predict和fit_generator的坑,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

  def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python 如何利用pandas 和 matplotlib绘制柱状图

    Python 如何利用pandas 和 matplotlib绘制柱状图

    Python 中的 pandas 和 matplotlib 库提供了丰富的功能,可以帮助你轻松地绘制各种类型的图表,本文将介绍如何使用这两个库,绘制一个店铺销售数量的柱状图,并添加各种元素,如数据标签、图例、网格线等,感兴趣的朋友一起看看吧
    2023-10-10
  • 用Python编程实现语音控制电脑

    用Python编程实现语音控制电脑

    是否经常好莱坞电影里看强大的语音识别系统? 是否每每看到都会羡慕嫉妒恨? 可是我们真心买不起啊。
    2014-04-04
  • Python重试库 Tenacity详解(推荐)

    Python重试库 Tenacity详解(推荐)

    这篇文章主要介绍了Python重试库Tenacity详解,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-09-09
  • Django接收自定义http header过程详解

    Django接收自定义http header过程详解

    这篇文章主要介绍了Django接收自定义http header过程详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • Win10下Python3.7.3安装教程图解

    Win10下Python3.7.3安装教程图解

    到2019年初,Python3已经更新到了Python3.7.3,Python有两个大版本Python2和Python3,Python3是现在和未来的主流。这篇文章主要介绍了Win10下Python3.7.3安装教程图解,非常不错,感兴趣的朋友跟随小编一起看看吧
    2019-07-07
  • PyTorch模型保存与加载实例详解

    PyTorch模型保存与加载实例详解

    大家都知道pytorch的模型和参数是分开的,可以分别保存或加载模型和参数,这篇文章主要给大家介绍了关于PyTorch模型保存与加载的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-04-04
  • python 按不同维度求和,最值,均值的实例

    python 按不同维度求和,最值,均值的实例

    今天小编就为大家分享一篇python 按不同维度求和,最值,均值的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • python简单商城购物车实例代码

    python简单商城购物车实例代码

    这篇文章主要为大家详细介绍了python简单商城购物车的实例代码,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03
  • 在python list中筛选包含字符的字段方式

    在python list中筛选包含字符的字段方式

    这篇文章主要介绍了在python list中筛选包含字符的字段方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • Python设置Socket代理及实现远程摄像头控制的例子

    Python设置Socket代理及实现远程摄像头控制的例子

    这篇文章主要介绍了Python设置Socket代理及实现远程摄像头控制的例子,皆是对socket模块的实际运用,需要的朋友可以参考下
    2015-11-11

最新评论