pytorch如何实现多个矩阵拼接

 更新时间:2023年09月09日 08:59:59   作者:Arxan_hjw  
这篇文章主要介绍了pytorch如何实现多个矩阵拼接问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

pytorch多个矩阵拼接

问题描述

在处理数据的时候遇到一个for循环中生成多个【max_len*max_len】的二维矩阵,现需要将这些矩阵在第一维上进行堆叠,形成一个新的【batch * max_len * max_len】三维矩阵

实现过程

a = torch.ones(3, 3)  # 假设生成的矩阵形状为3*3
c = []  # 定义一个空列表用于存储矩阵
for i in range(3):
    a = a
    c.append(a.unsqueeze(0))
# 使用cat方法可之间实现该操作
c = torch.cat(c, dim=0)  
print(c.size())

输出c的形状:

torch.Size([3, 3, 3])

pytorch中torch.cat()矩阵拼接的用法

深度学习模型里的输出的东西还是有点搞的。torch.cat()的用处还是蛮大的。

下面直接举例子理解。

一维拼接

import torch
a = torch.Tensor([1, 2, 3])
b = a * 2
c = torch.cat((a, b), dim=0)  # dim=-1为取最后一维。这里只有一维-1和0是一样的
print(a.shape)
print(c.shape)
print(c)

二维拼接

dim就是选择哪一维进行拼接,dim=-1就表示最后一维进行拼接,这个也很好理解,索引-1一般都指最后一个字符

a = torch.Tensor([[1, 2]])
b = a * 2
c1 = torch.cat((a, b), dim=0)
c2 = torch.cat((a, b), dim=1)  # 这里第二维是最后一维,dim=-1和dim=1是一样的
print("a:", a)
print("a.shape:", a.shape)
print("c1:", c1)
print("c1.shape:", c1.shape)
print("c2:", c2)
print("c2.shape:", c2.shape)

当你使用pytorch深度学习模型时,隐藏层不止一层,最好将所有的隐藏层都利用起来,那么就需要进行隐藏层的拼接了。

假设隐藏层h_n.shape为(2,3,4)表示有2个隐藏层,batch_size为3(3个样本一起训练),隐藏层大小为4。由于隐藏层都包含了一定的信息,那么我们都利用起来应该效果比较好(听学长说很多论文都证明过了),那么每个样本对应的隐藏层应该都拼接起来用即2*4的大小。这样就需要用到拼接了。

h_n = torch.randn(2, 3, 4)  # 假设隐藏层
# 下面三种写法是一个意思
feature_map = torch.cat([h_n[i] for i in range(h_n.shape[0])], dim=-1)  # 索引第i个整元素,元素里剩下的维度缺省是全取的意思
feature_map1 = torch.cat([h_n[i, :, :] for i in range(h_n.shape[0])], dim=-1)
feature_map2 = torch.cat([h_n[i] for i in range(h_n.shape[0])], dim=1)
print(feature_map.shape)
print(feature_map1.shape)
print(feature_map2.shape)

隐藏层拼接完之后就可以放进全连接层然后出结果了。

由于LSTM的现在时刻的输出是前一个时刻的隐藏层和现在时刻的输入经过softmax得到的,而现在时刻的隐藏层是 现在时刻的输出*tanh(现在时刻的细胞状态)得到的,现在时刻的隐藏层也是包含了现在输入的信息的,因此直接放入全连接然后出结果就好了,至于模型的输出可以不用,直接用隐藏层也是可以的吧。或者说隐藏层就相当于包含着各自特征信息,输出层也是基于隐藏层来的,因此我们深度学习模型里直接用隐藏层就是在直接用那些特征吧(强行理解一波)

用模型的输出或者模型隐藏层应该都是可以得出结果的,目前对我来说,效果应该都差不多。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python使用functools实现注解同步方法

    Python使用functools实现注解同步方法

    这篇文章主要介绍了Python使用functools实现注解同步方法,非常不错,具有参考借鉴价值,需要的朋友可以参考下
    2018-02-02
  • Python操作列表的常用方法分享

    Python操作列表的常用方法分享

    这篇文章主要介绍了Python操作列表的常用方法,需要的朋友可以参考下
    2014-02-02
  • Python实现拉格朗日插值法的示例详解

    Python实现拉格朗日插值法的示例详解

    插值法是一种数学方法,用于在已知数据点(离散数据)之间插入数据,以生成连续的函数曲线,而格朗日插值法是一种多项式插值法。本文就来用Python实现拉格朗日插值法,希望对大家有所帮助
    2023-02-02
  • 使用python tkinter实现各种个样的撩妹鼠标拖尾效果

    使用python tkinter实现各种个样的撩妹鼠标拖尾效果

    这篇文章主要介绍了使用python tkinter实现各种个样的撩妹鼠标拖尾效果,本文通过实例代码,给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-09-09
  • Ubuntu下Anaconda和Pycharm配置方法详解

    Ubuntu下Anaconda和Pycharm配置方法详解

    这篇文章主要为大家详细介绍了Ubuntu下Anaconda和Pycharm配置方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-06-06
  • Python易忽视知识点小结

    Python易忽视知识点小结

    这篇文章主要介绍了Python易忽视知识点,实例分析了Python中容易被忽视的常见操作技巧,需要的朋友可以参考下
    2015-05-05
  • python可迭代类型遍历过程中数据改变会不会报错

    python可迭代类型遍历过程中数据改变会不会报错

    这篇文章主要介绍了python可迭代类型遍历过程中数据改变会不会报错问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-12-12
  • python决策树预测学生成绩等级实现详情

    python决策树预测学生成绩等级实现详情

    这篇文章主要为介绍了python决策树预测学生成绩等级,使用决策树完成学生成绩等级预测,可选取部分或全部特征,分析参数对结果的影响,并进行调参优化,决策树可视化进行调参优化分析
    2022-04-04
  • Python编写一个多线程的12306抢票程序的示例

    Python编写一个多线程的12306抢票程序的示例

    对于很多人来说,抢购火车票人们成了一个令人头疼的问题,本文主要介绍了Python编写一个多线程的12306抢票程序的示例,具有一定的参考价值,感兴趣的可以了解一下
    2023-09-09
  • Python字典操作得力助手Get()函数的使用

    Python字典操作得力助手Get()函数的使用

    在Python编程中,get()函数是字典(Dictionary)对象中非常有用的函数,本文将详细介绍get()函数的用法及示例代码,感兴趣的可以了解一下
    2023-11-11

最新评论