K-中心点聚类算法

(1)任意选择k个对象作为初始的簇中心点
(2)指派每个剩余对象给离他最近的中心点所表示的簇
(3)选择一个未被选择的中心点直到所有的中心点都被选择过
(4)选择一个未被选择过的非中心点对象,计算用代替的总代价并记录在S中
,直到所有非中心点都被选择过。
(5)如果在S中的所有非中心点代替所有中心点后的计算出总代价有小于0的存在,然后找出S中的用非中心点替代中心点后代价最小的一个,并用该非中心点替代对应的中心点,形成一个新的k个中心点的集合
(6)重复步骤2-5,直到没有再发生簇的重新分配,即所有的S都大于0.

代码

public class Cluster {
    private int id;// 标识
    private Point center;// 中心
    private List<Point> members = new ArrayList<Point>();// 成员

    public Cluster(int id, Point center) {
        this.id = id;
        this.center = center;
    }

    public Cluster(int id, Point center, List<Point> members) {
        this.id = id;
        this.center = center;
        this.members = members;
    }

    public void addPoint(Point newPoint) {
        if (!members.contains(newPoint)){
            members.add(newPoint);
        }else{
            System.out.println("样本数据点 {"+newPoint.toString()+"} 已经存在!");
        }
    }
    public float getdis() {
        float cur=0;
        for (Point point : members) {
            cur+=point.getDist()*point.getDist();
        }
        return cur;
    }
    public int getId() {
        return id;
    }

    public Point getCenter() {
        return center;
    }

    public void setCenter(Point center) {
        this.center = center;
    }

    public List<Point> getMembers() {
        return members;
    }

    @Override
    public String toString() {

        String toString = "-----------Cluster"+this.getId()+"---------\n";

         toString+="Mid_Point:  "+center+"  Points_num:  "+members.size();
        for (Point point : members) {
            toString+="\n"+point.toString();
        }
        return toString+"\n";
    }
}
public class datahandler {
    public static List<float[]> readTxt(String fileName){
        List<float[]> list=new ArrayList<>();
        try {
            File filename = new File(fileName); // 读取input.txt文件
            InputStreamReader reader = new InputStreamReader(
                    new FileInputStream(filename)); // 建立一个输入流对象reader
            BufferedReader br = new BufferedReader(reader);
            String line = "";
            line = br.readLine();

            while (true) {


                line = br.readLine();
                if(line==null) break;
                String[] temp=line.split(",");
                float[] c=new float[temp.length];
              for(int i=0;i<temp.length;i++)
              {
                  c[i]=Float.parseFloat(temp[i]);
              }

             list.add(c);
            }

        } catch (Exception e) {
            e.printStackTrace();
        }
        return list;
    }
    public static void writeTxt(String content){
        try { // 防止文件建立或读取失败,用catch捕捉错误并打印,也可以throw

            /* 读入TXT文件 */

            File writename = new File("src/k/output.txt"); // 相对路径,如果没有则要建立一个新的output。txt文件
            writename.createNewFile(); // 创建新文件
            BufferedWriter out = new BufferedWriter(new FileWriter(writename));
            out.write(content); // \r\n即为换行
            out.flush(); // 把缓存区内容压入文件
            out.close(); // 最后记得关闭文件


        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    public static void main(String[] args) {
/*        List<float[]> ret = readTxt("src/k/t2.txt");
        long s=System.currentTimeMillis();
        KMeansRun kRun = new KMeansRun(5, ret);
        Set<Cluster> clusterSet = kRun.run();
        System.out.println("K-means聚类算法运行时间:"+(System.currentTimeMillis()-s)+"ms");
        System.out.println("单次迭代运行次数:" + kRun.getIterTimes());
        StringBuilder stringBuilder=new StringBuilder();
        for (Cluster cluster : clusterSet) {
            System.out.println("Mid_Point:  "+cluster.getCenter()+" clusterId: "+cluster.getId()+"  Points_num:  "+cluster.getMembers().size());
           stringBuilder.append(cluster).append("\n");
        }
        writeTxt(stringBuilder.toString());*/
        List<float[]> ret = readTxt("src/k/t2.txt");
        XYSeries series = new XYSeries("xySeries");
        for (int x = 1; x < 20; x++) {
            KMeansRun kRun = new KMeansRun(x, ret);
            Set<Cluster> clusterSet = kRun.run();
            float y = 0;
            for (Cluster cluster : clusterSet)
            {
                y+=cluster.getdis();
            }

            series.add(x, y);
        }

        XYSeriesCollection dataset = new XYSeriesCollection();
        dataset.addSeries(series);
        JFreeChart chart = ChartFactory.createXYLineChart(
                "sum of the squared errors", // chart title
                "K", // x axis label
                "SSE", // y axis label
                dataset, // data
                PlotOrientation.VERTICAL,
                false, // include legend
                false, // tooltips
                false // urls
        );

        ChartFrame frame = new ChartFrame("my picture", chart);
        frame.pack();
        frame.setVisible(true);
        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);


    }
}
public class DistanceCompute {
    /**
     * 求欧式距离
     */
    public double getEuclideanDis(Point p1, Point p2) {
        double count_dis = 0;
        float[] p1_local_array = p1.getlocalArray();
        float[] p2_local_array = p2.getlocalArray();

        if (p1_local_array.length != p2_local_array.length) {
            throw new IllegalArgumentException("length of array must be equal!");
        }

        for (int i = 0; i < p1_local_array.length; i++) {
            count_dis += Math.pow(p1_local_array[i] - p2_local_array[i], 2);
        }

        return Math.sqrt(count_dis);
    }
}
import java.util.*;


    public class KMeansRun {
        private int kNum;                             //簇的个数
        private int iterNum = 200;                     //迭代次数

        private int iterMaxTimes = 100000;            //单次迭代最大运行次数
        private int iterRunTimes = 0;                 //单次迭代实际运行次数
        private float disDiff = (float) 0.01;         //单次迭代终止条件,两次运行中类中心的距离差

        private List<float[]> original_data =null;    //用于存放,原始数据集
        private static List<Point> pointList = null;  //用于存放,原始数据集所构建的点集
        private DistanceCompute disC = new DistanceCompute();
        private int len = 0;                          //用于记录每个数据点的维度

        public KMeansRun(int k, List<float[]> original_data) {
            this.kNum = k;
            this.original_data = original_data;
            this.len = original_data.get(0).length;
            //检查规范
            check();
            //初始化点集。
            init();
        }

        /**
         * 检查规范
         */
        private void check() {
            if (kNum == 0){
                throw new IllegalArgumentException("k must be the number > 0");
            }
            if (original_data == null){
                throw new IllegalArgumentException("program can't get real data");
            }
        }

        /**
         * 初始化数据集,把数组转化为Point类型。
         */
        private void init() {
            pointList = new ArrayList<Point>();
            for (int i = 0, j = original_data.size(); i < j; i++){
                pointList.add(new Point(i, original_data.get(i)));
            }
        }

        /**
         * 随机选取中心点,构建成中心类。
         */
        private Set<Cluster> chooseCenterCluster() {
            Set<Cluster> clusterSet = new HashSet<Cluster>();
            Random random = new Random();
            for (int id = 0; id < kNum; ) {
                Point point = pointList.get(random.nextInt(pointList.size()));
                // 用于标记是否已经选择过该数据。
                boolean flag =true;
                for (Cluster cluster : clusterSet) {
                    if (cluster.getCenter().equals(point)) {
                        flag = false;
                    }
                }
                // 如果随机选取的点没有被选中过,则生成一个cluster
                if (flag) {
                    Cluster cluster =new Cluster(id, point);
                    clusterSet.add(cluster);
                    id++;
                }
            }
            return clusterSet;
        }

        /**
         * 为每个点分配一个类!
         */
        public void cluster(Set<Cluster> clusterSet){
            // 计算每个点到K个中心的距离,并且为每个点标记类别号
            for (Point point : pointList) {
                float min_dis = Integer.MAX_VALUE;
                for (Cluster cluster : clusterSet) {
                    float tmp_dis = (float) Math.min(disC.getEuclideanDis(point, cluster.getCenter()), min_dis);
                    if (tmp_dis != min_dis) {
                        min_dis = tmp_dis;
                        point.setClusterId(cluster.getId());
                        point.setDist(min_dis);
                    }
                }
            }
            // 新清除原来所有的类中成员。把所有的点,分别加入每个类别
            for (Cluster cluster : clusterSet) {
                cluster.getMembers().clear();
                for (Point point : pointList) {
                    if (point.getClusterid()==cluster.getId()) {
                        cluster.addPoint(point);
                    }
                }
            }
        }

        /**
         * 计算每个类的中心位置!
         */
        public boolean calculateCenter(Set<Cluster> clusterSet) {
            boolean ifNeedIter = false;
            for (Cluster cluster : clusterSet) {
                List<Point> point_list = cluster.getMembers();
                float[] sumAll =new float[len];
                // 所有点,对应各个维度进行求和
                for (int i = 0; i < len; i++) {
                    for (int j = 0; j < point_list.size(); j++) {
                        sumAll[i] += point_list.get(j).getlocalArray()[i];
                    }
                }
                // 计算平均值
                for (int i = 0; i < sumAll.length; i++) {
                    sumAll[i] = (float) sumAll[i]/point_list.size();
                }
                // 计算两个新、旧中心的距离,如果任意一个类中心移动的距离大于dis_diff则继续迭代。
                if(disC.getEuclideanDis(cluster.getCenter(), new Point(sumAll)) > disDiff){
                    ifNeedIter = true;
                }
                // 设置新的类中心位置
                cluster.setCenter(new Point(sumAll));
            }
            return ifNeedIter;
        }

        /**
         * 运行 k-means
         */
        public Set<Cluster> run() {
            Set<Cluster> clusterSet= chooseCenterCluster();
            boolean ifNeedIter = true;
            while (ifNeedIter) {
                cluster(clusterSet);
                ifNeedIter = calculateCenter(clusterSet);
                iterRunTimes ++ ;
            }
            return clusterSet;
        }

        /**
         * 返回实际运行次数
         */
        public int getIterTimes() {
            return iterRunTimes;
        }
    }
public class Point {
    private float[] localArray;
    private int id;
    private int clusterId;  // 标识属于哪个类中心。
    private float dist;     // 标识和所属类中心的距离。

    public Point(int id, float[] localArray) {
        this.id = id;
        this.localArray = localArray;
    }

    public Point(float[] localArray) {
        this.id = -1; //表示不属于任意一个类
        this.localArray = localArray;
    }

    public float[] getlocalArray() {
        return localArray;
    }

    public int getId() {
        return id;
    }

    public void setClusterId(int clusterId) {
        this.clusterId = clusterId;
    }

    public int getClusterid() {
        return clusterId;
    }

    public float getDist() {
        return dist;
    }

    public void setDist(float dist) {
        this.dist = dist;
    }

    @Override
    public String toString() {
        String result = "Point_id=" + id + "  [";
        for (int i = 0; i < localArray.length; i++) {
            result += localArray[i] + " ";
        }
        return result.trim()+"] clusterId: "+clusterId;
    }

    @Override
    public boolean equals(Object obj) {
        if (obj == null || getClass() != obj.getClass())
            return false;

        Point point = (Point) obj;
        if (point.localArray.length != localArray.length)
            return false;

        for (int i = 0; i < localArray.length; i++) {
            if (Float.compare(point.localArray[i], localArray[i]) != 0) {
                return false;
            }
        }
        return true;
    }

    @Override
    public int hashCode() {
        float x = localArray[0];
        float y = localArray[localArray.length - 1];
        long temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L;
        int result = (int) (temp ^ (temp >>> 32));
        temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L;
        result = 31 * result + (int) (temp ^ (temp >>> 32));
        return result;
    }
}