pytorch 实现cross entropy损失函数计算方式

 更新时间:2020年01月02日 14:29:20   作者:HawardScut  
今天小编就为大家分享一篇pytorch 实现cross entropy损失函数计算方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

均方损失函数:

这里 loss, x, y 的维度是一样的,可以是向量或者矩阵,i 是下标。

很多的 loss 函数都有 size_average 和 reduce 两个布尔类型的参数。因为一般损失函数都是直接计算 batch 的数据,因此返回的 loss 结果都是维度为 (batch_size, ) 的向量。

(1)如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss

(2)如果 reduce = True,那么 loss 返回的是标量

a)如果 size_average = True,返回 loss.mean();
b)如果 size_average = False,返回 loss.sum();

注意:默认情况下, reduce = True,size_average = True

import torch
import numpy as np

1、返回向量

loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)


a=np.array([[1,2],[3,4]])
b=np.array([[2,3],[4,5]])
input = torch.autograd.Variable(torch.from_numpy(a))
target = torch.autograd.Variable(torch.from_numpy(b))

这里将Variable类型统一为float()(tensor类型也是调用xxx.float())

loss = loss_fn(input.float(), target.float())
print(loss)
tensor([[ 1., 1.],
  [ 1., 1.]])

2、返回平均值

a=np.array([[1,2],[3,4]])
b=np.array([[2,3],[4,4]])
loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)
input = torch.autograd.Variable(torch.from_numpy(a))
target = torch.autograd.Variable(torch.from_numpy(b))
loss = loss_fn(input.float(), target.float())
 print(loss)
tensor(0.7500)

以上这篇pytorch 实现cross entropy损失函数计算方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python爬虫基本知识

    python爬虫基本知识

    最近在做一个项目,这个项目需要使用网络爬虫从特定网站上爬取数据,于是乎,我打算写一个爬虫系列的文章,与大家分享如何编写一个爬虫。下面这篇文章给大家介绍了python爬虫基本知识,感兴趣的朋友一起看看吧
    2018-03-03
  • Python Tkinter GUI编程实现Frame切换

    Python Tkinter GUI编程实现Frame切换

    本文主要介绍了Python Tkinter GUI编程实现Frame切换,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-04-04
  • Python中的引用和拷贝浅析

    Python中的引用和拷贝浅析

    这篇文章主要介绍了Python中的引用和拷贝浅析,本文同时讲解了深拷贝和浅拷贝、引用计数和垃圾回收等内容,需要的朋友可以参考下
    2014-11-11
  • python字符串格式化(%格式符和format方式)

    python字符串格式化(%格式符和format方式)

    在编写程序的过程中,经常需要进行格式化输出,每次用每次查,干脆就在这里整理一下,下面这篇文章主要给大家介绍了关于python字符串格式化的相关资料,分别是%格式符和format方式,需要的朋友可以参考下
    2022-02-02
  • Python解决C盘卡顿问题及操作脚本示例

    Python解决C盘卡顿问题及操作脚本示例

    这篇文章主要为大家介绍了Python解决C盘卡顿问题脚本示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01
  • Python使用Keras库中的LSTM模型生成新文本内容教程

    Python使用Keras库中的LSTM模型生成新文本内容教程

    Python语言使用金庸小说文本库,对文本进行预处理,然后使用Keras库中的LSTM模型创建和训练了模型,根据这个模型,我们可以生成新的文本,并探索小说的不同应用
    2024-01-01
  • 解读Opencv中Filter2D函数的补全方式

    解读Opencv中Filter2D函数的补全方式

    这篇文章主要介绍了解读Opencv中Filter2D函数的补全方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-12-12
  • Python实现的矩阵转置与矩阵相乘运算示例

    Python实现的矩阵转置与矩阵相乘运算示例

    这篇文章主要介绍了Python实现的矩阵转置与矩阵相乘运算,结合实例形式分析了Python针对矩阵进行转置与相乘运算的相关实现技巧与操作注意事项,需要的朋友可以参考下
    2019-03-03
  • python GUI实例学习

    python GUI实例学习

    给大家介绍一下python GUI实例学习的心得以及实现的方式,希望能帮助到你。
    2017-11-11
  • Python实现按特定格式对文件进行读写的方法示例

    Python实现按特定格式对文件进行读写的方法示例

    这篇文章主要介绍了Python实现按特定格式对文件进行读写的方法,可实现文件按原有格式读取与写入的功能,涉及文件的读取、遍历、转换、写入等相关操作技巧,需要的朋友可以参考下
    2017-11-11

最新评论