Python使用pytorch动手实现LSTM模块

 更新时间:2022年07月27日 08:53:21   作者:qyhyzard  
这篇文章主要介绍了Python使用pytorch动手实现LSTM模块,LSTM是RNN中一个较为流行的网络模块。主要包括输入,输入门,输出门,遗忘门,激活函数,全连接层(Cell)和输出

LSTM 简介:

LSTM是RNN中一个较为流行的网络模块。主要包括输入,输入门,输出门,遗忘门,激活函数,全连接层(Cell)和输出。

其结构如下:

上述公式不做解释,我们只要大概记得以下几个点就可以了:

  • 当前时刻LSTM模块的输入有来自当前时刻的输入值,上一时刻的输出值,输入值和隐含层输出值,就是一共有四个输入值,这意味着一个LSTM模块的输入量是原来普通全连接层的四倍左右,计算量多了许多。
  • 所谓的门就是前一时刻的计算值输入到sigmoid激活函数得到一个概率值,这个概率值决定了当前输入的强弱程度。 这个概率值和当前输入进行矩阵乘法得到经过门控处理后的实际值。
  • 门控的激活函数都是sigmoid,范围在(0,1),而输出输出单元的激活函数都是tanh,范围在(-1,1)。

Pytorch实现如下:

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import math
class NaiveLSTM(nn.Module):
    """Naive LSTM like nn.LSTM"""
    def __init__(self, input_size: int, hidden_size: int):
        super(NaiveLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # input gate
        self.w_ii = Parameter(Tensor(hidden_size, input_size))
        self.w_hi = Parameter(Tensor(hidden_size, hidden_size))
        self.b_ii = Parameter(Tensor(hidden_size, 1))
        self.b_hi = Parameter(Tensor(hidden_size, 1))

        # forget gate
        self.w_if = Parameter(Tensor(hidden_size, input_size))
        self.w_hf = Parameter(Tensor(hidden_size, hidden_size))
        self.b_if = Parameter(Tensor(hidden_size, 1))
        self.b_hf = Parameter(Tensor(hidden_size, 1))

        # output gate
        self.w_io = Parameter(Tensor(hidden_size, input_size))
        self.w_ho = Parameter(Tensor(hidden_size, hidden_size))
        self.b_io = Parameter(Tensor(hidden_size, 1))
        self.b_ho = Parameter(Tensor(hidden_size, 1))

        # cell
        self.w_ig = Parameter(Tensor(hidden_size, input_size))
        self.w_hg = Parameter(Tensor(hidden_size, hidden_size))
        self.b_ig = Parameter(Tensor(hidden_size, 1))
        self.b_hg = Parameter(Tensor(hidden_size, 1))

        self.reset_weigths()

    def reset_weigths(self):
        """reset weights
        """
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            init.uniform_(weight, -stdv, stdv)

    def forward(self, inputs: Tensor, state: Tuple[Tensor]) \
        -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        """Forward
        Args:
            inputs: [1, 1, input_size]
            state: ([1, 1, hidden_size], [1, 1, hidden_size])
        """
#         seq_size, batch_size, _ = inputs.size()

        if state is None:
            h_t = torch.zeros(1, self.hidden_size).t()
            c_t = torch.zeros(1, self.hidden_size).t()
        else:
            (h, c) = state
            h_t = h.squeeze(0).t()
            c_t = c.squeeze(0).t()

        hidden_seq = []

        seq_size = 1
        for t in range(seq_size):
            x = inputs[:, t, :].t()
            # input gate
            i = torch.sigmoid(self.w_ii @ x + self.b_ii + self.w_hi @ h_t +
                              self.b_hi)
            # forget gate
            f = torch.sigmoid(self.w_if @ x + self.b_if + self.w_hf @ h_t +
                              self.b_hf)
            # cell
            g = torch.tanh(self.w_ig @ x + self.b_ig + self.w_hg @ h_t
                           + self.b_hg)
            # output gate
            o = torch.sigmoid(self.w_io @ x + self.b_io + self.w_ho @ h_t +
                              self.b_ho)

            c_next = f * c_t + i * g
            h_next = o * torch.tanh(c_next)
            c_next_t = c_next.t().unsqueeze(0)
            h_next_t = h_next.t().unsqueeze(0)
            hidden_seq.append(h_next_t)

        hidden_seq = torch.cat(hidden_seq, dim=0)
        return hidden_seq, (h_next_t, c_next_t)

def reset_weigths(model):
    """reset weights
    """
    for weight in model.parameters():
        init.constant_(weight, 0.5)
### test 
inputs = torch.ones(1, 1, 10)
h0 = torch.ones(1, 1, 20)
c0 = torch.ones(1, 1, 20)
print(h0.shape, h0)
print(c0.shape, c0)
print(inputs.shape, inputs)
# test naive_lstm with input_size=10, hidden_size=20
naive_lstm = NaiveLSTM(10, 20)
reset_weigths(naive_lstm)
output1, (hn1, cn1) = naive_lstm(inputs, (h0, c0))
print(hn1.shape, cn1.shape, output1.shape)
print(hn1)
print(cn1)
print(output1)

对比官方实现:

# Use official lstm with input_size=10, hidden_size=20
lstm = nn.LSTM(10, 20)
reset_weigths(lstm)
output2, (hn2, cn2) = lstm(inputs, (h0, c0))
print(hn2.shape, cn2.shape, output2.shape)
print(hn2)
print(cn2)
print(output2)

可以看到与官方的实现有些许的不同,但是输出的结果仍旧一致。

到此这篇关于Python使用pytorch动手实现LSTM模块的文章就介绍到这了,更多相关Python实现LSTM模块内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python 网络爬虫--关于简单的模拟登录实例讲解

    Python 网络爬虫--关于简单的模拟登录实例讲解

    今天小编就为大家分享一篇Python 网络爬虫--关于简单的模拟登录实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • python之Django自动化资产扫描的实现

    python之Django自动化资产扫描的实现

    这篇文章主要介绍了python之Django自动化资产扫描的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04
  • 基于pytorch实现对图片进行数据增强

    基于pytorch实现对图片进行数据增强

    图像数据增强是一种在训练机器学习和深度学习模型时常用的策略,尤其是在计算机视觉领域,具体而言,它通过创建和原始图像稍有不同的新图像来扩大训练集,本文给大家介绍了如何基于pytorch实现对图片进行数据增强,需要的朋友可以参考下
    2024-01-01
  • Python中的条件判断语句与循环语句用法小结

    Python中的条件判断语句与循环语句用法小结

    这篇文章主要介绍了Python中的条件判断语句与循环语句用法小结,条件语句和循环语句是Python程序流程控制的基础,需要的朋友可以参考下
    2016-03-03
  • Python中for循环语句实战案例

    Python中for循环语句实战案例

    这篇文章主要给大家介绍了关于Python中for循环语句的相关资料,python中for循环一般用来迭代字符串,列表,元组等,当for循环用于迭代时不需要考虑循环次数,循环次数由后面的对象长度来决定,需要的朋友可以参考下
    2023-09-09
  • python开发之anaconda以及win7下安装gensim的方法

    python开发之anaconda以及win7下安装gensim的方法

    这篇文章主要介绍了python开发之anaconda以及win7下安装gensim的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • Python使用matplotlib 模块scatter方法画散点图示例

    Python使用matplotlib 模块scatter方法画散点图示例

    这篇文章主要介绍了Python使用matplotlib 模块scatter方法画散点图,结合实例形式分析了Python数值运算与matplotlib模块图形绘制相关操作技巧,需要的朋友可以参考下
    2019-09-09
  • Python绘图之turtle库的基础语法使用

    Python绘图之turtle库的基础语法使用

    这篇文章主要给大家介绍了关于Python绘图之turtle库的基础语法使用的相关资料, Turtle库是Python语言中一个很流行的绘制图像的函数库,再绘图的时候经常需要用到的一个库需要的朋友可以参考下
    2021-06-06
  • 通过cmd进入python的步骤

    通过cmd进入python的步骤

    在本篇文章里小编给大家整理了关于通过cmd进入python的步骤和实例,需要的朋友们可以参考下。
    2020-06-06
  • 简单了解Python中的几种函数

    简单了解Python中的几种函数

    这篇文章主要介绍了简单了解Python中的几种函数,具有一定参考价值。需要的朋友可以了解下。
    2017-11-11

最新评论