Pytorch 如何实现LSTM时间序列预测

 更新时间:2021年05月17日 09:29:45   作者:CodeInHand  
本文主要基于Pytorch深度学习框架,实现LSTM神经网络模型,用于时间序列的预测

开发环境说明:

Python 35

Pytorch 0.2

CPU/GPU均可

1、LSTM简介

人类在进行学习时,往往不总是零开始,学习物理你会有数学基础、学习英语你会有中文基础等等。

于是对于机器而言,神经网络的学习亦可不再从零开始,于是出现了Transfer Learning,就是把一个领域已训练好的网络用于初始化另一个领域的任务,例如会下棋的神经网络可以用于打德州扑克。

我们这讲的是另一种不从零开始学习的神经网络——循环神经网络(Recurrent Neural Network, RNN),它的每一次迭代都是基于上一次的学习结果,不断循环以得到对于整体序列的学习,区别于传统的MLP神经网络,这种神经网络模型存在环型结构,

具体下所示:

上图是RNN的基本单元,通过不断循环迭代展开模型如下所示,图中ht是神经网络的在t时刻的输出,xt是t时刻的输入数据。

这种循环结构对时间序列数据能够很好地建模,例如语音识别、语言建模、机器翻译等领域。

但是普通的RNN对于长期依赖问题效果比较差,当序列本身比较长时,由于神经网络模型的训练是采用backward进行,在梯度链式法则中容易出现梯度消失和梯度爆炸的问题,需要进一步改进RNN的模型结构。

针对Simple RNN存在的问题,LSTM网络模型被提出,LSTM的核心是修改了增添了Cell State,即加入了LSTM CELL,通过输入门、输出门、遗忘门把上一时刻的hidden state和cell state传给下一个状态。

如下所示:

遗忘门:ft = sigma(Wf*[ht-1, xt] + bf)

输入门:it = sigma(Wi*[ht-1, xt] + bi)

cell state initial: C't = tanh(Wc*[ht-1, xt] +bc)

cell state: Ct = ft*Ct-1+ itC't

输出门:ot = sigma(Wo*[ht-1, xt] + bo)

模型输出:ht = ot*tanh(Ct)

LSTM有很多种变型结构,实际工程化过程中用的比较多的是peephole,就是计算每个门的时候增添了cell state的信息,有兴趣的童鞋可以专研专研。

上一部分简单地介绍了LSTM的模型结构,下边将具体介绍使用LSTM模型进行时间序列预测的具体过程。

2、数据准备

对于时间序列,本文选取正弦波序列,事先产生一定数量的序列数据,然后截取前部分作为训练数据训练LSTM模型,后部分作为真实值与模型预测结果进行比较。正弦波的产生过程如下:

SeriesGen(N)方法用于产生长度为N的正弦波数值序列;

trainDataGen(seq,k)用于产生训练或测试数据,返回数据结构为输入输出数据。seq为序列数据,k为LSTM模型循环的长度,使用1~k的数据预测2~k+1的数据。

3、模型构建

Pytorch的nn模块提供了LSTM方法,具体接口使用说明可以参见Pytorch的接口使用说明书。此处调用nn.LSTM构建LSTM神经网络,模型另增加了线性变化的全连接层Linear(),但并未加入激活函数。由于是单个数值的预测,这里input_size和output_size都为1.

4、训练和测试

(1)模型定义、损失函数定义

(2)训练与测试

(3)结果展示

比较模型预测序列结果与真实值之间的差距

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

相关文章

  • Python的Django框架中URLconf相关的一些技巧整理

    Python的Django框架中URLconf相关的一些技巧整理

    这篇文章主要介绍了Python的Django框架中URLconf相关的一些技巧整理,包括视图配置和debug的示例等,需要的朋友可以参考下
    2015-07-07
  • Python的pywifi无线网络库的具体使用

    Python的pywifi无线网络库的具体使用

    pywifi是一个基于Python的用于操作无线网络的库,本文就来介绍一下pywifi的安装及实际应用场景使用,具有一定的参考价值,感兴趣的可以了解一下
    2024-02-02
  • selenium2.0中常用的python函数汇总

    selenium2.0中常用的python函数汇总

    这篇文章主要介绍了selenium2.0中常用的python函数,总结分析了selenium2.0中常用的python函数的功能、原理与基本用法,需要的朋友可以参考下
    2019-08-08
  • Python类型注解必备利器typing模块全面解读

    Python类型注解必备利器typing模块全面解读

    在Python 3.5版本后引入的typing模块为Python的静态类型注解提供了支持,这个模块在增强代码可读性和维护性方面提供了帮助,本文将深入探讨typing模块,介绍其基本概念、常用类型注解以及使用示例,以帮助读者更全面地了解和应用静态类型注解
    2024-01-01
  • Python机器学习pytorch模型选择及欠拟合和过拟合详解

    Python机器学习pytorch模型选择及欠拟合和过拟合详解

    如何发现可以泛化的模式是机器学习的根本问题,将模型在训练数据上过拟合得比潜在分布中更接近的现象称为过拟合,用于对抗过拟合的技术称为正则化
    2021-10-10
  • Python 使用list和tuple+条件判断详解

    Python 使用list和tuple+条件判断详解

    这篇文章主要介绍了Python 使用list和tuple+条件判断详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • Python列表(List)知识点总结

    Python列表(List)知识点总结

    在本篇文章中小编给大家分享了关于Python列表(List)知识点一直对应的实例内容,需要的朋友们学习下。
    2019-02-02
  • 深入理解NumPy简明教程---数组3(组合)

    深入理解NumPy简明教程---数组3(组合)

    本篇文章对NumPy数组进行较深入的探讨。首先介绍自定义类型的数组,接着数组的组合,最后介绍数组复制方面的问题,有兴趣的可以了解一下。
    2016-12-12
  • 基于Tensorflow读取MNIST数据集时网络超时的解决方式

    基于Tensorflow读取MNIST数据集时网络超时的解决方式

    这篇文章主要介绍了基于Tensorflow读取MNIST数据集时网络超时的解决方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • 深入理解python虚拟机GIL详解

    深入理解python虚拟机GIL详解

    在目前的 CPython 当中一直有一个臭名昭著的问题就是 GIL (Global Interpreter Lock ),就是全局解释器锁,他限制了 Python 在多核架构当中的性能,在本篇文章当中我们将详细分析一下 GIL 的利弊和 GIL 的 C 的源代码
    2023-10-10

最新评论