PyTorch中常见损失函数的使用详解

 更新时间:2023年06月18日 11:34:38   作者:夏天是冰红茶  
损失函数,又叫目标函数,是指计算机标签值和预测值直接差异的函数,本文为大家整理了PyTorch中常见损失函数的简单解释和使用,希望对大家有所帮助

损失函数

损失函数,又叫目标函数。在编译神经网络模型必须的两个参数之一。另一个必不可少的就是优化器,我将在后面详解到。

重点

损失函数是指计算机标签值和预测值直接差异的函数。

这里我们会结束几种常见的损失函数的计算方法,pytorch中也是以及定义了很多类型的预定义函数,具体的公式不需要去深究(学了也不一定remember),这里暂时能做就是了解。

我们先来定义两个二维的数组,然后用不同的损失函数计算其损失值。

import torch
from torch.autograd import Variable
import torch.nn as nn
sample=Variable(torch.ones(2,2))
a=torch.Tensor(2,2)
a[0,0]=0
a[0,1]=1
a[1,0]=2
a[1,1]=3
target=Variable(a)
print(sample,target)

这里:

sample的值为tensor([[1., 1.],[1., 1.]])

target的值为tensor([[0., 1.],[2., 3.]])

nn.L1Loss

L1Loss计算方法很简单,取预测值和真实值的绝对误差的平均数。

loss=FunLoss(sample,target)['L1Loss']
print(loss)

在控制台中打印出来是

tensor(1.)

它的计算过程是这样的:(∣0−1∣+∣1−1∣+∣2−1∣+∣3−1∣)/4=1,先计算的是绝对值求和,然后再平均。

nn.SmoothL1Loss

SmoothL1Loss的误差在(-1,1)上是平方损失,其他情况是L1损失。

loss=FunLoss(sample,target)['SmoothL1Loss']
print(loss)

在控制台中打印出来是

tensor(0.6250)

nn.MSELoss

平方损失函数。其计算公式是预测值和真实值之间的平方和的平均数。

loss=FunLoss(sample,target)['MSELoss']
print(loss)

在控制台中打印出来是

tensor(1.5000)

nn.CrossEntropyLoss

交叉熵损失公式

此公式常在图像分类神经网络模型中会常常用到。

loss=FunLoss(sample,target)['CrossEntropyLoss']
print(loss)

在控制台中打印出来是

tensor(2.0794)

nn.NLLLoss

负对数似然损失函数

需要注意的是,这里的xlabel和上面的交叉熵损失里的是不一样的,这里是经过log运算后的数值。这个损失函数一般用在图像识别的模型上。

loss=FunLoss(sample,target)['NLLLoss']
print(loss)

这里,控制台报错,需要0D或1D目标张量,不支持多目标。可能需要其他的一些条件,这里我们如果遇到了再说。

损失函数模块化设计

class FunLoss():
    def __init__(self, sample, target):
        self.sample = sample
        self.target = target
        self.loss = {
            'L1Loss': nn.L1Loss(),
            'SmoothL1Loss': nn.SmoothL1Loss(),
            'MSELoss': nn.MSELoss(),
            'CrossEntropyLoss': nn.CrossEntropyLoss(),
            'NLLLoss': nn.NLLLoss()
        }
    def __getitem__(self, loss_type):
        if loss_type in self.loss:
            loss_func = self.loss[loss_type]
            return loss_func(self.sample, self.target)
        else:
            raise KeyError(f"Invalid loss type '{loss_type}'")
if __name__=="__main__":
    loss=FunLoss(sample,target)['NLLLoss']
    print(loss)

总结

这篇博客适合那些希望了解在PyTorch中常见损失函数的读者。通过FunLoss我们自己也能简单的去调用。

到此这篇关于PyTorch中常见损失函数的使用详解的文章就介绍到这了,更多相关PyTorch损失函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python中正则表达式的用法总结

    Python中正则表达式的用法总结

    今天小编就为大家分享一篇关于Python中正则表达式的用法总结,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-02-02
  • pytorch 转换矩阵的维数位置方法

    pytorch 转换矩阵的维数位置方法

    今天小编就为大家分享一篇pytorch 转换矩阵的维数位置方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • python spilt()分隔字符串的实现示例

    python spilt()分隔字符串的实现示例

    split() 方法可以实现将一个字符串按照指定的分隔符切分成多个子串,本文介绍了spilt的具体使用,感兴趣的可以了解一下
    2021-05-05
  • Python循环结构详解

    Python循环结构详解

    这篇文章主要介绍了Python循环结构详解,文中有非常详细的代码示例,对正在学习python的小伙伴们有很好的帮助,需要的朋友可以参考下
    2021-05-05
  • Python数字比较与类结构

    Python数字比较与类结构

    这篇文章主要介绍了Python数字比较与类结构,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-07-07
  • Python issubclass和isinstance函数的具体使用

    Python issubclass和isinstance函数的具体使用

    本文主要介绍了Python issubclass和isinstance函数的具体使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • Python使用Selenium模块实现模拟浏览器抓取淘宝商品美食信息功能示例

    Python使用Selenium模块实现模拟浏览器抓取淘宝商品美食信息功能示例

    这篇文章主要介绍了Python使用Selenium模块实现模拟浏览器抓取淘宝商品美食信息功能,涉及Python基于re模块的正则匹配及selenium模块的页面抓取等相关操作技巧,需要的朋友可以参考下
    2018-07-07
  • Django展示可视化图表的多种方式

    Django展示可视化图表的多种方式

    这篇文章主要介绍了Django展示可视化图表的多种方式,帮助大家更好的理解和学习使用django框架,感兴趣的朋友可以了解下
    2021-04-04
  • python开启多个子进程并行运行的方法

    python开启多个子进程并行运行的方法

    这篇文章主要介绍了python开启多个子进程并行运行的方法,涉及Python进程操作的相关技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-04-04
  • 通过conda把已有虚拟环境的python版本进行降级操作指南

    通过conda把已有虚拟环境的python版本进行降级操作指南

    当使用conda创建虚拟环境时,有时候可能会遇到python版本不对的问题,下面这篇文章主要给大家介绍了关于如何通过conda把已有虚拟环境的python版本进行降级操作的相关资料,需要的朋友可以参考下
    2024-05-05

最新评论