学习《机器学习实战》,对python语言不怎么熟悉,决定一段程序一段程序来学习,既学习算法,也顺便学习python的基础知识。最后,我将把python代码用Java重写一遍。

第一个算法kNNk-邻近算法)

这个算法的理论很简单,很容易理解。如果学过KMeans聚类算法,那么学这个算法会感觉更简单。

我对这个算法的过程理解如下:

第一步:把所有的训练集读入到内存中,这也是这个算法为什么会有空间复杂度高的原因了。

第二步:读入待分类的向量(如果是文本,要处理成向量的方式,VSM模型在这里起作用了)

第三步:计算待分类向量到所有训练集的距离。(既然是向量计算距离,一般用欧式距离就OK了)

第四步:对距离进行从小到大排序,取前k个训练集的Label

第五步:对前K个训练集的Label进行统计。把待分类向量分到Label个数最多的那一个类别。

第六步:算法结束。


学习了过程,再来学习代码的实现。

算法的过程了解后,就很容易得出需要输入的参数:待分类文本,训练集,训练集标签,k值。输出的参数:待分类文本的分类标签。

输入输出解决了的话,至少解决了问题的三分之一。

def classify(inX,dataSet,labels,k):
    //
    dataSetSize=dataSet.shape[0]
    diffMat=tile(inX,(dataSetSize,1))-dataSet
    sqDiffMat=diffMat**2
    sqDistances=sqDiffMat.sum(axis=1)
    distances=sqDistances**0.5
    sortedDistIndicies=distances.argsort();
    classCount={}
    for i in range(k):
        voteIlabel=labels[sortedDistIndicies[i]]
        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
    sortedClassCount=sorted(classCount.iteritems(),
                            key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

python版的翻译成java版的:

package com.vancl.knn;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
public class KNN {
    /*
     * @param inX 待分类的文本
     * @param dataSet 训练集
     * @param labels 训练集的分类标签
     * @param k值
     * @return 分类器得到的分类标签
     * */
    public char classify(double[] inX,double[][] dataSet,char[] labels,int k){
        //对应python 的dataSet.shape()[0]
        int dataSetSize=dataSet.length;
        //对应python 的tile(inX,(dataSetSize,1))-dataSet
        //和diffMat**2 两行代码
        double[][] sqDiffMat=createDiffMat(inX,dataSet,dataSetSize);
        //对应python 的sqdiffMat.sum(axis=1);
        //和distances=sqDistances**0.5两行代码
        double[] distances=sum(sqDiffMat);
                       
        Node[] disNode=new Node[distances.length];
        for(int i=0;i<distances.length;i++){
            Node node=new Node(distances[i],i);
            disNode[i]=node;
        }
        //对应python中 sortedDistaIndicies=distances.argsort(),排序得到下标
        Arrays.sort(disNode,new KNNCompartor());
        //选择距离最小的k个点
        //对应pyhton的classCount={}
        Map<Character,Integer> classCount=new HashMap<Character,Integer>();
                       
        char voteLabel;
        //对应python的 for i in range(k):
        for(int i=0;i<k;i++){
            //对应python的voteIlabel=labels[sortedDistIndicies[i]]
            voteLabel=labels[disNode[i].idx];
            //对应 classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
            add(voteLabel,classCount);
        }
        //sortedClassCount=sorted(classCount.iteritems(),
        //       key=operator.itemgetter(1),reverse=True)
        ArrayList<Map.Entry<Character,Integer>> l = new ArrayList<Map.Entry<Character,Integer>>(classCount.entrySet()); 
        Collections.sort(l,new Comparator<Map.Entry<Character,Integer>>(){
            @Override
            public int compare(Entry<Character, Integer> o1,
                    Entry<Character, Integer> o2) {
                               
                return o2.getValue()-o1.getValue();
            }
        });
        //对应 return sortedClassCount[0][0]
        return l.get(0).getKey();
    }
                   
    public void add(char voteLabel,Map<Character,Integer> classCount){
        Integer id=classCount.get(voteLabel);
        if(id==null) id=0;
        classCount.put(voteLabel, id+1);
    }
                   
    private double[] sum(double[][] sqDiffMat) {
        int i,j;
        double[] sqDistances=new double[sqDiffMat.length];
        for(i=0;i<sqDiffMat.length;i++){
            sqDistances[i]=0;
            for(j=0;j<sqDiffMat[i].length;j++){
                sqDistances[i]+=sqDiffMat[i][j];
            }
            sqDistances[i]=Math.sqrt(sqDistances[i]);
        }
        return sqDistances;
    }
    private double[][] createDiffMat(double[] inX, double[][] dataSet,int dataSetSize) {
        double[][] diffMat=new double[dataSetSize][inX.length];
        for(int i=0;i<dataSetSize;i++){
            System.arraycopy(inX, 0, diffMat[i], 0, inX.length);
            for(int j=0;j<inX.length;j++){
                diffMat[i][j]=diffMat[i][j]-dataSet[i][j];
                diffMat[i][j]=Math.pow(diffMat[i][j], 2);
            }
        }
                           
        return diffMat;
    }
    class Node{
        public Node(double value, int idx) {
            super();
            this.value = value;
            this.idx = idx;
        }
        double value;
        int idx;
    }
    class KNNCompartor implements Comparator<Node>{
        @Override
        public int compare(Node o1, Node o2) {
            return Double.compare(o1.value, o2.value);
        }
                       
    }
    public static void main(String[] args) {
        KNN knn=new KNN();
        double[] inX={1,1};
        double[][] dataSet ={{1,1.1},{1,1},{0,0},{0,0.1}};
        char[]labels={'A','A','B','B'};
        char rs=knn.classify(inX, dataSet, labels,3 );
        System.out.println(rs);
    }
}

至此,明白为什么机器学习的书籍为什么大都选择python,而不选择java了。