python的numpy模块实现逻辑回归模型

 更新时间:2022年07月30日 10:02:26   作者:上进的小菜鸟  
这篇文章主要为大家详细介绍了python的numpy模块实现逻辑回归模型,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

使用python的numpy模块实现逻辑回归模型的代码,供大家参考,具体内容如下

使用了numpy模块,pandas模块,matplotlib模块

1.初始化参数

def initial_para(nums_feature):
    """initial the weights and bias which is zero"""
    #nums_feature是输入数据的属性数目,因此权重w是[1, nums_feature]维
    #且w和b均初始化为0
    w = np.zeros((1, nums_feature))
    b = 0
    return w, b

2.逻辑回归方程

def activation(x, w , b):
    """a linear function and then sigmoid activation function: 
    x_ = w*x +b,y = 1/(1+exp(-x_))"""
    #线性方程,输入的x是[batch, 2]维,输出是[1, batch]维,batch是模型优化迭代一次输入数据的数目
    #[1, 2] * [2, batch] = [1, batch], 所以是w * x.T(x的转置)
    #np.dot是矩阵乘法
    x_ = np.dot(w, x.T) + b
    #np.exp是实现e的x次幂
    sigmoid = 1 / (1 + np.exp(-x_))
    return sigmoid

3.梯度下降

def gradient_descent_batch(x, w, b, label, learning_rate):
    #获取输入数据的数目,即batch大小
    n = len(label)
    #进行逻辑回归预测
    sigmoid = activation(x, w, b)
    #损失函数,np.sum是将矩阵求和
    cost = -np.sum(label.T * np.log(sigmoid) + (1-label).T * np.log(1-sigmoid)) / n
    #求对w和b的偏导(即梯度值)
    g_w = np.dot(x.T, (sigmoid - label.T).T) / n
    g_b = np.sum((sigmoid - label.T)) / n
    #根据梯度更新参数
    w = w - learning_rate * g_w.T
    b = b - learning_rate * g_b
    return w, b, cost

4.模型优化

def optimal_model_batch(x, label, nums_feature, step=10000, batch_size=1):
    """train the model with batch"""
    length = len(x)
    w, b = initial_para(nums_feature)
    for i in range(step):
        #随机获取一个batch数目的数据
        num = randint(0, length - 1 - batch_size)
        x_batch = x[num:(num+batch_size), :]
        label_batch = label[num:num+batch_size]
        #进行一次梯度更新(优化)
        w, b, cost = gradient_descent_batch(x_batch, w, b, label_batch, 0.0001)
        #每1000次打印一下损失值
        if i%1000 == 0:
            print('step is : ', i, ', cost is: ', cost)
    return w, b

5.读取数据,数据预处理,训练模型,评估精度

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from random import randint
from sklearn.preprocessing import StandardScaler
 
def _main():
    #读取csv格式的数据data_path是数据的路径
    data = pd.read_csv('data_path')
    #获取样本属性和标签
    x = data.iloc[:, 2:4].values
    y = data.iloc[:, 4].values
    #将数据集分为测试集和训练集
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=0)
    #数据预处理,去均值化
    standardscaler = StandardScaler()
    x_train = standardscaler.fit_transform(x_train)
    x_test = standardscaler.transform(x_test)
    #w, b = optimal_model(x_train, y_train, 2, 50000)
    #训练模型
    w, b = optimal_model_batch(x_train, y_train, 2, 50000, 64)
    print('trian is over')
    #对测试集进行预测,并计算精度
    predict = activation(x_test, w, b).T
    n = 0
    for i, p in enumerate(predict):
        if p >=0.5:
            if y_test[i] == 1:
                n += 1
        else:
            if y_test[i] == 0:
                n += 1
    print('accuracy is : ', n / len(y_test))

6.结果可视化

predict = np.reshape(np.int32(predict), [len(predict)])
    #将预测结果以散点图的形式可视化
    for i, j in enumerate(np.unique(predict)):
        plt.scatter(x_test[predict == j, 0], x_test[predict == j, 1], 
        c = ListedColormap(('red', 'blue'))(i), label=j)
    plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • python打印当前文件的绝对路径并解决打印为空的问题

    python打印当前文件的绝对路径并解决打印为空的问题

    这篇文章主要介绍了python打印当前文件的绝对路径并解决打印为空的问题,文中补充介绍了python中对文件路径的获取方法,需要的朋友可以参考下
    2023-03-03
  • Python实现普通图片转ico图标的方法详解

    Python实现普通图片转ico图标的方法详解

    ICO是一种图标文件格式,图标文件可以存储单个图案、多尺寸、多色板的图标文件。本文将利用Python实现普通图片转ico图标,感兴趣的小伙伴可以了解一下
    2022-11-11
  • python 多进程和协程配合使用写入数据

    python 多进程和协程配合使用写入数据

    这篇文章主要介绍了python 多进程和协程配合使用写入数据,帮助大家利用python高效办公,感兴趣的朋友可以了解下
    2020-10-10
  • python读取和保存图片5种方法对比

    python读取和保存图片5种方法对比

    为大家分享一下python读取和保存图片5种方法与比较,python中对象之间的赋值是按引用传递的,如果需要拷贝对象,需要用到标准库中的copy模块
    2018-09-09
  • Django CSRF验证失败请求被中断的问题

    Django CSRF验证失败请求被中断的问题

    这篇文章主要介绍了Django CSRF验证失败请求被中断的问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-09-09
  • Python入门教程(二十四)Python的迭代器

    Python入门教程(二十四)Python的迭代器

    这篇文章主要介绍了Python入门教程(二十四)Python的迭代器,Python是一门非常强大好用的语言,也有着易上手的特性,本文为入门教程,需要的朋友可以参考下
    2023-04-04
  • PyTorch中torch.load()的用法和应用

    PyTorch中torch.load()的用法和应用

    torch.load()它用于加载由torch.save()保存的模型或张量,本文主要介绍了PyTorch中torch.load()的用法和应用,具有一定的参考价值,感兴趣的可以了解一下
    2024-03-03
  • 利用pandas合并多个excel的方法示例

    利用pandas合并多个excel的方法示例

    这篇文章主要介绍了利用pandas合并多个excel的方法示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-10-10
  • python读取图片并修改格式与大小的方法

    python读取图片并修改格式与大小的方法

    这篇文章主要为大家详细介绍了python读取图片并修改格式与大小的方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-07-07
  • python 将字符串中的数字相加求和的实现

    python 将字符串中的数字相加求和的实现

    这篇文章主要介绍了python 将字符串中的数字相加求和的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07

最新评论