K-中心点聚类算法
(1)任意选择k个对象作为初始的簇中心点
(2)指派每个剩余对象给离他最近的中心点所表示的簇
(3)选择一个未被选择的中心点直到所有的中心点都被选择过
(4)选择一个未被选择过的非中心点对象,计算用代替的总代价并记录在S中
,直到所有非中心点都被选择过。
(5)如果在S中的所有非中心点代替所有中心点后的计算出总代价有小于0的存在,然后找出S中的用非中心点替代中心点后代价最小的一个,并用该非中心点替代对应的中心点,形成一个新的k个中心点的集合
(6)重复步骤2-5,直到没有再发生簇的重新分配,即所有的S都大于0.
代码
public class Cluster {
private int id;// 标识
private Point center;// 中心
private List<Point> members = new ArrayList<Point>();// 成员
public Cluster(int id, Point center) {
this.id = id;
this.center = center;
}
public Cluster(int id, Point center, List<Point> members) {
this.id = id;
this.center = center;
this.members = members;
}
public void addPoint(Point newPoint) {
if (!members.contains(newPoint)){
members.add(newPoint);
}else{
System.out.println("样本数据点 {"+newPoint.toString()+"} 已经存在!");
}
}
public float getdis() {
float cur=0;
for (Point point : members) {
cur+=point.getDist()*point.getDist();
}
return cur;
}
public int getId() {
return id;
}
public Point getCenter() {
return center;
}
public void setCenter(Point center) {
this.center = center;
}
public List<Point> getMembers() {
return members;
}
@Override
public String toString() {
String toString = "-----------Cluster"+this.getId()+"---------\n";
toString+="Mid_Point: "+center+" Points_num: "+members.size();
for (Point point : members) {
toString+="\n"+point.toString();
}
return toString+"\n";
}
}
public class datahandler {
public static List<float[]> readTxt(String fileName){
List<float[]> list=new ArrayList<>();
try {
File filename = new File(fileName); // 读取input.txt文件
InputStreamReader reader = new InputStreamReader(
new FileInputStream(filename)); // 建立一个输入流对象reader
BufferedReader br = new BufferedReader(reader);
String line = "";
line = br.readLine();
while (true) {
line = br.readLine();
if(line==null) break;
String[] temp=line.split(",");
float[] c=new float[temp.length];
for(int i=0;i<temp.length;i++)
{
c[i]=Float.parseFloat(temp[i]);
}
list.add(c);
}
} catch (Exception e) {
e.printStackTrace();
}
return list;
}
public static void writeTxt(String content){
try { // 防止文件建立或读取失败,用catch捕捉错误并打印,也可以throw
/* 读入TXT文件 */
File writename = new File("src/k/output.txt"); // 相对路径,如果没有则要建立一个新的output。txt文件
writename.createNewFile(); // 创建新文件
BufferedWriter out = new BufferedWriter(new FileWriter(writename));
out.write(content); // \r\n即为换行
out.flush(); // 把缓存区内容压入文件
out.close(); // 最后记得关闭文件
} catch (Exception e) {
e.printStackTrace();
}
}
public static void main(String[] args) {
/* List<float[]> ret = readTxt("src/k/t2.txt");
long s=System.currentTimeMillis();
KMeansRun kRun = new KMeansRun(5, ret);
Set<Cluster> clusterSet = kRun.run();
System.out.println("K-means聚类算法运行时间:"+(System.currentTimeMillis()-s)+"ms");
System.out.println("单次迭代运行次数:" + kRun.getIterTimes());
StringBuilder stringBuilder=new StringBuilder();
for (Cluster cluster : clusterSet) {
System.out.println("Mid_Point: "+cluster.getCenter()+" clusterId: "+cluster.getId()+" Points_num: "+cluster.getMembers().size());
stringBuilder.append(cluster).append("\n");
}
writeTxt(stringBuilder.toString());*/
List<float[]> ret = readTxt("src/k/t2.txt");
XYSeries series = new XYSeries("xySeries");
for (int x = 1; x < 20; x++) {
KMeansRun kRun = new KMeansRun(x, ret);
Set<Cluster> clusterSet = kRun.run();
float y = 0;
for (Cluster cluster : clusterSet)
{
y+=cluster.getdis();
}
series.add(x, y);
}
XYSeriesCollection dataset = new XYSeriesCollection();
dataset.addSeries(series);
JFreeChart chart = ChartFactory.createXYLineChart(
"sum of the squared errors", // chart title
"K", // x axis label
"SSE", // y axis label
dataset, // data
PlotOrientation.VERTICAL,
false, // include legend
false, // tooltips
false // urls
);
ChartFrame frame = new ChartFrame("my picture", chart);
frame.pack();
frame.setVisible(true);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
}
}
public class DistanceCompute {
/**
* 求欧式距离
*/
public double getEuclideanDis(Point p1, Point p2) {
double count_dis = 0;
float[] p1_local_array = p1.getlocalArray();
float[] p2_local_array = p2.getlocalArray();
if (p1_local_array.length != p2_local_array.length) {
throw new IllegalArgumentException("length of array must be equal!");
}
for (int i = 0; i < p1_local_array.length; i++) {
count_dis += Math.pow(p1_local_array[i] - p2_local_array[i], 2);
}
return Math.sqrt(count_dis);
}
}
import java.util.*;
public class KMeansRun {
private int kNum; //簇的个数
private int iterNum = 200; //迭代次数
private int iterMaxTimes = 100000; //单次迭代最大运行次数
private int iterRunTimes = 0; //单次迭代实际运行次数
private float disDiff = (float) 0.01; //单次迭代终止条件,两次运行中类中心的距离差
private List<float[]> original_data =null; //用于存放,原始数据集
private static List<Point> pointList = null; //用于存放,原始数据集所构建的点集
private DistanceCompute disC = new DistanceCompute();
private int len = 0; //用于记录每个数据点的维度
public KMeansRun(int k, List<float[]> original_data) {
this.kNum = k;
this.original_data = original_data;
this.len = original_data.get(0).length;
//检查规范
check();
//初始化点集。
init();
}
/**
* 检查规范
*/
private void check() {
if (kNum == 0){
throw new IllegalArgumentException("k must be the number > 0");
}
if (original_data == null){
throw new IllegalArgumentException("program can't get real data");
}
}
/**
* 初始化数据集,把数组转化为Point类型。
*/
private void init() {
pointList = new ArrayList<Point>();
for (int i = 0, j = original_data.size(); i < j; i++){
pointList.add(new Point(i, original_data.get(i)));
}
}
/**
* 随机选取中心点,构建成中心类。
*/
private Set<Cluster> chooseCenterCluster() {
Set<Cluster> clusterSet = new HashSet<Cluster>();
Random random = new Random();
for (int id = 0; id < kNum; ) {
Point point = pointList.get(random.nextInt(pointList.size()));
// 用于标记是否已经选择过该数据。
boolean flag =true;
for (Cluster cluster : clusterSet) {
if (cluster.getCenter().equals(point)) {
flag = false;
}
}
// 如果随机选取的点没有被选中过,则生成一个cluster
if (flag) {
Cluster cluster =new Cluster(id, point);
clusterSet.add(cluster);
id++;
}
}
return clusterSet;
}
/**
* 为每个点分配一个类!
*/
public void cluster(Set<Cluster> clusterSet){
// 计算每个点到K个中心的距离,并且为每个点标记类别号
for (Point point : pointList) {
float min_dis = Integer.MAX_VALUE;
for (Cluster cluster : clusterSet) {
float tmp_dis = (float) Math.min(disC.getEuclideanDis(point, cluster.getCenter()), min_dis);
if (tmp_dis != min_dis) {
min_dis = tmp_dis;
point.setClusterId(cluster.getId());
point.setDist(min_dis);
}
}
}
// 新清除原来所有的类中成员。把所有的点,分别加入每个类别
for (Cluster cluster : clusterSet) {
cluster.getMembers().clear();
for (Point point : pointList) {
if (point.getClusterid()==cluster.getId()) {
cluster.addPoint(point);
}
}
}
}
/**
* 计算每个类的中心位置!
*/
public boolean calculateCenter(Set<Cluster> clusterSet) {
boolean ifNeedIter = false;
for (Cluster cluster : clusterSet) {
List<Point> point_list = cluster.getMembers();
float[] sumAll =new float[len];
// 所有点,对应各个维度进行求和
for (int i = 0; i < len; i++) {
for (int j = 0; j < point_list.size(); j++) {
sumAll[i] += point_list.get(j).getlocalArray()[i];
}
}
// 计算平均值
for (int i = 0; i < sumAll.length; i++) {
sumAll[i] = (float) sumAll[i]/point_list.size();
}
// 计算两个新、旧中心的距离,如果任意一个类中心移动的距离大于dis_diff则继续迭代。
if(disC.getEuclideanDis(cluster.getCenter(), new Point(sumAll)) > disDiff){
ifNeedIter = true;
}
// 设置新的类中心位置
cluster.setCenter(new Point(sumAll));
}
return ifNeedIter;
}
/**
* 运行 k-means
*/
public Set<Cluster> run() {
Set<Cluster> clusterSet= chooseCenterCluster();
boolean ifNeedIter = true;
while (ifNeedIter) {
cluster(clusterSet);
ifNeedIter = calculateCenter(clusterSet);
iterRunTimes ++ ;
}
return clusterSet;
}
/**
* 返回实际运行次数
*/
public int getIterTimes() {
return iterRunTimes;
}
}
public class Point {
private float[] localArray;
private int id;
private int clusterId; // 标识属于哪个类中心。
private float dist; // 标识和所属类中心的距离。
public Point(int id, float[] localArray) {
this.id = id;
this.localArray = localArray;
}
public Point(float[] localArray) {
this.id = -1; //表示不属于任意一个类
this.localArray = localArray;
}
public float[] getlocalArray() {
return localArray;
}
public int getId() {
return id;
}
public void setClusterId(int clusterId) {
this.clusterId = clusterId;
}
public int getClusterid() {
return clusterId;
}
public float getDist() {
return dist;
}
public void setDist(float dist) {
this.dist = dist;
}
@Override
public String toString() {
String result = "Point_id=" + id + " [";
for (int i = 0; i < localArray.length; i++) {
result += localArray[i] + " ";
}
return result.trim()+"] clusterId: "+clusterId;
}
@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass())
return false;
Point point = (Point) obj;
if (point.localArray.length != localArray.length)
return false;
for (int i = 0; i < localArray.length; i++) {
if (Float.compare(point.localArray[i], localArray[i]) != 0) {
return false;
}
}
return true;
}
@Override
public int hashCode() {
float x = localArray[0];
float y = localArray[localArray.length - 1];
long temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L;
int result = (int) (temp ^ (temp >>> 32));
temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L;
result = 31 * result + (int) (temp ^ (temp >>> 32));
return result;
}
}