注:本算法的实现仅仅适用于小规模数据集的实验与测试,不适合用于工程应用

算法假定训练数据各属性列的值均是离散类型的。若是非离散类型的数据,需要首先进行数据的预处理,将非离散型的数据离散化。

算法中使用到了DecimalCaculate类,该类是java中BigDecimal类的扩展,用于高精度浮点数的运算。该类的实现同本人转载的一篇博文:对BigDecimal常用方法的归类中的Arith类相同。

算法实现的代码如下

package Bayes; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import util.DecimalCalculate; /** * 贝叶斯主体类 * @author Rowen * @qq 443773264 * @mail luowen3405@163.com * @blog blog.csdn.net/luowen3405 * @data 2011.03.15 */ public class Bayes { /** * 将原训练元组按类别划分 * @param datas 训练元组 * @return Map<类别,属于该类别的训练元组> */ Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){ Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>(); ArrayList<String> t = null; String c = ""; for (int i = 0; i < datas.size(); i++) { t = datas.get(i); c = t.get(t.size() - 1); if (map.containsKey(c)) { map.get(c).add(t); } else { ArrayList<ArrayList<String>> nt = new ArrayList<ArrayList<String>>(); nt.add(t); map.put(c, nt); } } return map; } /** * 在训练数据的基础上预测测试元组的类别 * @param datas 训练元组 * @param testT 测试元组 * @return 测试元组的类别 */ public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) { Map<String, ArrayList<ArrayList<String>>> doc = this.datasOfClass(datas); Object classes[] = doc.keySet().toArray(); double maxP = 0.00; int maxPIndex = -1; for (int i = 0; i < doc.size(); i++) { String c = classes[i].toString(); ArrayList<ArrayList<String>> d = doc.get(c); double pOfC = DecimalCalculate.div(d.size(), datas.size(), 3); for (int j = 0; j < testT.size(); j++) { double pv = this.pOfV(d, testT.get(j), j); pOfC = DecimalCalculate.mul(pOfC, pv); } if(pOfC > maxP){ maxP = pOfC; maxPIndex = i; } } return classes[maxPIndex].toString(); } /** * 计算指定属性列上指定值出现的概率 * @param d 属于某一类的训练元组 * @param value 列值 * @param index 属性列索引 * @return 概率 */ private double pOfV(ArrayList<ArrayList<String>> d, String value, int index) { double p = 0.00; int count = 0; int total = d.size(); ArrayList<String> t = null; for (int i = 0; i < total; i++) { if(d.get(i).get(index).equals(value)){ count++; } } p = DecimalCalculate.div(count, total, 3); return p; } }

算法测试类:

package Bayes; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.StringTokenizer; /** * 贝叶斯算法测试类 * @author Rowen * @qq 443773264 * @mail luowen3405@163.com * @blog blog.csdn.net/luowen3405 * @data 2011.03.15 */ public class TestBayes { /** * 读取测试元组 * @return 一条测试元组 * @throws IOException */ public ArrayList<String> readTestData() throws IOException{ ArrayList<String> candAttr = new ArrayList<String>(); BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); String str = ""; while (!(str = reader.readLine()).equals("")) { StringTokenizer tokenizer = new StringTokenizer(str); while (tokenizer.hasMoreTokens()) { candAttr.add(tokenizer.nextToken()); } } return candAttr; } /** * 读取训练元组 * @return 训练元组集合 * @throws IOException */ public ArrayList<ArrayList<String>> readData() throws IOException { ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>(); BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); String str = ""; while (!(str = reader.readLine()).equals("")) { StringTokenizer tokenizer = new StringTokenizer(str); ArrayList<String> s = new ArrayList<String>(); while (tokenizer.hasMoreTokens()) { s.add(tokenizer.nextToken()); } datas.add(s); } return datas; } public static void main(String[] args) { TestBayes tb = new TestBayes(); ArrayList<ArrayList<String>> datas = null; ArrayList<String> testT = null; Bayes bayes = new Bayes(); try { System.out.println("请输入训练数据"); datas = tb.readData(); while (true) { System.out.println("请输入测试元组"); testT = tb.readTestData(); String c = bayes.predictClass(datas, testT); System.out.println("The class is: " + c); } } catch (IOException e) { e.printStackTrace(); } } }

训练数据:

youth high no fair no youth high no excellent no middle_aged high no fair yes senior medium no fair yes senior low yes fair yes senior low yes excellent no middle_aged low yes excellent yes youth medium no fair no youth low yes fair yes senior medium yes fair yes youth medium yes excellent yes middle_aged medium no excellent yes middle_aged high yes fair yes senior medium no excellent no

对原训练数据进行测试,测试如果如下:

请输入测试元组 youth high no fair The class is: no 请输入测试元组 youth high no excellent The class is: no 请输入测试元组 middle_aged high no fair The class is: yes 请输入测试元组 senior medium no fair The class is: yes 请输入测试元组 senior low yes fair The class is: yes 请输入测试元组 senior low yes excellent The class is: yes 请输入测试元组 middle_aged low yes excellent The class is: yes 请输入测试元组 youth medium no fair The class is: no 请输入测试元组 youth low yes fair The class is: yes 请输入测试元组 senior medium yes fair The class is: yes 请输入测试元组 youth medium yes excellent The class is: yes 请输入测试元组 middle_aged medium no excellent The class is: yes 请输入测试元组 middle_aged high yes fair The class is: yes 请输入测试元组 senior medium no excellent The class is: no

测试结果显示14个测试实例中有13个分类是正确的,正确率为93%,说明算法能够给出一个准确的预测与分类,但是算法还需改进以提高正确率。

改进的可选方法之一:

为避免单个属性值对分类结果的权重过大,例如当某属性值在某一类中出现0次时,该属性值就决定了测试实例已经不可能属于该类了,这就可能会造成误差,因此在计算概率时可能进行如下改进:

将原先的P(Xk|Ci)=|Xk| / |Ci| 改为P(Xk|Ci)=(|Xk|+mp) / (|Ci|+m),其中m可设定为训练元组的个数,p为等可能假设的先验概率。