此文不对理论做相关阐述,仅涉及代码实现:
1.熵计算公式:
P为正例,Q为反例
Entropy(S) = -PLog2(P) - QLog2(Q);
2.信息增量计算:
Gain(S,Sv) = Entropy(S) - (|Sv|/|S|)ΣEntropy(Sv);
举例:
转化数据输入:
5 14
Outlook Sunny Sunny Overcast Rain Rain Rain Overcast Sunny Sunny Rain Sunny Overcast Overcast Rain
Temperature Hot Hot Hot Mild Cool Cool Cool Mild Cool Mild Mild Mild Hot Mild
Humidity High High High High Normal Normal Normal High Normal Normal Normal High Normal High
Wind Weak Strong Weak Weak Weak Strong Strong Weak Weak Weak Strong Strong Weak Strong
PlayTennis No No Yes Yes Yes No Yes No Yes Yes Yes Yes Yes No
Outlook Temperature Humidity Wind PlayTennis
1 package com.qunar.data.tree; 2 3 /** 4 * ********************************************************* 5 * <p/> 6 * Author: XiJun.Gong 7 * Date: 2016-09-02 15:28 8 * Version: default 1.0.0 9 * Class description: 10 * <p>统计该类型出现的次数</p> 11 * <p/> 12 * ********************************************************* 13 */ 14 public class CountMap<T> { 15 16 private T key; //类型 17 private int value; //出现的次数 18 19 public CountMap() { 20 this(null, 0); 21 } 22 23 public CountMap(T key, int value) { 24 this.key = key; 25 this.value = value; 26 } 27 28 public T getKey() { 29 return key; 30 } 31 32 public void setKey(T key) { 33 this.key = key; 34 } 35 36 public int getValue() { 37 return value; 38 } 39 40 public void setValue(int value) { 41 this.value = value; 42 } 43 }
1 package com.qunar.data.tree; 2 3 import com.google.common.collect.ArrayListMultimap; 4 import com.google.common.collect.Maps; 5 import com.google.common.collect.Multimap; 6 import com.google.common.collect.Sets; 7 8 import java.util.*; 9 10 /** 11 * ********************************************************* 12 * <p/> 13 * Author: XiJun.Gong 14 * Date: 2016-09-02 14:24 15 * Version: default 1.0.0 16 * Class description: 17 * <p>决策树</p> 18 * <p/> 19 * ********************************************************* 20 */ 21 22 public class DecisionTree<T, K> { 23 24 private static String positiveExampleType = "Yes"; 25 private static String counterExampleType = "No"; 26 27 28 public double pLog2(final double p) { 29 if (0 == p) return 0; 30 return p * (Math.log(p) / Math.log(2)); 31 } 32 33 /** 34 * 熵计算 35 * 36 * @param positiveExample 正例个数 37 * @param counterExample 反例个数 38 * @return 熵值 39 */ 40 public double entropy(final double positiveExample, final double counterExample) { 41 42 double total = positiveExample + counterExample; 43 double positiveP = positiveExample / total; 44 double counterP = counterExample / total; 45 return -1d * (pLog2(positiveP) + pLog2(counterP)); 46 } 47 48 /** 49 * @param features 特征列表 50 * @param results 对应结果 51 * @return 将信息整合成新的格式 52 */ 53 public Multimap<T, CountMap<K>> merge(final List<T> features, final List<T> results) { 54 //数据转化 55 Multimap<T, CountMap<K>> InfoMap = ArrayListMultimap.create(); 56 Iterator result = results.iterator(); 57 for (T feature : features) { 58 K res = (K) result.next(); 59 boolean tag = false; 60 Collection<CountMap<K>> countMaps = InfoMap.get(feature); 61 for (CountMap countMap : countMaps) { 62 if (countMap.getKey().equals(res)) { 63 /*修改值*/ 64 int num = countMap.getValue() + 1; 65 InfoMap.remove(feature, countMap); 66 InfoMap.put(feature, new CountMap<K>(res, num)); 67 tag = true; 68 break; 69 } 70 } 71 if (!tag) 72 InfoMap.put(feature, new CountMap<K>(res, 1)); 73 } 74 75 return InfoMap; 76 } 77 78 /** 79 * 信息增益 80 * 81 * @param infoMap 因素(Outlook,Temperature,Humidity,Wind)对应的结果 82 * @param dataTable 输入的数据表 83 * @param type 因素中的类型(Outlook{Sunny,Overcast,Rain}) 84 * @param entropyS 总的熵值 85 * @param totalSize 总的样本数 86 * @return 信息增益 87 */ 88 public double gain(Multimap<T, CountMap<K>> infoMap, 89 Map<K, List<T>> dataTable, 90 final String type, 91 double entropyS, 92 final int totalSize) { 93 //去重 94 Set<T> subTypes = Sets.newHashSet(); 95 subTypes.addAll(dataTable.get(type)); 96 /*计算*/ 97 for (T subType : subTypes) { 98 Collection<CountMap<K>> countMaps = infoMap.get(subType); 99 double subSize = 0; 100 double positiveExample = 0; 101 double counterExample = 0; 102 for (CountMap<K> countMap : countMaps) { 103 subSize += countMap.getValue(); 104 if (positiveExampleType.equals(countMap.getKey())) 105 positiveExample = countMap.getValue(); 106 else 107 counterExample = countMap.getValue(); 108 } 109 entropyS -= (subSize / totalSize) * entropy(positiveExample, counterExample); 110 } 111 return entropyS; 112 } 113 114 /** 115 * 计算 116 * 117 * @param dataTable 数据表 118 * @param types 因素列表{Outlook,Temperature,Humidity,Wind} 119 * @param resultType 结果(PlayTennis) 120 * @return 返回信息增益集合 121 */ 122 public Map<String, Double> calculate(Map<K, List<T>> dataTable, List<K> types, K resultType) { 123 124 Map<String, Double> answer = Maps.newHashMap(); 125 List<T> results = dataTable.get(resultType); 126 int totalSize = results.size(); 127 int positiveExample = 0; 128 int counterExample = 0; 129 double entropyS = 0d; 130 for (T ExampleType : results) { 131 if (positiveExampleType.equals(ExampleType)) { 132 ++positiveExample; 133 continue; 134 } 135 ++counterExample; 136 } 137 /*计算总的熵*/ 138 entropyS = entropy(positiveExample, counterExample); 139 140 Multimap<T, CountMap<K>> infoMap; 141 for (K type : types) { 142 infoMap = merge(dataTable.get(type), results); 143 double _gain = gain(infoMap, dataTable, (String) type, entropyS, totalSize); 144 answer.put((String) type, _gain); 145 } 146 return answer; 147 } 148 149 } 1package com.qunar.data.tree;
2 3 import com.google.common.collect.Lists; 4 import com.google.common.collect.Maps; 5 6 import java.util.*; 7 8 /** 9 * ********************************************************* 10 * <p/> 11 * Author: XiJun.Gong 12 * Date: 2016-09-02 16:43 13 * Version: default 1.0.0 14 * Class description: 15 * <p/> 16 * ********************************************************* 17 */ 18 public class Main { 19 20 public static void main(String args[]) { 21 22 Scanner scanner = new Scanner(System.in); 23 while (scanner.hasNext()) { 24 DecisionTree<String, String> dt = new DecisionTree(); 25 Map<String, List<String>> dataTable = Maps.newHashMap(); 26 /*Map<String, List<String>> dataTable = Maps.newHashMap();*/ 27 List<String> types = Lists.newArrayList(); 28 String resultType; 29 int factorSize = scanner.nextInt(); 30 int demoSize = scanner.nextInt(); 31 String type; 32 33 for (int i = 0; i < factorSize; i++) { 34 List<String> demos = Lists.newArrayList(); 35 type = scanner.next(); 36 for (int j = 0; j < demoSize; j++) { 37 demos.add(scanner.next()); 38 } 39 dataTable.put(type, demos); 40 } 41 for (int i = 1; i < factorSize; i++) { 42 types.add(scanner.next()); 43 } 44 resultType = scanner.next(); 45 Map<String, Double> ans = dt.calculate(dataTable, types, resultType); 46 List<Map.Entry<String, Double>> list = new ArrayList<Map.Entry<String, Double>>(ans.entrySet()); 47 Collections.sort(list, new Comparator<Map.Entry<String, Double>>() { 48 49 50 @Override 51 public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) { 52 return (o2.getValue() > o1.getValue() ? 1 : -1); 53 } 54 }); 55 56 for (Map.Entry<String, Double> iterator : list) { 57 System.out.println(iterator.getKey() + "= " + iterator.getValue()); 58 } 59 } 60 } 61 62 } 63 /** 64 *使用举例:* 65 5 14 66 Outlook Sunny Sunny Overcast Rain Rain Rain Overcast Sunny Sunny Rain Sunny Overcast Overcast Rain 67 Temperature Hot Hot Hot Mild Cool Cool Cool Mild Cool Mild Mild Mild Hot Mild 68 Humidity High High High High Normal Normal Normal High Normal Normal Normal High Normal High 69 Wind Weak Strong Weak Weak Weak Strong Strong Weak Weak Weak Strong Strong Weak Strong 70 PlayTennis No No Yes Yes Yes No Yes No Yes Yes Yes Yes Yes No 71 Outlook Temperature Humidity Wind PlayTennis 72 */
结果:
Outlook= 0.2467498197744391 Humidity= 0.15183550136234136 Wind= 0.04812703040826927 Temperature= 0.029222565658954647