1, 计算已知类别数据集合中的点与当前点之间的距离(使用欧式距离公司: d =sqrt(pow(x-x1),2)+pow(y-y1),2)

2, 按照距离递增次序排序(由近到远)

3, 选取与当前点距离最小的的K个点(如上题中的 k=3,k=5)

4, 确定前K个点所在类别的出现频率

5, 将频率最高的那组,作为该点的预测分类

 1 package com.data.knn;
2
3 /**
4  * *********************************************************
5  * <p/>
6  * Author:     XiJun.Gong
7  * Date:       2016-09-06 12:02
8  * Version:    default 1.0.0
9  * Class description：
10  * <p/>
11  * *********************************************************
12  */
13 public class Point {
14
15     private double x;  //x坐标
16     private double y;  //y坐标
17     private double dist; //距离另一个点的距离
18
19
20
21     private String label; //所属类别
22
23     public Point() {
24         this(0d, 0d, "");
25     }
26
27     public Point(double x, double y, String label) {
28         this.x = x;
29         this.y = y;
30         this.label = label;
31     }
32
33     /*计算两点之间的距离*/
34     public double distance(final Point a) {
35         return Math.sqrt((a.x - x) * (a.x - x) + (a.y - y) * (a.y - y));
36     }
37
38     public double getX() {
39         return x;
40     }
41
42     public void setX(double x) {
43         this.x = x;
44     }
45
46     public double getY() {
47         return y;
48     }
49
50     public void setY(double y) {
51         this.y = y;
52     }
53
54     public String getLabel() {
55         return label;
56     }
57
58     public void setLabel(String label) {
59         this.label = label;
60     }
61
62
63     public double getDist() {
64         return dist;
65     }
66
67     public void setDist(double dist) {
68         this.dist = dist;
69     }
70 }

KNN实现

 1 package com.data.knn;
2
5
6 import java.util.Collections;
7 import java.util.Comparator;
8 import java.util.List;
9 import java.util.Map;
10
11 /**
12  * *********************************************************
13  * <p/>
14  * Author:     XiJun.Gong
15  * Date:       2016-09-06 11:59
16  * Version:    default 1.0.0
17  * Class description：
18  * <p/>
19  * *********************************************************
20  */
21 public class knn {
22
23     private List<Point> dataSet;    //统计频率
24     private Point newPoint;         //当前点
25
26
27     //进行KNN分类
28     public String classify(List<Point> dataSet, final Point newPoint, Integer K) {
29
30         Preconditions.checkArgument(K < dataSet.size(), "K的值超过了dataSet的元素");
31         //求解每一个点到新的点的距离
32         for (Point point : dataSet) {
33             point.setDist(newPoint.distance(point));
34         }
35         //进行排序
36         Collections.sort(dataSet, new Comparator<Point>() {
37             @Override
38             public int compare(Point o1, Point o2) {
39                 //return o1.distance(newPoint) < o2.distance(newPoint) ? 1 : -1;
40                 return o1.getDist() < o2.getDist() ? 1 : -1;
41             }
42         });
43
44         //统计前K个标签的频率
45         Map<String, Integer> map = Maps.newHashMap();
46         Integer maxCnt = -9999; //最高频率
47         String label = "";  //最高频率标签
48         Integer currentCnt = 0; //当前标签的频率
49         Integer times = 0;
50         for (Point point : dataSet) {
51             currentCnt = 1;
52             if (map.containsKey(point.getLabel())) {
53                 currentCnt += map.get(point);
54             }
55             if (maxCnt < currentCnt) {
56                 maxCnt = currentCnt;
57                 label = point.getLabel();
58             }
59             map.put(point.getLabel(), currentCnt);
60             times++;
61             if (times > K) break;
62         }
63         return label;
64     }
65
66
67 }
 1 package com.data.knn;
2
4
5 import java.util.List;
6
7 /**
8  * *********************************************************
9  * <p/>
10  * Author:     XiJun.Gong
11  * Date:       2016-09-06 14:45
12  * Version:    default 1.0.0
13  * Class description：
14  * <p/>
15  * *********************************************************
16  */
17 public class Main {
18
19     public static void main(String args[]) {
20         List<Point> list = Lists.newArrayList();
25         Point point = new Point(0.5, 0.5, null);
26         KNN knn = new KNN();
27         System.out.println(knn.classify(list, point, 3));
28     }
29 }

A