Keras之自定义损失(loss)函数用法说明

 更新时间:2020年06月10日 08:41:51   作者:鹊踏枝  
这篇文章主要介绍了Keras之自定义损失(loss)函数用法说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • pyqt 实现为长内容添加滑轮 scrollArea

    pyqt 实现为长内容添加滑轮 scrollArea

    今天小编就为大家分享一篇pyqt 实现为长内容添加滑轮 scrollArea,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • Python实战之基于OpenCV的美颜挂件制作

    Python实战之基于OpenCV的美颜挂件制作

    在本文中,我们将学习如何创建有趣的基于Snapchat的增强现实,主要包括两个实战项目:在检测到的人脸上的鼻子和嘴巴之间添加胡子挂件,在检测到的人脸上添加眼镜挂件。感兴趣的童鞋可以看看哦
    2021-11-11
  • Python 中 function(#) (X)格式 和 (#)在Python3.*中的注意事项

    Python 中 function(#) (X)格式 和 (#)在Python3.*中的注意事项

    这篇文章主要介绍了Python 中 function(#) (X)格式 和 (#)在Python3.*中的注意事项,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2018-11-11
  • Python2.x中str与unicode相关问题的解决方法

    Python2.x中str与unicode相关问题的解决方法

    这篇文章主要介绍了Python2.x中str与Unicode相关问题的解决方法,Python2.x版本中由于没有默认使用Unicode而会在实际使用中碰到一些字符问题,针对这些问题本文讨论了一些解决方法,需要的朋友可以参考下
    2015-03-03
  • Python在日志中隐藏明文密码的方法

    Python在日志中隐藏明文密码的方法

    logging日志模块是python的一个内置模块,该模块定义了一些函数和类,为上层应用程序或库实现了一个强大而又灵活的日志记录系统,这篇文章主要介绍了Python如何在日志中隐藏明文密码 ,需要的朋友可以参考下
    2023-10-10
  • Python中if elif else及缩进的使用简述

    Python中if elif else及缩进的使用简述

    这篇文章主要介绍了Python中if elif else及缩进的使用,代码简单易懂,非常不错,具有一定的参考借鉴价值,需要的朋友参考下吧
    2018-05-05
  • 简单介绍Python中用于求最小值的min()方法

    简单介绍Python中用于求最小值的min()方法

    这篇文章主要介绍了简单介绍Python中用于求最小值的min()方法,是Python入门中的基础知识,需要的朋友可以参考下
    2015-05-05
  • python __add__()的具体使用

    python __add__()的具体使用

    本文主要介绍了python __add__()的具体使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • python集合常见运算案例解析

    python集合常见运算案例解析

    这篇文章主要介绍了python集合常见运算,结合具体实例形式分析了Python使用集合生成随机数的几种常用算法的效率比较,需要的朋友可以参考下
    2019-10-10
  • Python文件时间操作步骤代码详解

    Python文件时间操作步骤代码详解

    这篇文章主要介绍了Python文件时间操作步骤代码详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-04-04

最新评论