呵呵,继续。
本节的学习内容:
4.从剩余的词中提取文本特征,即最能代表文本的词
5.用空间向量表示文本,空间向量需标准化,即将数值映射到-1到1之间
6.利用所获取的空间向量进行聚类分析
7.交叉验证
第四步,提取文本特征
本文使用KNN算法和SVM算法学习提取文本特征的思想。
研究最终目的。
训练材料:
语料 | 分类 |
腐化 "生活作风" "女色" "情妇" "权色" "生活糜烂" "生活堕落" | 生活作风 |
"东城" "西城" "崇文" "宣武" "朝阳" "海淀" "丰台" "石景山" "房山" "通州" "顺义" "大兴" "昌平 " "平谷" "怀柔" "门头沟" "密云" "延庆" | 北京 |
上访 信访 举报 揭发 揭露 "买官" "卖官" | 上访举报 |
"河大撞人" 撞人 | 我爸是 |
"送钱短信" OR ( 驾校 AND 交警 ) | 送钱短信 |
"乐东县" "保亭县" "陵水县" "琼中县" "白沙县" "昌江县" "屯昌县" "定安县" "澄迈县" "临高县" "儋州" "东方" "五指山" "万宁" "琼海" "文昌" "三亚" "海口" | 海南 |
训练结果就是跟上面语料和分类的有极高的相似度。
下面是基本的KNN算法。KNN.java
package com.antbee.cluster.knn;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
/**
* @author KNN算法主体类
* @version 创建时间:2011-4-2 下午03:47:28
* 类说明
*/
public class KNN {
/**
* 设置优先级队列的比较函数,距离越大,优先级越高
*/
private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
public int compare(KNNNode o1, KNNNode o2) {
if (o1.getDistance() >= o2.getDistance()) {
return 1;
} else {
return 0;
}
}
};
/**
* 获取K个不同的随机数
* @param k 随机数的个数
* @param max 随机数最大的范围
* @return 生成的随机数数组
*/
public List<Integer> getRandKNum(int k, int max) {
List<Integer> rand = new ArrayList<Integer>(k);
for (int i = 0; i < k; i++) {
int temp = (int) (Math.random() * max);
if (!rand.contains(temp)) {
rand.add(temp);
} else {
i--;
}
}
return rand;
}
/**
* 计算测试元组与训练元组之前的距离
* @param d1 测试元组
* @param d2 训练元组
* @return 距离值
*/
public double calDistance(List<Double> d1, List<Double> d2) {
double distance = 0.00;
for (int i = 0; i < d1.size(); i++) {
distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
}
return distance;
}
/**
* 执行KNN算法,获取测试元组的类别
* @param datas 训练数据集
* @param testData 测试元组
* @param k 设定的K值
* @return 测试元组的类别
*/
public String knn(List<List<Double>> datas, List<Double> testData, int k) {
PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);
List<Integer> randNum = getRandKNum(k, datas.size());
for (int i = 0; i < k; i++) {
int index = randNum.get(i);
List<Double> currData = datas.get(index);
String c = currData.get(currData.size() - 1).toString();
KNNNode node = new KNNNode(index, calDistance(testData, currData), c);
pq.add(node);
}
for (int i = 0; i < datas.size(); i++) {
List<Double> t = datas.get(i);
double distance = calDistance(testData, t);
KNNNode top = pq.peek();
if (top.getDistance() > distance) {
pq.remove();
pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));
}
}
return getMostClass(pq);
}
/**
* 获取所得到的k个最近邻元组的多数类
* @param pq 存储k个最近近邻元组的优先级队列
* @return 多数类的名称
*/
private String getMostClass(PriorityQueue<KNNNode> pq) {
Map<String, Integer> classCount = new HashMap<String, Integer>();
for (int i = 0; i < pq.size(); i++) {
KNNNode node = pq.remove();
String c = node.getC();
if (classCount.containsKey(c)) {
classCount.put(c, classCount.get(c) + 1);
} else {
classCount.put(c, 1);
}
}
int maxIndex = -1;
int maxCount = 0;
Object[] classes = classCount.keySet().toArray();
for (int i = 0; i < classes.length; i++) {
if (classCount.get(classes[i]) > maxCount) {
maxIndex = i;
maxCount = classCount.get(classes[i]);
}
}
return classes[maxIndex].toString();
}
}
KNNNode.java 结点类
package com.antbee.cluster.knn;
/**
* @author KNN结点类,用来存储最近邻的k个元组相关的信息
* @version 创建时间:2011-4-2 下午03:43:39
* 类说明
*/
public class KNNNode {
private int index; // 元组标号
private double distance; // 与测试元组的距离
private String c; // 所属类别
public KNNNode(int index, double distance, String c) {
super();
this.index = index;
this.distance = distance;
this.c = c;
}
public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
public String getC() {
return c;
}
public void setC(String c) {
this.c = c;
}
}
TestKNN.java 测试类
package com.antbee.cluster.knn;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
/**
* @author Weiya He E-mail:heweiya@gmail.com
* @version 创建时间:2011-4-2 下午03:49:04
* 类说明
*/
public class TestKNN {
/**
* 从数据文件中读取数据
* @param datas 存储数据的集合对象
* @param path 数据文件的路径
*/
public void read(List<List<Double>> datas, String path){
try {
BufferedReader br = new BufferedReader(new FileReader(new File(path)));
String data = br.readLine();
List<Double> l = null;
while (data != null) {
String t[] = data.split(" ");
l = new ArrayList<Double>();
for (int i = 0; i < t.length; i++) {
l.add(Double.parseDouble(t[i]));
}
datas.add(l);
data = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 程序执行入口
* @param args
*/
@Test
public void test() {
TestKNN t = new TestKNN();
String datafile = this.getClass().getClassLoader().getResource("datafile.txt").toString();
datafile = datafile.replace("file:/", "");//windows 环境上要做的一步
String testfile = this.getClass().getClassLoader().getResource("testfile.txt").toString();
testfile = testfile.replace("file:/", "");//windows 环境上要做的一步
try {
List<List<Double>> datas = new ArrayList<List<Double>>();
List<List<Double>> testDatas = new ArrayList<List<Double>>();
t.read(datas, datafile);
t.read(testDatas, testfile);
KNN knn = new KNN();
for (int i = 0; i < testDatas.size(); i++) {
List<Double> test = testDatas.get(i);
System.out.print("测试元组: ");
for (int j = 0; j < test.size(); j++) {
System.out.print(test.get(j) + " ");
}
System.out.print("类别为: ");
System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 2)))));
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
datafile.txt文件内容:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
testfile.txt文件内容:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
最终的运行结果:
测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1
测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1
测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1
测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0
测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1
测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0
下面的工作就是如何让汉字也成为如上的Long类型的数字呢,我们现在使用词频的空间向量来代替这些文字。