1. package graph;
2.
3. import java.util.ArrayList;
4. import java.util.List;
5. import java.util.TreeSet;
6.
7. /**
8. * 决策树的ID3算法
9. * 参照实现http://www.blog.edu.cn/user2/huangbo929/archives/2006/1533249.shtml
10. * @author Leon.Chen
11. */
12. public class DTree {
13.
14. /**
15. * root
16. */
17. TreeNode root;
18.
19. /**
20. * 可见性数组
21. */
22. private static boolean[] visable;
23.
24. private Object[] array;
25.
26. private int index;
27.
28. /**
29. * @param args
30. */
31. @SuppressWarnings("boxing")
32. public static void main(String[] args) {
33. //初始数据
34. Object[] array = new Object[] {
35. new String[]{ "Sunny" ,"Hot" ,"High" ,"Weak" ,"No" },
36. new String[]{ "Sunny" ,"Hot" ,"High" ,"Strong" ,"No" },
37. new String[]{ "Overcast" ,"Hot" ,"High" ,"Weak" ,"Yes"},
38. new String[]{ "Rain" ,"Mild" ,"High" ,"Weak" ,"Yes"},
39. new String[]{ "Rain" ,"Cool" ,"Normal" ,"Weak" ,"Yes"},
40. new String[]{ "Rain" ,"Cool" ,"Normal" ,"Strong" ,"No" },
41. new String[]{ "Overcast" ,"Cool" ,"Normal" ,"Strong" ,"Yes"},
42. new String[]{ "Sunny" ,"Mild" ,"High" ,"Weak" ,"No" },
43. new String[]{ "Sunny" ,"Cool" ,"Normal" ,"Weak" ,"Yes"},
44. new String[]{ "Rain" ,"Mild" ,"Normal" ,"Weak" ,"Yes"},
45. new String[]{ "Sunny" ,"Mild" ,"Normal" ,"Strong" ,"Yes"},
46. new String[]{ "Overcast" ,"Mild" ,"High" ,"Strong" ,"Yes"},
47. new String[]{ "Overcast" ,"Hot" ,"Normal" ,"Weak" ,"Yes"},
48. new String[]{ "Rain" ,"Mild" ,"High" ,"Strong" ,"No" },
49. };
50.
51. DTree tree = new DTree();
52. tree.create(array,4);
53. }
54.
55. public void create(Object[] array,int index){
56. this.array = array;
57. init(array,index);
58. createDTree(array);
59. printDTree(root);
60. }
61.
62. public Object[] getMaxGain(Object[] array){
63. Object[] result = new Object[2];
64. double gain = 0;
65. int index = 0;
66.
67. for(int i=0;i<visable.length;i++){
68. if(!visable[i]){
69. double value = gain(array,i);
70. if(gain < value){
71. gain = value;
72. index = i;
73. }
74. }
75. }
76. result[0] = gain;
77. result[1] = index;
78. visable[index] = true;
79. return result;
80. }
81.
82. public void createDTree(Object[] array) {
83. Object[] maxgain = getMaxGain(array);
84. if (root == null) {
85. root = new TreeNode();
86. root.parent = null;
87. root.parentArrtibute = null;
88. root.arrtibutes = getArrtibutes(((Integer) maxgain[1]).intValue());
89. root.nodeName = getNodeName(((Integer) maxgain[1]).intValue());
90. root.childNodes = new TreeNode[root.arrtibutes.length];
91. insertTree(array,root);
92. }
93. }
94.
95. public void insertTree(Object[] array,TreeNode parentNode){
96. String[] arrtibutes = parentNode.arrtibutes;
97. for(int i=0;i<arrtibutes.length;i++){
98. Object[] pickArray = pickUpAndCreateArray(array,arrtibutes[i],getNodeIndex(parentNode.nodeName));
99. Object[] info = getMaxGain(pickArray);
100. double gain = ((Double)info[0]).doubleValue();
101. if(gain != 0){
102. int index = ((Integer) info[1]).intValue();
103. System.out.println("gain = "+gain+" ,node name = "+getNodeName(index));
104. TreeNode currentNode = new TreeNode();
105. currentNode.parent = parentNode;
106. currentNode.parentArrtibute = arrtibutes[i];
107. currentNode.arrtibutes = getArrtibutes(index);
108. currentNode.nodeName = getNodeName(index);
109. currentNode.childNodes = new TreeNode[currentNode.arrtibutes.length];
110. parentNode.childNodes[i] = currentNode;
111. insertTree(pickArray,currentNode);
112. }else {
113. TreeNode leafNode = new TreeNode();
114. leafNode.parent = parentNode;
115. leafNode.parentArrtibute = arrtibutes[i];
116. leafNode.arrtibutes = new String[0];
117. leafNode.nodeName = getLeafNodeName(pickArray);
118. leafNode.childNodes = new TreeNode[0];
119. parentNode.childNodes[i] = leafNode;
120. }
121. }
122.
123. }
124.
125. public void printDTree(TreeNode node){
126. System.out.println(node.nodeName);
127.
128. TreeNode[] childs = node.childNodes;
129. for(int i=0;i<childs.length;i++){
130. if(childs[i]!=null){
131. System.out.println(childs[i].parentArrtibute);
132. printDTree(childs[i]);
133. }
134. }
135. }
136.
137. /**
138. * @param dataArray 原始数组 D
139. * @param criterion 标准值
140. * @return double
141. */
142. public void init(Object[] dataArray,int index) {
143. this.index = index;
144. //数据初始化
145. visable = new boolean[((String[])dataArray[0]).length];
146. for(int i=0;i<visable.length;i++) {
147. if(i == index){
148. visable[i] = true;
149. }else {
150. visable[i] = false;
151. }
152. }
153. }
154.
155. public Object[] pickUpAndCreateArray(Object[] array,String arrtibute,int index){
156. List<String[]> list = new ArrayList<String[]>();
157. for(int i=0;i<array.length;i++){
158. String[] strs = (String[])array[i];
159. if(strs[index].equals(arrtibute)){
160. list.add(strs);
161. }
162. }
163. return list.toArray();
164. }
165.
166. /**
167. * Entropy(S)
168. * @param array
169. * @return double
170. */
171. public double gain(Object[] array,int index) {
172. String[] playBalls = getArrtibutes(this.index);
173. int[] counts = new int[playBalls.length];
174. for(int i=0;i<counts.length;i++) {
175. counts[i] = 0;
176. }
177. for(int i=0;i<array.length;i++) {
178. String[] strs = (String[])array[i];
179. for(int j=0;j<playBalls.length;j++) {
180. if(strs[this.index].equals(playBalls[j])) {
181. counts[j]++;
182. }
183. }
184. }
185. /**
186. * Entropy(S) = S -p(I) log2 p(I)
187. */
188. double entropyS = 0;
189. for(int i=0;i<counts.length;i++) {
190. entropyS += DTreeUtil.sigma(counts[i],array.length);
191. }
192. String[] arrtibutes = getArrtibutes(index);
193. /**
194. * total ((|Sv| / |S|) * Entropy(Sv))
195. */
196. double sv_total = 0;
197. for(int i=0;i<arrtibutes.length;i++){
198. sv_total += entropySv(array, index,arrtibutes[i],array.length);
199. }
200. return entropyS-sv_total;
201. }
202.
203. /**
204. * ((|Sv| / |S|) * Entropy(Sv))
205. * @param array
206. * @param index
207. * @param arrtibute
208. * @param allTotal
209. * @return
210. */
211. public double entropySv(Object[] array,int index,String arrtibute,int allTotal) {
212. String[] playBalls = getArrtibutes(this.index);
213. int[] counts = new int[playBalls.length];
214. for(int i=0;i<counts.length;i++) {
215. counts[i] = 0;
216. }
217.
218. for (int i = 0; i < array.length; i++) {
219. String[] strs = (String[]) array[i];
220. if (strs[index].equals(arrtibute)) {
221. for (int k = 0; k < playBalls.length; k++) {
222. if (strs[this.index].equals(playBalls[k])) {
223. counts[k]++;
224. }
225. }
226. }
227. }
228.
229. int total = 0;
230. double entropySv = 0;
231. for(int i=0;i<counts.length;i++){
232. total += counts[i];
233. }
234. for(int i=0;i<counts.length;i++){
235. entropySv += DTreeUtil.sigma(counts[i],total);
236. }
237. return DTreeUtil.getPi(total, allTotal)*entropySv;
238. }
239.
240. @SuppressWarnings("unchecked")
241. public String[] getArrtibutes(int index) {
242. TreeSet<String> set = new TreeSet<String>(new SequenceComparator());
243. for (int i = 0; i < array.length; i++) {
244. String[] strs = (String[]) array[i];
245. set.add(strs[index]);
246. }
247. String[] result = new String[set.size()];
248. return set.toArray(result);
249. }
250.
251. public String getNodeName(int index) {
252. String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};
253. for(int i=0;i<strs.length;i++){
254. if(i == index){
255. return strs[i];
256. }
257. }
258. return null;
259. }
260.
261. public String getLeafNodeName(Object[] array){
262. if(array!=null && array.length>0){
263. String[] strs = (String[])array[0];
264. return strs[index];
265. }
266. return null;
267. }
268.
269. public int getNodeIndex(String name) {
270. String[] strs = new String[]{"Outlook","Temperature","Humidity","Wind","Play ball"};
271. for(int i=0;i<strs.length;i++){
272. if(name.equals(strs[i])){
273. return i;
274. }
275. }
276. return -1;
277. }
278. }
279.
280. package graph;
281.
282. /**
283. * @author B.Chen
284. */
285. public class TreeNode {
286.
287. /**
288. * 父
289. */
290. TreeNode parent;
291.
292. /**
293. * 指向父的哪个属性
294. */
295. String parentArrtibute;
296.
297. /**
298. * 节点名
299. */
300. String nodeName;
301.
302. /**
303. * 属性数组
304. */
305. String[] arrtibutes;
306.
307. /**
308. * 节点数组
309. */
310. TreeNode[] childNodes;
311.
312. }
313.
314. package graph;
315.
316. public class DTreeUtil {
317.
318. /**
319. * 属性值熵的计算 Info(T)=(i=1...k)pi*log(2)pi
320. *
321. * @param x
322. * @param total
323. * @return double
324. */
325. public static double sigma(int x, int total) {
326. if(x == 0){
327. return 0;
328. }
329. double x_pi = getPi(x, total);
330. return -(x_pi * logYBase2(x_pi));
331. }
332.
333. /**
334. * log2y
335. *
336. * @param y
337. * @return double
338. */
339. public static double logYBase2(double y) {
340. return Math.log(y) / Math.log(2);
341. }
342.
343. /**
344. * pi是当前这个属性出现的概率(=出现次数/总数)
345. *
346. * @param x
347. * @param total
348. * @return double
349. */
350. public static double getPi(int x, int total) {
351. return x * Double.parseDouble("1.0") / total;
352. }
353.
354. }
355.
356. package graph;
357.
358. import java.util.Comparator;
359.
360. public class SequenceComparator implements Comparator {
361.
362. public int compare(Object o1, Object o2) throws ClassCastException {
363. String str1 = (String) o1;
364. String str2 = (String) o2;
365. return str1.compareTo(str2);
366. }
367.
368. }
决策树ID3算法
原创
©著作权归作者所有:来自51CTO博客作者mb6434c781b2176的原创作品,请联系作者获取转载授权,否则将追究法律责任
上一篇:XML生成树型菜单
下一篇:(转)窗口间的关系与交互(二)

提问和评论都可以,用心的回复会被更多人看到
评论
发布评论
相关文章
-
机器学习 | 决策树ID3算法
ID3是Quinlan于1979年提出的,是机器学习中一种广为人知的一个算法,它的提出开创了决策树
人工智能 机器学习 分类 信息增益 数据 -
ML《决策树(一)ID3》
今天做个回顾和记录,简单做个学习,也是梳理下知识点,决策树的学习。本文的学习,
决策树 机器学习 ID3 信息增益 信息熵