解决pytorch中的kl divergence计算问题

 更新时间:2021年05月24日 09:20:47   投稿:jingxian  
这篇文章主要介绍了解决pytorch中的kl divergence计算问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

偶然从pytorch讨论论坛中看到的一个问题,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中计算结果不同,平时没有注意到,记录下

一篇关于KL散度、JS散度以及交叉熵对比的文章

kl divergence 介绍

KL散度( Kullback–Leibler divergence),又称相对熵,是描述两个概率分布 P 和 Q 差异的一种方法。计算公式:

可以发现,P 和 Q 中元素的个数不用相等,只需要两个分布中的离散元素一致。

举个简单例子:

两个离散分布分布分别为 P 和 Q

P 的分布为:{1,1,2,2,3}

Q 的分布为:{1,1,1,1,1,2,3,3,3,3}

我们发现,虽然两个分布中元素个数不相同,P 的元素个数为 5,Q 的元素个数为 10。但里面的元素都有 “1”,“2”,“3” 这三个元素。

当 x = 1时,在 P 分布中,“1” 这个元素的个数为 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 这个元素的个数为 5,故 Q(x = 1) = 5/10 = 0.5

同理,

当 x = 2 时,P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1

当 x = 3 时,P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4

把上述概率带入公式:

至此,就计算完成了两个离散变量分布的KL散度。

pytorch 中的 kl_div 函数

pytorch中有用于计算kl散度的函数 kl_div

torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')

在这里插入图片描述

计算 D (p||q)

1、不用这个函数的计算结果为:

在这里插入图片描述

与手算结果相同

2、使用函数:

(这是计算正确的,结果有差异是因为pytorch这个函数中默认的是以e为底)

在这里插入图片描述

注意:

1、函数中的 p q 位置相反(也就是想要计算D(p||q),要写成kl_div(q.log(),p)的形式),而且q要先取 log

2、reduction 是选择对各部分结果做什么操作,默认为取平均数,这里选择求和

好别扭的用法,不知道为啥官方把它设计成这样

补充:pytorch 的KL divergence的实现

看代码吧~

import torch.nn.functional as F
# p_logit: [batch, class_num]
# q_logit: [batch, class_num]
def kl_categorical(p_logit, q_logit):
    p = F.softmax(p_logit, dim=-1)
    _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1)
                                  - F.log_softmax(q_logit, dim=-1)), 1)
    return torch.mean(_kl)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python实现 PS 图像调整中的亮度调整

    Python实现 PS 图像调整中的亮度调整

    这篇文章主要介绍了Python实现 PS 图像调整中的亮度调整 ,需要的朋友可以参考下
    2019-06-06
  • 教你用pytorch训练五子棋ai示例代码

    教你用pytorch训练五子棋ai示例代码

    这篇文章主要介绍了五个与五子棋相关的Python文件,包括游戏逻辑、神经网络模型、训练代码以及玩家对战代码,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2025-03-03
  • Python使用MoviePy轻松搞定视频编辑

    Python使用MoviePy轻松搞定视频编辑

    MoviePy 是一个使用 Python 编写的开源库,用于在视频编辑中创建、编辑和操作视频文件。本文就来教一下大家如何使用MoviePy轻松搞定视频编辑,需要的可以了解一下
    2023-05-05
  • 详解python路径拼接os.path.join()函数的用法

    详解python路径拼接os.path.join()函数的用法

    os.path.join()函数:连接两个或更多的路径名组件。这篇文章主要介绍了python路径拼接os.path.join()函数的用法,需要的朋友可以参考下
    2019-10-10
  • 如何使用pyinstaller打包时引入自己编写的库

    如何使用pyinstaller打包时引入自己编写的库

    这篇文章主要介绍了如何使用pyinstaller打包时引入自己编写的库,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-02-02
  • python如何调用php文件中的函数详解

    python如何调用php文件中的函数详解

    这篇文章主要给大家介绍了关于python如何调用php文件中函数的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-12-12
  • 解析pandas apply() 函数用法(推荐)

    解析pandas apply() 函数用法(推荐)

    这篇文章主要介绍了pandas apply() 函数用法,大家需要掌握函数作为一个对象,能作为参数传递给其它函数,也能作为函数的返回值,具体内容详情跟随小编一起看看吧
    2021-10-10
  • Python解析器安装指南分享(Mac/Windows/Linux)

    Python解析器安装指南分享(Mac/Windows/Linux)

    这篇文章主要介绍了Python解析器安装指南(Mac/Windows/Linux),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2025-03-03
  • python多进程及通信实现异步任务的方法

    python多进程及通信实现异步任务的方法

    这篇文章主要介绍了python多进程及通信实现异步任务需求,本人也是很少接触多进程的场景,对于python多进程的使用也是比较陌生的。在接触了一些多进程的业务场景下,对python多进程的使用进行了学习,觉得很有必要进行一个梳理总结,感兴趣的朋友一起看看吧
    2022-05-05
  • Python常见沙箱技术与沙箱逃逸避免方法详解

    Python常见沙箱技术与沙箱逃逸避免方法详解

    Python沙箱可以帮助你在安全的环境中运行不受信任的代码,本文将探讨 Python 沙箱的概念、常见的沙箱技术以及如何避免沙箱逃逸,感兴趣的可以了解下
    2024-01-01

最新评论