图文详解梯度下降算法的原理及Python实现

 更新时间:2022年08月03日 09:01:10   作者:Mr.Winter`  
梯度下降是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以)。本文将通过图文详解梯度下降算法的原理及实现,需要的可以参考一下

1.引例

给定如图所示的某个函数,如何通过计算机算法编程求f(x)min?

2.数值解法

传统方法是数值解法,如图所示

按照以下步骤迭代循环直至最优:

① 任意给定一个初值x0

② 随机生成增量方向,结合步长生成Δx;

③ 计算比较f(x0)与f(x0+Δx)的大小,若f(x0+Δx)<f(x0)则更新位置,否则重新生成Δx;

④ 重复②③直至收敛到最优f(x)min。

数值解法最大的优点是编程简明,但缺陷也很明显:

① 初值的设定对结果收敛快慢影响很大;

② 增量方向随机生成,效率较低;

③ 容易陷入局部最优解;

④ 无法处理“高原”类型函数。

所谓陷入局部最优解是指当迭代进入到某个极小值或其邻域时,由于步长选择不恰当,无论正方向还是负方向,学习效果都不如当前,导致无法向全局最优迭代。就本问题而言如图所示,当迭代陷入x=xj时,由于学习步长step的限制,无法使f(xj±Step)<f(xj),因此迭代就被锁死在了图中的红色区段。可以看出x=xj并非期望的全局最优。

若出现下图所示的“高原”函数,也可能使迭代得不到更新。

3.梯度下降算法

梯度下降算法可视为数值解法的一种改进,阐述如下:

记第k轮迭代后,自变量更新为x=xk,令目标函数f(x)在x=xk泰勒展开:

f(x)=f(xk​)+f′(xk​)(x−xk​)+o(x)

考察f(x)min ,则期望f(xk+1)<f(xk),从而:

f(xk+1​)−f(xk​)=f′(xk​)(xk+1​−xk​)<0

若f′(xk)>0则xk+1<xk ,即迭代方向为负;反之为正。不妨设xk+1−xk=−f′(xk),从而保证f(xk+1)−f(xk)<0。必须指出,泰勒公式成立的条件是x→x0,故|f′(xk)|不能太大,否则xk+1与xk距离太远产生余项误差。因此引入学习率γ∈(0,1)来减小偏移度,即xk+1-xk=−γf′(xk​)

在工程上,学习率γ \gammaγ要结合实际应用合理选择,γ \gammaγ过大会使迭代在极小值两侧振荡,算法无法收敛;γ \gammaγ过小会使学习效率下降,算法收敛慢。

对于向量 ,将上述迭代公式推广为

xk+1​=xk​−γ∇xk​​

其中

为多元函数的梯度,故此迭代算法也称为梯度下降算法

梯度下降算法通过函数梯度确定了每一次迭代的方向和步长,提高了算法效率。但从原理上可以知道,此算法并不能解决数值解法中初值设定、局部最优陷落和部分函数锁死的问题。

4.代码实战:Logistic回归

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from Logit import Logit

'''
* @breif: 从CSV中加载指定数据
* @param[in]: file -> 文件名
* @param[in]: colName -> 要加载的列名
* @param[in]: mode -> 加载模式, set: 列名与该列数据组成的字典, df: df类型
* @retval: mode模式下的返回值
'''
def loadCsvData(file, colName, mode='df'):
    assert mode in ('set', 'df')
    df = pd.read_csv(file, encoding='utf-8-sig', usecols=colName)
    if mode == 'df':
        return df
    if mode == 'set':
        res = {}
        for col in colName:
            res[col] = df[col].values
        return res

if __name__ == '__main__':
    # ============================
    # 读取CSV数据
    # ============================
    csvPath = os.path.abspath(os.path.join(__file__, "../../data/dataset3.0alpha.csv"))
    dataX = loadCsvData(csvPath, ["含糖率", "密度"], 'df')
    dataY = loadCsvData(csvPath, ["好瓜"], 'df')
    label = np.array([
        1 if i == "是" else 0
        for i in list(map(lambda s: s.strip(), list(dataY['好瓜'])))
    ])

    # ============================
    # 绘制样本点
    # ============================
    line_x = np.array([np.min(dataX['密度']), np.max(dataX['密度'])])
    mpl.rcParams['font.sans-serif'] = [u'SimHei']
    plt.title('对数几率回归模拟\nLogistic Regression Simulation')
    plt.xlabel('density')
    plt.ylabel('sugarRate')
    plt.scatter(dataX['密度'][label==0],
                dataX['含糖率'][label==0],
                marker='^',
                color='k',
                s=100,
                label='坏瓜')
    plt.scatter(dataX['密度'][label==1],
                dataX['含糖率'][label==1],
                marker='^',
                color='r',
                s=100,
                label='好瓜')

    # ============================
    # 实例化对数几率回归模型
    # ============================
    logit = Logit(dataX, label)

    # 采用梯度下降法
    logit.logitRegression(logit.gradientDescent)
    line_y = -logit.w[0, 0] / logit.w[1, 0] * line_x - logit.w[2, 0] / logit.w[1, 0]
    plt.plot(line_x, line_y, 'b-', label="梯度下降法")

    # 绘图
    plt.legend(loc='upper left')
    plt.show()

到此这篇关于图文详解梯度下降算法的原理及Python实现的文章就介绍到这了,更多相关Python梯度下降算法内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python读取Ansible playbooks返回信息示例解析

    Python读取Ansible playbooks返回信息示例解析

    这篇文章主要为大家介绍了Python读取Ansible playbooks返回信息示例解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-12-12
  • python help函数实例用法

    python help函数实例用法

    在本篇文章里小编给大家整理了关于python help函数实例用法及相关实例,需要的朋友们可以学习下。
    2020-12-12
  • python实现WebSocket服务端过程解析

    python实现WebSocket服务端过程解析

    这篇文章主要介绍了python实现WebSocket服务端过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-10-10
  • pandas读取csv文件提示不存在的解决方法及原因分析

    pandas读取csv文件提示不存在的解决方法及原因分析

    这篇文章主要介绍了pandas读取csv文件提示不存在的解决方法及原因分析,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • Flask Paginate实现表格分页的使用示例

    Flask Paginate实现表格分页的使用示例

    flask_paginate是Flask框架的一个分页扩展,用于处理分页相关的功能,本文就来介绍一下Flask Paginate实现表格分页的使用示例,感兴趣的可以了解一下
    2023-11-11
  • Python实现公历(阳历)转农历(阴历)的方法示例

    Python实现公历(阳历)转农历(阴历)的方法示例

    这篇文章主要介绍了Python实现公历(阳历)转农历(阴历)的方法,涉及农历算法原理及Python日期运算相关操作技巧,需要的朋友可以参考下
    2017-08-08
  • python类型强制转换long to int的代码

    python类型强制转换long to int的代码

    python的int型最大值和系统有关,32位和64位系统结果是不同的,分别为2的31次方减1和2的63次方减1,可以通过sys.maxint查看此值
    2013-02-02
  • 简单的Python动态可视化神器,编程小白也能上手

    简单的Python动态可视化神器,编程小白也能上手

    这篇文章就来介绍简单的Python动态可视化神器,最近发现了一个宝藏动态可视化库,非常简单,即使是小白也能轻松上手。这个库就是motionchart,它能够用 pandas 的 dataframe 数据直接创建交互式的动态图表,下面来简单看一下如何使用。

    2021-10-10
  • python实现拉普拉斯特征图降维示例

    python实现拉普拉斯特征图降维示例

    今天小编就为大家分享一篇python实现拉普拉斯特征图降维示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • 2020最新pycharm汉化安装(python工程狮亲测有效)

    2020最新pycharm汉化安装(python工程狮亲测有效)

    这篇文章主要介绍了2020最新pycharm汉化安装(python工程狮亲测有效),文中通过图文介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-04-04

最新评论