python实现机器学习之多元线性回归

 更新时间:2021年04月20日 08:50:07   作者:婉如  
这篇文章主要为大家详细介绍了python实现机器学习之多元线性回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

总体思路与一元线性回归思想一样,现在将数据以矩阵形式进行运算,更加方便。
一元线性回归实现代码

下面是多元线性回归用Python实现的代码:

import numpy as np

def linearRegression(data_X,data_Y,learningRate,loopNum):
 W = np.zeros(shape=[1, data_X.shape[1]])
 # W的shape取决于特征个数,而x的行是样本个数,x的列是特征值个数
 # 所需要的W的形式为 行=特征个数,列=1 这样的矩阵。但也可以用1行,再进行转置:W.T
 # X.shape[0]取X的行数,X.shape[1]取X的列数
 b = 0

 #梯度下降
 for i in range(loopNum):
  W_derivative = np.zeros(shape=[1, data_X.shape[1]])
  b_derivative, cost = 0, 0

  WXPlusb = np.dot(data_X, W.T) + b # W.T:W的转置
  W_derivative += np.dot((WXPlusb - data_Y).T, data_X) # np.dot:矩阵乘法
  b_derivative += np.dot(np.ones(shape=[1, data_X.shape[0]]), WXPlusb - data_Y)
  cost += (WXPlusb - data_Y)*(WXPlusb - data_Y)
  W_derivative = W_derivative / data_X.shape[0] # data_X.shape[0]:data_X矩阵的行数,即样本个数
  b_derivative = b_derivative / data_X.shape[0]


  W = W - learningRate*W_derivative
  b = b - learningRate*b_derivative

  cost = cost/(2*data_X.shape[0])
  if i % 100 == 0:
   print(cost)
 print(W)
 print(b)

if __name__== "__main__":
 X = np.random.normal(0, 10, 100)
 noise = np.random.normal(0, 0.05, 20)
 W = np.array([[3, 5, 8, 2, 1]]) #设5个特征值
 X = X.reshape(20, 5)  #reshape成20行5列
 noise = noise.reshape(20, 1)
 Y = np.dot(X, W.T)+6 + noise
 linearRegression(X, Y, 0.003, 5000)

特别需要注意的是要弄清:矩阵的形状

在梯度下降的时候,计算两个偏导值,这里面的矩阵形状变化需要注意。

梯度下降数学式子:

以代码中为例,来分析一下梯度下降中的矩阵形状。
代码中设了5个特征。

WXPlusb = np.dot(data_X, W.T) + b

W是一个1*5矩阵,data_X是一个20*5矩阵
WXPlusb矩阵形状=20*5矩阵乘上5*1(W的转置)的矩阵=20*1矩阵

W_derivative += np.dot((WXPlusb - data_Y).T, data_X)

W偏导矩阵形状=1*20矩阵乘上 20*5矩阵=1*5矩阵

b_derivative += np.dot(np.ones(shape=[1, data_X.shape[0]]), WXPlusb - data_Y)

b是一个数,用1*20的全1矩阵乘上20*1矩阵=一个数

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • python config文件的读写操作示例

    python config文件的读写操作示例

    这篇文章主要介绍了python config文件的读写操作,结合简单示例形式分析了Python针对config文件的设置、读取、写入相关操作技巧,需要的朋友可以参考下
    2019-09-09
  • Python中的pandas表格模块、文件模块和数据库模块

    Python中的pandas表格模块、文件模块和数据库模块

    这篇文章介绍了Python中的pandas表格模块、文件模块和数据库模块,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-05-05
  • python中import和from-import的区别解析

    python中import和from-import的区别解析

    这篇文章主要介绍了python中import和from-import的区别解析,本文通过实例代码给大家讲解的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-12-12
  • 利用Chatgpt开发一款加减乘除计算器(Python代码实现)

    利用Chatgpt开发一款加减乘除计算器(Python代码实现)

    这篇文章主要为大家详细介绍了如何利用Chatgpt开发一款加减乘除计算器(用Python代码实现),文中的示例代码讲解详细,感兴趣的小伙伴可以了解一下
    2023-02-02
  • 关于python中range()的参数问题

    关于python中range()的参数问题

    这篇文章主要介绍了关于python中range()的参数问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05
  • Python OpenCV实现3种滤镜效果实例

    Python OpenCV实现3种滤镜效果实例

    opencv是一个很强大的库,支持多个编程语言,下面这篇文章主要给大家介绍了关于Python OpenCV实现3种滤镜效果的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下
    2022-04-04
  • Python操作MongoDB详解及实例

    Python操作MongoDB详解及实例

    这篇文章主要介绍了Python操作MongoDB详解及实例的相关资料,需要的朋友可以参考下
    2017-05-05
  • python3 使用OpenCV计算滑块拼图验证码缺口位置(场景示例)

    python3 使用OpenCV计算滑块拼图验证码缺口位置(场景示例)

    这篇文章主要介绍了python3 使用OpenCV计算滑块拼图验证码缺口位置,本文通过场景示例给大家详细介绍,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-08-08
  • python 实现

    python 实现"神经衰弱"翻牌游戏

    这篇文章主要介绍了python 实现"神经衰弱"游戏,帮助大家更好的理解和使用python的pygame库,感兴趣的朋友可以了解下
    2020-11-11
  • python绘制BA无标度网络示例代码

    python绘制BA无标度网络示例代码

    今天小编就为大家分享一篇python绘制BA无标度网络示例代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11

最新评论