keras 自定义loss层+接受输入实例

 更新时间:2020年06月28日 14:52:35   作者:lgy_keira  
这篇文章主要介绍了keras 自定义loss层+接受输入实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Django实现从数据库中获取到的数据转换为dict

    Django实现从数据库中获取到的数据转换为dict

    这篇文章主要介绍了Django实现从数据库中获取到的数据转换为dict,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • 如何用python反转图片,视频

    如何用python反转图片,视频

    这篇文章主要介绍了如何用python反转图片,视频,帮助大家更好的利用python处理图像,感兴趣的朋友可以了解下
    2021-04-04
  • 对pyqt5多线程正确的开启姿势详解

    对pyqt5多线程正确的开启姿势详解

    今天小编就为大家分享一篇对pyqt5多线程正确的开启姿势详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • Django跨域请求原理及实现代码

    Django跨域请求原理及实现代码

    这篇文章主要介绍了Django跨域请求原理及实现代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • Python数据分析之缺失值检测与处理详解

    Python数据分析之缺失值检测与处理详解

    在实际的数据处理中,缺失值是普遍存在的,如何使用 Python 检测和处理缺失值,就是本文要讲的主要内容。感兴趣的同学可以关注一下
    2021-12-12
  • 详解如何使用Python处理INI、YAML和JSON配置文件

    详解如何使用Python处理INI、YAML和JSON配置文件

    在软件开发中,配置文件是存储程序配置信息的常见方式,INI、YAML和JSON是常用的配置文件格式,各自有着特定的结构和用途,Python拥有丰富的库和模块,本文将重点探讨如何使用Python处理这三种格式的配置文件,需要的朋友可以参考下
    2023-12-12
  • Python干货实战之八音符酱小游戏全过程详解

    Python干货实战之八音符酱小游戏全过程详解

    读万卷书不如行万里路,只学书上的理论是远远不够的,只有在实战中才能获得能力的提升,本篇文章手把手带你用Python实现一个八音符酱小游戏,大家可以在过程中查缺补漏,提升水平
    2021-10-10
  • 基于Python开发云主机类型管理脚本分享

    基于Python开发云主机类型管理脚本分享

    这篇文章主要为大家详细介绍了如何基于Python开发一个云主机类型管理脚本,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下
    2023-02-02
  • Python中的省略号(Ellipsis)赋值方式详解

    Python中的省略号(Ellipsis)赋值方式详解

    在Python编程中,省略号(...)是一种特殊对象,主要用作函数占位、未实现的方法示例和NumPy数组处理,本文通过示例详细解释了省略号的赋值方式及其在不同编程场景下的应用,帮助提升Python编程技巧
    2024-10-10
  • 在Python程序中进行文件读取和写入操作的教程

    在Python程序中进行文件读取和写入操作的教程

    这篇文章主要介绍了在Python程序中进行文件读取和写入操作的教程,是Python学习当中的基础知识,需要的朋友可以参考下
    2015-04-04

最新评论