Pytorch linear 多维输入的参数问题

 更新时间:2022年08月19日 15:07:31   作者:又是花落时  
这篇文章主要介绍了Pytorch linear多维输入的参数的问题,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

问题: 由于 在输入lstm 层 每个batch 做了根据输入序列最大长度做了padding,导致每个 batch 的 length 不同。 导致输出 长度不同 。如:(batch, length, output_dim): (12,128,10),(12,111,10). 但是输入 linear 层的时候没有出现问题。

网站解释:

官网 pytorch linear:

  • Input:(*, H_{in})(∗,Hin​)where*∗means any number of dimensions including none andH_{in} = \text{in\_features}Hin​=in_features. 任意维度 number 理解有歧义 (a)number. k可以理解三维,四维。。。 (b) 可以理解 为某一维度的数 。
  • Output:(*, H_{out})(∗,Hout​)where all but the last dimension are the same shape as the input andH_{out} = \text{out\_features}Hout​=out_features.

代码解释:

分别 用三维 和二维输入数组,查看他们参数数目是否一样。

import torch
 
x = torch.randn(128, 20)  # 输入的维度是(128,20)
m = torch.nn.Linear(20, 30)  # 20,30是指维度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
 
# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
ans = torch.mm(x, m.weight.t()) + m.bias   
print('ans.shape:\n', ans.shape)
 
print(torch.equal(ans, output))

output:

m.weight.shape:
  torch.Size([30, 20])
m.bias.shape:
 torch.Size([30])
output.shape:
 torch.Size([128, 30])
ans.shape:
 torch.Size([128, 30])
True
x = torch.randn(128, 30,20)  # 输入的维度是(128,30,20)
m = torch.nn.Linear(20, 30)  # 20,30是指维度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
ouput:
m.weight.shape:
  torch.Size([30, 20])
m.bias.shape:
 torch.Size([30])
output.shape:
 torch.Size([128, 30, 30])

结果:

(128,30,20),和 (128,20) 分别是如 nn.linear(30,20) 层。

weight.shape 均为: (30,20)

linear() 参数数目只和 input_dim ,output_dim 有关。

weight 在源码的定义, 没找到如何计算多维input的代码。

到此这篇关于Pytorch linear 多维 输入的参数的文章就介绍到这了,更多相关Pytorch多维 输入内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python实现快速排序的示例(二分法思想)

    python实现快速排序的示例(二分法思想)

    本篇文章主要介绍了python实现快速排序的示例(二分法思想),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-03-03
  • python中PIL安装简单教程

    python中PIL安装简单教程

    这篇文章主要为大家分享了python中PIL安装简单教程,感兴趣的小伙伴们可以参考一下
    2016-04-04
  • 解决Python 异常TypeError: cannot concatenate ''str'' and ''int'' objects

    解决Python 异常TypeError: cannot concatenate ''str'' and ''int''

    这篇文章主要介绍了解决Python 异常TypeError: cannot concatenate 'str' and 'int' objects,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • python方向键控制上下左右代码

    python方向键控制上下左右代码

    这篇文章主要介绍了python方向键控制上下左右代码,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友可以参考下
    2018-01-01
  • 推荐一款高效的python数据框处理工具Sidetable

    推荐一款高效的python数据框处理工具Sidetable

    这篇文章主要为大家介绍推荐一款高效的python数据框处理工具Sidetable,文章详细的讲解了Sidetable的安装及用法,有需要的朋友可以借鉴参考下,希望能够有所帮助
    2021-11-11
  • Python实现单项链表的最全教程

    Python实现单项链表的最全教程

    单向链表也叫单链表,是链表中最简单的一种形式,它的每个节点包含两个域,一个信息域(元素域)和一个链接域,这个链接指向链表中的下一个节点,而最后一个节点的链接域则指向一个空值,这篇文章主要介绍了Python实现单项链表,需要的朋友可以参考下
    2023-01-01
  • python制作的天气预报小工具(gui界面)

    python制作的天气预报小工具(gui界面)

    大家好啊!我用Tkinter写了一个天气预报小工具,支持34个省级行政区以及港澳台地区天气,覆盖全面。程序打包好放在了蓝奏云,与大家分享一下。
    2021-05-05
  • Python 3.11.0下载安装并使用help查看模块信息的方法

    Python 3.11.0下载安装并使用help查看模块信息的方法

    本文给大家介绍Python 3.11.0下载安装并使用help查看模块信息的相关知识,首先给大家讲解了Python 3.11.0下载及安装紧接着介绍了在命令行使用help查看模块信息的方法,感兴趣的朋友跟随小编一起看看吧
    2022-11-11
  • python的staticmethod与classmethod实现实例代码

    python的staticmethod与classmethod实现实例代码

    这篇文章主要介绍了python的staticmethod与classmethod实现实例代码,分享了相关代码示例,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友可以参考下
    2018-02-02
  • Python实现统计文本文件字数的方法

    Python实现统计文本文件字数的方法

    这篇文章主要介绍了Python实现统计文本文件字数的方法,涉及Python针对文本文件读取及字符串转换、运算等相关操作技巧,需要的朋友可以参考下
    2017-05-05

最新评论