浅谈Keras中fit()和fit_generator()的区别及其参数的坑

 更新时间:2021年05月17日 11:14:02   作者:MrLeaper  
这篇文章主要介绍了Keras中fit()和fit_generator()的区别及其参数的坑,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

1、fit和fit_generator的区别

首先Keras中的fit()函数传入的x_train和y_train是被完整的加载进内存的,当然用起来很方便,但是如果我们数据量很大,那么是不可能将所有数据载入内存的,必将导致内存泄漏,这时候我们可以用fit_generator函数来进行训练。

下面是fit传参的例子:

history = model.fit(x_train, y_train, epochs=10,batch_size=32, 
                    validation_split=0.2)

这里需要给出epochs和batch_size,epoch是这个数据集要被轮多少次,batch_size是指这个数据集被分成多少个batch进行处理。

最后可以给出交叉验证集的大小,这里的0.2是指在训练集上占比20%。

fit_generator函数必须传入一个生成器,我们的训练数据也是通过生成器产生的,下面给出一个简单的生成器函数:

batch_size = 128
def generator():
    while 1:
        row = np.random.randint(0,len(x_train),size=batch_size)
        x = np.zeros((batch_size,x_train.shape[-1]))
        y = np.zeros((batch_size,))
        x = x_train[row]
        y = y_train[row]
        yield x,y

这里的生成器函数我产生的是一个batch_size为128大小的数据,这只是一个demo。如果我在生成器里没有规定batch_size的大小,就是每次产生一个数据,那么在用fit_generator时候里面的参数steps_per_epoch是不一样的。

这里的坑我困惑了好久,虽然不是什么大问题

下面是fit_generator函数的传参:

history = model.fit_generator(generator(),epochs=epochs,steps_per_epoch=len(x_train)//(batch_size*epochs))

2、batch_size和steps_per_epoch的区别

首先batch_size = 数据集大小/steps_per_epoch的,如果我们在生成函数里设置了batch_size的大小,那么在fit_generator传参的时候,,steps_per_epoch=len(x_train)//(batch_size*epochs)

我得完整demo代码:

from keras.datasets import imdb
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras import layers
import numpy as np
import random
from sklearn.metrics import f1_score,accuracy_score
max_features = 10000
maxlen = 500
batch_size = 32
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = pad_sequences(x_train,maxlen=maxlen)
x_test = pad_sequences(x_test,maxlen=maxlen)
 
def generator():
    while 1:
        row = np.random.randint(0,len(x_train),size=batch_size)
        x = np.zeros((batch_size,x_train.shape[-1]))
        y = np.zeros((batch_size,))
        x = x_train[row]
        y = y_train[row]
        yield x,y
# generator()
 
model = Sequential()
model.add(layers.Embedding(max_features,32,input_length=maxlen))
model.add(layers.GRU(64,return_sequences=True))
model.add(layers.GRU(32))
# model.add(layers.Flatten())
# model.add(layers.Dense(32,activation='relu'))
 
model.add(layers.Dense(1,activation='sigmoid'))
model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['acc'])
print(model.summary())
 
# history = model.fit(x_train, y_train, epochs=1,batch_size=32, validation_split=0.2)
history = model.fit_generator(generator(),epochs=1,steps_per_epoch=len(x_train)//(batch_size)) 
 
print(model.evaluate(x_test,y_test))
y = model.predict_classes(x_test) 
print(accuracy_score(y_test,y))

补充:model.fit_generator()详细解读

如下所示:

from keras import models
model = models.Sequential()

首先

利用keras,搭建顺序模型,具体搭建步骤省略。完成搭建后,我们需要将数据送入模型进行训练,送入数据的方式有很多种,models.fit_generator()是其中一种方式。

具体说,model.fit_generator()是利用生成器,分批次向模型送入数据的方式,可以有效节省单次内存的消耗。

具体函数形式如下:

fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, \
callbacks=None, validation_data=None, validation_steps=None,\
 class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)

参数解释:

generator:一般是一个生成器函数;

steps_per_epochs:是指在每个epoch中生成器执行生成数据的次数,若设定steps_per_epochs=100,这情况如下图所示;

epochs:指训练过程中需要迭代的次数;

verbose:默认值为1,是指在训练过程中日志的显示模式,取 1 时表示“进度条模式”,取2时表示“每轮一行”,取0时表示“安静模式”;

validation_data, validation_steps指验证集的情况,使用方式和generator, steps_per_epoch相同;

models.fit_generator()会返回一个history对象,history.history 属性记录训练过程中,连续 epoch 训练损失和评估值,以及验证集损失和评估值,可以通过以下方式调取这些值!

acc = history.history["acc"]
val_acc = history.history["val_acc"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python三元表达式的基本用法详解

    Python三元表达式的基本用法详解

    Python的三元表达式是一种紧凑、简洁的条件表达式,允许在一行代码中根据条件选择不同的值,三元表达式通常用于需要在单行中根据条件进行值选择的情况,有助于提高代码的可读性和简洁性,本文给大家介绍了Python三元表达式的基本用法,需要的朋友可以参考下
    2023-10-10
  • Python中for循环语句实战案例

    Python中for循环语句实战案例

    这篇文章主要给大家介绍了关于Python中for循环语句的相关资料,python中for循环一般用来迭代字符串,列表,元组等,当for循环用于迭代时不需要考虑循环次数,循环次数由后面的对象长度来决定,需要的朋友可以参考下
    2023-09-09
  • Python数据可视化之简单折线图的绘制

    Python数据可视化之简单折线图的绘制

    这篇文章主要为大家详细介绍了Python数据可视化之绘制简单折线图的相关资料,文中的示例代码简洁易懂,感兴趣的小伙伴可以了解一下
    2022-10-10
  • Python序列之list和tuple常用方法以及注意事项

    Python序列之list和tuple常用方法以及注意事项

    这篇文章主要介绍了Python序列之list和tuple常用方法以及注意事项,sequence(序列)是一组有顺序的对象的集合,序列可以包含一个或多个元素,也可以没有任何元素,序列有两种:list (表) 和 tuple(元组),需要的朋友可以参考下
    2015-01-01
  • 浅谈Python反射 & 单例模式

    浅谈Python反射 & 单例模式

    这篇文章主要介绍了Python反射 & 单例模式,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • 使用Python的matplotlib库绘制柱状图

    使用Python的matplotlib库绘制柱状图

    这篇文章主要介绍了使用Python的matplotlib库绘制柱状图,Matplotlib是Python中最常用的可视化工具之一,可以非常方便地创建海量类型地2D图表和一些基本的3D图表,可根据数据集自行定义x,y轴,绘制图形,需要的朋友可以参考下
    2023-07-07
  • Pytorch 使用不同版本的cuda的方法步骤

    Pytorch 使用不同版本的cuda的方法步骤

    这篇文章主要介绍了Pytorch 使用不同版本的cuda的方法步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-04-04
  • 解决Python pandas plot输出图形中显示中文乱码问题

    解决Python pandas plot输出图形中显示中文乱码问题

    今天小编就为大家分享一篇解决Python pandas plot输出图形中显示中文乱码问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • 关于python的第三方库下载与更改方式

    关于python的第三方库下载与更改方式

    这篇文章主要介绍了关于python的第三方库下载与更改方式,使用python的朋友都知道python有很多非常方便的第三方库可以使用,那么如果下载这些第三方库呢,今天小编就带你们来看看
    2023-04-04
  • 详解python异步编程之asyncio(百万并发)

    详解python异步编程之asyncio(百万并发)

    这篇文章主要介绍了详解python异步编程之asyncio(百万并发),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-07-07

最新评论