pytorch中使用LSTM详解

 更新时间:2022年07月27日 09:01:59   作者:qyhyzard  
这篇文章主要介绍了pytorch中使用LSTM,可以在troch.nn模块中找到LSTM类,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的朋友可以参考一下

LSMT层

可以在troch.nn模块中找到LSTM类

lstm = torch.nn.LSTM(*paramsters)

1、__init__方法

首先对nn.LSTM类进行实例化,需要传入的参数如下图所示:

一般我们关注这4个:

  • input_size表示输入的每个token的维度,也可以理解为一个word的embedding的维度。
  • hidden_size表示隐藏层也就是记忆单元C的维度,也可以理解为要将一个word的embedding维度转变成另一个大小的维度。除了C,在LSTM中输出的H的维度与C的维度是一致的。
  • num_layers表示有多少层LSTM,加深网络的深度,这个参数对LSTM的输出的维度是有影响的(后文会提到)。
  • bidirectional表示是否需要双向LSTM,这个参数也会对后面的输出有影响。

2、forward方法的输入

将数据input传入forward方法进行前向传播时有3个参数可以输入,见下图:

  • 这里要注意的是input参数各个维度的意义,一般来说如果不在实例化时制定batch_first=True,那么input的第一个维度是输入句子的长度seq_len,第二个维度是批量的大小,第三个维度是输入句子的embedding维度也就是input_size,这个参数要与__init__方法中的第一个参数对应。
  • 另外记忆细胞中的两个参数h_0c_0可以选择自己初始化传入也可以不传,系统默认是都初始化为0。传入的话注意维度[bidirectional * num_layers, batch_size, hidden_size]。

3、forward方法的输出

forward方法的输出如下图所示:

一般采用如下形式:

out,(h_n, c_n) = lstm(x)

out表示在最后一层上,每一个时间步的输出,也就是句子有多长,这个out的输出就有多长;其维度为[seq_len, batch_size, hidden_size * bidirectional]。因为如果的双向LSTM,最后一层的输出会把正向的和反向的进行拼接,故需要hidden_size * bidirectional。h_n表示的是每一层(双向算两层)在最后一个时间步上的输出;其维度为[bidirectional * num_layers, batch_size, hidden_size]
假设是双向的LSTM,且是3层LSTM,双向每个方向算一层,两个方向的组合起来叫一层LSTM,故共会有6层(3个正向,3个反向)。所以h_n是每层的输出,bidirectional * num_layers = 6。c_n表示的是每一层(双向算两层)在最后一个时间步上的记忆单元,意义不同,但是其余均与 h_n一样。

LSTMCell

可以在troch.nn模块中找到LSTMCell类

lstm = torch.nn.LSTMCell(*paramsters)

它的__init__方法的参数设置与LSTM类似,但是没有num_layers参数,因为这就是一个细胞单元,谈不上多少层和是否双向。
forward输入和输出与LSTM均有所不同:

其相比LSTM,输入没有了时间步的概念,因为只有一个Cell单元;输出 也没有out参数,因为就一个Cell,out就是h_1h_1c_1也因为只有一个Cell单元,其没有层数上的意义,故只是一个Cell的输出的维度[batch_size, hidden_size].

代码演示如下:

rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
hx = torch.randn(3, 20) # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []
# 从输入的第一个维度也就是seq_len上遍历,每循环一次,输入一个单词
for i in range(input.size()[0]):
		# 更新细胞记忆单元
        hx, cx = rnn(input[i], (hx, cx))
        # 将每个word作为输入的输出存起来,相当于LSTM中的out
        output.append(hx)
output = torch.stack(output, dim=0)

到此这篇关于pytorch中使用LSTM详细解说的文章就介绍到这了,更多相关pytorch使用LSTM内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python序列类型种类详解

    python序列类型种类详解

    这篇文章主要介绍了python序列类型种类详解,需要的朋友们可以学习参考下。
    2020-02-02
  • Python 代码实现各种酷炫功能

    Python 代码实现各种酷炫功能

    这篇文章主要介绍了Python 代码实现各种酷炫功能,生成二维码、生成词云、批量抠图、文字情绪识别等功能分享,需要的小伙伴可以参考一下
    2022-03-03
  • 基于Python制作一个多进制转换工具

    基于Python制作一个多进制转换工具

    这篇文章主要介绍了如何利用Python制作一个多进制转换工具,可以实现2进制、4进制、8进制、10进制、16进制、32进制直接的互转,需要的可以参考一下
    2022-02-02
  • python爬虫爬取指定内容的解决方法

    python爬虫爬取指定内容的解决方法

    这篇文章主要介绍了python爬虫爬取指定内容,爬取一些网站下指定的内容,一般来说可以用xpath来直接从网页上来获取,但是当我们获取的内容不唯一的时候我们无法选择,我们所需要的、所指定的内容,需要的朋友可以参考下
    2022-06-06
  • 爬虫逆向抖音新版signature分析案例

    爬虫逆向抖音新版signature分析案例

    这篇文章主要为大家介绍了爬虫逆向抖音新版signature分析的案例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-02-02
  • python爬虫请求库httpx和parsel解析库的使用测评

    python爬虫请求库httpx和parsel解析库的使用测评

    这篇文章主要介绍了python爬虫请求库httpx和parsel解析库的使用测评,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-05-05
  • 对python 操作solr索引数据的实例详解

    对python 操作solr索引数据的实例详解

    今天小编就为大家分享一篇对python 操作solr索引数据的实例详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • Python使用Beautiful Soup包编写爬虫时的一些关键点

    Python使用Beautiful Soup包编写爬虫时的一些关键点

    这篇文章主要介绍了Python使用Beautiful Soup包编写爬虫时的一些关键点,文中讲到了parent属性的使用以及soup的编码问题,需要的朋友可以参考下
    2016-01-01
  • Python爬虫必备之Xpath简介及实例讲解

    Python爬虫必备之Xpath简介及实例讲解

    xpath是一种在XML文档中定位元素的语言,常用于xml、html文件解析,比css选择器使用方便,下面这篇文章主要给大家介绍了关于Python爬虫必备之Xpath简介及实例的相关资料,需要的朋友可以参考下
    2022-04-04
  • Python实现操作Redis的高级用法分享

    Python实现操作Redis的高级用法分享

    redis-py是Python操作Redis的第三方库,它提供了与Redis服务器交互的API,本文为大家介绍了Python利用redis-py操作Redis的高级用法,需要的可以收藏一下
    2023-05-05

最新评论