java运用kmeans算法进行聚类
文章目录
- java运用kmeans算法进行聚类
- 一、Kmeans算法使用步骤
- 二、Java实现
- 1.准备工作
- 误差平方和的计算
- 需要用到的数据集
- 工具:eclipse及效果图演示
- 2.代码
- 3.使用weka验证
- 三、源码
一、Kmeans算法使用步骤
- 选出k值,随机出k个簇的中心点。
- 分别计算每个点和k个中心点之间的欧式距离,就近归类。
欧式距离计算方法如下: - 最终中心点集可以划分为k类,分别重新计算每类中新的中心点。
重新计算的中心点坐标即把当前簇内所有点的x,y坐标分别加起来取平均值
- 重复2,3步骤对所有点进行归类,如果当所有分类的质心点不再改变,则最终收敛。
二、Java实现
java代码参考这篇文章为了满足老师作业的要求进行了一定程度的添加和修改(主要改进在于对计算过程进行了记录以及后续误差平方和的计算)
1.准备工作
误差平方和的计算
需要用到的数据集
(我们在java代码导入数据集时需要先新建一个txt文件)
工具:eclipse及效果图演示
先看效果图:(首先导入文件,然后点击按钮后会显示每次迭代结果和sse,sse越小说明迭代效果越好)
2.代码
代码如下:
定义一个类表示每个点·的坐标
class Point{
String name;
double x,y;
public Point(String name,double x,double y){
this.name=name;
this.x=x;
this.y=y;
}
public Point(){
}
//计算两个坐标点之间的欧式距离
public static double distance(Point a,Point b){
return Math.sqrt((a.x-b.x)*(a.x-b.x)+(a.y-b.y)*(a.y-b.y));
}
//计算两个坐标点之间距离的平方
public static double squaredistance(Point a,Point b){
return (a.x-b.x)*(a.x-b.x)+(a.y-b.y)*(a.y-b.y);
}
//计算每个坐标点距离哪个簇最近
public static int ClusterDistance(Point p,ArrayList<Cluster> array){
int a=0;
for(int i=0;i<array.size();i++){
if(Point.distance(p,array.get(a).center)>Point.distance(p, array.get(i).center)){
a=i;
}
}
return a;
}
public static void PointAdd(ArrayList<Point> PointArray, ArrayList<double[]> DoubleArray){
for(int i=0;i<DoubleArray.size();i++){
PointArray.add(new Point("p"+(i+1),DoubleArray.get(i)[0],DoubleArray.get(i)[1]));
}
}
}
定义一个类表示每个簇
//定义一个类用于表示簇
class Cluster{
Point center=new Point();//簇中心点
ArrayList<Point> Array=new ArrayList<Point>(); //簇中的坐标元素
boolean changed=true;//用于判断该簇的中心点坐标是否发生变化
public Cluster(double x,double y){
center.x=x;
center.y=y;
}
//计算簇的新中心点坐标
public static void NewCenter(Cluster cluster){
double sumx=0,sumy=0;
int i=0;
for(Point ex:cluster.Array){
sumx+=ex.x;
sumy+=ex.y;
i++;
}
if(cluster.center.x!=sumx/i||cluster.center.y!=sumy/i){
cluster.center.x=sumx/i;
cluster.center.y=sumy/i;
cluster.changed=true;
}
else if(cluster.center.x==sumx/i&&cluster.center.y==sumy/i){
cluster.changed=false;
}
}
//判断所有簇的中心点是否不再发生变化
public static boolean Changing(ArrayList<Cluster> array){
boolean ex=false;
for(Cluster e:array){
if(e.changed==true){
ex=true;
}
}
return ex;
}
}
随机选择初始簇中心的类
//用于获取1~n之间的t个不重复的整数
class GetRandom{
@SuppressWarnings("unchecked")
static public int[] fn(int n,int t)
{
ArrayList numbers=new ArrayList();
int[] rtnumbers=new int[t];
for(int i=0;i<n;i++){ //初始化数组
numbers.add(i+1);
}
for(int j=0;j<t;j++){
int raNum=(int)(Math.random()*numbers.size());
rtnumbers[j]=Integer.parseInt(numbers.get(raNum).toString());
numbers.remove(raNum);
}
return rtnumbers;
}
}
另外定义一个类来声明需要使用到的全局变量
class Pub{
static StringBuffer str=new StringBuffer();//存放textArea的内容
static StringBuffer str1=new StringBuffer();//存放textArea1的内容
static ArrayList<Point> TotalArray=new ArrayList<Point>(); //定义用于存放坐标点的泛型数组
static ArrayList<Cluster> ClusterArray=new ArrayList<Cluster>(); //定义用于存放各个簇的点的泛型数组
static ArrayList<Color> ColorArray=new ArrayList<Color>();
}
绘制可视化窗口相关代码
class DrawFrame extends JFrame
{
private JFileChooser chooser=new JFileChooser();
JTextArea textArea=new JTextArea();
JTextArea textArea1=new JTextArea();
JTextField testField=new JTextField(5);
public DrawFrame()
{
setTitle("Kmeans");
setSize(700,700);
setLayout(null);
//-----------Menu---------------/
JMenuBar menuBar=new JMenuBar();//菜单栏
setJMenuBar(menuBar);
JMenu openMenu=new JMenu("File");//菜单对象
menuBar.add(openMenu);
JMenuItem FileOpen=new JMenuItem("Open File..."); //菜单项
openMenu.add(FileOpen);
FileOpen.addActionListener(new FileOpenListener());
//---------TitlePanel-----------/
JLabel TitleLabel=new JLabel("k-means聚类算法模拟");
TitleLabel.setFont(new Font("Dialog",1,22));
add(TitleLabel);
TitleLabel.setBounds(250, -10, 300, 60);
//----------kPanel--------------/
JPanel kPanel=new JPanel();
JLabel label=new JLabel("请输入k值(必须为整数且不大于数据集中点的个数):");
// JLabel label=new JLabel("请输入k值(必须为整数)(1<=k" + Pub.ClusterArray.size() + "):");
JButton testButton=new JButton("开始模拟");
testButton.addActionListener(new ButtonListener());
kPanel.add(label);
kPanel.add(testField);
kPanel.add(testButton);
add(kPanel);
kPanel.setBounds(80,50,500,35);
//-----------textPanel1--------------/
JScrollPane scrollPane=new JScrollPane(textArea);
add(scrollPane);
scrollPane.setBounds(350,100,300,400);
//-----------textPanel2--------------/
JScrollPane scrollPane2=new JScrollPane(textArea1);
add(scrollPane2);
scrollPane2.setBounds(60,100,250,400);
}
文件及测试按钮的监听
//选择文件监听事件
private class FileOpenListener implements ActionListener{
public void actionPerformed(ActionEvent event){
chooser.setCurrentDirectory(new File("D://"));
int result=chooser.showOpenDialog(DrawFrame.this);
if(result==JFileChooser.APPROVE_OPTION){
String FilePath=chooser.getSelectedFile().getPath();
ArrayList<double[]> test=new takenumber().getPoint(FilePath);
//先把以前点集中的数据清空
Pub.TotalArray.clear();
new Point().PointAdd(Pub.TotalArray, test);
//把str1中存的字符串清空
Pub.str1.delete(0, Pub.str1.length());
//把从文件中获取的点坐标转化为字符串
for(Point ex:Pub.TotalArray){
Pub.str1.append(" "+ex.name+" ("+ex.x+","+ex.y+")\n");
}
textArea1.setText(Pub.str1.toString());
textArea1.setFont(new Font("Serif",0,16));
textArea.setText("");
textArea.setEditable(false);
textArea1.setEditable(false);
Pub.ClusterArray.clear();
}
}
}
//测试按钮监听事件
private class ButtonListener implements ActionListener{
public void actionPerformed(ActionEvent event){
if(testField.getText().equals("")||Integer.parseInt(testField.getText())==0){
Pub.str.append("k值不能为空或者为0!请输入一个值!");
textArea.setText(Pub.str.toString());
Pub.str.delete(0, Pub.str.length());
}
else if(Integer.parseInt(testField.getText())>Pub.TotalArray.size()){
Pub.str.append("请输入正确的k值!");
textArea.setText(Pub.str.toString());
// textArea.setText(String.valueOf(Pub.TotalArray.size()));
Pub.str.delete(0, Pub.str.length());
System.out.print(Pub.ClusterArray.size());
}
else{
int k=Integer.parseInt(testField.getText());
//在所有元素的数组里随机选择k个坐标点用于代表初始簇中心
int ramdon[]=new GetRandom().fn(Pub.TotalArray.size(), k);
Pub.ClusterArray.clear();
//将随机选择的点加入到簇中
for(int i=0;i<k;i++){
Pub.ClusterArray.add(new Cluster(Pub.TotalArray.get(ramdon[i]-1).x,Pub.TotalArray.get(ramdon[i]-1).y));
}
Pub.str.append("初始随机选择的"+k+"个簇中心点的坐标为:\n");
for(Cluster ex:Pub.ClusterArray){
Pub.str.append("("+ex.center.x+","+ex.center.y+")"+"\n");
}
//将点分到不同的簇中
for(Point ex:Pub.TotalArray){
Pub.ClusterArray.get(Point.ClusterDistance(ex,Pub.ClusterArray)).Array.add(ex);
}
/*
//计算每一个簇的中心点
if(Pub.TotalArray.size()!=k||k==1){
for(Cluster ex:Pub.ClusterArray){
Cluster.NewCenter(ex);
}}*/
int m=1;
ArrayList<Point> PointArray=new ArrayList<Point>();//用于存放要删除的点的数组
while(Cluster.Changing(Pub.ClusterArray)!=false){
Pub.str.append("第"+(m++)+"次迭代的结果为:\n");
int ii=1;
for(int i = 1; i <= Pub.ClusterArray.size(); i++) {
Pub.str.append(" "+"C"+i+"("+String.format("%.3f", Pub.ClusterArray.get(i-1).center.x)+","+String.format("%.3f", Pub.ClusterArray.get(i-1).center.y)+")"+ " ");
}
Pub.str.append("\n");
for(int i = 0; i < Pub.TotalArray.size(); i++) {
Point point1 = Pub.TotalArray.get(i);
Pub.str.append("("+ String.format("%.2f",point1.x)+","+String.format("%.2f",point1.y)+")");
for(int j=0; j<Pub.ClusterArray.size(); j++) {
Point point2 = Pub.ClusterArray.get(j).center;
String d1 = String.format("%.2f",Point.distance(point1, point2));
Pub.str.append(" "+ d1 + " ");
}
Pub.str.append("\n");
}
for(Cluster ex:Pub.ClusterArray){
Cluster.NewCenter(ex);
}
for(Cluster ex:Pub.ClusterArray){
Pub.str.append("第"+(ii++)+"个簇的所有点为:\n");
for(Point p:ex.Array){
Pub.str.append(p.name+"("+p.x+","+p.y+")"+"\n");
}
Pub.str.append("新的中心点坐标为:"+"("+String.format("%.3f",ex.center.x)+","+String.format("%.3f",ex.center.y)+")"+"\n");
Pub.str.append("---------------------------------------------------------------\n");
}
for(int j=0;j<k;j++){
for(Point t:Pub.ClusterArray.get(j).Array){
if(Point.ClusterDistance(t,Pub.ClusterArray)!=j){
Pub.ClusterArray.get(Point.ClusterDistance(t, Pub.ClusterArray)).Array.add(t);
PointArray.add(t);
}
}
Pub.ClusterArray.get(j).Array.removeAll(PointArray);
PointArray.clear();
}
/*
for(Cluster ex:Pub.ClusterArray){
Cluster.NewCenter(ex);
} */
}
double sse = 0;
System.out.print(Pub.ClusterArray.size());
for(int i=0; i < Pub.ClusterArray.size(); i++) {
Point point1 = Pub.ClusterArray.get(i).center;
for(int j=0;j<Pub.ClusterArray.get(i).Array.size();j++) {
Point point2 = Pub.ClusterArray.get(i).Array.get(j);
// double squarediatance = Point.distance(point1, point2);
sse+=Point.squaredistance(point1, point2);
// sse+=Point.distance(point1, point2);
// System.out.print(Pub.ClusterArray.get(j).Array.size());
}
}
Pub.str.append("平方和误差为:" + String.format("%.2f",sse));
textArea.setText(Pub.str.toString());
textArea.setEditable(false);
Pub.str.delete(0, Pub.str.length());
}
}
}
}
还有一个从文件中获取数据时需要用到的类:
public class takenumber {
public static void main(String args[]){
}
public ArrayList<double[]> getPoint(String path){
File filePath = new File(path);
Scanner scanner;
try {
scanner = new Scanner(filePath);
ArrayList<String> StringArray=new ArrayList<String>();
while(scanner.hasNextLine()){
StringArray.add(scanner.nextLine());
}
scanner.close();
ArrayList<double[]> DoubleArray=new ArrayList<double[]>();
for(int i=0;i<StringArray.size();i++){
String []point=StringArray.get(i).split(",");
double[] ex={Double.parseDouble(point[0]),Double.parseDouble(point[1])};
DoubleArray.add(i,ex);
}
return DoubleArray;
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
return null;
}
}
}
3.使用weka验证
为了验证结果的准确性,我又使用weka进行了依次计算,得到了一个与我计算相同的结果
(weka聚类时初始点也是随机的,我们可以通过改变seed值改变初始随机点)
三、源码
源文件如下链接