MapReduce之KNN算法
什么是
-邻近算法(KNN)
KNN分类问题是找出一个数据集中与一个给定查询数据点最近的个数据点。这个操作也称KNN连接。定义为:给定两个数据集
和
,对于
中的每一个对象,希望从
中找到
个最近的相邻对象。其中
为查询数据集,
为训练数据集
KNN分类
KNN的中心思想为建立一个分类方法,使得对于将(响应变量)与
(预测变量)关联的“平滑“函数
的形式没有任何假设:
函数是非参数化的,在KNN中,给定一个新的点
,要动态识别训练数据集中与
相似的
个观察(
个邻近)。近邻由一个距离或不相似度来定义。通过计算查询对象与所有训练数据对象之间的欧氏距离,然后将这个查询对象分配到
个最近的训练数据中大多数对象所在的类。因为要计算每个对象之间的距离,所以所有数据类型必须为double。
KNN距离函数
给定如下两个维对象
和
:
欧氏距离
曼哈顿距离
闵可夫斯基距离
KNN算法非形式化描述
KNN算法可以总结为以下的简单步骤:
- 1、确定
(
取决于具体需求)
- 2、计算新输入与所有训练数据之间的距离
- 3、对距离进行排序,并根据第
个最小距离确定
个近邻
- 4、收集这些近邻所属的类别
- 5、根据多数投票确定新输入数据类别
MapReduce解决方案
在理解了KNN算法的步骤之后,理解MapReduce方案就简单了,在映射器运行之前将训练集中的数据读取出来,接下来通过计算每条数据与训练集数据中的距离,对距离进行排序,根据多数投票原则确定新输入数据类别,整个操作过程使用映射器即可实现。
输入数据
S.txt文件如下
100;c1;1.0,1.0
101;c1;1.1,1.2
102;c1;1.2,1.0
103;c1;1.6,1.5
104;c1;1.3,1.7
105;c1;2.0,2.1
106;c1;2.0,2.2
107;c1;2.3,2.3
208;c2;9.0,9.0
209;c2;9.1,9.2
210;c2;9.2,9.0
211;c2;10.6,10.5
212;c2;10.3,10.7
213;c2;9.6,9.1
214;c2;9.4,10.4
215;c2;10.3,10.3
300;c3;10.0,1.0
301;c3;10.1,1.2
302;c3;10.2,1.0
303;c3;10.6,1.5
304;c3;10.3,1.7
305;c3;1.0,2.1
306;c3;10.0,2.2
307;c3;10.3,2.3R.txt文件如下:
1000;3.0,3.0
1001;10.1,3.2
1003;2.7,2.7
1004;5.0,5.0
1005;13.1,2.2
1006;12.7,12.7mapper阶段任务
这个阶段的主要任务两个:
- 1、读取训练集中的数据
- 2、计算训练集数据与输入数据距离并根据投票原则实现分类
mapper阶段编码
package com.deng.KNN;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import java.io.IOException;
import java.util.List;
import java.util.SortedMap;
import java.util.TreeMap;
public class KNNMapper extends Mapper<LongWritable,Text,Text, Text> {
private static Text reduceKey;
private static Text reduceValue;
private static List<Point> training=null;
public static List<Point> readTrainingFromHFDS() throws IOException{
return KNNUtil.readFromHDFS("input/S.txt");
}
//从文件系统中读取数据并存入链表中
public void setup(Context context) throws IOException{
training=readTrainingFromHFDS();
}
public void map(LongWritable key,Text value,Context context){
String line=value.toString();
Point query=new Point(line); //查询数据
SortedMap<Double,Point> top=new TreeMap<Double, Point>(); //按照距离由小到大存取
for(int i=0;i<training.size();i++){
double distance=KNNUtil.calculateEuclidianDistance(query.getVector(),training.get(i).getVector());
top.put(distance,training.get(i));
if(top.size()>5){
top.remove(top.firstKey());
}
}
//根据投票原则进行分类,majorityVote为输入数据按照投票原则分类到的祖
String majorityVote=null;
int maxCount=0;
for(Point p:top.values()) {
p.addCount();
if (p.getGroupCount() > maxCount) {
maxCount = p.getGroupCount();
majorityVote = p.getGroup();
}
}
reduceKey=new Text(query.getGroup());
reduceValue=new Text(majorityVote);
try {
context.write(reduceKey,reduceValue);
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}工具类KNNUtil如下
package com.deng.KNN;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Vector;
public class KNNUtil {
//计算两个输入数据欧氏距离
public static double calculateEuclidianDistance(Vector<Double> query,Vector<Double> training){
double sum=0.0;
for(int i=0;i<query.size();i++){
sum+=Math.pow(training.get(i)-query.get(i),2);
}
return sum;
}
//从文件系统中读取数据
public static List<Point> readFromHDFS(String p) throws IOException{
BufferedReader br=new BufferedReader(new FileReader(p));
String str;
int k=0;
List<Point> points=new ArrayList<>();
while((str=br.readLine())!=null){
Point point=new Point(str);
System.out.println(point);
points.add(point);
}
br.close();
return points;
}
}自定义类point如下
package com.deng.KNN;
import java.util.Vector;
public class Point {
private String group;
private Integer groupCount;
private Vector<Double> vector=new Vector<>();
public Point(){}
public Point(String s){
// 输入数据中,训练集数据和输入数据输入格式不同,利用长度来进行区分并标记
String[] line=s.split(";");
if(line.length==3){
group=line[1];
String[] tokens=line[2].split(",");
for(int i=0;i<tokens.length;i++){
vector.add(Double.parseDouble(tokens[i]));
}
}else{
group=line[0];
String[] tokens=line[1].split(",");
for(int i=0;i<tokens.length;i++){
vector.add(Double.parseDouble(tokens[i]));
}
}
groupCount=0;
}
public String getGroup() {
return group;
}
public Vector<Double> getVector() {
return vector;
}
public Integer getGroupCount() {
return groupCount;
}
//封装加法操作
public void addCount(){
this.groupCount++;
}
@Override
public String toString() {
return "Point{" +
"group='" + group + '\'' +
", vector=" + vector +
'}';
}
}驱动程序如下
package com.deng.KNN;
import com.deng.util.FileUtil;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import java.io.IOException;
public class KNNDriver {
public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
FileUtil.deleteDirs("output");
String[] otherArgs=new String[]{"input/R.txt","output"};
Configuration conf=new Configuration();
Job job=new Job(conf,"KNN");
job.setJarByClass(KNNDriver.class);
job.setMapperClass(KNNMapper.class);
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(Text.class);
job.setNumReduceTasks(0);
FileInputFormat.addInputPath(job,new Path(otherArgs[0]));
FileOutputFormat.setOutputPath(job,new Path(otherArgs[1]));
System.exit((job.waitForCompletion(true)?0:1));
}
}运行结果如下

















