pytorch交叉熵损失函数的weight参数的使用

 更新时间:2021年05月24日 09:59:08   作者:Nick Blog  
这篇文章主要介绍了pytorch交叉熵损失函数的weight参数的使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 解析python调用函数加括号和不加括号的区别

    解析python调用函数加括号和不加括号的区别

    这篇文章主要介绍了python调用函数加括号和不加括号的区别,不带括号时,调用的是这个函数本身 ,是整个函数体,是一个函数对象,不须等该函数执行完成,具体实例代码跟随小编一起看看吧
    2021-10-10
  • Python简单获取网卡名称及其IP地址的方法【基于psutil模块】

    Python简单获取网卡名称及其IP地址的方法【基于psutil模块】

    这篇文章主要介绍了Python简单获取网卡名称及其IP地址的方法,结合实例形式分析了Python基于psutil模块针对本机网卡硬件信息的读取操作简单使用技巧,需要的朋友可以参考下
    2018-05-05
  • python使用socket向客户端发送数据的方法

    python使用socket向客户端发送数据的方法

    这篇文章主要介绍了python使用socket向客户端发送数据的方法,涉及Python使用socket实现数据通信的技巧,非常具有实用价值,需要的朋友可以参考下
    2015-04-04
  • python中的继承机制super()函数详解

    python中的继承机制super()函数详解

    这篇文章主要介绍了python中的继承机制super()函数详解,super 是用来解决多重继承问题的,直接用类名调用父类方法在使用单继承的时候没问题,但是如果使用多继承,会涉及到查找顺序、重复调用等问题,需要的朋友可以参考下
    2023-08-08
  • Python管理Windows服务小脚本

    Python管理Windows服务小脚本

    这篇文章主要为大家详细介绍了Python管理Windows服务的小脚本,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03
  • django的csrf实现过程详解

    django的csrf实现过程详解

    这篇文章主要介绍了django的csrf实现过程相加,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • Python实现句子翻译功能

    Python实现句子翻译功能

    这篇文章主要介绍了Python实现句子翻译功能,涉及urllib库的使用等相关内容,具有一定参考价值,需要的朋友可以了解下。
    2017-11-11
  • Python爬虫学习之翻译小程序

    Python爬虫学习之翻译小程序

    这篇文章主要为大家详细介绍了Python爬虫学习之翻译小程序,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-07-07
  • python实现简易通讯录修改版

    python实现简易通讯录修改版

    这篇文章主要为大家详细介绍了python实现简易通讯录的修改版,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03
  • Python使用FFMPEG压缩视频的方法

    Python使用FFMPEG压缩视频的方法

    FFMPEG是一个完整的,跨平台的解决方案,记录,转换和流音频和视频,,这篇文章主要介绍了FFMPEG视频压缩与Python使用方法,需要的朋友可以参考下
    2023-09-09

最新评论