1)熵与信息增益:
2)以下是实现代码:
1 //import java.awt.color.ICC_ColorSpace;
2 import java.io.*;
3 import java.util.ArrayList;
4 import java.util.Collections;
5 import java.util.Comparator;
6 import java.util.HashMap;
7 import java.util.HashSet;
8 import java.util.Iterator;
9 //import java.util.Iterator;
10 import java.util.List;
11 //import java.util.Locale.Category;
12 import java.util.Map;
13 import java.util.Map.Entry;
14 import java.util.Set;
15 class decisionTree{
16
17 private static Map<String, Map<String, Integer>> featureValuesAndCounts=new HashMap<String, Map<String,Integer>>();
18 private static ArrayList<String[]> dataSet=new ArrayList<String[]>();
19 private static ArrayList<String> features=new ArrayList<String>();
20 private static Set<String> category=new HashSet<String>();
21 //public static DecisionNode root=new DecisionNode();
22 //private static ArrayList<ArrayList<String>> featureValue=new ArrayList<ArrayList<String>>();
23 public static void GetDataSet()
24 {
25 File file = new File("C:\\Users\\hfz\\workspace\\decisionTree\\src\\loan.txt");
26 try{
27 BufferedReader br = new BufferedReader(new FileReader(file));//
28 String s = null;
29 s=br.readLine();//读取第一行的内容,即是各特征的名称
30 String[] tempFeatures=s.split(",");
31 for (String string1 : tempFeatures) {
32 features.add(string1);
33 }
34 s=br.readLine(); //开始读取特征值
35 String[] tt=null;
36 int flag=s.length();
37 while(flag!=0){//英文文档读到结尾得到的值是null,而中文文档读到结尾得到的值却是""
38 tt=s.split(",");
39 dataSet.add(tt); //将特征值存入
40 category.add(tt[tt.length-1]);//category为集合类型,用于存储类型值
41
42 s=br.readLine();
43 if (s!=null) {
44 flag = s.length();
45 }
46 else{
47 flag=0;
48 }
49
50 }
51
52 for (int j = 0; j < features.size(); j++) {//逻辑上模拟列优先的方式读取二维数组形式的数据集,就是首先读取一个特征名称,再遍历数据集
53 Map<String, Integer> ttt=new HashMap<String, Integer>();//将某特征的各个特征值存入Map中,然后再度第二个特征,再遍历数据集。。。
54 for (int i = 0; i < dataSet.size(); i++) {
55 String currentFeatureValue=dataSet.get(i)[j];
56 if(!(ttt.containsKey(currentFeatureValue)))
57 ttt.put(currentFeatureValue, 1);
58 else {
59 ttt.replace(currentFeatureValue, ttt.get(currentFeatureValue)+1);
60 }
61
62 }
63 featureValuesAndCounts.put(features.get(j), ttt);//嵌套形式的Map,第一层的key是特征名称,value是一个新的Map
64 // 新Map中key是特征的各个值,value是特征值在数据集中出现的次数。
65
66 }
67
68 br.close();
69 }
70
71 catch(Exception e){
72 e.printStackTrace();
73 }
74 }
75 public static DecisionNode treeGrowth(ArrayList<String[]> dataset,String currentFeatureName,
76 String currentFeatureValue,ArrayList<String> current_features,
77 Map<String,Map<String,Integer>> current_featureValuesCounts){
78 /*
79 dataset:用于split方法,从dataset数据集中去除掉具有某个特征值的对应的若干实例,生成一个新的新的数据集
80 currentFeatureName:当前的特征名称,用于叶子节点,赋值给叶子节点的featureName字段
81 currentFeatureValue:当前特征名称对应的特征值,也用于叶子节点,赋值给featureValue字段
82 current_features:当前数据集中包含的所有特征名称,用于findBestAttribute方法,找到信息增益最大的的属性
83 current_featureValuesCounts:当前数据集中所有特征的各个特征值出现的次数,用于findBestAttribute方法,用于计算条件熵,进而计算信息增益。
84 */
85 ArrayList<String> classList=new ArrayList<String>();
86 int flag=0;
87 for (String[] string : dataset) {
88 //测试数据集中类型值的数量,flag表示数据集中的类型数量
89 if (classList.contains(string[string.length-1])) {
90
91 }
92 else {
93 classList.add(string[string.length-1]);
94 flag++;//如果flag>1表示数据集
95 }
96
97 }
98 if(1==flag){//如果只有一个类结果,则返回此叶子节点
99 DecisionNode d=new DecisionNode();
100 d.init(currentFeatureName,classList.get(0),currentFeatureValue);
101 return d;
102 }
103 if (dataset.get(0).length==1) {//如果数据集已经没有属性了只剩下类结果,则返回占比最大的类结果,也是叶子节点
104 DecisionNode d=new DecisionNode();
105 d.init(currentFeatureName,classify(classList),currentFeatureValue);
106 return d;
107 }
108
109 /*
110 DecisionNode是一个自定义的递归型的数据类型,类中一个children字段是DecisionNode类型的数组,
111 正好用这种类型来存储递归算法产生的结果(决策树),也就是用这种结构来存储一棵树。
112 */
113 //程序运行到这里就说明此节点不是叶子节点
114 DecisionNode root2=new DecisionNode();//那么root2就是一个决策属性节点(非叶子节点)了,非叶子节点就有孩子节点,下面就是计算它的孩子节点
115
116 int bestFeatureIndex=findBestAttribute(dataset,current_features,current_featureValuesCounts);
117 String bestFeatureLabel=current_features.get(bestFeatureIndex);
118 //root.testCondition=bestFeatureLabel;
119 ArrayList<String> feature_values=new ArrayList<String>();
120 for (Entry<String, Integer> featureEntry : current_featureValuesCounts.get(bestFeatureLabel).entrySet()) {
121 feature_values.add(featureEntry.getKey());
122
123 }
124 //给非叶子节点,也就是特征节点仅仅赋特征名称值
125 root2.init(currentFeatureName,currentFeatureValue);//java中不能是使用像C++中默认参数的函数,只能通过重载来实现同样的目的。
126 for (String values : feature_values) {
127 //DecisionNode tempRoot=new DecisionNode();
128
129 ArrayList<String[]> subDataSet = splitDataSet(dataset, bestFeatureIndex, values);//生成子数据集,即去除了包含values的实例,
130 // 接下来就是计算对此数据集利用决策树进行决策,又需要调用treeGrow方法
131 //所以,接下来需要得到对应这个子数据集的特征名称以及每个特征值在数据集中出现的次数
132 ArrayList<String> currentAttibutes=new ArrayList<>();
133 Iterator item1=current_features.iterator();
134 while(item1.hasNext()){
135 currentAttibutes.add(item1.next().toString());//这个子数据集的特征名称
136 }
137
138 Map<String,Map<String,Integer>> currentAttributeValuesCounts=new HashMap<String, Map<String, Integer>>();
139 //ArrayList<String[]> subDataSet = splitDataSet(dataset, bestFeatureIndex, values);
140 currentAttibutes.remove(bestFeatureLabel);
141 for (int j = 0; j < currentAttibutes.size(); j++) {
142 Map<String, Integer> ttt=new HashMap<String, Integer>();
143 for (int i = 0; i <subDataSet.size(); i++) {
144 String currentFeatureValueXX=subDataSet.get(i)[j];
145 if(!(ttt.containsKey(currentFeatureValueXX)))
146 ttt.put(currentFeatureValueXX, 1);
147 else {
148 ttt.replace(currentFeatureValueXX, ttt.get(currentFeatureValueXX)+1);
149 }
150
151 }
152 currentAttributeValuesCounts.put(currentAttibutes.get(j), ttt);//每个特征值在数据集中出现的次数
153
154 }
155
156 root2.add(treeGrowth(subDataSet, bestFeatureLabel, values, currentAttibutes, currentAttributeValuesCounts));
157
158 }
159
160
161 return root2;
162
163 }
164
165 public static void main(String[] agrs){
166 decisionTree.GetDataSet();
167 DecisionNode dd=decisionTree.treeGrowth(dataSet,"oo","xx",features,featureValuesAndCounts);
168 System.out.print(dd);
169
170
171
172 }
173
174 public static double calEntropy(ArrayList<String[]> dataset){//熵表示随机变量X不确定性的度量,在决策树中计算的熵就是决策结果这个变量的熵。
175 int sampleCounts=dataset.size();
176 Map<String, Integer> categoryCounts=new HashMap<String, Integer>();
177 for (String[] strings : dataset) {
178
179 if(categoryCounts.containsKey(strings[strings.length-1]))
180 categoryCounts.replace(strings[strings.length-1], categoryCounts.get(strings[strings.length-1])+1);
181 else {
182 categoryCounts.put(strings[strings.length-1],1);
183 }
184
185 }
186 double shannonEnt=0.0;
187 for (Integer value: categoryCounts.values()) {
188 double probability=value.doubleValue()/sampleCounts;
189 shannonEnt-=probability*(Math.log10(probability)/Math.log10(2));
190
191 }
192 return shannonEnt;
193 }
194
195 public static int findBestAttribute(ArrayList<String[]> dataset,ArrayList<String> currentFeatures,
196 Map<String,Map<String,Integer>> currentFeatureValuesCounts){
197 double baseEntroy=calEntropy(dataset);//计算基础熵,就是在不划分出某个特征的情况下。
198 double bestInfoGain=0.0;
199 int bestFeatureIndex=-1;
200
201 for (int i = 0; i <currentFeatures.size(); i++) {//遍历当前数据集的每个特征,计算每个特征的信息增益
202 double conditionalEntroy=0.0;
203 Map<String,Integer> tempFeatureCounts=currentFeatureValuesCounts.get(currentFeatures.get(i));
204 //Map类型有一个entrySet方法,此方法返回一个Map.Entry类型的集合,其中集合中的每个元素就是一个键值对,利用增强型的for循环可以遍历Map中
205 //key(entry.getkey)和value(entry.getValue)
206 for (Entry<String, Integer> entry : tempFeatureCounts.entrySet()) {
207 //计算条件熵,就是根据某个具体特征值划分出新的数据集,计算新的数据集的基础熵,再乘以权值,累加得到某个特征的条件熵。
208 conditionalEntroy+=(entry.getValue().doubleValue()/dataset.size())*calEntropy(splitDataSet(dataset, i, entry.getKey()));
209 }
210 if (baseEntroy-conditionalEntroy>bestInfoGain) {
211 bestInfoGain=baseEntroy-conditionalEntroy;
212 bestFeatureIndex=i;
213
214 }
215 }
216 if (-1==bestFeatureIndex){
217 System.out.print("cannot find best attribute!");
218 return -1;
219 }
220 else {
221 return bestFeatureIndex;//返回信息增益最大的特征的索引,在当前特征(currentFeatures)中的索引。
222 }
223 }
224 public static String classify(ArrayList<String> dataset) {
225
226 Map<String, Integer> categoryCount = new HashMap<String, Integer>();
227 for (String s1 : dataset) {
228 if (categoryCount.containsKey(s1)) {
229 categoryCount.replace(s1, categoryCount.get(s1) + 1);
230 } else {
231 categoryCount.put(s1, 1);
232 }
233 }
234 int maxCounts=-1;
235 String maxCountsCategory=null;
236 for (Entry<String,Integer> entry:categoryCount.entrySet()){//利用Map.Entry得到Map中的Value最大的键值对。
237 if (entry.getValue()>maxCounts){
238 maxCounts=entry.getValue();
239 maxCountsCategory=entry.getKey();
240 }
241 }
242 return maxCountsCategory;
243
244 }
245
246 public static ArrayList<String[]> splitDataSet(ArrayList<String[]> dataset,int featureIndex,String featureValue
247 ){
248 ArrayList<String[]> tempDataSet=new ArrayList<String[]>();
249 for (String[] strings : dataset) {
250 if (strings[featureIndex].equals(featureValue)) {
251
252 String[] xx=strings.clone();//数组的clone方法实现的是浅拷贝,实质就是以下的过程
253 /*
254 for (int i = featureIndex; i < strings.length-1; i++) {
255 xx[i]=strings[i];//就是把引用的值(地址)复制了一份,指向了同一个对象。
256 }
257
258 */
259 for (int i = featureIndex; i < strings.length-1; i++) {//xx中各个元素的值与strings中各个元素的值完全相等。
260 xx[i]=xx[i+1];//只是复制了引用的值而已,跟引用指向的对象没一点关系。Java将基本类型和引用类型变量都看成是值而已·
261 }
262 //最最最需要注意的一点,以上代码不能以下面这种形式实现
263 /*
264 for (int i = featureIndex; i < strings.length-1; i++) {//
265 strings[i]=strings[i+1];//这样会改变strings指向的对象,进而影响到dataset,改变了函数的参数dataset,
266 这样就在函数内“无意间”修改了dataset的值,集合类型,其实所有引用类型都是,以参数形式传入函数的话,可能会“无意间”就被修改了
267 }
268 */
269 String[] tempStrings=new String[xx.length-1];
270 for (int i = 0; i < tempStrings.length; i++) {
271 tempStrings[i]=xx[i];
272
273 }
274 tempDataSet.add(tempStrings);
275 }
276
277
278 }
279 return tempDataSet;
280 }
281
282 }
283 class DecisionNode{
284 public String featureName;
285 public String result;
286 public String featureValue;
287 public List<DecisionNode> children=new ArrayList<DecisionNode>();
288 public void add(DecisionNode node){
289 children.add(node);
290 }
291 public void init(String featureName,String result,String featureValue){
292 this.featureName=featureName;
293 this.result=result;
294 this.featureValue=featureValue;
295 }
296 public void init(String featureName,String featureValue){
297 this.featureName=featureName;
298 this.featureValue=featureValue;
299 }
300 }
参考:
http://www.blogjava.net/zhenandaci/archive/2009/03/24/261701.html