Python实现梯度下降法的示例代码

 更新时间:2022年08月08日 10:46:32   作者:侯小啾  
梯度下降法的机器学习的重要思想之一,梯度下降法的目标,是使得代价函数最小。本文将对梯度下降算法的原理及实现展开详细介绍,感兴趣的快跟随小编一起学习学习吧

1.首先读取数据集

导包并读取数据,数据自行任意准备,只要有两列,可以分为自变量x和因变量y即可即可。

import numpy as np
import matplotlib.pyplot as plt

data = np.loadtxt("data.csv", delimiter=",")

x_data = data[:, 0]
y_data = data[:, 1]

2.初始化相关参数

# 初始化 学习率 即每次梯度下降时的步长 这里设置为0.0001
learning_rate = 0.0001

# 初始化 截距b 与 斜率k
b = 0
k = 0

# 初始化最大迭代的次数 以50次为例
n_iterables = 50

3.定义计算代价函数–>MSE

使用均方误差 MSE (Mean Square Error)来作为性能度量标准

假设共有m个样本数据,则均方误差:

将该公式定义为代价函数,此外为例后续求导方便,则使结果在原mse的基础上,再乘以1/2。

def compute_mse(b,  k,  x_data, y_data):
    total_error = 0
    for i in range(len(x_data)):
        total_error += (y_data[i] - (k * x_data[i] + b)) ** 2

    # 为方便求导:乘以1/2
    mse_ = total_error / len(x_data) / 2
    return mse_

4.梯度下降

分别对上述的MSE表达式(乘以1/2后)中的k,b求偏导,

更新b和k时,使用原来的b,k值分别减去关于b、k的偏导数与学习率的乘积即可。至于为什么使用减号,可以这么理解:以斜率k为例,当其导数大于零的时候,则表示均方误差随着斜率的增大而增大,为了使均方误差减小,则不应该使斜率继续增大,所以需要使其减小,反之当偏导大于零的时候也是同理。其次,因为这个导数衡量的是均方误差的变化,而不是斜率和截距的变化,所以这里需要引入一个学习率,使得其与偏导数的乘积能够在一定程度上起到控制截距和斜率变化的作用。

def gradient_descent(x_data, y_data, b,  k,  learning_rate,  n_iterables):
    m = len(x_data)
    # 迭代
    for i in range(n_iterables):
        # 初始化b、k的偏导
        b_grad = 0
        k_grad = 0

        # 遍历m次
        for j in range(m):
            # 对b,k求偏导
            b_grad += (1 / m) * ((k * x_data[j] + b) - y_data[j])
            k_grad += (1 / m) * ((k * x_data[j] + b) - y_data[j]) * x_data[j]

        # 更新 b 和 k  减去偏导乘以学习率
        b = b - (learning_rate * b_grad)
        k = k - (learning_rate * k_grad)
        # 每迭代 5 次  输出一次图形
        if i % 5 == 0:
            print(f"当前第{i}次迭代")
            print("b_gard:", b_grad, "k_gard:", k_grad)
            print("b:", b, "k:", k)
            plt.scatter(x_data, y_data, color="maroon", marker="x")
            plt.plot(x_data, k * x_data + b)
            plt.show()
    return b, k

5.执行

print(f"开始:截距b={b},斜率k={k},损失={compute_mse(b,k,x_data,y_data)}")
print("开始迭代")
b, k = gradient_descent(x_data, y_data, b, k, learning_rate, n_iterables)
print(f"迭代{n_iterables}次后:截距b={b},斜率k={k},损失={compute_mse(b,k,x_data,y_data)}")

代码执行过程产生了一系列的图像,部分图像如下图所示,随着迭代次数的增加,代价函数越来越小,最终达到预期效果,如下图所示:

第5次迭代:

第10次迭代:

第50次迭代:

执行过程的输出结果如下图所示:

可以看到,随着偏导数越来越小,斜率与截距的变化也越来越细微。

到此这篇关于Python实现梯度下降法的示例代码的文章就介绍到这了,更多相关Python梯度下降法内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 用python画个敬业福字代码

    用python画个敬业福字代码

    大家好,本篇文章主要讲的是用python画个敬业福字代码,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下
    2022-01-01
  • 用Anaconda安装本地python包的方法及路径问题(图文)

    用Anaconda安装本地python包的方法及路径问题(图文)

    这篇文章主要介绍了用Anaconda安装本地python包的方法及路径问题,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-07-07
  • Ubuntu20下的Django安装的方法步骤

    Ubuntu20下的Django安装的方法步骤

    这篇文章主要介绍了Ubuntu20下的Django安装的方法步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • 详解Python中import机制

    详解Python中import机制

    这篇文章主要介绍了Python中import机制的相关资料,帮助大家更好的理解和学习python,感兴趣的朋友可以了解下
    2020-09-09
  • 详解Python中DOM方法的动态性

    详解Python中DOM方法的动态性

    这篇文章主要介绍了详解Python中DOM方法的动态性,xml.dom模块在Python的网络编程中相当有用,本文来自于IBM官网的开发者技术文档,需要的朋友可以参考下
    2015-04-04
  • Python tkinter控件样式详解

    Python tkinter控件样式详解

    tkinter对控件的诸多属性提供了可定制的功能,下面以最常用的按钮作为示例,集中展示其样式特点,感兴趣的小伙伴可以跟随小编一起学习一下
    2023-09-09
  • flask重启后端口被占用的问题解决(非kill)

    flask重启后端口被占用的问题解决(非kill)

    本文主要介绍了flask重启后端口被占用的问题解决(非kill),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-04-04
  • python实现人脸检测的简单实例

    python实现人脸检测的简单实例

    这篇文章主要给大家介绍了关于python实现人脸检测的相关资料,OpenCV 可以使用机器学习算法搜索图像中的人脸,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-02-02
  • Python必考的5道面试题集合

    Python必考的5道面试题集合

    这篇文章介绍了Python必考的5道面试题,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-07-07
  • 用实例解释Python中的继承和多态的概念

    用实例解释Python中的继承和多态的概念

    这篇文章主要介绍了用实例解释Python中的继承和多态的概念,继承和多台是学习每一门面对对象的编程语言时都必须掌握的重要知识,需要的朋友可以参考下
    2015-04-04

最新评论