1. package graph;        
2.        
3. import java.util.ArrayList;  
4. import java.util.List;  
5. import java.util.TreeSet;        
6.        
7. /**       
8.  * 决策树的ID3算法       
9.  * 参照实现http://www.blog.edu.cn/user2/huangbo929/archives/2006/1533249.shtml     
10.  * @author Leon.Chen       
11.  */       
12. public class DTree {        
13.             
14.     /**     
15.      * root     
16.      */       
17.     TreeNode root;        
18.             
19.     /**     
20.      * 可见性数组     
21.      */       
22.     private static boolean[] visable;        
23.             
24.     private Object[] array;    
25.       
26.     private int index;  
27.        
28.     /**     
29.      * @param args     
30.      */       
31.     @SuppressWarnings("boxing")        
32.     public static void main(String[] args) {        
33.         //初始数据        
34.         Object[] array = new Object[] {         
35.                         new String[]{ "Sunny"    ,"Hot"   ,"High"    ,"Weak"    ,"No" },        
36.                         new String[]{ "Sunny"    ,"Hot"   ,"High"    ,"Strong"  ,"No" },        
37.                         new String[]{ "Overcast" ,"Hot"   ,"High"    ,"Weak"    ,"Yes"},        
38.                         new String[]{ "Rain"     ,"Mild"  ,"High"    ,"Weak"    ,"Yes"},        
39.                         new String[]{ "Rain"     ,"Cool"  ,"Normal"  ,"Weak"    ,"Yes"},        
40.                         new String[]{ "Rain"     ,"Cool"  ,"Normal"  ,"Strong"  ,"No" },        
41.                         new String[]{ "Overcast" ,"Cool"  ,"Normal"  ,"Strong"  ,"Yes"},        
42.                         new String[]{ "Sunny"    ,"Mild"  ,"High"    ,"Weak"    ,"No" },        
43.                         new String[]{ "Sunny"    ,"Cool"  ,"Normal"  ,"Weak"    ,"Yes"},        
44.                         new String[]{ "Rain"     ,"Mild"  ,"Normal"  ,"Weak"    ,"Yes"},        
45.                         new String[]{ "Sunny"    ,"Mild"  ,"Normal"  ,"Strong"  ,"Yes"},        
46.                         new String[]{ "Overcast" ,"Mild"  ,"High"    ,"Strong"  ,"Yes"},        
47.                         new String[]{ "Overcast" ,"Hot"   ,"Normal"  ,"Weak"    ,"Yes"},        
48.                         new String[]{ "Rain"     ,"Mild"  ,"High"    ,"Strong"  ,"No" },        
49.                         };        
50.                 
51.         DTree tree = new DTree();         
52.         tree.create(array,4);  
53.     }   
54.       
55.     public void create(Object[] array,int index){  
56.         this.array = array;  
57.         init(array,index);  
58.         createDTree(array);  
59.         printDTree(root);  
60.     }  
61.          
62.     public Object[] getMaxGain(Object[] array){     
63.         Object[] result = new Object[2];     
64.         double gain = 0;     
65.         int index = 0;  
66.           
67.         for(int i=0;i<visable.length;i++){     
68.             if(!visable[i]){     
69.                 double value = gain(array,i);     
70.                 if(gain < value){     
71.                     gain = value;     
72.                     index = i;     
73.                 }     
74.             }     
75.         }     
76.         result[0] = gain;     
77.         result[1] = index;     
78.         visable[index] = true;     
79.         return result;     
80.     }     
81.          
82.     public void createDTree(Object[] array) {     
83.         Object[] maxgain = getMaxGain(array);     
84.         if (root == null) {     
85.             root = new TreeNode();     
86.             root.parent = null;     
87.             root.parentArrtibute = null;     
88.             root.arrtibutes = getArrtibutes(((Integer) maxgain[1]).intValue());     
89.             root.nodeName = getNodeName(((Integer) maxgain[1]).intValue());    
90.             root.childNodes = new TreeNode[root.arrtibutes.length];  
91.             insertTree(array,root);  
92.         }  
93.     }     
94.          
95.     public void insertTree(Object[] array,TreeNode parentNode){  
96.         String[] arrtibutes = parentNode.arrtibutes;  
97.         for(int i=0;i<arrtibutes.length;i++){  
98.             Object[] pickArray = pickUpAndCreateArray(array,arrtibutes[i],getNodeIndex(parentNode.nodeName));  
99.             Object[] info = getMaxGain(pickArray);  
100.             double gain = ((Double)info[0]).doubleValue();  
101.             if(gain != 0){  
102.                 int index = ((Integer) info[1]).intValue();  
103.                 System.out.println("gain = "+gain+" ,node name = "+getNodeName(index));  
104.                 TreeNode currentNode = new TreeNode();  
105.                 currentNode.parent = parentNode;  
106.                 currentNode.parentArrtibute = arrtibutes[i];  
107.                 currentNode.arrtibutes = getArrtibutes(index);  
108.                 currentNode.nodeName = getNodeName(index);  
109.                 currentNode.childNodes = new TreeNode[currentNode.arrtibutes.length];  
110.                 parentNode.childNodes[i] = currentNode;  
111.                 insertTree(pickArray,currentNode);  
112.             }else {  
113.                 TreeNode leafNode = new TreeNode();  
114.                 leafNode.parent = parentNode;  
115.                 leafNode.parentArrtibute = arrtibutes[i];  
116.                 leafNode.arrtibutes = new String[0];  
117.                 leafNode.nodeName = getLeafNodeName(pickArray);  
118.                 leafNode.childNodes = new TreeNode[0];  
119.                 parentNode.childNodes[i] = leafNode;  
120.             }  
121.         }  
122.   
123.     }  
124.       
125.     public void printDTree(TreeNode node){  
126.         System.out.println(node.nodeName);  
127.   
128.         TreeNode[] childs = node.childNodes;  
129.         for(int i=0;i<childs.length;i++){  
130.             if(childs[i]!=null){  
131.                 System.out.println(childs[i].parentArrtibute);  
132.                 printDTree(childs[i]);  
133.             }  
134.         }  
135.     }  
136.     
137.     /**       
138.      * @param dataArray 原始数组 D      
139.      * @param criterion 标准值       
140.      * @return double       
141.      */       
142.     public void init(Object[] dataArray,int index) {  
143.         this.index = index;  
144.         //数据初始化     
145.         visable = new boolean[((String[])dataArray[0]).length];        
146.         for(int i=0;i<visable.length;i++) {      
147.             if(i == index){     
148.                 visable[i] = true;      
149.             }else {     
150.                 visable[i] = false;      
151.             }     
152.         }     
153.     }     
154.       
155.     public Object[] pickUpAndCreateArray(Object[] array,String arrtibute,int index){  
156.         List<String[]> list = new ArrayList<String[]>();  
157.         for(int i=0;i<array.length;i++){  
158.             String[] strs = (String[])array[i];  
159.             if(strs[index].equals(arrtibute)){  
160.                 list.add(strs);  
161.             }  
162.         }  
163.         return list.toArray();  
164.     }  
165.     
166.     /**     
167.      * Entropy(S)     
168.      * @param array     
169.      * @return double      
170.      */       
171.     public double gain(Object[] array,int index) {        
172.         String[] playBalls = getArrtibutes(this.index);     
173.         int[] counts = new int[playBalls.length];        
174.         for(int i=0;i<counts.length;i++) {     
175.             counts[i] = 0;        
176.         }     
177.         for(int i=0;i<array.length;i++) {     
178.             String[] strs = (String[])array[i];     
179.             for(int j=0;j<playBalls.length;j++) {     
180.                 if(strs[this.index].equals(playBalls[j])) {     
181.                     counts[j]++;     
182.                 }     
183.             }     
184.         }     
185.         /**   
186.          * Entropy(S) = S -p(I) log2 p(I)   
187.          */    
188.         double entropyS = 0;     
189.         for(int i=0;i<counts.length;i++) {        
190.             entropyS += DTreeUtil.sigma(counts[i],array.length);        
191.         }     
192.         String[] arrtibutes = getArrtibutes(index);     
193.         /**   
194.          * total ((|Sv| / |S|) * Entropy(Sv))    
195.          */    
196.         double sv_total = 0;     
197.         for(int i=0;i<arrtibutes.length;i++){     
198.             sv_total += entropySv(array, index,arrtibutes[i],array.length);     
199.         }     
200.         return entropyS-sv_total;     
201.     }     
202.          
203.     /**   
204.      * ((|Sv| / |S|) * Entropy(Sv))   
205.      * @param array   
206.      * @param index   
207.      * @param arrtibute   
208.      * @param allTotal   
209.      * @return   
210.      */    
211.     public double entropySv(Object[] array,int index,String arrtibute,int allTotal) {     
212.         String[] playBalls = getArrtibutes(this.index);     
213.         int[] counts = new int[playBalls.length];     
214.         for(int i=0;i<counts.length;i++) {     
215.             counts[i] = 0;        
216.         }     
217.     
218.         for (int i = 0; i < array.length; i++) {     
219.             String[] strs = (String[]) array[i];     
220.             if (strs[index].equals(arrtibute)) {     
221.                 for (int k = 0; k < playBalls.length; k++) {     
222.                     if (strs[this.index].equals(playBalls[k])) {     
223.                         counts[k]++;     
224.                     }     
225.                 }     
226.             }     
227.         }     
228.     
229.         int total = 0;     
230.         double entropySv = 0;      
231.         for(int i=0;i<counts.length;i++){     
232.             total += counts[i];     
233.         }     
234.         for(int i=0;i<counts.length;i++){     
235.             entropySv += DTreeUtil.sigma(counts[i],total);      
236.         }      
237.         return DTreeUtil.getPi(total, allTotal)*entropySv;     
238.     }     
239.             
240.     @SuppressWarnings("unchecked")     
241.     public String[] getArrtibutes(int index) {        
242.         TreeSet<String> set = new TreeSet<String>(new SequenceComparator());        
243.         for (int i = 0; i < array.length; i++) {        
244.             String[] strs = (String[]) array[i];        
245.             set.add(strs[index]);        
246.         }        
247.         String[] result = new String[set.size()];        
248.         return set.toArray(result);        
249.     }     
250.             
251.     public String getNodeName(int index) {      
252.         String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};  
253.         for(int i=0;i<strs.length;i++){  
254.             if(i == index){  
255.                 return strs[i];  
256.             }  
257.         }  
258.         return null;      
259.     }  
260.       
261.     public String getLeafNodeName(Object[] array){  
262.         if(array!=null && array.length>0){  
263.             String[] strs = (String[])array[0];  
264.             return strs[index];  
265.         }  
266.         return null;          
267.     }  
268.       
269.     public int getNodeIndex(String name) {    
270.         String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};  
271.         for(int i=0;i<strs.length;i++){  
272.             if(name.equals(strs[i])){  
273.                 return i;  
274.             }  
275.         }  
276.         return -1;      
277.     }   
278. }        
279.   
280. package graph;        
281.        
282. /**     
283.  * @author B.Chen     
284.  */       
285. public class TreeNode {        
286.        
287.     /**     
288.      * 父     
289.      */       
290.     TreeNode parent;      
291.        
292.     /**     
293.      * 指向父的哪个属性     
294.      */       
295.     String parentArrtibute;        
296.        
297.     /**      
298.      * 节点名      
299.      */       
300.     String nodeName;        
301.        
302.     /**     
303.      * 属性数组     
304.      */       
305.     String[] arrtibutes;      
306.       
307.     /** 
308.      * 节点数组 
309.      */  
310.     TreeNode[] childNodes;  
311.   
312. }        
313.   
314. package graph;     
315.     
316. public class DTreeUtil {     
317.     
318.     /**   
319.      * 属性值熵的计算 Info(T)=(i=1...k)pi*log(2)pi   
320.      *    
321.      * @param x   
322.      * @param total   
323.      * @return double   
324.      */    
325.     public static double sigma(int x, int total) {  
326.         if(x == 0){  
327.             return 0;  
328.         }  
329.         double x_pi = getPi(x, total);     
330.         return -(x_pi * logYBase2(x_pi));     
331.     }     
332.     
333.     /**   
334.      * log2y   
335.      *    
336.      * @param y   
337.      * @return double   
338.      */    
339.     public static double logYBase2(double y) {     
340.         return Math.log(y) / Math.log(2);     
341.     }     
342.     
343.     /**   
344.      * pi是当前这个属性出现的概率(=出现次数/总数)   
345.      *    
346.      * @param x   
347.      * @param total   
348.      * @return double   
349.      */    
350.     public static double getPi(int x, int total) {     
351.         return x * Double.parseDouble("1.0") / total;     
352.     }    
353.   
354. }     
355.   
356. package graph;  
357.   
358. import java.util.Comparator;  
359.   
360. public class SequenceComparator implements Comparator {  
361.   
362.     public int compare(Object o1, Object o2) throws ClassCastException {  
363.         String str1 = (String) o1;  
364.         String str2 = (String) o2;  
365.         return str1.compareTo(str2);  
366.     }  
367.   
368. }