Pytorch损失函数nn.NLLLoss2d()用法说明

 更新时间:2020年07月07日 14:24:25   作者:起步晚就要快点跑  
这篇文章主要介绍了Pytorch损失函数nn.NLLLoss2d()用法说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

最近做显著星检测用到了NLL损失函数

对于NLL函数,需要自己计算log和softmax的概率值,然后从才能作为输入

输入 [batch_size, channel , h, w]

目标 [batch_size, h, w]

输入的目标矩阵,每个像素必须是类型.举个例子。第一个像素是0,代表着类别属于输入的第1个通道;第二个像素是0,代表着类别属于输入的第0个通道,以此类推。

x = Variable(torch.Tensor([[[1, 2, 1],
       [2, 2, 1],
       [0, 1, 1]],
       [[0, 1, 3],
       [2, 3, 1],
       [0, 0, 1]]]))

x = x.view([1, 2, 3, 3])
print("x输入", x)

这里输入x,并改成[batch_size, channel , h, w]的格式。

soft = nn.Softmax(dim=1)

log_soft = nn.LogSoftmax(dim=1)

然后使用softmax函数计算每个类别的概率,这里dim=1表示从在1维度

上计算,也就是channel维度。logsoftmax是计算完softmax后在计算log值

手动计算举个栗子:第一个元素

y = Variable(torch.LongTensor([[1, 0, 1],
       [0, 0, 1],
       [1, 1, 1]]))

y = y.view([1, 3, 3])

输入label y,改变成[batch_size, h, w]格式

loss = nn.NLLLoss2d()
out = loss(x, y)
print(out)

输入函数,得到loss=0.7947

来手动计算

第一个label=1,则 loss=-1.3133

第二个label=0, 则loss=-0.3133

.
…
…
loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223

是一致的

注意:这个函数会对每个像素做平均,每个batch也会做平均,这里有9个像素,1个batch_size。

补充知识:PyTorch:NLLLoss2d

我就废话不多说了,大家还是直接看代码吧~

import torch
import torch.nn as nn
from torch import autograd
import torch.nn.functional as F
 
inputs_tensor = torch.FloatTensor([
[[2, 4],
 [1, 2]],
[[5, 3],
 [3, 0]],
[[5, 3],
 [5, 2]],
[[4, 2],
 [3, 2]],
 ])
inputs_tensor = torch.unsqueeze(inputs_tensor,0)
# inputs_tensor = torch.unsqueeze(inputs_tensor,1)
print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape
 
targets_tensor = torch.LongTensor([
 [0, 2],
 [2, 3]
])
 
targets_tensor = torch.unsqueeze(targets_tensor,0)
print '--target size(nBatch x height x width): ', targets_tensor.shape
 
inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)
inputs_variable = F.log_softmax(inputs_variable)
targets_variable = autograd.Variable(targets_tensor)
 
loss = nn.NLLLoss2d()
output = loss(inputs_variable, targets_variable)
print '--NLLLoss2d: {}'.format(output)

以上这篇Pytorch损失函数nn.NLLLoss2d()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • flask应用部署到服务器的方法

    flask应用部署到服务器的方法

    这篇文章主要介绍了flask应用部署到服务器的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • Python日志无延迟实时写入的示例

    Python日志无延迟实时写入的示例

    今天小编就为大家分享一篇Python日志无延迟实时写入的示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python3实现发送QQ邮件功能(附件)

    Python3实现发送QQ邮件功能(附件)

    这篇文章主要为大家详细介绍了Python3实现发送QQ邮件功能,附件方面,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-12-12
  • Python中的lambda和apply用法及说明

    Python中的lambda和apply用法及说明

    这篇文章主要介绍了Python中的lambda和apply用法及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-12-12
  • python+selenium定时爬取丁香园的新型冠状病毒数据并制作出类似的地图(部署到云服务器)

    python+selenium定时爬取丁香园的新型冠状病毒数据并制作出类似的地图(部署到云服务器)

    这篇文章主要介绍了python+selenium定时爬取丁香园的新冠病毒每天的数据并制作出类似的地图(部署到云服务器),本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-02-02
  • Python selenium爬取微博数据代码实例

    Python selenium爬取微博数据代码实例

    这篇文章主要介绍了Python selenium爬取微博数据代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • Django框架设置cookies与获取cookies操作详解

    Django框架设置cookies与获取cookies操作详解

    这篇文章主要介绍了Django框架设置cookies与获取cookies操作,结合实例形式详细分析了Django框架针对cookie操作的各种常见技巧与操作注意事项,需要的朋友可以参考下
    2019-05-05
  • Python动态声明变量赋值代码实例

    Python动态声明变量赋值代码实例

    这篇文章主要介绍了Python动态声明变量赋值代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-12-12
  • Django ForeignKey与数据库的FOREIGN KEY约束详解

    Django ForeignKey与数据库的FOREIGN KEY约束详解

    这篇文章主要介绍了Django ForeignKey与数据库的FOREIGN KEY约束详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Python中JSON转换的全面指南与最佳实践

    Python中JSON转换的全面指南与最佳实践

    JSON是现代应用程序中最流行的数据交换格式之一,Python通过内置的json模块提供了强大的JSON处理能力,本文将深入探讨Python中的JSON转换,包括基本用法、高级特性以及最佳实践,需要的朋友可以参考下
    2025-03-03

最新评论