文章目录

  • 一、什么是knn算法
  • 二、算法原理
  • 三、通用步骤
  • 四、简单应用


一、什么是knn算法

knn算法实际上是利用训练数据集对特征向量空间进行划分,并作为其分类的模型。其输入是实例的特征向量,输出为实例的类别。寻找最近的k个数据,推测新数据的分类。

二、算法原理

nlp 相似度计算 knn相似度计算_算法原理


对于上面的这个散点图,已知的点是分布在一个二维空间的,当然,在实际生活中,情况会变得复杂,可能是多维的。这个例子表示的是肿瘤病人的相关信息,横轴表示肿瘤的大小,纵轴表示肿瘤发现的时间,这两个轴均是病人肿瘤的特征信息。红色的点表示良性肿瘤,蓝色的表示恶性,表示的是肿瘤的分类。

nlp 相似度计算 knn相似度计算_数据集_02


有了这些初始信息,假如这个时候来了一条新的肿瘤特征信息,我们就可以通过knn算法来对其分类进行预测:

首先我们会取一个k值,假如这里k=3(k如何取值后面会提到)。接下来要做的就是在所有的点中选择离这个新的点最近的三个点,然后这三个点根据自己的分类进行投票,这里的话三个点都是恶性肿瘤,所以根据恶性对良性3:0推测这个新的点也是恶性肿瘤。

knn算法的本质其实是认为两个样本如果他们足够相似的话,那么他们就具有更高的概率属于同一类别。这里我们通过计算两个样本在样本空间中的距离来衡量他们的相似度。

三、通用步骤

  1. 计算距离
    计算待求点与其他点之间的距离(常用欧几里得距离或马氏距离)
  2. 升序排列
    将计算的距离按照升序排列,即距离越近越靠前
  3. 取前k个
  4. 加权平均

四、简单应用

nlp 相似度计算 knn相似度计算_升序_03


对于上图这样的100条数据(数据来源:https://www.kaggle.com/sajidsaifi/prostate-cancer),表示癌症病人诊疗结果与各项指标之间的关系,使用knn算法来预测新数据基于各项指标的诊疗结果

import pandas as pd
import random
from collections import Counter

# 读取数据
file_data = pd.read_csv("./_Cancer.csv").to_dict(orient="records")

# 分组:测试集(少量)、训练集--(确保算法是有效的、可行的)
# 避免偶然情况,先将数据集打乱
random.shuffle(file_data)
test_set = file_data[:len(file_data) // 3]
train_set = file_data[len(file_data) // 3:]

# 计算距离(这里使用欧氏距离)
def distance(p1, p2):
    res = 0
    for key in (
            "radius", "texture", "perimeter", "area", "smoothness", "compactness", "symmetry", "fractal_dimension"):
        res += (p1[key] - p2[key]) ** 2
    return res ** 0.5

K = 5  #这里取K=5

# KNN
def knn(new_data):
    # 1.计算距离
    res = [{"diagnosis_result": train_data["diagnosis_result"], "distance": distance(new_data, train_data)} for
           train_data in train_set]

    # 2.升序排列
    res = sorted(res, key=lambda item: item["distance"])

    # 3.取前K个
    res = res[:K]

    # 4.投票选举(也可以加权平均)
    res_list = [item["diagnosis_result"] for item in res]
    votes = Counter(res_list)
    return votes.most_common(1)[0][0]

# 测试
correct = 0
for test in test_set:
    if test["diagnosis_result"] == knn(test):
        correct += 1
print("正确率 = {:.2f}%".format((correct / len(test_set)) * 100))

100条数据测得准确率在80%左右