Keras之自定义损失(loss)函数用法说明

 更新时间:2020年06月10日 08:41:51   作者:鹊踏枝  
这篇文章主要介绍了Keras之自定义损失(loss)函数用法说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • mac系统下Redis安装和使用步骤详解

    mac系统下Redis安装和使用步骤详解

    这篇文章主要介绍了mac下Redis安装和使用步骤详解,并将python如何操作Redis做了简单介绍,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • 详细介绍Python中的偏函数

    详细介绍Python中的偏函数

    这篇文章主要介绍了Python中的偏函数,示例代码基于Python2.x版本,需要的朋友可以参考下
    2015-04-04
  • 使用tensorflow根据输入更改tensor shape

    使用tensorflow根据输入更改tensor shape

    这篇文章主要介绍了使用tensorflow根据输入更改tensor shape,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Python的UTC时间转换讲解

    Python的UTC时间转换讲解

    今天小编就为大家分享一篇关于Python的UTC时间转换讲解,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-02-02
  • Python2.7 实现引入自己写的类方法

    Python2.7 实现引入自己写的类方法

    下面小编就为大家分享一篇Python2.7 实现引入自己写的类方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • 利用pyinstaller将py文件打包为exe的方法

    利用pyinstaller将py文件打包为exe的方法

    本篇文章主要介绍了利用pyinstaller将py文件打包为exe的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-05-05
  • Pandas聚合运算和分组运算的实现示例

    Pandas聚合运算和分组运算的实现示例

    这篇文章主要介绍了Pandas聚合运算和分组运算的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-10-10
  • python中的hashlib和base64加密模块使用实例

    python中的hashlib和base64加密模块使用实例

    这篇文章主要介绍了python中的hashlib和base64加密模块使用实例,hashlib模块支持的加密算法有md5 sha1 sha224 sha256 sha384 sha512,需要的朋友可以参考下
    2014-09-09
  • 将string类型的数据类型转换为spark rdd时报错的解决方法

    将string类型的数据类型转换为spark rdd时报错的解决方法

    今天小编就为大家分享一篇关于将string类型的数据类型转换为spark rdd时报错的解决方法,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-02-02
  • python实现视频分帧效果

    python实现视频分帧效果

    这篇文章主要为大家详细介绍了python实现视频分帧效果,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-05-05

最新评论