鸢尾花数据分类,通过Python实现KNN分类算法。
项目来源:https://aistudio.baidu.com/aistudio/projectdetail/1988428
数据集来源:鸢尾花数据集https://aistudio.baidu.com/aistudio/datasetdetail/91206
1 import numpy as np
2 import pandas as pd
3 import matplotlib as mpl
4 import matplotlib.pyplot as plt
5
6 # 读取鸢尾花数据集,header参数来指定标题的行。默认为0。如果没有标题,则使用None
7 # 四个特征分别为花萼长度sepal length,花萼宽度sepal width,花瓣长度petal length,花瓣宽度petal width。鸢尾花的种类,共有3种,分别为山鸢尾Iris Setosa、杂色鸢尾Iris Versicolour、维吉尼亚鸢尾Iris Virginica。
8 data = pd.read_csv('./data/data91206/iris.csv', header=0)
9 # 显示全部数据
10 # data
11
12 # 显示前n行的数据,默认n的值为5
13 # data.head()
14
15 # 显示末尾的n行记录,默认n的值为5
16 # data.tail()
17
18 # 随机抽取样本,默认抽取一条,我们可以通过修改参数来指定抽取样本的数量
19 data.sample(10)
20
21 # 将类别文本映射为数值类型
22 data['Species'] = data['Species'].map({'versicolor':0,'setosa':1,'virginica':2})
23 # 删除不需要的Id列,并改变原来的文本,以下有两种方法
24 # data.drop('Id', axis=1, inplace = True)
25 data = data.drop('Id', axis=1)
26
27 # 查看是否有重复数据
28 # data.duplicated().any()
29
30 # 查看数据集的列数
31 # len(data)
32
33 # 删除重复的记录
34 data.drop_duplicates(inplace=True)
35 # len(data)
36
37 # 查看各个类别的鸢尾花有多少条记录
38 data['Species'].value_counts()
39
40 class KNN:
41 '''使用Python语言实现K近邻算法。(实现分类)'''
42
43 def __init__(self, k):
44 '''初始化方法
45
46 Parameters
47 ------
48 k : int
49 邻居的个数
50
51 '''
52 self.k = k
53
54 def fit(self, X, y):
55 '''训练方法
56
57 Parameters
58 ------
59 X : 类数组类型,形状为:{样本数量,特征数量}
60 待训练的样本特征(属性)
61 y : 类数组类型,形状为:{样本数量}
62 每个样本的目标值(标签)
63
64 '''
65 # 将X转换为array数组
66 self.X = np.asarray(X)
67 self.y = np.asarray(y)
68
69 def predict(self, X):
70 '''根据参数传递的样本,对样本数据进行预测。
71
72 Parameters
73 -------
74 X : 类数组类型,形状为:[样本数量,特征数量]
75 待训陈的样本特征(属性)
76
77 Returns
78 ----
79 result : 数组类型
80 预测的结果
81 '''
82
83 X = np.asarray(X)
84 result = []
85 # 对array数组进行遍历,每次取数组中的一行。
86 for x in X:
87 # 对于测试集中的每一个样本,依次与训练集中的所有样本求距离。
88 dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1))
89 # 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引
90 index = dis.argsort()
91 # 进行截断,只取前k个元素。【取距离最近的k个元素的索引】
92 index = index[:self.k]
93 # 返回数组中每个元素出现的次数。元素必须是非负的整数
94 count = np.bincount(self.y[index])
95 # 返回ndarray数组中值最大的元素对应的索引,该索引就是我们判定的索引
96 # 最大元素索引,就是出现次数最多的元素
97 result.append(count.argmax())
98 return np.asarray(result)
99
100
101 def predict2(self, X):
102 '''根据参数传递的样本,对样本数据进行预测。
103
104 Parameters
105 -------
106 X : 类数组类型,形状为:[样本数量,特征数量]
107 待训陈的样本特征(属性)
108
109 Returns
110 ----
111 result : 数组类型
112 预测的结果
113 '''
114
115 X = np.asarray(X)
116 result = []
117 # 对array数组进行遍历,每次取数组中的一行。
118 for x in X:
119 # 对于测试集中的每一个样本,依次与训练集中的所有样本求距离。
120 dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1))
121 # 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引
122 index = dis.argsort()
123 # 进行截断,只取前k个元素。【取距离最近的k个元素的索引】
124 index = index[:self.k]
125 # 返回数组中每个元素出现的次数。元素必须是非负的整数。【使用weights考虑权重,权重为距离的倒数】
126 count = np.bincount(self.y[index], weights=1 / dis[index])
127 # 返回ndarray数组中值最大的元素对应的索引,该索引就是我们判定的索引
128 # 最大元素索引,就是出现次数最多的元素
129 result.append(count.argmax())
130 return np.asarray(result)
131
132 # 提取出每个类别的鸢尾花数据
133 t0 = data[data['Species'] == 0]
134 t1 = data[data['Species'] == 1]
135 t2 = data[data['Species'] == 2]
136 # 对每个类别数据进行打乱洗牌
137 t0 = t0.sample(len(t0), random_state=0)
138 t1 = t1.sample(len(t1), random_state=0)
139 t2 = t2.sample(len(t2), random_state=0)
140 # 构建训练集和测试集
141 train_X = pd.concat([t0.iloc[:40, :-1], t1.iloc[:40, :-1], t2.iloc[:40, :-1]], axis=0)
142 train_y = pd.concat([t0.iloc[:40, -1], t1.iloc[:40, -1], t2.iloc[:40, -1]], axis=0)
143 test_X = pd.concat([t0.iloc[40:, :-1], t1.iloc[40:, :-1], t2.iloc[40:, :-1]], axis=0)
144 test_y = pd.concat([t0.iloc[40:, -1], t1.iloc[40:, -1], t2.iloc[40:, -1]], axis=0)
145 # 创建KNN对象,进行训练与测试
146 knn = KNN(k=3)
147 # 进行训练
148 knn.fit(train_X, train_y)
149 # 进行测试,获得测试的结果
150 result = knn.predict(test_X)
151
152 # 查看显示
153 # display(result)
154 # display(test_y)
155
156 display(np.sum(result == test_y))
157 display(np.sum(result == test_y)/ len(result))
158
159 # 考虑权重,进行一下测试。
160 result2 = knn.predict2(test_X)
161 display(np.sum(result2 == test_y))
162
163 # 如果想显示中文的话,可以看这一段,默认情况下,matplotlib不支持中文显示,进行以下设置
164 # 设置字体为黑体,以支持中文显示
165 mpl.rcParams['font.family'] = 'SimHei'
166 # 设置在中文字体时,能够正常的显示负号(-)
167 mpl.rcParams['axes.unicode_minus'] = False
168
169 # 绘制数据集数据
170 # 设置画布的大小
171 plt.figure(figsize=(10, 10))
172 plt.scatter(x=t0['Sepal.Length'][:40], y=t0['Petal.Length'][:40], color='r', label='versicolor')
173 plt.scatter(x=t1['Sepal.Length'][:40], y=t1['Petal.Length'][:40], color='g', label='setosa')
174 plt.scatter(x=t2['Sepal.Length'][:40], y=t2['Petal.Length'][:40], color='b', label='virginica')
175 # 绘制测试集数据
176 right = test_X[result == test_y]
177 wrong = test_X[result != test_y]
178 plt.scatter(x=right['Sepal.Length'], y=right['Petal.Length'], color='c', marker='x', label='right')
179 plt.scatter(x=wrong['Sepal.Length'], y=wrong['Petal.Length'], color='m', marker='>', label='wrong')
180 # 英文显示title、label
181 plt.xlabel('Sepal.Length')
182 plt.ylabel('Petal.Length')
183 plt.title('KNN classification')
184 plt.legend(loc='best')
185 plt.show()