pytorch交叉熵损失函数的weight参数的使用
首先
必须将权重也转为Tensor的cuda格式;
然后
将该class_weight作为交叉熵函数对应参数的输入值。
1 | class_weight = torch.FloatTensor([ 0.13859937 , 0.5821059 , 0.63871904 , 2.30220396 , 7.1588294 , 0 ]).cuda() |
补充:关于pytorch的CrossEntropyLoss的weight参数
首先这个weight参数比想象中的要考虑的多
你可以试试下面代码
1 2 3 4 5 6 7 8 9 10 | import torch import torch.nn as nn inputs = torch.FloatTensor([ 0 , 1 , 0 , 0 , 0 , 1 ]) outputs = torch.LongTensor([ 0 , 1 ]) inputs = inputs.view(( 1 , 3 , 2 )) outputs = outputs.view(( 1 , 2 )) weight_CE = torch.FloatTensor([ 1 , 1 , 1 ]) ce = nn.CrossEntropyLoss(ignore_index = 255 ,weight = weight_CE) loss = ce(inputs,outputs) print (loss) |
这里的手动计算是:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803
加权呢?
1 2 3 4 5 6 7 8 9 10 | import torch import torch.nn as nn inputs = torch.FloatTensor([ 0 , 1 , 0 , 0 , 0 , 1 ]) outputs = torch.LongTensor([ 0 , 1 ]) inputs = inputs.view(( 1 , 3 , 2 )) outputs = outputs.view(( 1 , 2 )) weight_CE = torch.FloatTensor([ 1 , 2 , 3 ]) ce = nn.CrossEntropyLoss(ignore_index = 255 ,weight = weight_CE) loss = ce(inputs,outputs) print (loss) |
手算发现,并不是单纯的那权重相乘:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113
而是
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075
发现了么,加权后,除以的是权重的和,不是数目的和。
我们再验证一遍:
1 2 3 4 5 6 7 8 9 10 11 | import torch import torch.nn as nn inputs = torch.FloatTensor([ 0 , 1 , 2 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0.5 ]) outputs = torch.LongTensor([ 0 , 1 , 2 , 2 ]) inputs = inputs.view(( 1 , 3 , 4 )) outputs = outputs.view(( 1 , 4 )) weight_CE = torch.FloatTensor([ 1 , 2 , 3 ]) ce = nn.CrossEntropyLoss(weight = weight_CE) # ce = nn.CrossEntropyLoss(ignore_index=255) loss = ce(inputs,outputs) print (loss) |
手算:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
loss3 = 0 + ln(e2 + e0 + e0) = 2.2395
loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943
求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472
可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

微信公众号搜索 “ 脚本之家 ” ,选择关注
程序猿的那些事、送书等活动等着你
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若内容造成侵权/违法违规/事实不符,请将相关资料发送至 reterry123@163.com 进行投诉反馈,一经查实,立即处理!
相关文章
Python简单获取网卡名称及其IP地址的方法【基于psutil模块】
这篇文章主要介绍了Python简单获取网卡名称及其IP地址的方法,结合实例形式分析了Python基于psutil模块针对本机网卡硬件信息的读取操作简单使用技巧,需要的朋友可以参考下2018-05-05
最新评论