ID3决策树也是决策树的一种,其作用在于根据已有数据训练决策树,并通过决策树的分支实现对新数据的分类,是一种有监督的学习。
在生成决策树的过程中,ID3使用的信息熵增益对子节点类别进行确定。根据信息熵越是有序的数据熵值越低,信息熵增益越大表示当前属性对于数据的分类结果越好。
信息熵计算公式:


Info=−∑i=1nP(xi)∗log2P(xi)

信息增益:

Gain(A)=Info(D)−InfoA(D)

where

InfoAD=∑j=1vDjD×Info(Dj)

增益率:

SplitInfoA(D)=−∑j=1v|Dj||D|×log2(|Dj||D|)


Gainratio(A)=Gain(A)SplitInfo(A)



同样作为分类决策树的cart和ID4.5决策树分别采用了基尼系数和信息增益率作为确定子节点属性的判别方法。

使用平台:eclipse
实验数据:人工数据
相关程序:
package mytree;

import java.io.IOException;

public class Mytreemain {
    public static void main(String[] args) throws IOException {
        InputStringData ori=new InputStringData();
        String[][] data=ori.loadData("watermelon.txt");
        String[] feature=data[0];
        for(int i=0;i<feature.length;i++){
        }
        String[][]attribute=new String[data.length-1][data[0].length-1];
        for(int i=0;i<data.length-1;i++){
            attribute[i]=data[i+1];
            for(int j=0;j<attribute[0].length;j++){
            }
        }
        String[][]blank=new String[1][1];
        blank[0][0]="null";

        generateTree newTree=new generateTree(feature,attribute);
        newTree.node(feature, attribute,"start","null",blank);
        System.out.println("此例子到此结束");
        System.out.println(" ");
        String[][]trainattribute=new String[10][attribute.length];
        trainattribute[0]=attribute[0];
        trainattribute[1]=attribute[1];
        trainattribute[2]=attribute[2];
        trainattribute[3]=attribute[5];
        trainattribute[4]=attribute[6];
        trainattribute[5]=attribute[9];
        trainattribute[6]=attribute[13];
        trainattribute[7]=attribute[14];
        trainattribute[8]=attribute[15];
        trainattribute[9]=attribute[16];
        String[][]testattribute=new String[7][attribute.length];
        testattribute[0]=attribute[3];
        testattribute[1]=attribute[4];
        testattribute[2]=attribute[7];
        testattribute[3]=attribute[8];
        testattribute[4]=attribute[10];
        testattribute[5]=attribute[11];
        testattribute[6]=attribute[12];
        generateTree newTree1=new generateTree(feature,attribute);
        newTree1.node(feature, trainattribute,"start","null",testattribute);
    }
}


package mytree;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

public class Tree {
    private String[] feature;
    public Tree(String[]a){
        feature=a;
    }
    public String[]getvalue(String[][]t,String n){
        int index=-1;
        for(int i=0;i<feature.length;i++){
            if (feature[i].equals(n)){
            index=i;
            break;
            }
        }
        if (index==-1){
            System.out.print("error");
        }
            Set<String> temp=new HashSet<String>();
            for (int i=0;i<t.length;i++){
                temp.add(t[i][index]);
                }
            Iterator<String> set=temp.iterator();
            String[]value=new String[temp.size()];
            for(int i=0;i<value.length;i++){
                value[i]=set.next();
            }
        return value;
    }
    public String[][]subattr(String[][]b,String n,String t){
        int index=-1;
        String[][]subattr=new String[1][1];
        boolean smart=false;
        boolean handsome=false;
        for(int i=0;i<feature.length;i++){
            if(feature[i].equals(n)){
                index=i;
                smart=true;
                break;
            }
        }
        ArrayList<String[]> submat=new ArrayList<String[]>();
        for (int i=0;i<b.length;i++){
            if(b[i][index].equals(t)){
                handsome=true;
                submat.add(b[i]);
            }
        }
        if(smart&&handsome){
            String[][] subattr1=new String[submat.size()][b[0].length];
            for(int i=0;i<submat.size();i++){
            subattr1[i]=submat.get(i);
            }
            return subattr1;
        }
        if(!smart||!handsome){
            subattr[0][0]="null";
        }
        System.out.println(111);
        return subattr;
    }
    public double[]getRatio(String[][]b,String t,String[] n){
        int index=-1;
        for(int i=0;i<feature.length;i++){
            if(feature[i].equals(t)){
                index=i;
                break;
            }
        }
        double[] count=new double[n.length];
        for(int i=0;i<n.length;i++){
            for(int j=0;j<b.length;j++){
                if(b[j][index].equals(n[i])){
                    count[i]++;
                }
            }
        }
        Arith ari=new Arith();
        double[]getRatio=new double[n.length];
        for(int i=0;i<n.length;i++){
            getRatio[i]=ari.div(count[i],(double)b.length);
        }
        return getRatio;
    }

    public boolean isSame(String[][]a,String b){
        boolean isSame=true;
        int j=1;
        int index=-1;
        for(int i=0;i<feature.length;i++){
            if(feature[i].equals(b)){
                index=i;
                break;
            }
        }
        String ori=a[0][index];
        while(isSame==true&&(a.length-j)>0){
            if(!a[j][index].equals(ori)){
                isSame=false;
            }
            j++;
        }
        return isSame;
    }
    public boolean belongs(String thisnode, String[] orinode) {
        boolean belongs=false;
        for(int i=0;i<orinode.length;i++){
            if(thisnode.equals(orinode[i])){
                belongs=true;
            };
        }
        return belongs;
    }
}



package mytree;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Scanner;

public class InputStringData {
           int countRow=0,countCol=0;
    public  String[][] loadData(String trainfile)throws IOException{
           ArrayList<String>features = new ArrayList<String>();
           File file = new File("C:\\Users\\CJH\\Desktop\\R程序运行",trainfile);
           Scanner input1 = new Scanner(file);
           while(input1.hasNext()){
               String line = input1.nextLine();
               Scanner input2 = new Scanner(line);
               while(input2.hasNext()){
               features.add(input2.next());
               countCol++;
               }
               countRow++;input2.close();
           }
           countCol=countCol/countRow;
           input1.close();
           String [][]x = new String[countRow][countCol];
           int index=0;
           for(int i=0;i<countRow;i++){
               for(int j=0;j<countCol;j++){
                   x[i][j]=features.get(index);
                   index++;
               }
           }
   return x;
}
}




package mytree;

import java.math.BigDecimal;

public class Arith{
private static final int DEF_DIV_SCALE=10;

          public double add(double v1,double v2){
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              return b1.add(b2).doubleValue();
              }
          public double sub(double v1,double v2){
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              return b1.subtract(b2).doubleValue();
              }
          public double mul(double v1,double v2){
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              return b1.multiply(b2).doubleValue();
              }
          public double div(double v1,double v2){
              return div(v1,v2,DEF_DIV_SCALE);
              }
          public double div(double v1,double v2,int scale){
                  if(scale<0){
                      throw new IllegalArgumentException(
                              "The scale must be a positive integer or zero");
                      }
                  BigDecimal b1=new BigDecimal(Double.toString(v1));
                  BigDecimal b2=new BigDecimal(Double.toString(v2));
                  return b1.divide(b2,scale,BigDecimal.ROUND_HALF_UP).doubleValue();
                  }
          public double mul(double v1,double v2,int scale){
              if(scale<0){
                  throw new IllegalArgumentException(
                          "The scale must be a positive integer or zero");
                  }
              BigDecimal b1=new BigDecimal(Double.toString(v1));
              BigDecimal b2=new BigDecimal(Double.toString(v2));
              if(v1!=0&&v2!=0){
                  BigDecimal b3=new BigDecimal(Double.toString(1));
                  BigDecimal b4=new BigDecimal(b3.divide(b2,scale,BigDecimal.ROUND_HALF_UP).doubleValue());
                  return b1.divide(b4,scale,BigDecimal.ROUND_HALF_UP).doubleValue();
              }
              else{
                  return 0;
              }
              }
              public double round(double v,int scale){
                  if(scale<0){
                      throw new IllegalArgumentException(
                              "The scale must be a positive integer or zero");
                      }
                  BigDecimal b=new BigDecimal(Double.toString(v));
                  BigDecimal one=new BigDecimal("1");
                  return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue();
                  }
              }



package mytree;


public class generateTree {
    private String[] feature;
    private String[][]attribute;
    private String[]subfeature;
    private String[][]splitattribute;
    private String[][]attributeofchild;
    private String valueoffather;
    private String thisnode;
    private String[][]splittest;


    public generateTree(String[] a,String[][]b){
        feature=a;
        attribute=b;
    }
    public void node(String[] a,String[][] b,String value,String p,String[][]test){
        Tree mytree=new Tree(feature);
        subfeature=a;
        splitattribute=b;
        valueoffather=value;
        Arith ari=new Arith();
        String[] orinode=mytree.getvalue(splitattribute, feature[0]);
        double[]nodeori=mytree.getRatio(splitattribute, feature[0], orinode);
        int ori=0;
        double stupid=nodeori[0];
        for(int i=1;i<nodeori.length;i++){
            if(nodeori[i]>stupid){
                ori=i;
                stupid=nodeori[i];
            }
        }
        String node1=orinode[ori];
        thisnode=node1;
        if(splitattribute[0][0].equals("null")){
            thisnode=p;
            System.out.println("继承属性:"+valueoffather);
            System.out.println("本节属性:"+thisnode);
            System.out.println("向上一层");
            return;
        }
        for(int i=0;i<b.length;i++){
            if(mytree.isSame(b,a[0])){
                System.out.println("继承属性:"+valueoffather);
                System.out.println("本节属性:"+thisnode);
                System.out.println("向上一层");
                return;
            }
        }
        if(subfeature.length==1){
            thisnode=p;
            System.out.println("继承属性:"+valueoffather);
            System.out.println("本节属性:"+thisnode);
            System.out.println("向上一层");
            return;
        }
        double entD=0;
        double[]ratioy=mytree.getRatio(splitattribute, feature[0],mytree.getvalue(splitattribute,feature[0]));
        for(int i=0;i<ratioy.length;i++){
            if(ratioy[i]!=0){
                entD=ari.sub(entD,ari.mul(ratioy[i],ari.div(Math.log(ratioy[i]),Math.log(2)),3));
            }
        }
        double[] entDi=new double[subfeature.length-1];

        String[][] blank=new String[1][1];
        blank[0][0]="null";
        for(int i=0;i<subfeature.length-1;i++){
            String[] ithvalue=mytree.getvalue(splitattribute,subfeature[i+1]);
            double[] ratio=mytree.getRatio(splitattribute, subfeature[i+1], ithvalue);
            for(int j=0;j<ithvalue.length;j++){
                String[][]ithsplitmatrix=mytree.subattr(splitattribute, subfeature[i+1], ithvalue[j]);
                String[]yvalue=mytree.getvalue(ithsplitmatrix, subfeature[0]);
                double[]entDdi=new double[ithvalue.length];
                double[]Ddi=mytree.getRatio(ithsplitmatrix, subfeature[0], yvalue);
                for(int t=0;t<yvalue.length;t++){
                    entDdi[j]=ari.sub(entDdi[j],ari.mul(Ddi[t],ari.div(Math.log(Ddi[t]),Math.log(2)),3));
                }
                entDi[i]=ari.add(entDi[i],ari.mul(ratio[j],entDdi[j],3));
            } 
            entDi[i]=ari.sub(entD,entDi[i]);


        }

        int index=0;
        double dd=entDi[0];
        for(int i=1;i<entDi.length;i++){
            if(entDi[i]>=dd){
                dd=entDi[i];
                index=i;
            }
        }

        thisnode=subfeature[index+1];
        String[]valueofchild=mytree.getvalue(splitattribute, subfeature[index+1]);
        String[]featureofchild=new String[subfeature.length-1];
        int Mozart=0;
        int Bethoven=0;
        while(Mozart<subfeature.length){
            if(Mozart!=index+1){
                featureofchild[Bethoven]=subfeature[Mozart];
                Bethoven++;
                Mozart++;
            }
            else{
                Mozart++;
            }
        }
        System.out.println("继承属性:"+valueoffather);
        System.out.println("本节属性:"+thisnode);
        System.out.println("向上一层");
        String[]allvalue=mytree.getvalue(attribute, subfeature[index+1]);
        if(valueofchild.length<allvalue.length){
            int a1=allvalue.length-valueofchild.length;
            String[]itt=new String[a1];
            int ii=0;
            for(int i=0;i<allvalue.length;i++){
                boolean ttt=false;
                for(int j=0;j<valueofchild.length;j++){
                    if(allvalue[i].equals(valueofchild[j])){
                        ttt=true;
                    }
                }
                if(!ttt){
                    itt[ii]=allvalue[i];
                    ii++;
                }
            }
            for(int i=0;i<itt.length;i++){
                generateTree treeofchild=new generateTree(feature,attribute);
                treeofchild.node(featureofchild,blank, itt[i],node1,test);
            }
        }
        double[]accuracybefore=new double[valueofchild.length];
        double[]accuracyafter=new double[valueofchild.length];
        boolean cutornot=false;
        generateTree treeofchild=new generateTree(feature,attribute);

        for (int i=0;i<valueofchild.length;i++){
            attributeofchild=mytree.subattr(splitattribute, subfeature[index+1], valueofchild[i]);
            if(test[0][0]=="null"){
                splittest=new String[1][1];
                splittest[0][0]="null";
            }
            else{
                splittest=mytree.subattr(test,subfeature[index+1],valueofchild[i]);
            }
            treeofchild.node(featureofchild,attributeofchild, valueofchild[i],node1,splittest);
            if(i==valueofchild.length-1){
                System.out.println("向上一层");
            }
            for(int j=0;j<valueofchild.length;j++){
                if(!mytree.belongs(treeofchild.thisnode,orinode)){
                    cutornot=false;
                }
            }
        }


        if(!cutornot){
        for(int i=0;i<valueofchild.length;i++){
            if(!splittest[0][0].equals("null")){
                    for(int i1=0;i1<splittest.length;i1++){
                        if(splittest[i1][0].equals(treeofchild.thisnode)){
                            accuracybefore[i]++;
                            accuracybefore[i]=ari.div(accuracybefore[i],test.length,4);

                        }
                        if(splittest[i1][0].equals(node1)){
                            accuracyafter[i]++;
                            accuracyafter[i]=ari.div(accuracyafter[i],test.length,4);

                        }
                    }
            }
            double accuracyb=0;
            double accuracya=0;
            for(int i1=0;i1<valueofchild.length;i1++){
                accuracyb=accuracyb+accuracybefore[i1];
                accuracya=accuracya+accuracyafter[i1];
            }
            if(accuracyb<accuracya){
        if(!test[0][0].equals("null")){
        System.out.println(" ");
        System.out.println("");
        System.out.println("剪枝");
        System.out.print("子叶属性");


            System.out.println("a:"+accuracya+" "+"b:"+accuracyb);

                thisnode=node1;
            System.out.println("剪枝前正确率:"+accuracyb);
            System.out.println("剪枝后正确率:"+accuracya);
            System.out.println("原本父属性:"+valueoffather+"原本本节属性: "+thisnode);
            System.out.println("原本父属性:"+valueoffather+"原本本节属性: "+node1);
            }
        }
        }
    }
    }   
}

继承属性:start
本节属性:纹理
继承属性:稍糊
本节属性:触感
继承属性:软粘
本节属性:好瓜
向上一层
继承属性:硬滑
本节属性:坏瓜
向上一层
向上一层
继承属性:模糊
本节属性:坏瓜
向上一层
继承属性:清晰
本节属性:触感
继承属性:软粘
本节属性:脐部
继承属性:凹陷
本节属性:坏瓜
向上一层
继承属性:稍凹
本节属性:色泽
继承属性:浅白
本节属性:好瓜
向上一层
继承属性:青绿
本节属性:好瓜
向上一层
继承属性:乌黑
本节属性:坏瓜
向上一层
向上一层
继承属性:平坦
本节属性:坏瓜
向上一层
向上一层
继承属性:硬滑
本节属性:好瓜
向上一层
向上一层
向上一层
此例子到此结束

继承属性:start
本节属性:脐部
继承属性:稍凹
本节属性:触感
继承属性:软粘
本节属性:纹理
继承属性:模糊
本节属性:好瓜
向上一层
继承属性:稍糊
本节属性:好瓜
向上一层
继承属性:清晰
本节属性:色泽
继承属性:浅白
本节属性:好瓜
向上一层
继承属性:青绿
本节属性:好瓜
向上一层
继承属性:乌黑
本节属性:坏瓜
向上一层
向上一层
向上一层
继承属性:硬滑
本节属性:坏瓜
向上一层
向上一层
继承属性:凹陷
本节属性:纹理
向上一层
继承属性:模糊
本节属性:好瓜
向上一层
继承属性:稍糊
本节属性:坏瓜
向上一层
继承属性:清晰
本节属性:好瓜
向上一层
向上一层
继承属性:平坦
本节属性:坏瓜
向上一层
向上一层