python实现KNN分类算法

 更新时间:2019年10月16日 10:31:54   作者:王念晨  
这篇文章主要为大家详细介绍了python实现KNN分类算法,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

一、KNN算法简介

邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。

二、算法过程

1.读取数据集

2.处理数据集数据 清洗,采用留出法hold-out拆分数据集:训练集、测试集

3.实现KNN算法类:

   1)遍历训练数据集,离差平方和计算各点之间的距离

   2)对各点的距离数组进行排序,根据输入的k值取对应的k个点

   3)k个点中,统计每个点出现的次数,权重为距离的导数,得到最大的值,该值的索引就是我们计算出的判定类别

三、代码实现及数据分析

import numpy as np
import pandas as pd
 
# 读取鸢尾花数据集,header参数来指定标题的行。默认为0。如果没有标题,则使用None。
data = pd.read_csv("你的目录/Iris.csv",header=0)
# 显示前n行记录。默认n的值为5。
#data.head()
# 显示末尾的n行记录。默认n的值为5。
#data.tail()
# 随机抽取样本。默认抽取一条,我们可以通过参数进行指定抽取样本的数量。
# data.sample(10)
# 将类别文本映射成为数值类型
 
data["Species"] = data["Species"].map({"Iris-virginica": 0, "Iris-setosa": 1, "Iris-versicolor": 2})
# 删除不需要的Id列。
data.drop("Id", axis=1, inplace=True )
data.drop_duplicates(inplace=True)
## 查看各个类别的鸢尾花具有多少条记录。
data["Species"].value_counts()

分析:首先读取数据集,如下图

最后一列为数据集的分类名称,但是在程序中,我们更倾向于使用如0、1、2数字来表示分类,所以对数据集进行处理,处理后的数据集如下:

然后采用留出法对数据集进行拆分,一部分用作训练,一部分用作测试,如下图:

#构建训练集与测试集,用于对模型进行训练与测试。
# 提取出每个类比的鸢尾花数据
t0 = data[data["Species"] == 0]
t1 = data[data["Species"] == 1]
t2 = data[data["Species"] == 2]
# 对每个类别数据进行洗牌 random_state 每次以相同的方式洗牌 保证训练集与测试集数据取样方式相同
t0 = t0.sample(len(t0), random_state=0)
t1 = t1.sample(len(t1), random_state=0)
t2 = t2.sample(len(t2), random_state=0)
# 构建训练集与测试集。
train_X = pd.concat([t0.iloc[:40, :-1], t1.iloc[:40, :-1], t2.iloc[:40, :-1]] , axis=0)#截取前40行,除最后列外的列,因为最后一列是y
train_y = pd.concat([t0.iloc[:40, -1], t1.iloc[:40, -1], t2.iloc[:40, -1]], axis=0)
test_X = pd.concat([t0.iloc[40:, :-1], t1.iloc[40:, :-1], t2.iloc[40:, :-1]], axis=0)
test_y = pd.concat([t0.iloc[40:, -1], t1.iloc[40:, -1], t2.iloc[40:, -1]], axis=0)

实现KNN算法类:

#定义KNN类,用于分类,类中定义两个预测方法,分为考虑权重不考虑权重两种情况
class KNN:
 ''' 使用Python语言实现K近邻算法。(实现分类) '''
 def __init__(self, k):
  '''初始化方法 
   Parameters
   -----
   k:int 邻居的个数
  '''
  self.k = k
 
 def fit(self,X,y):
  '''训练方法
   Parameters
   ----
   X : 类数组类型,形状为:[样本数量, 特征数量]
   待训练的样本特征(属性)
  
  y : 类数组类型,形状为: [样本数量]
   每个样本的目标值(标签)。
  '''
  #将X转换成ndarray数组
  self.X = np.asarray(X)
  self.y = np.asarray(y)
  
 def predict(self,X):
  """根据参数传递的样本,对样本数据进行预测。
  
  Parameters
  -----
  X : 类数组类型,形状为:[样本数量, 特征数量]
   待训练的样本特征(属性) 
  
  Returns
  -----
  result : 数组类型
   预测的结果。
  """
  X = np.asarray(X)
  result = []
  # 对ndarray数组进行遍历,每次取数组中的一行。
  for x in X:
   # 对于测试集中的每一个样本,依次与训练集中的所有样本求距离。
   dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1))
   ## 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引。
   index = dis.argsort()
   # 进行截断,只取前k个元素。【取距离最近的k个元素的索引】
   index = index[:self.k]
   # 返回数组中每个元素出现的次数。元素必须是非负的整数。【使用weights考虑权重,权重为距离的倒数。】
   count = np.bincount(self.y[index], weights= 1 / dis[index])
   # 返回ndarray数组中,值最大的元素对应的索引。该索引就是我们判定的类别。
   # 最大元素索引,就是出现次数最多的元素。
   result.append(count.argmax())
  return np.asarray(result)
#创建KNN对象,进行训练与测试。
knn = KNN(k=3)
#进行训练
knn.fit(train_X,train_y)
#进行测试
result = knn.predict(test_X)
# display(result)
# display(test_y)
display(np.sum(result == test_y))
display(np.sum(result == test_y)/ len(result))

得出计算结果:

26
0.9629629629629629

得出该模型计算的结果中,有26条记录与测试集相等,准确率为96%

接下来绘制散点图:

#导入可视化所必须的库。
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams["font.family"] = "SimHei"
mpl.rcParams["axes.unicode_minus"] = False
 
#绘制散点图。为了能够更方便的进行可视化,这里只选择了两个维度(分别是花萼长度与花瓣长度)。
# {"Iris-virginica": 0, "Iris-setosa": 1, "Iris-versicolor": 2})
# 设置画布的大小
plt.figure(figsize=(10, 10))
# 绘制训练集数据
plt.scatter(x=t0["SepalLengthCm"][:40], y=t0["PetalLengthCm"][:40], color="r", label="Iris-virginica")
plt.scatter(x=t1["SepalLengthCm"][:40], y=t1["PetalLengthCm"][:40], color="g", label="Iris-setosa")
plt.scatter(x=t2["SepalLengthCm"][:40], y=t2["PetalLengthCm"][:40], color="b", label="Iris-versicolor")
# 绘制测试集数据
right = test_X[result == test_y]
wrong = test_X[result != test_y]
plt.scatter(x=right["SepalLengthCm"], y=right["PetalLengthCm"], color="c", marker="x", label="right")
plt.scatter(x=wrong["SepalLengthCm"], y=wrong["PetalLengthCm"], color="m", marker=">", label="wrong")
plt.xlabel("花萼长度")
plt.ylabel("花瓣长度")
plt.title("KNN分类结果显示")
plt.legend(loc="best")
plt.show()

程序运行结果如下:

四、思考与优化

①尝试去改变邻居的数量。

②在考虑权重的情况下,修改邻居的数量。

③对比查看结果上的差异。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • python采集百度百科的方法

    python采集百度百科的方法

    这篇文章主要介绍了python采集百度百科的方法,涉及Python正则匹配及页面抓取的相关技巧,需要的朋友可以参考下
    2015-06-06
  • Python操作MySQL数据库的方法

    Python操作MySQL数据库的方法

    pymsql是Python中操作MySQL的模块,其使用方法和MySQLdb几乎相同。接下来通过本文给大家介绍Python操作MySQL数据库的方法,感兴趣的朋友一起看看吧
    2018-06-06
  • 编写Python CGI脚本的教程

    编写Python CGI脚本的教程

    这篇文章主要介绍了编写Python CGI脚本的教程,CGI是Python和服务器软件连接的接口,需要的朋友可以参考下
    2015-06-06
  • python基础教程之五种数据类型详解

    python基础教程之五种数据类型详解

    这篇文章主要介绍了python基础教程之五种数据类型详解的相关资料,这里对Python 的数据类型进行了详细介绍,需要的朋友可以参考下
    2017-01-01
  • Python数据分析库pandas基本操作方法

    Python数据分析库pandas基本操作方法

    下面小编就为大家分享一篇Python数据分析库pandas基本操作方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • Python正则表达式 r'(.*) are (.*?) .*'的深入理解

    Python正则表达式 r'(.*) are (.*?) .*'的深入理解

    日常的开发工作中经常会有处理字符串的需求,简单的字符串处理,我们使用python内置的字符串处理函数就可以了,但是复杂的字符串匹配就需要借助正则表达式了,这篇文章主要给大家介绍了关于Python正则表达式 r‘(.*) are (.*?) .*‘的相关资料,需要的朋友可以参考下
    2022-07-07
  • python PyQt5/Pyside2 按钮右击菜单实例代码

    python PyQt5/Pyside2 按钮右击菜单实例代码

    本文通过实例代码给大家介绍了python PyQt5/Pyside2 按钮右击菜单,代码简单易懂,非常不错,具有一定的参考借鉴价值,需要的朋友参考下吧
    2019-08-08
  • 网站渗透常用Python小脚本查询同ip网站

    网站渗透常用Python小脚本查询同ip网站

    这篇文章主要介绍了网站渗透常用Python小脚本查询同ip网站,需要的朋友可以参考下
    2017-05-05
  • python常用数据结构集合详解

    python常用数据结构集合详解

    这篇文章主要介绍了python常用数据结构集合详解,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感兴趣的小伙伴可以参考一下,希望对你的学习有所帮助
    2022-08-08
  • pytorch 数据处理:定义自己的数据集合实例

    pytorch 数据处理:定义自己的数据集合实例

    今天小编就为大家分享一篇pytorch 数据处理:定义自己的数据集合实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12

最新评论