keras回调函数的使用

 更新时间:2023年03月13日 10:13:38   作者:辛勤的小码农^_^  
本文主要介绍了keras回调函数的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

回调函数

  • 回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用
  • 可以访问关于模型状态与模型性能的所有可用数据
  • 模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态。
  • 提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中的最佳模型)。
  • 在训练过程中动态调节某些参数值:比如调节优化器的学习率。
  • 在训练过程中记录训练指标和验证指标,或者将模型学到的表示可视化(这些表示在不断更新):fit()进度条实际上就是一个回调函数。

fit()方法中使用callbacks参数

# 这里有两个callback函数:早停和模型检查点
callbacks_list=[
    keras.callbacks.EarlyStopping(
        monitor="val_accuracy",#监控指标
        patience=2 #两轮内不再改善中断训练
    ),
    keras.callbacks.ModelCheckpoint(
        filepath="checkpoint_path",
        monitor="val_loss",
        save_best_only=True
    )
]
#模型获取
model=get_minist_model()
model.compile(optimizer="rmsprop",
             loss="sparse_categorical_crossentropy",
             metrics=["accuracy"])

model.fit(train_images,train_labels,
         epochs=10,callbacks=callbacks_list, #该参数使用回调函数
         validation_data=(val_images,val_labels))

test_metrics=model.evaluate(test_images,test_labels)#计算模型在新数据上的损失和指标
predictions=model.predict(test_images)#计算模型在新数据上的分类概率

训练结果

模型的保存和加载

#也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。
#重新加载模型
model_new=keras.models.load_model("checkpoint_path.keras")

通过对Callback类子类化来创建自定义回调函数

on_epoch_begin(epoch, logs) ←----在每轮开始时被调用
on_epoch_end(epoch, logs) ←----在每轮结束时被调用
on_batch_begin(batch, logs) ←----在处理每个批量之前被调用
on_batch_end(batch, logs) ←----在处理每个批量之后被调用
on_train_begin(logs) ←----在训练开始时被调用
on_train_end(logs ←----在训练结束时被调用

from matplotlib import pyplot as plt
# 实现记录每一轮中每个batch训练后的损失,并为每个epoch绘制一个图
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
                 label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"plot_at_epoch_{epoch}")
        self.per_batch_losses = [] #清空,方便下一轮的技术
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=10,
          callbacks=[LossHistory()],
          validation_data=(val_images, val_labels))

在这里插入图片描述

【其他】模型的定义 和 数据加载

def get_minist_model():
    inputs=keras.Input(shape=(28*28,))
    features=layers.Dense(512,activation="relu")(inputs)
    features=layers.Dropout(0.5)(features)
    outputs=layers.Dense(10,activation="softmax")(features)
    model=keras.Model(inputs,outputs)
    return model
    
#datset
from tensorflow.keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
train_images=train_images.reshape((60000,28*28)).astype("float32")/255
test_images=test_images.reshape((10000,28*28)).astype("float32")/255
train_images,val_images=train_images[10000:],train_images[:10000]
train_labels,val_labels=train_labels[10000:],train_labels[:10000]

到此这篇关于keras回调函数的使用的文章就介绍到这了,更多相关keras回调函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 解决Python pip 自动更新升级失败的问题

    解决Python pip 自动更新升级失败的问题

    今天小编就为大家分享一篇解决Python pip 自动更新升级失败的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • python正则-re的用法详解

    python正则-re的用法详解

    这篇文章主要介绍了python正则-re的用法详解,文中给大家提到了正则中的修饰符以及它的功能,需要的朋友可以参考下
    2019-07-07
  • Python使用函数默认值实现函数静态变量的方法

    Python使用函数默认值实现函数静态变量的方法

    这篇文章主要介绍了Python使用函数默认值实现函数静态变量的方法,是很实用的功能,需要的朋友可以参考下
    2014-08-08
  • Python实现的文本简单可逆加密算法示例

    Python实现的文本简单可逆加密算法示例

    这篇文章主要介绍了Python实现的文本简单可逆加密算法,结合完整实例形式分析了Python自定义加密与解密算法具体实现与使用技巧,需要的朋友可以参考下
    2017-05-05
  • Pytest中skip skipif跳过用例详解

    Pytest中skip skipif跳过用例详解

    今天给大家带来的是关于Python的相关知识,文章围绕着Pytest中skip skipif跳过用例展开,文中有非常详细的介绍及代码示例,需要的朋友可以参考下
    2021-06-06
  • Python内置函数memoryview()的实现示例

    Python内置函数memoryview()的实现示例

    本文主要介绍了Python内置函数memoryview()的实现示例,它允许你在不复制其内容的情况下操作同一个数组的不同切片,具有一定的参考价值,感兴趣的可以了解一下
    2024-05-05
  • Python OOP类中的几种函数或方法总结

    Python OOP类中的几种函数或方法总结

    今天小编就为大家分享一篇关于Python OOP类中的几种函数或方法总结,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-02-02
  • python matplotlib imshow热图坐标替换/映射实例

    python matplotlib imshow热图坐标替换/映射实例

    这篇文章主要介绍了python matplotlib imshow热图坐标替换/映射实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • pytorch中节约显卡内存的方法和技巧

    pytorch中节约显卡内存的方法和技巧

    显存不足是很多人感到头疼的问题,毕竟能拥有大量显存的实验室还是少数,而现在的模型已经越跑越大,模型参数量和数据集也越来越大,所以这篇文章给大家总结了一些pytorch中节约显卡内存的方法和技巧,需要的朋友可以参考下
    2023-11-11
  • Python学习之configparser模块的使用详解

    Python学习之configparser模块的使用详解

    ConfigParser是用来读取配置文件的包。这篇文章主要通过一些简单的实例带大家了解一下ConfigParser模块的具体使用,感兴趣的小伙伴跟随小编一起了解一下
    2023-01-01

最新评论