鸢尾花数据分类,通过Python实现KNN分类算法。

项目来源:https://aistudio.baidu.com/aistudio/projectdetail/1988428

数据集来源:鸢尾花数据集https://aistudio.baidu.com/aistudio/datasetdetail/91206

1 import numpy as np 
  2 import pandas as pd
  3 import matplotlib as mpl 
  4 import matplotlib.pyplot as plt 
  5 
  6 # 读取鸢尾花数据集,header参数来指定标题的行。默认为0。如果没有标题,则使用None
  7 # 四个特征分别为花萼长度sepal length,花萼宽度sepal width,花瓣长度petal length,花瓣宽度petal width。鸢尾花的种类,共有3种,分别为山鸢尾Iris Setosa、杂色鸢尾Iris Versicolour、维吉尼亚鸢尾Iris Virginica。
  8 data = pd.read_csv('./data/data91206/iris.csv', header=0)
  9 # 显示全部数据
 10 # data
 11 
 12 # 显示前n行的数据,默认n的值为5
 13 # data.head()
 14 
 15 # 显示末尾的n行记录,默认n的值为5
 16 # data.tail()
 17 
 18 # 随机抽取样本,默认抽取一条,我们可以通过修改参数来指定抽取样本的数量
 19 data.sample(10)
 20 
 21 # 将类别文本映射为数值类型
 22 data['Species'] = data['Species'].map({'versicolor':0,'setosa':1,'virginica':2})
 23 # 删除不需要的Id列,并改变原来的文本,以下有两种方法
 24 # data.drop('Id', axis=1, inplace = True)
 25 data = data.drop('Id', axis=1)
 26 
 27 # 查看是否有重复数据
 28 # data.duplicated().any()
 29 
 30 # 查看数据集的列数
 31 # len(data)
 32 
 33 # 删除重复的记录
 34 data.drop_duplicates(inplace=True)
 35 # len(data)
 36 
 37 # 查看各个类别的鸢尾花有多少条记录
 38 data['Species'].value_counts()
 39 
 40 class KNN:
 41     '''使用Python语言实现K近邻算法。(实现分类)'''
 42     
 43     def __init__(self, k):
 44         '''初始化方法
 45         
 46         Parameters
 47         ------
 48         k : int
 49            邻居的个数
 50            
 51         '''
 52         self.k = k
 53         
 54     def fit(self, X, y):
 55         '''训练方法
 56         
 57         Parameters
 58         ------
 59         X : 类数组类型,形状为:{样本数量,特征数量}
 60             待训练的样本特征(属性)
 61         y : 类数组类型,形状为:{样本数量}
 62             每个样本的目标值(标签)
 63         
 64         '''
 65         # 将X转换为array数组
 66         self.X = np.asarray(X)
 67         self.y = np.asarray(y)
 68         
 69     def predict(self, X):
 70         '''根据参数传递的样本,对样本数据进行预测。
 71         
 72         Parameters
 73         -------
 74         X : 类数组类型,形状为:[样本数量,特征数量]
 75             待训陈的样本特征(属性)
 76                        
 77         Returns
 78         ----
 79         result : 数组类型
 80                 预测的结果
 81         '''
 82         
 83         X = np.asarray(X)
 84         result = []
 85         # 对array数组进行遍历,每次取数组中的一行。
 86         for x in X:
 87             # 对于测试集中的每一个样本,依次与训练集中的所有样本求距离。
 88             dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1))
 89             # 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引
 90             index = dis.argsort()
 91             # 进行截断,只取前k个元素。【取距离最近的k个元素的索引】
 92             index = index[:self.k]
 93             # 返回数组中每个元素出现的次数。元素必须是非负的整数
 94             count = np.bincount(self.y[index])
 95             # 返回ndarray数组中值最大的元素对应的索引,该索引就是我们判定的索引
 96             # 最大元素索引,就是出现次数最多的元素
 97             result.append(count.argmax())
 98         return np.asarray(result)
 99     
100     
101     def predict2(self, X):
102         '''根据参数传递的样本,对样本数据进行预测。
103         
104         Parameters
105         -------
106         X : 类数组类型,形状为:[样本数量,特征数量]
107             待训陈的样本特征(属性)
108             
109         Returns
110         ----
111         result : 数组类型
112                 预测的结果
113         '''
114 
115         X = np.asarray(X)
116         result = []
117         # 对array数组进行遍历,每次取数组中的一行。
118         for x in X:
119             # 对于测试集中的每一个样本,依次与训练集中的所有样本求距离。
120             dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1))
121             # 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引
122             index = dis.argsort()
123             # 进行截断,只取前k个元素。【取距离最近的k个元素的索引】
124             index = index[:self.k]
125             # 返回数组中每个元素出现的次数。元素必须是非负的整数。【使用weights考虑权重,权重为距离的倒数】      
126             count = np.bincount(self.y[index], weights=1 / dis[index])
127             # 返回ndarray数组中值最大的元素对应的索引,该索引就是我们判定的索引
128             # 最大元素索引,就是出现次数最多的元素
129             result.append(count.argmax())
130         return np.asarray(result)
131 
132 # 提取出每个类别的鸢尾花数据
133 t0 = data[data['Species'] == 0]
134 t1 = data[data['Species'] == 1]
135 t2 = data[data['Species'] == 2]
136 # 对每个类别数据进行打乱洗牌
137 t0 = t0.sample(len(t0), random_state=0)
138 t1 = t1.sample(len(t1), random_state=0)
139 t2 = t2.sample(len(t2), random_state=0)
140 # 构建训练集和测试集
141 train_X = pd.concat([t0.iloc[:40, :-1], t1.iloc[:40, :-1], t2.iloc[:40, :-1]], axis=0)
142 train_y = pd.concat([t0.iloc[:40, -1], t1.iloc[:40, -1], t2.iloc[:40, -1]], axis=0)
143 test_X = pd.concat([t0.iloc[40:, :-1], t1.iloc[40:, :-1], t2.iloc[40:, :-1]], axis=0)
144 test_y = pd.concat([t0.iloc[40:, -1], t1.iloc[40:, -1], t2.iloc[40:, -1]], axis=0)
145 # 创建KNN对象,进行训练与测试
146 knn = KNN(k=3)
147 # 进行训练
148 knn.fit(train_X, train_y)
149 # 进行测试,获得测试的结果
150 result = knn.predict(test_X)
151 
152 # 查看显示
153 # display(result)
154 # display(test_y)
155 
156 display(np.sum(result == test_y))
157 display(np.sum(result == test_y)/ len(result))
158 
159 # 考虑权重,进行一下测试。
160 result2 = knn.predict2(test_X)
161 display(np.sum(result2 == test_y))
162 
163 # 如果想显示中文的话,可以看这一段,默认情况下,matplotlib不支持中文显示,进行以下设置
164 # 设置字体为黑体,以支持中文显示
165 mpl.rcParams['font.family'] = 'SimHei'
166 # 设置在中文字体时,能够正常的显示负号(-)
167 mpl.rcParams['axes.unicode_minus'] = False
168 
169 # 绘制数据集数据
170 # 设置画布的大小
171 plt.figure(figsize=(10, 10))
172 plt.scatter(x=t0['Sepal.Length'][:40], y=t0['Petal.Length'][:40], color='r', label='versicolor')  
173 plt.scatter(x=t1['Sepal.Length'][:40], y=t1['Petal.Length'][:40], color='g', label='setosa')  
174 plt.scatter(x=t2['Sepal.Length'][:40], y=t2['Petal.Length'][:40], color='b', label='virginica')  
175 # 绘制测试集数据
176 right = test_X[result == test_y]
177 wrong = test_X[result != test_y]
178 plt.scatter(x=right['Sepal.Length'], y=right['Petal.Length'], color='c', marker='x', label='right')  
179 plt.scatter(x=wrong['Sepal.Length'], y=wrong['Petal.Length'], color='m', marker='>', label='wrong')  
180 # 英文显示title、label
181 plt.xlabel('Sepal.Length')
182 plt.ylabel('Petal.Length')
183 plt.title('KNN classification')
184 plt.legend(loc='best')
185 plt.show()