ID3是数据挖掘分类中的一种(是一种if-then的模式),其中运用到熵的概念,表示随机变量不确定性的度量
H(x)=-∑pi *log pi
信息增益是指特征A对训练数据集D的信息增益g(D,A),定义为集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差
g(D,A)=H(D)-H(D|A)
其中H(Y|X)=∑pi H(Y|X=xi)
Pi=P(x=xi)
ID3 是一种自顶向下增长树的贪婪算法,在每个结点选取能最好地分类样例的属性。继续这个过程直到这棵树能完美分类训练样例,或所有的属性都使用过了。
ID3算法流程
ID3(Examples,Target_attribute,Attributes)
Examples 即训练样例集。Target_attribute 是这棵树要预测的目标属性。Attributes
是除目标属性外供学习到的决策树测试的属性列表。返回能正确分类给定
Examples 的决策树。
创建树的 Root 结点
如果 Examples 都为正,那么返回 label =+ 的单结点树 Root
如果 Examples 都为反,那么返回 label =- 的单结点树 Root
如果 Attributes 为空,那么返回单结点树 Root,label=Examples 中最普遍的
Target_attribute 值
否则
A←Attributes 中分类 Examples 能力最好*的属性
Root 的决策属性←A
对于A的每个可能值v
在Root下加一个新的分支对应测试A= vi
令Examples vi 为Examples中满足A属性值为v i的子集
如果的子集Examples vi 为空在这个新分支下加一个叶子结点,结点的 label=Examples vi
中最普遍的 Target_attribute 值
否则在这个新分支下加一个子树 ID3(Examples vi ,Target_attribute, Attributes-{A})
结束
返回 Root
其主要代码如下
1 /**
2 * 利用源数据构造决策树
3 * @param node 正在处理处理的节点,
4 * @param parentAttrValue父节点划分的属性
5 */
6 private void buildDecisionTree(AttrNode node, String parentAttrValue,
7 String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
8 node.setParentAttrValue(parentAttrValue);
9
10 String attrName = "";
11 double gainValue = 0;
12 double tempValue = 0;
13
14 // 如果只有1个属性则直接返回
15 if (remainAttr.size() == 1) {
16 System.out.println("attr null");
17 return;
18 }
19
20 // 选择剩余属性中信息增益最大的作为下一个分类的属性
21 for (int i = 0; i < remainAttr.size(); i++) {
22 // 判断是否用ID3算法还是C4.5算法
23 if (isID3) {
24 // ID3算法采用的是按照信息增益的值来比
25 tempValue = computeGain(remainData, remainAttr.get(i));
26 } else {
27 // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
28 tempValue = computeGainRatio(remainData, remainAttr.get(i));
29 }
30
31 if (tempValue > gainValue) {
32 gainValue = tempValue;
33 attrName = remainAttr.get(i);
34 }
35 }
36
37 node.setAttrName(attrName);
38 ArrayList<String> valueTypes = attrValue.get(attrName);
39 remainAttr.remove(attrName);//将选择的属性从剩余的属性集合中去除
40
41 AttrNode[] childNode = new AttrNode[valueTypes.size()];
42 String[][] rData;
43 for (int i = 0; i < valueTypes.size(); i++) {
44 // 移除非此值类型的数据
45 rData = removeData(remainData, attrName, valueTypes.get(i));
46
47 childNode[i] = new AttrNode();
48 boolean sameClass = true;
49 ArrayList<String> indexArray = new ArrayList<>();
50 for (int k = 1; k < rData.length; k++) {//rdata[0]保存的是attrName
51 indexArray.add(rData[k][0]);//将编号统计进去
52 // 判断是否为同一类的,是否同为yes或者同为no
53 if (!rData[k][attrNames.length - 1]
54 .equals(rData[1][attrNames.length - 1])) {
55 // 只要有1个不相等,就不是同类型的
56 sameClass = false;
57 break;
58 }
59 }
60
61 if (!sameClass) {
62 // 创建新的对象属性,对象的同个引用会出错,rAttr是剩余的属性
63 ArrayList<String> rAttr = new ArrayList<>();
64 for (String str : remainAttr) {
65 rAttr.add(str);
66 }
67
68 buildDecisionTree(childNode[i], valueTypes.get(i), rData,
69 rAttr, isID3);
70 } else {
71 // 如果是同种类型,则直接为数据节点
72 childNode[i].setParentAttrValue(valueTypes.get(i));
73 childNode[i].setChildDataIndex(indexArray);
74 }
75
76 }
77 node.setChildAttrNode(childNode);
78 }
View Code
计算信息增益
/**
* 为某个属性计算信息增益
*
* @param remainData
* 剩余的数据
* @param value
* 待划分的属性名称
* @return
*/
private double computeGain(String[][] remainData, String value) {
double gainValue = 0;
// 源熵的大小将会与属性划分后进行比较
double entropyOri = 0;
// 子划分熵和
double childEntropySum = 0;
// 属性子类型的个数
int childValueNum = 0;
// 属性值的种数
ArrayList<String> attrTypes = attrValue.get(value);
// 子属性对应的权重比
HashMap<String, Integer> ratioValues = new HashMap<>();
for (int i = 0; i < attrTypes.size(); i++) {
// 首先都统一计数为0
ratioValues.put(attrTypes.get(i), 0);
}
// 还是按照一列,从左往右遍历
for (int j = 1; j < attrNames.length; j++) {
// 判断是否到了划分的属性列
if (value.equals(attrNames[j])) {
for (int i = 1; i <= remainData.length - 1; i++) {
childValueNum = ratioValues.get(remainData[i][j]);
// 增加个数并且重新存入
childValueNum++;
ratioValues.put(remainData[i][j], childValueNum);
}
}
}
// 计算原熵的大小
entropyOri = computeEntropy(remainData, value, null, true);
for (int i = 0; i < attrTypes.size(); i++) {
double ratio = (double) ratioValues.get(attrTypes.get(i))
/ (remainData.length - 1);
childEntropySum += ratio
* computeEntropy(remainData, value, attrTypes.get(i), false);
// System.out.println("ratio:value: " + ratio + " " +
// computeEntropy(remainData, value,
// attrTypes.get(i), false));
}
// 二者熵相减就是信息增益
gainValue = entropyOri - childEntropySum;
return gainValue;
}
View Code
若使用C4.5就会使用信息增益比
1 /**
2 * 计算信息增益比
3 *
4 * @param remainData
5 * 剩余数据
6 * @param value
7 * 待划分属性
8 * @return
9 */
10 private double computeGainRatio(String[][] remainData, String value) {
11 double gain = 0;
12 double spiltInfo = 0;
13 int childValueNum = 0;
14 // 属性值的种数
15 ArrayList<String> attrTypes = attrValue.get(value);
16 // 子属性对应的权重比
17 HashMap<String, Integer> ratioValues = new HashMap<>();
18
19 for (int i = 0; i < attrTypes.size(); i++) {
20 // 首先都统一计数为0
21 ratioValues.put(attrTypes.get(i), 0);
22 }
23
24 // 还是按照一列,从左往右遍历
25 for (int j = 1; j < attrNames.length; j++) {
26 // 判断是否到了划分的属性列
27 if (value.equals(attrNames[j])) {
28 for (int i = 1; i <= remainData.length - 1; i++) {
29 childValueNum = ratioValues.get(remainData[i][j]);
30 // 增加个数并且重新存入
31 childValueNum++;
32 ratioValues.put(remainData[i][j], childValueNum);
33 }
34 }
35 }
36
37 // 计算信息增益
38 gain = computeGain(remainData, value);
39 // 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
40 for (int i = 0; i < attrTypes.size(); i++) {
41 double ratio = (double) ratioValues.get(attrTypes.get(i))
42 / (remainData.length - 1);
43 spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0);
44 }
45
46 // 计算机信息增益率
47 return gain / spiltInfo;
48 }
49
50 /**
51 * 利用源数据构造决策树
52 * @param node 正在处理处理的节点,
53 * @param parentAttrValue父节点划分的属性
54 */
55 private void buildDecisionTree(AttrNode node, String parentAttrValue,
56 String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
57 node.setParentAttrValue(parentAttrValue);
58
59 String attrName = "";
60 double gainValue = 0;
61 double tempValue = 0;
62
63 // 如果只有1个属性则直接返回
64 if (remainAttr.size() == 1) {
65 System.out.println("attr null");
66 return;
67 }
68
69 // 选择剩余属性中信息增益最大的作为下一个分类的属性
70 for (int i = 0; i < remainAttr.size(); i++) {
71 // 判断是否用ID3算法还是C4.5算法
72 if (isID3) {
73 // ID3算法采用的是按照信息增益的值来比
74 tempValue = computeGain(remainData, remainAttr.get(i));
75 } else {
76 // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
77 tempValue = computeGainRatio(remainData, remainAttr.get(i));
78 }
79
80 if (tempValue > gainValue) {
81 gainValue = tempValue;
82 attrName = remainAttr.get(i);
83 }
84 }
85
86 node.setAttrName(attrName);
87 ArrayList<String> valueTypes = attrValue.get(attrName);
88 remainAttr.remove(attrName);//将选择的属性从剩余的属性集合中去除
89
90 AttrNode[] childNode = new AttrNode[valueTypes.size()];
91 String[][] rData;
92 for (int i = 0; i < valueTypes.size(); i++) {
93 // 移除非此值类型的数据
94 rData = removeData(remainData, attrName, valueTypes.get(i));
95
96 childNode[i] = new AttrNode();
97 boolean sameClass = true;
98 ArrayList<String> indexArray = new ArrayList<>();
99 for (int k = 1; k < rData.length; k++) {//rdata[0]保存的是attrName
100 indexArray.add(rData[k][0]);//将编号统计进去
101 // 判断是否为同一类的,是否同为yes或者同为no
102 if (!rData[k][attrNames.length - 1]
103 .equals(rData[1][attrNames.length - 1])) {
104 // 只要有1个不相等,就不是同类型的
105 sameClass = false;
106 break;
107 }
108 }
109
110 if (!sameClass) {
111 // 创建新的对象属性,对象的同个引用会出错,rAttr是剩余的属性
112 ArrayList<String> rAttr = new ArrayList<>();
113 for (String str : remainAttr) {
114 rAttr.add(str);
115 }
116
117 buildDecisionTree(childNode[i], valueTypes.get(i), rData,
118 rAttr, isID3);
119 } else {
120 // 如果是同种类型,则直接为数据节点
121 childNode[i].setParentAttrValue(valueTypes.get(i));
122 childNode[i].setChildDataIndex(indexArray);
123 }
124
125 }
126 node.setChildAttrNode(childNode);
127 }
View Code
ID3 : 归纳偏置的更贴切近似:较短的树比较长的得到优先。那些信息增益高的属性
更靠近根结点的树得到优先