Python纯代码通过神经网络实现线性回归的拟合方式

 更新时间:2023年05月31日 10:51:22   作者:Zhao-Jichao  
这篇文章主要介绍了Python纯代码通过神经网络实现线性回归的拟合方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

纯代码通过神经网络实现线性回归的拟合

参考链接中的文章,有错误,我给更正了。

并且原文中是需要数据集文件的,我直接给替换成了一个数组,采用直接赋值的方式。

# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
class SimpleDataReader(object):
    def __init__(self, data_file):
        self.train_file_name = data_file
        self.num_train = 0
        self.XTrain = None
        self.YTrain = None
    # read data from file
    def ReadData(self):
        # data = np.load(self.train_file_name)
        # self.XTrain = data["data"]
        # self.YTrain = data["label"]
        self.XTrain = np.array([0.95, 3, 4, 5.07, 6.03, 8.21, 8.85, 12.02, 15], dtype=float)
        self.YTrain = np.array([5.1, 8.7, 11.5, 13, 15.3, 18, 21, 26.87, 32.5], dtype=float)
        self.num_train = self.XTrain.shape[0]
        #end if
    # get batch training data
    def GetSingleTrainSample(self, iteration):
        x = self.XTrain[iteration]
        y = self.YTrain[iteration]
        return x, y
    def GetWholeTrainSamples(self):
        return self.XTrain, self.YTrain
class NeuralNet(object):
    def __init__(self, eta):
        self.eta = eta
        self.w = 0
        self.b = 0
    def __forward(self, x):
        z = x * self.w + self.b
        return z
    def __backward(self, x,y,z):
        dz = z - y                  # 原错误为:dz = x * (z - y)
        db = dz
        dw = dz
        return dw, db
    def __update(self, dw, db):
        self.w = self.w - self.eta * dw
        self.b = self.b - self.eta * db
    def train(self, dataReader):
        for i in range(dataReader.num_train):
            # get x and y value for one sample
            x,y = dataReader.GetSingleTrainSample(i)
            # get z from x,y
            z = self.__forward(x)
            # calculate gradient of w and b
            dw, db = self.__backward(x, y, z)
            # update w,b
            self.__update(dw, db)
            # end for
    def inference(self, x):
        return self.__forward(x)
if __name__ == '__main__':
    # read data
    sdr = SimpleDataReader('ch04.npz')
    sdr.ReadData()
    # create net
    eta = 0.1
    net = NeuralNet(eta)
    net.train(sdr)
    # result
    print("w=%f,b=%f" %(net.w, net.b))
    # 绘图部分
    trainX,trainY = sdr.GetWholeTrainSamples()
    fig = plt.figure()
    ax = fig.add_subplot(111)
    # 绘制散点图
    ax.scatter(trainX,trainY)
    # 绘制线性回归
    x = np.arange(0, 15, 0.01)
    f = np.vectorize(net.inference, excluded=['x'])
    plt.plot(x,f(x),color='red')
    # 显示图表
    plt.show()

在这里插入图片描述

Ref:

通过神经网络实现线性回归的拟合

Python使用线性回归和神经网络模型进行预测

公路运量主要包括公路客运量和公路货运量两个方面。

根据研究,某地区的公路运量主要与该地区的人数、机动车数量和公路面积有关,表5-11给出了某个地区20年的公路运量相关数据。

根据相关部门数据,该地区2010年和2011年的人数分别为73.39万和75.55万,机动车数量分别为3.9635万辆和4.0975万辆,公路面积分别为0.9880万平方千米和1.0268万平方千米。

请利用BP神经网络预测该地区2010年和2011年的公路客运量和公路货运量。

表5-11 运力数据表

年份 人数 机动车数量 公路面积 公里客运量 公里货运量
1990 20.55 0.6 0.09 5126 1237
1991 22.44 0.75 0.11 6217 1379
1992 25.37 0.85 0.11 7730 1385
1993 27.13 0.9 0.14 9145 1399
1994 29.45 1.05 0.2 10460 1663
1995 30.1 1.35 0.23 11387 1714
1996 30.96 1.45 0.23 12353 1834
1997 34.06 1.6 0.32 15750 4322
1998 36.42 1.7 0.32 18304 8132
1999 38.09 1.85 0.34 19836 8936
2000 39.13 2.15 0.36 21024 11099
2001 39.99 2.2 0.36 19490 11203
2002 41.93 2.25 0.38 20433 10524
2003 44.59 2.35 0.49 22598 11115
2004 47.3 2.5 0.56 25107 13320
2005 52.89 2.6 0.59 33442 16762
2006 55.73 2.7 0.59 36836 18673
2007 56.76 2.85 0.67 40548 20724
2008 59.17 2.95 0.69 42927 20803
2009 60.63 3.1 0.79 43462 21804

注:数据取自《Matlab在数学建模中的应用(第2版)》,卓金武,第134页。

在这里插入图片描述

#1.数据获取
import pandas as pd
data = pd.read_excel('运力数据表.xlsx')
x = data.iloc[:20,:4]
y = data.iloc[:20,4:]
#2.导入线性回归模块,简称为LR
from sklearn.linear_model import LinearRegression as LR
lr = LR()             #创建线性回归模型类
lr.fit(x,y)         #拟合
slr=lr.score(x,y)   #判定系数 R^2
c_x=lr.coef_         #x对应的回归系数
c_b=lr.intercept_   #回归系数常数项
#3.预测
x1 = data.iloc[20:,:4]
r1=lr.predict(x1)    #采用自带函数预测
#print('x回归系数为:',c_x)
#print('回归系数常数项为:',c_b)
#print('判定系数为:',slr)
#print('样本预测值为:',r1)
n=list(data["公里客运量(万人)"])
n.extend(r1[:,0])
num=pd.DataFrame(n).dropna()
g=list(data["公里货运量(万吨)"])
g.extend(r1[:,1])
gravity=pd.DataFrame(g).dropna()
import pandas as pd
import matplotlib.pyplot as plt  #导入绘图库中的pyplot模块,并且简称为plt
#构造绘图所需的横轴数据列和纵轴数据列
#在figure界面上绘制线性图
plt.rcParams['font.sans-serif'] = 'SimHei'     #设置字体为SimHei
plt.figure(1)
plt.plot(data["年份"],num,'r*--')  #红色“*”号连续图,
plt.xlabel('日期')
plt.ylabel('公里客运量(万人')
plt.title('公里客运量(万人)走势图')
plt.xticks(data["年份"],rotation = 45)
plt.savefig('myfigure1')
plt.figure(2)
plt.plot(data["年份"],gravity,'b*--')  #红色“*”号连续图,
plt.xlabel('日期')
plt.ylabel('公里货运量(万吨)')
plt.title('公里货运量(万吨)走势图')
plt.xticks(data["年份"],rotation = 45)
plt.savefig('myfigure2')
from sklearn.neural_network import MLPRegressor 
clf = MLPRegressor(solver='lbfgs', alpha=1e-5,hidden_layer_sizes=8, random_state=1) 
clf.fit(x, y);   
rv=clf.score(x,y)
r2=clf.predict(x1)   
print('样本预测值为:',r2)
n2=list(data["公里客运量(万人)"])
n2.extend(r2[:,0])
num2=pd.DataFrame(n2).dropna()
g2=list(data["公里货运量(万吨)"])
g2.extend(r2[:,1])
gravity2=pd.DataFrame(g2).dropna()

结果显示:

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中Selenium模块的使用详解

    Python中Selenium模块的使用详解

    这篇文章主要介绍了Python中Selenium模块的使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-10-10
  • Python入门教程之运算符重载详解

    Python入门教程之运算符重载详解

    运算符重载意味着赋予超出其预定义的操作含义的扩展含义。例如运算符 + 用于添加两个整数以及连接两个字符串和合并两个列表。本文将通过示例带大家详细了解Python的运算符重载,感兴趣的可以了解一下
    2022-09-09
  • Python lxml模块安装教程

    Python lxml模块安装教程

    这篇文章主要介绍了Python lxml模块安装教程,本文分别讲解了Windows系统和Linux系统下的安装教程,需要的朋友可以参考下
    2015-06-06
  • Python的numpy选择特定行列的方法

    Python的numpy选择特定行列的方法

    这篇文章主要介绍了Python的numpy选择特定行列的方法,有时需要抽取矩阵中特定行的特定列,比如,需要抽取矩阵x的0,1行的0,3列,结果为矩阵域,需要的朋友可以参考下
    2023-08-08
  • Python使用psutil获取系统信息

    Python使用psutil获取系统信息

    这篇文章介绍了Python使用psutil获取系统信息的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-05-05
  • PyQt中实现自定义工具提示ToolTip的方法详解

    PyQt中实现自定义工具提示ToolTip的方法详解

    这篇文章主要为大家详细介绍了PyQt中实现自定义工具提示ToolTip的方法详解,文中的示例代码讲解详细,对我们学习有一定帮助,需要的可以参考一下
    2022-05-05
  • Python身份运算符is与is not区别用法基础教程

    Python身份运算符is与is not区别用法基础教程

    这篇文章主要为大家介绍了Python身份运算符is与is not区别用法基础教程详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-06-06
  • 对Python中GIL(全局解释器锁)的一点理解浅析

    对Python中GIL(全局解释器锁)的一点理解浅析

    首先需要明确的一点是GIL并不是Python的特性,它是在实现Python解析器(CPython)时所引入的一个概念,下面这篇文章主要给大家介绍了关于对Python中GIL的一点理解,文中通过示例代码介绍的非常详细,需要的朋友可以参考下
    2022-05-05
  • Python 中PyQt5 点击主窗口弹出另一个窗口的实现方法

    Python 中PyQt5 点击主窗口弹出另一个窗口的实现方法

    这篇文章主要介绍了Python 中PyQt5 点击主窗口弹出另一个窗口的实现方法,本文代码实例图文相结合的形式给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-07-07
  • 基于Python实现复刻人生重开模拟器

    基于Python实现复刻人生重开模拟器

    人生重开模拟器是由VickScarlet上传至GitHub的一款简单的文字网页游戏。本文将用Python复刻一下这个游戏,感兴趣的小伙伴可以尝试一下
    2022-10-10

最新评论