KNN(k-NearestNeighbor)中文称K-近邻算法,是数据挖掘/机器学习中最简单的算法。既然叫做K近邻,那么按照字面理解就很容易知道该算法与距离相关,而在实际中距离也是该算法的核心。

        该算法需要解决的问题也很简单,就是对未知样本进行分类,那么对于一个未知的数据,我们如何确定其分类到底属于A,B,C还是D呢?按照KNN的算法而言,最基础的方法就是找出距离该数据最近的K个数据点(已知标签),最后看哪个标签出现的概率最大则将其赋值为该测试样本的最终标签。


KNN算法流程

  • 计算未知数据与已知样本数据间的距离distances;
  • 得到与未知样本距离最近的K个样本数据点;
  • 确定该K个样本数据所在的类别出现的频率;
  • 返回前K个样本数据所出现的频率最高的类最为该未知数据的预测类;


KNN算法优缺点

  • 优点:由于算法很简单,而且是按照最终的投票确定预测分类,因此该算法相对而言精度较高且对异常值不敏感
  • 缺点:由于每次进入未知数据都需要将其与其他所有训练数据进行距离的计算和排序,因此该算法的时间复杂度和空间复杂度都较高;此外,因为涉及到距离,因此该算法只能适用于数值型和标称型的数据。

KNN算法的改进

  • 将存储的训练元组预先排序并安排在搜索树中;
  • 通过并行实现距离的计算,缩短计算时间;
  • 剪枝或精简:删除证明是“无用的”元组;


Python代码的实现

  • 距离的衡量:本代码中以欧式距离为例进行计算;即:d={(A0-B0)^2+(A1-B1)^2}^0.5
  • 所使用的库:Numpy , operator (关于两个库将在今后的文章中做专门介绍)
  • 参数说明:

           - inX             输入数据矩阵;

           - dataSet     训练数据集;

           - labels        数据集所对应的标签;

           - K                所选择的近邻数量;

<pre name="code" class="python">def classify0(inX, dataSet, labels, k):  
    # 确定数据的维数                                           
    dataSetSize = dataSet.shape[0]                                                  
    # 计算矩阵之间的差矩阵
    diffMat = tile(inX, (dataSetSize,1)) - dataSet                                  
    # 计算距离
    sqDiffMat = diffMat**2                                                          
    sqDistances = sqDiffMat.sum(axis=1)                                             
    distances = sqDistances**0.5                                                    
    # 对距离进行排序
    sortedDistIndicies = distances.argsort()                                        
    classCount={}                                                                   
    # 选择距离最小的K个点并统计K个近邻点中各类的数量
    for i in range(k):                                                              
        voteIlabel = labels[sortedDistIndicies[i]]                                  
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1                   
    # 对类进行排序并返回最高的为预测类
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]