- 算法简介
假设有一个样本数据集合,并且这个集合里边的的每个数据都存在分类用的标签(也就是监督学习)。输入没有标签的新数据,将新数据和原有数据进行对应的特征进行比较,然后算法提取样本集中特征最相似的数据的分类标签。由于类别常用K表示,所以叫KNN(k-NearestNeighbor,k邻近算法)算法。K通常小于20,因为再大,分的类就会太细,意义不大。
- 实现原理
- 计算出集合中的点与输入点的距离;
- 按照距离递增进行排序;
- 选取距离最小的K个点;
- 确定K个点所在类别的出现频率;
- 返回前K个点出现频率最高的类别当作前预测分类
- 代码实现
def classify(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distance = sqDistances ** 0.5
sortedDistIndices = distance.argsotr()
classCount = {}
for i in range(k):
voteIlable = labels[sortedDistIndices[i]]
classCount[voteIlable] = classCount.get(voteIlable, 0) + 1
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
这个方法是根据上边的思路取实现,但是我拿到这段代码的时候还是比较懵的。一个是python代码不熟,另一个就是刚刚入门机器学习,对很多的做法不理解,下边开始讲我所遇到的问题。
- 问题
- python版本转化问题
由于python版本不一样,所以上边的代码并不能直接转换成python3的代码,虽然IDE会显示报错,但是还需要更多百度。例如我在运行代码的时候会报这个错误:
AttributeError: 'dict' object has no attribute 'iteritems'
这是由于第12行中调用了“operator.itemgetter(1)”,这个方法在python3之后就变成了“items()”。
- 显示问题
由于暂时不会更加高级的python操作,所以没有找到“断点”的用法,所以我是直接把所有的参数都通过控制台打印的方法查看的。在打印的时候发现输出的结果都是科学计算法的形式,通过搜索发现原来是numpy默认都是用科学计算法显示数据,如果要用普通的形式显示结果,需要增加下边这行配置:
np.set_printoptions(suppress=True)
np是引入numpy的时候自定义的名称。suppress是压缩的意思,所以这行代码的意思是打印的时候压缩浮点精度。
- python语法问题
严格来说,这个其实不算问题,只是因为我没有了解python的语法,所以我记录在此。
classCount[voteIlable] = classCount.get(voteIlable, 0) + 1
classCount是定义好的一个dict,dict相当于Java里边的map,但是又有点不同。get方法相当于从这个dict里边根据key获取value,如果获取不到,就返回第二个参数(相当于默认值)。而前边的“classCount[voteIlable]”相当于是从dict里边获取到中括号里边的key值对应的value。
所以连起来的意思是:
从classCount中获取voteIlable的value,如果没有,就返回0。
然后+1,再把这个值赋到classCount中key值为voteIlable的value中。
看起来很拗口的逻辑,其实就是利用dict的特性去进行统计,也就是思路中的“统计K个点出现频率”。
- 优化后的代码
通过上边的理解和改造,最后上边的代码改造成python3之后就变成下边这样:
def classify(inX, dataSet, labels, k):
np.set_printoptions(suppress=True)
print("**********")
print(inX)
print(dataSet)
# shape[0] 是获取行数,shape[1]是获取列数。
dataSetSize = dataSet.shape[0]
# print("dataSetSize**********")
# print(dataSetSize)
set = np.tile(inX, (dataSetSize, 1))
# print("set**********")
# print(set)
# 计算距离
diffMat = set - dataSet
# print("diffMat**********")
# print(diffMat)
sqDiffMat = diffMat ** 2 #平方
# print("sqDiffMat**********")
# print(sqDiffMat)
sqDistances = sqDiffMat.sum(axis=1) # axis=1 是按行相加,axis=0是按列相加
# print("sqDistances**********")
# print(sqDistances)
distance = sqDistances ** 0.5
# print("distance**********")
# print(distance)
# 按距离递增次序排序
sortedDistIndices = distance.argsort()
classCount = {}
# 选取与输入店最近的K个点。
for i in range(k): # 循环K个点的距离
voteIlable = labels[sortedDistIndices[i]] # 相当于遍历labels里边的元素
# print(voteIlable)
# 确定前k个点所在类别出现的频率
classCount[voteIlable] = classCount.get(voteIlable, 0) + 1
# print("==================")
# print(classCount)
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# 返回前k个点出现频率最高的类别作为预测分类
return sortedClassCount[0][0]
- 总结
- 需要尽快去顶python的调试方式,不能一直print;
- 算法看似复杂,但是用代码很好实现,但是也不能落下数学。