题 目:基于KNN算法实现的单个图片数字识别 

1.问题描述

本文题目选自Kaggle官网中的0到9单个数字识别问题。

网址:Digit Recognizer | Kaggle

图片中数字识别 python_List

 

其中的训练数据集和测试数据集均来自Kaggle官网。

     训练集数据为train.csv,测试集数据为test.csv。

     现在需要根据训练集数据进行训练模型并以此对测试集数据中的数字进行测试识别并给出最终的识别结果。

2.KNN算法思想简介

kNN算法,即K最近邻(k-NearestNeighbor)分类算法,是最简单的机器学习算法之

一。算法思想很简单:从训练样本集中选择k个与测试样本“距离”最近的样本,这k个样本中出现频率最高的类别即作为测试样本的类别。

KNN算法所解决的问题简要描述如下:

目标:分类未知类别案例。

输入:待分类未知类别案例项目。已知类别案例集合D ,其中包含 j个已知类别的案例。

输出:项目可能的类别。

我们考虑样本为二维的情况下,利用knn方法进行二分类的问题。图中三角形和方形是已知类别的样本点,这里我们假设三角形为正类,方形为负类。图中圆形点是未知类别的数据,我们要利用这些已知类别的样本对它进行分类。

图片中数字识别 python_图片中数字识别 python_02

 


k 近邻算法例子示意图




分类过程如下:

  1. 首先我们事先定下k值(就是指k近邻方法的k的大小,代表对于一个待分类的数

据点,我们要寻找几个它的邻居)。这边为了说明问题,我们取两个k值,分别为3和5。

2.根据事先确定的距离度量公式(如:欧氏距离),得出待分类数据点和所有已知类样

本点中,距离最近的k个样本。

3.统计这k个样本点中,各个类别的数量。如上图,如果我们选定k值为3,则正类样本(三角形)有2个,负类样本(方形)有1个,那么我们就把这个圆形数据点 定为正类;而如果我们选择k值为5,则正类样本(三角形)有2个,负类样本(方形)有3个,那么我们这个数据点定为负类。即,根据k个样本中,数量最多的 样本是什么类别,我们就把这个数据点定为什么类别。

3.基于KNN算法实现的单个数字识别

3.1实现流程

在上述KNN算法思想的基础上,现在具体实现对图片上的单个数字进行识别。

详细实现步骤如下:

1)读取训练集train.csv文件中的每一行并保存到内存,其中训练集的每一行共785个数字,第一个数字为该图片的真实数字值,剩下的784为图片对应的一维向量。

2)读取训练集test.csv文件中的图片向量并保存到内存,其中每个图片是由28*28的矩阵实现,其对应的特征向量为1*784的一维数组。

3) 遍历内存中每个测试图片向量,并计算该测试图片向量和每个训练集中的训练图片向量的欧氏向量距离。选取距离最小的前K训练图片上的数字,其中出现频次最大的训练图片,其数字值即认为是当前测试图片上的数字值。

注意:本次实验中的每个图片(训练图片以及测试图片),均由28*28的数字方阵实现,以行优先转换为1*784一维向量。

两个n维向量a(x11,x12,…,x1n)与 b(x21,x22,…,x2n)间的欧氏距离表达式如下:

图片中数字识别 python_i++_03

 

3.2核心算法描述

function classify()中的参数解释:

输入:

testList:本次进行归类的测试图片对应的一维特征向量(注意:以集合形式表示)。

K:0~20之间的数字,决定归类的范围大小。

输出:本次进行归类的测试图片最终归类的数字值。

       

/**

         * @Title: classify

         * @Description:

         * @param  testList 测试图片向量

         * @param  k K值决定临近K个图片

         * @returnint  

         * @throws

         */

         int classify(ArrayList<Integer> testList, intk) {

                   HashMap<Double, String> distanceMap = new HashMap<Double, String>();

                   // 计算当前图片向量和每个训练集向量的距离

                   Iterator<Entry<String, ArrayList<Integer>>> train_iter = trainingMap.entrySet().iterator();

                   while (train_iter.hasNext()) {

                            Map.Entry<String, ArrayList<Integer>> train_entry = (Entry<String, ArrayList<Integer>>) train_iter.next();

                            String train_key = (String) train_entry.getKey();

                            ArrayList<Integer> trainList = train_entry.getValue();                    

                            doubledistance = getDistance(trainList, testList);

                            distanceMap.put(distance, train_key);

                   }

                  

                   // 初始化每个数字出现的频次map

                   HashMap<String, Integer> countMap = new HashMap<String, Integer>();

                   for(inti=0;i<10; i++) {

                            countMap.put(String.valueOf(i), 0);

                   }

                  

                   // 将距离map排序,并选取前K小的图片

                   SortedMap<Double, String> sortMap = new TreeMap<Double, String>(distanceMap);

                   Set<Entry<Double, String>> sort_entry = sortMap.entrySet();

                   Iterator<Entry<Double, String>> sort_it = sort_entry.iterator();

                  

                   inti = 0;

                   while (sort_it.hasNext() && i < k) {

                            Entry<Double, String> entry = sort_it.next();

                            String str = entry.getValue();

                            String digit = str.split("/")[1];

                           

                            intnum = countMap.get(digit);

                            num ++;

                            i++;

                            countMap.put(digit, num);

                   }

                  

                   // 删选出现频次最大的数字作为当前图片所属数字

                   intmax = 0;

                   String targetValue = "unknown";

                  

                   Iterator<Entry<String, Integer>> iter2 = countMap.entrySet().iterator();

                   while(iter2.hasNext()) {

                            Map.Entry<String, Integer> entry = (Entry<String, Integer>) iter2.next();

                            String digit = entry.getKey();

                            intnum = entry.getValue();

                            if(num > max) {

                                     max = num;

                                     targetValue = digit;

                            }

                   }

                   return Integer.valueOf(targetValue);

         }

针对当前的测试图片向量,首先计算当前测试图片向量和每个训练图片向量的距离,选取其中距离最小的前K个训练图片向量,并统计其中出现频次最大的图片对应的数字值,该数字值即为当前测试图片最后预测的数字值。

3.2实验结果与分析

本次实验环境:windows10 + eclipse,程序中读取CSV文件使用到第三方的javacsv.jar包。

本次实验中,train.csv文件中共21155个训练图片向量,test.csv文件中共14687个测试图片向量,K值取20,测试图片的数字值归类大部分“预测”正确,只有少部分的图片判断错误。

以下是程序运行结果的部分截图:

Test.csv中第1432行,图片数字值为”4“,最终归类为4,正确。



图片中数字识别 python_git_04

 

Test.csv中第1433行,图片数字值为”9“,最终归类为9,正确。

 

图片中数字识别 python_List_05

 



Test.csv中第1434行,图片数字值为”0“,最终归类为0,正确。

图片中数字识别 python_git_06

 

Test.csv中第14686行,图片数字值为”8“,最终归类为8,正确。

图片中数字识别 python_List_07

 

4源码                                    

最后附上本次基于KNN思想实现单个数字图片识别的全部源码。

/** 

* @Title: DigitClassification.java

* @Package com.org.meify

* @Description: 单个数字图片识别

* @authormeify

* @date 2016年1月15日下午4:19:04

* @version V1.0 

*/

publicclass DigitClassification {

         // 训练集图片向量map

         privatestatic HashMap<String, ArrayList<Integer>> trainingMap = new HashMap<String, ArrayList<Integer>>();

         // 测试集图片向量map

         privatestatic HashMap<String, ArrayList<Integer>> testMap = new HashMap<String, ArrayList<Integer>>();

        

         /**

         * @Title: getSubArray

         * @Description: 从数字下标1开始截取子集合

         * @param  arr

         * @return ArrayList<Integer>  

         * @throws

         */

         public ArrayList<Integer> getSubArray(String[] arr) {

                   ArrayList<Integer> list = new ArrayList<Integer>();

                   for (inti = 1; i < arr.length; i++) {

                            list.add(Integer.valueOf(arr[i]));

                   }

                   returnlist;

         }

        

         /**

         * @Title: toList

         * @Description: 将数组转为集合

         * @param  arr 

         * @return ArrayList<Integer>  

         * @throws

         */

         public ArrayList<Integer> toList(String[] arr) {

                   ArrayList<Integer> list = new ArrayList<Integer>();

                   for (inti = 0; i < arr.length; i++) {

                            list.add(Integer.valueOf(arr[i]));

                   }

                   returnlist;

         }

        

         /**

         * @Title: getDistance

         * @Description: 计算两个向量之间的距离(向量欧式距离)

         * @param@param list1

         * @param@param list2   

         * @return double  

         * @throws

         */

         publicdouble getDistance(ArrayList<Integer> list1, ArrayList<Integer> list2) {

                   if(list1.size() != list2.size()) {

                            System.out.println("警告:两个向量大小不等");

                            return 0.0d;

                   }

                   intsum = 0;

                   for (inti = 0; i < list1.size(); i++) {

                            inta = list1.get(i);

                            intb = list2.get(i);

                            sum += (a - b) * (a - b);

                   }

                   return Math.sqrt((double) sum);

         }

                  

         /**

         * @Title: display

         * @Description: 以矩阵的形式展示图片点阵图片大小为28*28

         * @paramlist   

         * @return void  

         * @throws

         */

         void display(ArrayList<Integer> list) {

                   if(list.size() == 784) {

                            for(inti=0;i<784;i++) {

                                     System.out.print(list.get(i) + "  ");

                                     if((i+1)%28 == 0) {

                                               System.out.println();

                                     }

                            }

                   }

         }

        

         /**

         * @Title: loadTrainData

         * @Description: 加载训练集数据

         * @param  path   

         * @return void  

         * @throws

         */

         void loadTrainData(String path) {

                   try {

                            ArrayList<String[]> csvList = new ArrayList<String[]>();

                            CsvReader reader = new CsvReader(path, ',', Charset.forName("SJIS")); 

                            reader.readHeaders(); // 跳过CSV header栏

                            intindex = 1;

                            while (reader.readRecord()) {

                                     csvList.add(reader.getValues());

                                     String[] arr = reader.getValues();

                                     String key = "line" + index + "/" + arr[0]; // key:行号 + 图片的真实数字值ֵ

                                     ArrayList<Integer> values = getSubArray(arr);

                                     trainingMap.put(key, values);

                                     index ++;

                            }

                            reader.close();

                   } catch (Exception ex) {

                            System.out.println(ex);

                   }

         }

                  

         /**

         * @Title: loadTestData

         * @Description: 加载测试集数据

         * @param@param path   

         * @return void  

         * @throws

         */

         void loadTestData(String path) {

                   try {

                            ArrayList<String[]> csvList = new ArrayList<String[]>();

                            CsvReader reader = new CsvReader(path, ',', Charset.forName("SJIS")); 

                            reader.readHeaders(); // 跳过CSV header栏

                            intindex = 1;

                            while (reader.readRecord()) {

                                     csvList.add(reader.getValues());

                                     String[] arr = reader.getValues();

                                     String key = "line" + index; // key:行号

                                     ArrayList<Integer> values = toList(arr);

                                     testMap.put(key, values);

                                     index ++;

                            }

                            reader.close();

                   } catch (Exception ex) {

                            System.out.println(ex);

                   }

         }

                  

         /**

         * @Title: classify

         * @Description:

         * @param  testList 测试图片向量

         * @param  k K值决定临近K个图片

         * @returnint  

         * @throws

         */

         int classify(ArrayList<Integer> testList, intk) {

                   HashMap<Double, String> distanceMap = new HashMap<Double, String>();

                   // 计算当前图片向量和每个训练集向量的距离

                   Iterator<Entry<String, ArrayList<Integer>>> train_iter = trainingMap.entrySet().iterator();

                   while (train_iter.hasNext()) {

                            Map.Entry<String, ArrayList<Integer>> train_entry = (Entry<String, ArrayList<Integer>>) train_iter.next();

                            String train_key = (String) train_entry.getKey(); //                                  

ArrayList<Integer> trainList = train_entry.getValue(); //                           

                            doubledistance = getDistance(trainList, testList);

                            distanceMap.put(distance, train_key);

                   }

                  

                   // 初始化每个数字出现的频次map

                   HashMap<String, Integer> countMap = new HashMap<String, Integer>();

                   for(inti=0;i<10; i++) {

                            countMap.put(String.valueOf(i), 0);

                   }

                  

                   // 将距离map排序,并选取前K小的图片

                   SortedMap<Double, String> sortMap = new TreeMap<Double, String>(distanceMap);

                   Set<Entry<Double, String>> sort_entry = sortMap.entrySet();

                   Iterator<Entry<Double, String>> sort_it = sort_entry.iterator();

                  

                   inti = 0;

                   while (sort_it.hasNext() && i < k) {

                            Entry<Double, String> entry = sort_it.next();

                            String str = entry.getValue();

                            String digit = str.split("/")[1];

                            intnum = countMap.get(digit);

                            num ++;

                            i++;

                            countMap.put(digit, num);

                   }

                  

                   // 删选出现频次最大的数字作为当前图片所属数字

                   intmax = 0;

                   String targetValue = "unknown";

                   Iterator<Entry<String, Integer>> iter2 = countMap.entrySet().iterator();

                   while(iter2.hasNext()) {

                            Map.Entry<String, Integer> entry = (Entry<String, Integer>) iter2.next();

                            String digit = entry.getKey();

                            intnum = entry.getValue();

                            if(num > max) {

                                     max = num;

                                     targetValue = digit;

                            }

                   }

                   return Integer.valueOf(targetValue);

         }

                  

         /**

         * @Title: getRealDigits

         * @Description: 将所有的测试图片向量进行归类,使用KNN算法思想,得出每个测试图片上的数字

         * @param k   

         * @return void  

         * @throws

         */

         void getRealDigits(intk) {

                   Iterator<Entry<String, ArrayList<Integer>>> test_iter = testMap.entrySet().iterator();

                   intindex = 1;

                   while (test_iter.hasNext()) {

                            Map.Entry<String, ArrayList<Integer>> test_entry = (Entry<String, ArrayList<Integer>>) test_iter.next();

                            String test_key = (String) test_entry.getKey();  // 当前行号

                            ArrayList<Integer> testList = test_entry.getValue();  // 当前测试图片的向量

                            // 展示当前图片向量

                            display(testList);

                            intfinalDigit = classify(testList, k);

                            System.out.println("line:" + index + ",图片上的数字为:" + finalDigit);

                            index  ++;

                   }

         }

         publicstaticvoid main(String[] args) {

                   DigitClassification classification = new DigitClassification();

                   // 1.加载所有训练图片向量

                   classification.loadTrainData("D://train.csv");

                   // 2.加载所有测试图片向量

                   classification.loadTestData("D://test.csv");

                   // 3.利用KNN算法思想,对测试图片进行归类其中K值取20

                   classification.getRealDigits(20);

         }

}