伪码
对未知的数据进行以下操作:
1.计算未知数据和样本集合之间的距离(有多种距离公式可供选择,此处使用欧氏距离)
2.把距离按从小到大的次序排序
3.选择前k个距离最小的样本
4.确定k个样本所在类别出现的频率
5.选择出现频率最高的类别作为预测值欧氏距离公式
计算A(A0,A1,A2,…,An)与B(B0,B1,B2,…,Bn)之间的距离,公式为:
d = √ (xA0 - xB0)^2 + (xA1 - xB1)^2 + ... + (xAn - xBn)^2A点的n个分量可以理解为样本A的n个特征的量化,同理B。
一个简单的测试代码
首先定义一个kNN.py文件,文件中包含两个函数:createDataSet函数用来产生数据集,classify函数用来进行分类。
classify中注释的print函数可以帮助理解每一步在干什么。
""""
kNN一般步骤:
对未知类别进行以下操作:
1.计算已知类别数据集中的点与当前点之间的距离
2.按照距离递增次序排序
3.选择与当前点距离最小的k个点
4.确定前k个点所在类别的出现频率
5.返回前k个点出现频率最高的类别作为预测值
"""
import numpy as np
import operator
def createDataSet():
group = np.array([[1.0,1.1],
[1.0,1.0],
[0,0],
[0,0.1]])
labels = ['A','A','B','B']
return group,labels
#inX:未知点;dataSet:数据集;labels:数据集标签;k:选择几个最相似样本
def classify(inX,dataSet,labels,k):
#print('inX:',inX)
#print('dataSet:',dataSet)
dataSetSize = dataSet.shape[0]#数据集大小
diffMat = np.tile(inX,(dataSetSize,1)) - dataSet#np.tile:把数组沿各方向复制。先把未知值复制成与数据集形状一致的,再对应相减
#print('diffMat:',diffMat)
sqDiffMat = diffMat ** 2
#print('sqDiffMat:',sqDiffMat)
#print(type(sqDiffMat))
sqDistances = sqDiffMat.sum(axis = 1)#axis = 1:按行加;axis = 0:按列加。返回numpy数组类型
#print('sqDistances:',sqDistances)
#print(type(sqDistances))
distances = sqDistances ** 0.5#每个元素都取平方根
#print("distances:",distances)
sortedDistIndicies = distances.argsort()
#print('sortedDistIndicies:',sortedDistIndicies)
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
if voteIlabel not in classCount.keys():
classCount[voteIlabel] = 0
else:
classCount[voteIlabel] = classCount[voteIlabel] + 1
#print('classCount:',classCount)
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
#print('sortedClassCount:',sortedClassCount)
return sortedClassCount[0][0]再定义一个apply.py函数,进行预测。
import kNN
group,labels = kNN.createDataSet()
result = kNN.classify([-1,-1],group,labels,3)
print("The result of classification is " + result)可能需要解释的一些函数
1.numpy.tile():把数组沿着各方向复制
import numpy as np
a = [1,1]
b = np.tile(a,(2,1))
print(b)
c = [2,2]
d = np.tile(c,(1,2))
print(d)
e = [3,3]
f = np.tile(e,(2,2))
print(f)结果如下:
[[1 1]
[1 1]]
[[2 2 2 2]]
[[3 3 3 3]
[3 3 3 3]]2. array.sum(axis = ?):数组中元组沿着指定方向相加,返回一个新数组
axis=0:沿着y轴
axis=1:沿着x轴
g = b.sum(axis = 0)
print(g)
h = d.sum(axis = 1)
print(h)
i = f.sum(axis = 0)
print(i)
i = f.sum(axis = 1)
print(i)运行结果
[2 2]
[8]
[6 6 6 6]
[12 12]3. array.argsort():将数组中元素排序,返回下标数组
j = np.array([9,8,7,6,5,4,3,2,1])
k = j.argsort()
print(k)运行结果
[8 7 6 5 4 3 2 1 0]4.sorted()函数:对所有可迭代的对象进行排序操作。
sorted(可迭代对象, 比较函数, 指定可迭代对象中的一个元素来进行排序, reverse=False\True(正序或倒序))
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
#将classCount中的元素对象进行排序,指定排序对象为每个元素的第‘1’维度,也就是字典每个key的value通常可以将operator.itemgetter(x)与sorrted()进行搭配使用,x指定要排序的对象中数据的维数。
import operator
m = {'A':5,'B':1,'C':2}
n = sorted(m.items(),key=operator.itemgetter(1))
print('n:',n)
#n: [('B', 1), ('C', 2), ('A', 5)]
















