Pytorch中关于RNN输入和输出的形状总结

 更新时间:2023年06月15日 08:35:29   作者:会唱歌的猪233  
这篇文章主要介绍了Pytorch中关于RNN输入和输出的形状总结,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

Pytorch对RNN输入和输出的形状总结

个人对于RNN的一些总结。

RNN的输入和输出

RNN的经典图如下所示


各个参数的含义

  • Xt: t时刻的输入,形状为[batch_size, input_dim]。对于整个RNN来说,总的X输入为[seq_len, batch_size, input_dim],具体如何理解batch_size和seq_len在下面有说明。
  • St: t时刻隐藏层的状态,也有时用ht表示,形状为[batch_size, hidden_size],St=f(U·Xt+W·St-1),通过W和U矩阵的映射,将embedding后的Xt和上一状态St-1转为St
  • Ot: t时刻的输出,Ot=g(V·St),形状为[batch_size, hidden_size],总的为输出O为[seq_len, batch_size, hidden_size]

Pytorch中的使用

Pytorch中RNN函数如下

RNN的主要参数如下

nn.RNN(input_size, hidden_size, num_layers=1, bias=True)

参数解释

  • input_size: 输入特征的维度,一般rnn中输入的是词向量,那么就为embedding-dim
  • hidden_size: 隐藏层神经元的个数,或者也叫输出的维度
  • num_layers: 隐藏层的个数,默认为1

output=输出O, 隐藏状态St,其中输出O=[time_step, batch_size, hidden_size],St为t时刻的隐藏层状态

理解RNN中的batch_size和seq_len

深度学习中采用mini-batch的方法进行迭代优化,在CNN中batch的思想较容易理解,一次输入batch个图片,进行迭代。但是RNN中引入了seq_len(time_step), 理解较为困难,下面是我自己的一些理解。

首先假如我有五句话,作为训练的语料。

sentences = ["i like dog", "i love coffee", "i hate milk", "i like music", "i hate you"]

那么在输入RNN之前要先进行embedding,比如one-hot encoding,容易得到这里的embedding-dim为9.

那么输入的sentences可以表示为如下方式

t=0t=1t=2
batch1ilikedog
batch2ilovecoffee
batch3ihatemilk
batch4ilikemusic
batch5ihateyou

那么在RNN的训练中。

  • t=0时, 输入第一个batch[i, i, i, i, i]这里用字符表示,其实应该是对应的one-hot编码。
  • t=1时,输入第二个batch[like, love, hate, like, hate]
  • t=2时,输入第三个batch[dog, coffee, milk, music, you]

那么对应的时间t来说,RNN需要对先后输入的batch_size个字符进行前向计算迭代,得到输出。

Pytorch双向RNN隐藏层和输出层结果拆分

1 RNN隐藏层和输出层结果的形状

从Pytorch官方文档可以得到,对于批量化输入的RNN来讲,其隐藏层的shape为(num_directions*num_layers, batch_size, hidden_size)。

其输出的shape为(seq_len, batch_size, D*hidden_size)。

2 双向RNN情况下,隐藏层和输出层结果拆分

当采用双向RNN时,其输出的结果包含正向和反向两个方向输出的结果。

2.1 输出层结果拆分

其中对于输出output来讲,从官方文档我们可以得到,其拆分正向和反向两个方向结果的方法为:

output.shape = (seq_len, batch_size, num_directions*hidden_size)

output.view(seq_len, batch, num_directions, hidden_size)

其中,对于(num_directions)方向维度,正向和反向的维度值分别为​​0​​​和​​1​。

2.2 隐藏层结果拆分

而对于隐藏层,包括初始值h_0以及最终输出h_n,也都包含两个方向的隐藏状态,但是其拆分方式跟输出层不一样。

方法如下:

h_0, h_n.shape = (num_directions*num_layers, batch_size, hidden_size)

h_0, h_n.view(num_layers, num_directions, batch_size, hidden_size)

可以从简单单层双向RNN的输出结果来验证,此时RNN的输出结果与最后一层的隐藏层结果是一样的。

import torch
import torch.nn as nn
if __name__ == "__main__":
    # input_size: 3, hidden_size: 5, num_layers: 3
    BiRNN_Net = nn.RNN(3, 5, 3, bidirectional=True, batch_first=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # batch_size: 1, seq_len: 1, input_size: 3
    inputs = torch.zeros(1, 1, 3, device=device)
    # state: (num_directions*num_layers, batch_size, hidden_size)
    state = torch.randn(6, 1, 5, device=device)
    BiRNN_Net.to(device)
    output, hidden = BiRNN_Net(inputs, state)
    output_re = output.reshape((1, 1, 2, 5))
    hidden_re = hidden.reshape((3, 2, 1, 5))
    print(output)
    print(output_re)
    print(hidden)
    print(hidden_re)

输出结果可以看出,隐藏层的结果是优先num_layers网络层数这一个维度来构成的。

tensor([[[ 0.3939, -0.9160,  0.5054,  0.2949, -0.5225,  0.0533,  0.4197,
          -0.7200, -0.1262, -0.7975]]], device='cuda:0',
       grad_fn=<CudnnRnnBackward0>)
tensor([[[[ 0.3939, -0.9160,  0.5054,  0.2949, -0.5225],
          [ 0.0533,  0.4197, -0.7200, -0.1262, -0.7975]]]], device='cuda:0',
       grad_fn=<ReshapeAliasBackward0>)
tensor([[[-0.2606,  0.5410, -0.2663,  0.6418, -0.2902]],
        [[ 0.1367,  0.7222, -0.3051, -0.6410, -0.3062]],
        [[ 0.2433,  0.3287, -0.4809, -0.1782, -0.5582]],
        [[ 0.4824, -0.8529,  0.7604,  0.8508, -0.1902]],
        [[ 0.3939, -0.9160,  0.5054,  0.2949, -0.5225]],
        [[ 0.0533,  0.4197, -0.7200, -0.1262, -0.7975]]], device='cuda:0',
       grad_fn=<CudnnRnnBackward0>)
tensor([[[[-0.2606,  0.5410, -0.2663,  0.6418, -0.2902]],
         [[ 0.1367,  0.7222, -0.3051, -0.6410, -0.3062]]],
        [[[ 0.2433,  0.3287, -0.4809, -0.1782, -0.5582]],
         [[ 0.4824, -0.8529,  0.7604,  0.8508, -0.1902]]],
        [[[ 0.3939, -0.9160,  0.5054,  0.2949, -0.5225]],
         [[ 0.0533,  0.4197, -0.7200, -0.1262, -0.7975]]]], device='cuda:0',
       grad_fn=<ReshapeAliasBackward0>)

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python实现远程调用MetaSploit的方法

    Python实现远程调用MetaSploit的方法

    这篇文章主要介绍了Python实现远程调用MetaSploit的方法,是很有借鉴价值的一个技巧,需要的朋友可以参考下
    2014-08-08
  • 浅谈Python3中strip()、lstrip()、rstrip()用法详解

    浅谈Python3中strip()、lstrip()、rstrip()用法详解

    这篇文章主要介绍了浅谈Python3中strip()、lstrip()、rstrip()用法详解,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-04-04
  • 利用OpenCV给彩色图像添加椒盐噪声的方法

    利用OpenCV给彩色图像添加椒盐噪声的方法

    椒盐噪声是数字图像中的常见噪声,一般是图像传感器、传输信道及解码处理等产生的黑白相间的亮暗点噪声,椒盐噪声常由图像切割产生,这篇文章主要给大家介绍了关于利用OpenCV给彩色图像添加椒盐噪声的相关资料,需要的朋友可以参考下
    2021-10-10
  • Python中的index()方法使用教程

    Python中的index()方法使用教程

    这篇文章主要介绍了Python中的index()方法使用教程,是Python入门学习中的基础知识,需要的朋友可以参考下
    2015-05-05
  • python绘制柱形图的方法

    python绘制柱形图的方法

    这篇文章主要为大家详细介绍了python绘制柱形图的方法,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-04-04
  • Python调用DeepSeek API的案例详细教程

    Python调用DeepSeek API的案例详细教程

    这篇文章主要为大家详细介绍了以 Python 为例的调用 DeepSeek API 的小白入门级详细教程,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下
    2025-02-02
  • python爬虫如何解决图片验证码

    python爬虫如何解决图片验证码

    这篇文章主要介绍了python爬虫如何解决图片验证码,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2021-02-02
  • Python 数据筛选功能实现

    Python 数据筛选功能实现

    这篇文章主要介绍了Python 数据筛选,无论是在数据分析还是数据挖掘的时候,数据筛选总会涉及到,这里我总结了一下python中列表,字典,数据框中一些常用的数据筛选的方法,需要的朋友可以参考下
    2023-04-04
  • python 五子棋如何获得鼠标点击坐标

    python 五子棋如何获得鼠标点击坐标

    这篇文章主要介绍了python 五子棋如何获得鼠标点击坐标,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-11-11
  • python实现扫描日志关键字的示例

    python实现扫描日志关键字的示例

    下面小编就为大家分享一篇python实现扫描日志关键字的示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04

最新评论