实验资源
1998.csv
airports.csv
实验环境
VMware Workstation
Ubuntu 16.04
spark-2.4.5
scala-2.12.10
实验内容
“我们很抱歉地通知您,您乘坐的由 XX 飞往 XX 的 XXXX 航班延误。”
相信很多在机场等待飞行的旅客都不愿意听到这句话。随着乘坐飞机这种交通方式的逐渐普及,航延延误问题也一直困扰着我们。航班延误通常会造成两种结果,一种是航班取消,另一种是航班晚点。
在本次实验中,我们将通过 Spark 提供的 DataFrame、 SQL 和机器学习框架等工具,基于 D3.js 数据可视化技术,对航班起降的记录数据进行分析,尝试找出造成航班延误的原因,以及对航班延误情况进行预测。
实验步骤
一、数据集简介及准备
1、数据集简介
本节实验用到的航班数据集是 2009 年 Data Expo 上提供的飞行准点率统计数据。此次我们选用 1998 年的数据集。
该数据集的各个字段解释如下:
此外,我们还会用到一些补充信息。如机场信息数据集等。
2、下载数据集
(1)在虚拟机中输入如下命令
wget https://labfile.oss.aliyuncs.com/courses/610/1998.csv.bz2
(2)然后使用解压缩命令对其进行解压
bunzip2 1998.csv.bz2
解压后的 CSV 数据文件位于你使用解压命令时的工作目录中,默认情况是在 /home/用户名
目录中。
(3)同样地,下载 airports 机场信息数据集,命令如下所示
wget https://labfile.oss.aliyuncs.com/courses/610/airports.csv
3、数据清洗
由于 airports 数据集中含有一些非常用字符,我们需要对其进行清洗处理,以防止部分记录字符的不能被识别错误引起后续检索的错误。
OpenRefine 是 Google 主导开发的一款开源数据清洗工具。我们先在环境中安装它:
wget https://labfile.oss.aliyuncs.com/courses/610/openrefine-linux-3.2.tar.gz
tar -zxvf openrefine-linux-3.2.tar.gz
cd openrefine-3.2
# 启动命令
./refine
当出现下图所示的提示信息后,在浏览器中打开 URL http://127.0.0.1:3333/
。
Open Refine 启动成功的标志是出现 Point your browser to http://127.0.0.1:3333 to start using Refine
的提示。
浏览器中会出现 OpenRefine 的应用网页,如下图所示。请选择刚刚下载的机场信息数据集,并点击 Next 按钮进入下一步。
在数据解析步骤中,直接点击右上角的 Create Project
按钮创建数据清洗项目。
稍作等待,项目创建完成后,就可以对数据进行各种操作。在稍后会提供 OpenRefine 的详细教程,此处只需要按照提示对数据集进行相应操作即可。
点击 airport 列旁边的下拉菜单按钮,然后在菜单中选择 Edit Column -> Remove this column 选项,以移除 airport 列。具体操作如下图所示。
请按照同样的方法,移除 lat 和 long 列。最后的数据集应只包含 iata 、city、state、country 四列。
最后我们点击右上角的 Export 按钮导出数据集。导出选项选择 Comma-separated value,即 CSV 文件。
然后在弹出的下载提示对话框中选择“保存文件”,并确定。
该文件位于 /home/用户名/下载
目录中,请在文件管理器中将其剪切至 /home/用户名
目录,并覆盖源文件。步骤如下图所示。
首先双击打开桌面上的 主文件夹
,找到其中的 下载
目录。右键点击 CSV 文件,选择剪切。
然后回到主目录,在空白处右键点击,选择“粘贴”即可。
最后关闭浏览器和运行着 OpenRefine 的终端即可。
4、启动 Spark Shell
为了更好地处理 CSV 格式的数据集,我们可以直接使用由 DataBricks 公司提供的第三方 Spark CSV 解析库来读取。
首先是启动 Spark Shell。在启动的同时,附上参数--packages com.databricks:spark-csv_2.11:1.1.0
spark-shell --packages com.databricks:spark-csv_2.11:1.1.0
5、导入数据及处理格式
等待 Spark Shell 启动完成后,输入以下命令来导入数据集。
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
val flightData = sqlContext.read.format("com.databricks.spark.csv").option("header","true").load("/home/shiyanlou/1998.csv")
上述命令中,我们调用了 sqlContext 提供的 read 接口,指定加载格式 format 为第三方库中定义的格式 com.databricks.spark.csv
。同时设置了一个读取选项 header 为 true
,这表示将数据集中的首行内容解析为字段名称。最后在 load 方法中 指明了待读取的数据集文件为我们刚刚下载的这个数据集。
此时, flightData
的数据类型为 Spark SQL 中常用的 DataFrame。着将 flightData 其注册为临时表,命令为:
flightData.registerTempTable("flights")
使用相同的方法导入机场信息数据集 airports.csv ,并将其注册为临时表。
val airportData = sqlContext.read.format("com.databricks.spark.csv").option("header","true").load("/home/shiyanlou/airports-csv.csv")
airportData.registerTempTable("airports")
二、数据探索
1、问题设计
在探索数据之前,我们已经知道该数据共有 29 个字段。根据出发时间、出发 / 抵达延误时间等信息,我们可以大胆地提出下面这些问题:
- **每天航班最繁忙的时间段是哪些?**通常早晚都容易有大雾等极端天气,是否中午的时候到港和离港航班更多呢?
- **飞哪最准时?**在设计旅行方案时,如果达到某个目的地有两个相邻的机场,我们似乎可以比较到哪里更准时,以减少可能发生的延误给我们出行带来的影响。
- **出发延误的重灾区都有哪些?**同样,从哪些地方出发最容易遭到延误?下次再要从这些地方出发的时候,就要考虑是不是要改乘地面交通工具了。
2、问题解答
我们已经把数据注册为临时表,对于上述问题的解答实际上就变成了如何设计合适的 SQL 查询语句。在数据量非常大的时候,Spark SQL 的使用尤为方便,它能够直接从 HDFS 等分布式存储系统中拉取数据进行查询,并且能够最大化地利用集群性能进行查询计算。
(1)每天航班最繁忙的时间段是哪些
分析某个问题时,要想办法将解答问题的来源落实在数据集的各个指标上。当问题不够详细时,可以取一些具有代表性的值作为该问题的答案。
例如,航班分为到港(Arrive)和离港(Depart)航班,若统计所有机场在每天的某个时间段内离港航班数量,就能在一定程序上反映这个时段的航班是否繁忙。
数据集中的每一条记录都朴实地反映了航班的基本情况,但它们并不会直接告诉我们每一天、每一个时段都发生了什么。为了得到后者这样的信息,我们需要对数据进行筛选和统计。
于是我们会顺理成章地用到 AVG(平均值)、COUNT(计数)和 SUM(求和)等统计函数。
为了分时间段统计航班数量,我们可以大致地将一天的时间分为以下五段:
- 凌晨(00:00 - 06:00):大部分人在这个时段都在休息,所以我们可以合理假设该时间段内航班数量较少。
- 早上(06:01 - 10:00):一些早班机会选择在此时间出发,机场也通常从这个时间段起逐渐进入高峰。
- 中午(10:01 - 14:00):早上从居住地出发的人们通常在这个时候方便抵达机场,因此选择在该时间段出发的航班可能更多。
- 下午(14:01 - 19:00):同样,在下午出发更为方便,抵达目的地是刚好是晚上,又不至于太晚,方便找到落脚之处。
- 晚上(19:01 - 23:59):在一天结束之际,接近凌晨的航班数量可能会更少。
当我们所需的数据不是单个离散的数据而是基于一定范围的时候,我们可以用关键字 BETWEEN x AND y
来设置数据的起止范围。
有了上述准备,我们可以尝试写出统计离港时间在 0 点 至 6 点 间的航班总数。首先选取的目标是 flights 这张表,即 FROM flights
。航班总数可以对 FlightNum 进行统计(使用 COUNT 函数),即 COUNT(FlightNum)
。限定的条件是离港时间在 0 (代表 00:00)至 600 (代表 6:00)之间,即 WHERE DepTime BETWEEN 0 AND 600
。所以我们要写出的语句是:
val queryFlightNumResult = sqlContext.sql("SELECT COUNT(FlightNum) FROM flights WHERE DepTime BETWEEN 0 AND 600")
查看其中 1 条结果:
queryFlightNumResult.take(1)
在此基础上我们可以细化一下,计算出每天的平均离港航班数量,并且每次只选择 1 个月的数据。这里我们选择的时间段为 10:00 至 14:00 。
// COUNT(DISTINCT DayofMonth) 的作用是计算每个月的天数
val queryFlightNumResult1 = sqlContext.sql("SELECT COUNT(FlightNum)/COUNT(DISTINCT DayofMonth) FROM flights WHERE Month = 1 AND DepTime BETWEEN 1001 AND 1400")
查询得到的结果只有一条,即该月每天的平均离港航班数量。查看一下:
queryFlightNumResult1.take(1)
你可以尝试计算出其他时间段的平均离港航班数量,并作记录。
最终统计的结果表明:1998 年 1 月,每天最繁忙的时段为下午。该时段的平均离港航班数量为 4356.7 个。
(2)飞哪最准时
要看飞哪最准时,实际上就是统计航班到港准点率。可以先来查询到港延误时间为 0 的航班都是飞往哪里的。
在上面这句话中,有几个信息:
- 要查询的主要信息为目的地代码。
- 信息的来源为 flights 表。
- 查询的条件为到港延误时间(ArrDelay)为 0 。
在面对任何一个问题时,我们都可以仿照上面的思路对问题进行拆解,然后将每一条信息转化为对应的 SQL 语句。
于是最终我们可以得到这样的查询代码:
val queryDestResult = sqlContext.sql("SELECT DISTINCT Dest, ArrDelay FROM flights WHERE ArrDelay = 0")
取出其中 5 条结果来看看。
queryDestResult.head(5)
在此基础上,我们尝试加入更多的限定条件。
我们可以统计出到港航班延误时间为 0 的次数(准点次数),并且最终输出的结果为 [目的地, 准点次数] ,并且按照降序对它们进行排列。
val queryDestResult2 = sqlContext.sql("SELECT DISTINCT Dest, COUNT(ArrDelay) AS delayTimes FROM flights where ArrDelay = 0 GROUP BY Dest ORDER BY delayTimes DESC")
查看其中 10 条结果。
queryDestResult2.head(10)
在美国,一个州通常会有多个机场。我们在上一步得到的查询结果都是按照目的地的机场代码进行输出的。那么抽象到每一个州都有多少个准点的到港航班呢?
我们可以在上一次查询的基础上,再次进行嵌套的查询。并且,我们会用到另一个数据集 airports 中的信息:目的地中的三字代码(Dest)即该数据集中的 IATA 代码(iata),而每个机场都给出了它所在的州的信息(state)。我们可以通过一个联结操作将 airports 表加入到查询中。
val queryDestResult3 = sqlContext.sql("SELECT DISTINCT state, SUM(delayTimes) AS s FROM (SELECT DISTINCT Dest, COUNT(ArrDelay) AS delayTimes FROM flights WHERE ArrDelay = 0 GROUP BY Dest ) a JOIN airports b ON a.Dest = b.iata GROUP BY state ORDER BY s DESC")
查看其中 10 条结果。
queryDestResult3.head(10)
最后还可以将结果输出为 CSV 格式,保存在用户主目录下。
// QueryDestResult.csv只是保存结果的文件夹名
queryDestResult3.rdd.saveAsTextFile("/home/shiyanlou/QueryDestResult.csv")
保存完毕后,我们还需要手动将其合并为一个文件。新打开一个终端,在终端中输入以下命令来进行文件合并。
# 进入到结果文件的目录
cd ~/QueryDestResult.csv/
# 使用通配符将每个part文件中的内容追加到 result.csv 文件中
cat part-* >> result.csv
最后打开 result.csv 文件就能看到最终结果,如下图所示。
(3)出发延误的重灾区都有哪些
可以大胆地设置查询条件为离港延误时间大于 60 分钟,写出查询语句如下:
val queryOriginResult = sqlContext.sql("SELECT DISTINCT Origin, DepDelay FROM flights where DepDelay > 60 ORDER BY DepDelay DESC")
因为数据已经按照降序的形式进行排列,所以我们取出前 10 个查询结果即为 1998 年内,延误最严重的十次航班及所在的离港机场。
queryOriginResult.head(10)
三、航班延误时间预测
1、引言
历史数据是对于过去已经发生的事情的一种记录,我们可以根据历史数据对过去进行总结。那么我们是否还能否据此来展望未来呢?
也许你首先想到的方法就是预测。谈到预测就不得不提到当下最热门的学科之一——机器学习。预测也是机器学习相关知识能够完成的任务之一。
作为数据分析人员,我们学习机器学习的主要目的不是对机器学习算法进行各方面的改进(机器学习专家们在为此努力),最低的要求应当是能够将机器学习算法应用到实际的数据分析问题中。
2、引入相关的包
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
3、DataFrame 转换为 RDD
Spark ML 中的操作大部分是基于 RDD (分布式弹性数据集)来进行的。而之前我们读进来的数据集的数据类型为 DataFrame 。在 DataFrame 中的每一条记录即对应于 RDD 中的每一行值。为此,我们需要将 DataFrame 转换为 RDD。
首先数据从 DataFrame 类型转换为 RDD 类型。从 row 中取值时是按照数据集中各个字段取出 Flight 类中对应字段的值。例如排在第二的 row(3)
取出的是 DataFrame 中 DayofWeek 字段的值,对应的是 Flight 类中的 dayOfWeek
成员变量。
val tmpFlightDataRDD = flightData.map(row => row(2).toString+","+row(3).toString+","+row(5).toString+","+row(7).toString+","+row(8).toString+","+row(12).toString+","+ row(16).toString+","+row(17).toString+","+row(14).toString+","+row(15).toString).rdd
接着需要建立一个类,将 RDD 中的部分字段映射到类的成员变量中。
case class Flight(dayOfMonth:Int, dayOfWeek:Int, crsDepTime:Double, crsArrTime:Double, uniqueCarrier:String, crsElapsedTime:Double, origin:String, dest:String, arrDelay:Int, depDelay:Int, delayFlag:Int)
在类 Flight
中,最后一个成员变量为 delayFlag
。通过对数据的观察,我们知道部分航班的延误时间仅仅为几分钟(无论是出发还是抵达时),而通常此类延误都是可以容忍的。为了减少待处理的数据量,我们可以将延误定义为出发或抵达的延迟时间大于半个小时(即 30 分钟),从而将抵达延误时间和出发延误时间简化为延误标记 delayFlag
。
可以先尝试写出伪代码:
if ArrDelayTime or DepDelayTime > 30
delayFlag = True
else
delayFlag = False
接着我们按照上述逻辑,定义一个解析方法。该方法用于将 DataFrame 中的记录转换为 RDD 。
def parseFields(input: String): Flight = {
val line = input.split(",")
// 针对可能出现的无效值“NA”进行过滤
var dayOfMonth = 0
if(line(0) != "NA"){
dayOfMonth = line(0).toInt
}
var dayOfWeek = 0
if(line(1) != "NA"){
dayOfWeek = line(1).toInt
}
var crsDepTime = 0.0
if(line(2) != "NA"){
crsDepTime = line(2).toDouble
}
var crsArrTime = 0.0
if(line(3) != "NA"){
crsArrTime = line(3).toDouble
}
var crsElapsedTime = 0.0
if(line(5) != "NA"){
crsElapsedTime = line(5).toDouble
}
var arrDelay = 0
if(line(8) != "NA"){
arrDelay = line(8).toInt
}
var depDelay = 0
if(line(9) != "NA"){
depDelay = line(9).toInt
}
// 根据延迟时间决定延迟标志是否为1
var delayFlag = 0
if(arrDelay > 30 || depDelay > 30){
delayFlag = 1
}
Flight(dayOfMonth, dayOfWeek, crsDepTime, crsArrTime, line(4), crsElapsedTime, line(6), line(7), arrDelay, depDelay, delayFlag)
}
解析方法定义完成后,我们就是用 map 操作来解析 RDD 中的各个字段。
val flightRDD = tmpFlightDataRDD.map(parseFields)
可以尝试随机取出一个值检查解析是否成功。
flightRDD.take(1)
4、提取特征
为了建立分类模型,我们需要提取出航班数据的特征。在刚刚解析数据的一步中,我们设立 delayFlag 的目的就是为了定义两个类用于分类。因此你可以将其称之为标签(Label),这是分类中常用的一个手段。标签有两种,如果 delayFlag 为 1 ,则代表航班有延误;如果为 0 ,则代表没有延误。区分延误与否的标准正如之前所讨论的那样:抵达或出发的延误时间是否超过了 30 分钟。
对于数据集中的每条记录,现在它们都包含了标签和特征信息。特征则是每条记录在 Flight 类中对应的所有属性(从 dayOfMonth 一直到 delayFlag)。
下面,我们需要将上述的这些特征转换为数值特征。在 Flight 类中,有些属性已经是数值特征,而诸如 crsDepTime 和 uniqueCarrier 等属性还不是数值特征。在这一步中我们都要将它们转换为数值特征。例如,uniqueCarrier 这个特征通常是航空公司代码( “WN” 等),我们按照先后顺序为它们进行编号,将字符串类型的特征转换为含有唯一 ID 的数值特征(如 “AA” 变成了 0,“AS” 变成了 1 ,以此类推,实际运算时是按照字母先后顺序进行标记的)。
var id: Int = 0
var mCarrier: Map[String, Int] = Map()
flightRDD.map(flight => flight.uniqueCarrier).distinct.collect.foreach(x => {mCarrier += (x -> id); id += 1})
计算完成后,我们来查看一下 carrier 中是否已经完成了从表示航空公司代码的字符串到对应的唯一 ID 之间的转换。
mCarrier.toString
按照同样的逻辑,我们要为出发地 Origin 、目的地 Dest 进行字符串到数值的转换。
先是对 Origin 进行转换:
var id_1: Int = 0
var mOrigin: Map[String, Int] = Map()
// 这里的origin相当于一个“全局”变量,在每次map中我们都在对其进行修改
flightRDD.map(flight => flight.origin).distinct.collect.foreach(x => {mOrigin += (x -> id_1); id_1 += 1})
最后是对 Dest 进行转换,不要忘了我们转换的目的是为了建立数值特征。
var id_2: Int = 0
var mDest: Map[String, Int] = Map()
flightRDD.map(flight => flight.dest).distinct.collect.foreach(x => {mDest += (x -> id_2); id_2 += 1})
至此我们就将所有的特征都准备好了。
5、定义特征数组
我们在上一步用不同的数字代表了不同的特征,这些特征最后都将放入数组中,可以将其理解为建立了特征向量。
接下来,我们将所有的标签(延迟与否)和特征都以数值的形式存储到一个新的 RDD 中,用作机器学习算法的输入。
val featuredRDD = flightRDD.map(flight => {
val vDayOfMonth = flight.dayOfMonth - 1
val vDayOfWeek = flight.dayOfWeek - 1
val vCRSDepTime = flight.crsDepTime
val vCRSArrTime = flight.crsArrTime
val vCarrierID = mCarrier(flight.uniqueCarrier)
val vCRSElapsedTime = flight.crsElapsedTime
val vOriginID = mOrigin(flight.origin)
val vDestID = mDest(flight.dest)
val vDelayFlag = flight.delayFlag
// 返回值中,将所有字段都转换成Double类型以利于建模时使用相关API
Array(vDelayFlag.toDouble, vDayOfMonth.toDouble, vDayOfWeek.toDouble, vCRSDepTime.toDouble, vCRSArrTime.toDouble, vCarrierID.toDouble, vCRSElapsedTime.toDouble, vOriginID.toDouble, vDestID.toDouble)
})
经历这个 map 阶段后,我们得到了包含所有信息的特征数组,并且这些特征都是数值类型的。
尝试取出其中一个值来查看转换是否成功。
featuredRDD.take(1)
6、创建标记点
此步骤中,我们需要将含有特征数组的 featuredRDD 转换为含有 org.apache.spark.mllib.regression.LabeledPoint
包中定义的标记点 LabeledPoints 的新 RDD 。在分类中,标记点含有两类信息,一是代表了数据点的标记,二是代表了特征向量类。
下面我们来完成这个转换。
// Label设定为 DelayFlag,Features设定为其他所有字段的值
val LabeledRDD = featuredRDD.map(x => LabeledPoint(x(0), Vectors.dense(x(1), x(2), x(3), x(4), x(5), x(6), x(7), x(8))))
尝试取出其中一个值来查看转换是否成功。
LabeledRDD.take(1)
回顾一下之前所做的工作:我们得到了含有延误标记 DelayFlag 的数据,所有的航班都可以被标记为延误了或者没有延误。下面我们会将上述数据使用随机划分的方法,划分为训练集和测试集。
以下是详细比例说明:
- 在 LabeledRDD 中,数据标记为 DelayFlag = 0 的数据为未延迟航班;数据标记为 DelayFlag = 1 的数据为已延迟航班。
- 未延迟航班总数的 80% ,将与所有的已延迟航班组成新的数据集。新数据集的 70% 和 30% 将被划分为训练集和测试集。
- 不直接使用 LabeledRDD 中的数据来划分训练集和测试集的目的是:尽可能提高已延迟航班在测试集中的比例,让训练得到的模型能更精确地描述延迟的情况。
因此,我们首先来提取 LabeledRDD 中的所有未延迟航班,再随机提取其中的 80% 。
// 末尾的(0)是为了取这 80% 的部分
val notDelayedFlights = LabeledRDD.filter(x => x.label == 0).randomSplit(Array(0.8, 0.2))(0)
接着我们提取所有的已延迟航班。
val delayedFlights = LabeledRDD.filter(x => x.label == 1)
将上述二者组合成新的数据集,用于后续划分训练集和测试集。
val tmpTTData = notDelayedFlights ++ delayedFlights
最后我们将这个数据集按照约定的比例随机划分为训练集和测试集。
// TT意为Train & Test
val TTData = tmpTTData.randomSplit(Array(0.7, 0.3))
val trainingData = TTData(0)
val testData = TTData(1)
7、训练模型
接下来,我们将会从训练集中提取特征(Feature)。这里会用到 Spark MLlib 中的决策树。决策树是一个预测模型,代表的是对象属性与对象值之间的一种映射关系。你可以在百度百科中详细了解决策树的原理。
希望在进行接下来的工作之前,你能够利用一些时间了解决策树,以便于更好地理解各项参数设置的含义。
在官方文档中,决策树的参数分为三类:
- 问题规格参数(Problem specification parameters):这些参数描述了待求解问题和数据集。我们需要设置
categoricalFeaturesInfo
这一项,它指明了哪些特征是已经明确的,以及这些特征都可以取多少明确的值。返回值是一个 Map 。例如Map(0 -> 2, 4 -> 10)
表示特征0
的取值有 2 个(0 和 1),特征4
的取值有 10 个(从 0 到 9)。 - 停止准则(Stopping criteria):这些参数决定了树的构造在什么时候停止(即停止添加新节点)。我们需要设置
maxDepth
这一项,它表示树的最大深度。更深的树可能更有表达力,但它也更难训练并且容易过拟合。 - 可调参数(Tunable parameters):这些参数都是可选的。我们需要设置两个。第一个是
maxBins
,表示离散连续特征时的桶信息数量。第二个是impurity
,表示在选择候选分支时的杂质度。
我们要训练的模型是通过建立输入特征与已标记的输出间的联系。要用到的方法是决策树类 DecisionTree
自带的 trainClassifier
方法。通过使用该方法,我们能够得到一个决策树模型。
下面来尝试构造训练逻辑。
// 仿照 API 文档中的提示,构造各项参数
var paramCateFeaturesInfo = Map[Int, Int]()
// 第一个特征信息:下标为 0 ,表示 dayOfMonth 有 0 到 30 的取值。
paramCateFeaturesInfo += (0 -> 31)
// 第二个特征信息:下标为 1 ,表示 dayOfWeek 有 0 到 6 的取值。
paramCateFeaturesInfo += (1 -> 7)
// 第三、四个特征是出发和抵达时间,这里我们不会用到,故省略。
// 第五个特征信息:下标为 4 ,表示 uniqueCarrier 的所有取值。
paramCateFeaturesInfo += (4 -> mCarrier.size)
// 第六个特征信息为飞行时间,同样忽略。
// 第七个特征信息:下标为 6 ,表示 origin 的所有取值。
paramCateFeaturesInfo += (6 -> mOrigin.size)
// 第八个特征信息:下标为 7, 表示 dest 的所有取值。
paramCateFeaturesInfo += (7 -> mDest.size)
// 分类的数量为 2,代表已延误航班和未延误航班。
val paramNumClasses = 2
// 下面的参数设置为经验值
val paramMaxDepth = 9
val paramMaxBins = 7000
val paramImpurity = "gini"
参数构造完成后,我们调用 trainClassfier 方法进行训练。
val flightDelayModel = DecisionTree.trainClassifier(trainingData, paramNumClasses, paramCateFeaturesInfo, paramImpurity, paramMaxDepth, paramMaxBins)
等待训练完成后,我们可以尝试打印出这棵决策树。
val tmpDM = flightDelayModel.toDebugString
print(tmpDM)
执行结果如下图所示,此处未显示所有的结果。
决策树的内容看起来大致是多重的分支结构。如果有足够的耐心,你可以在草稿纸上将决策的条件逐一画出来。按照上述这些条件,我们就能对今后的一个输入值作出预测了。当然,预测的结果就是会延误或者不会延误。
8、测试模型
在模型训练完成之后,我们还需要检验模型的构造效果。因此,最后一步是使用测试集对模型进行测试。
// 使用决策树模型的predict方法按照输入进行预测,预测结果临时存放于 tmpPredictResult 中。最后与输入信息的标记组成元祖,作为最终的返回结果。
val predictResult = testData.map{flight =>
val tmpPredictResult = flightDelayModel.predict(flight.features)
(flight.label, tmpPredictResult)
}
尝试取出 10 组预测结果,看一下效果。
predictResult.take(10)
执行结果如下图所示。
可以看到, 若 Label 的 0.0 与 PredictResult 的 0.0 是对应的,则表明预测结果是正确的。并且不是每一条预测值都是准确的。
val numOfCorrectPrediction = predictResult.filter{case (label, result) => (label == result)}.count()
执行结果如下图所示。
最后计算预测的正确率:
// 使用toDouble是为了提高正确率的精度,否则两个long值相除仍然是long值。
val predictAccuracy = numOfCorrectPrediction/testData.count().toDouble
执行结果如下图所示。
我们得到了该模型的预测正确率约为 85.77% ,可以说在实际的预测中还是有一定的应用价值的。为了提高预测的正确率,你可以考虑使用更多的数据进行模型的训练,并且在建立决策树时将参数调至最优。
因数据集在每次随机划分过程中均会有差异,此处的预测正确率仅供参考。结果在 80% 以上即可视为完成本项目的目标。
9、D3.js可视化编程
数据可视化的目的是为了让数据更可信。如果我们盯着一堆表格和数字看,很可能会忽视某些重要的信息。而数据可视化作为数据的一种表达方式,能够帮助我们发现一些从数据表面不容易看到的信息。
其实在这之前我们已经多多少少接触到了数据可视化。最简单的数据可视化就是我们在 Microsoft Office 的 Excel 中制作的那些柱状图、折线图、饼图等。数据可视化不在于制作的图表多么炫酷,而在于如何更加生动地表达复杂的数据。
D3 的全称是 Data-Driven Documents ,字面意思是数据驱动文档。它本身是一个 JavaScript 的函数库,可以用于数据可视化。
我们通过一个轻量的例子来学习如何在美国地图上表现每个州的准点航班数量和它们在总体中的繁忙程度。
(1)创建项目目录及文件
在桌面上双击打开终端,下载课程所需要的 D3.js 和其他 javascript 脚本文件。
wget https://labfile.oss.aliyuncs.com/courses/610/js.tar.gz
然后对其进行解压缩。
tar zxvf js.tar.gz
下面我们来为创建项目所需要的目录和文件。首先新建一个名为 DataVisualization
的项目目录。所有的网页、js 文件和数据我们都存放于该目录中。
mkdir DataVisualization
进入到新创建的目录中,再分别创建两个名为 data
和 js
的目录用于存放 CSV 数据和 js 文件,最后再新建一个名为 index.html
的网页文件。
cd DataVisualization
mkdir data
mkdir js
touch index.html
数据可视化的数据来自于我们保存在 /home/shiyanlou/QueryDestResult.csv/
目录下的 result.csv
文件。我们将它复制到当前的项目目录下的 data
文件夹中。
cp ~/QueryDestResult.csv/result.csv ~/DataVisualization/data
就整个项目而言,我们需要在 index.html 网页中调用 D3.js 的相关 API ,来编写主要的数据可视化逻辑。并且会添加相应的 html 元素来显示数据可视化结果。因此我们还需要将之前解压的两个 js 文件复制到项目目录的 js 文件夹下。
cp ~/js/* ~/DataVisualization/js
最后我们使用 tree 命令来查看项目目录的文件结构。
tree .
(2)数据完备性检查
首先去除 result.csv 文件中每行首尾的中括号:
sed -i "s/\[//g" ~/DataVisualization/data/result.csv
sed -i "s/\]//g" ~/DataVisualization/data/result.csv
通常,由于数据含有无效值或者是缺失值,绘制就容易出现错误。
避免这类错误的方法有两个:一是通过程序进行容错处理;二是对数据进行完备性检查,将无效或缺失的值补齐。
简便起见,我们直接通过文本编辑器来对数据进行修订。
请使用 gedit 文本编辑器打开数据文件。
gedit ~/DataVisualization/data/result.csv
第一步是为 CSV 文件添加字段名称,请在首行分别为两列数据添加字段:StateName
和 OnTimeFlightsNum
。具体如下图所示(注意大小写)。
对照美国州名列表,我们可以发现 result.csv 中还缺少华盛顿特区(DC)和特拉华州(DE)的信息,我们在文件末尾补充上它们,并将准点航班数设置为 0 ,如下图所示。
(3)编辑index.html
index.html 中,我们需要插入一些基本的 HTML 元素,以让它看起来更“像”一个网页。
请在 index.html 中加入以下内容。
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>US OnTime Flights Map</title>
</head>
<body>
</body>
</html>
因为我们会用 d3.js 的相关 API 来进行绘图,所以这里需要在 HTML 代码中引用它。此外,我们还需要一个名为 uStates.js 的脚本,该脚本中已经包含了美国各个州的轮廓数据。
请在 <head>
标签之间插入以下语句来引用相应的 js 文件。
插入的代码为:
<script src="js/d3.min.js"></script>
<script src="js/uStates.js"></script>
我们在 src 属性中设置了 js 文件的来源为网页根目录下的 js 目录中的对应 js 文件。
之后,我们需要在 <body>
标签之间插入一些 HTML 元素。
下面的代码中,svg 标签用于显示美国地图,div 标签用于装载提示框的相关信息。我们设置提示框所在元素的 id 为 tooltip
,设置地图所在元素的 id 为 statesvg
,并且设置宽和高分别为 960 和 800 像素。
<!-- 该 div 标签用于装载提示框 -->
<div id="tooltip"></div>
<!-- SVG 标签用于绘制地图 -->
<svg width="960" height="800" id="statesvg"></svg>
最后,我们需要准备一个 script 标签,用于书写我们的数据可视化相关逻辑代码。
请在 <body>
标签之间插入以下内容。
<script>
</script>
(4)实现数据加载和可视化准备相关功能
第一步是需要从 CSV 文件中读取数据,这里我们直接使用 d3.js 的 API 函数 d3.csv()
。
在该函数中,第一个参数为待读取的文件路径。我们设其为 data/result.csv
,表示项目目录下的 data 目录下的 result.csv 文件。
第二个参数为读取之后的匿名的回调函数,即读取完毕之后(无论成功与否)都会执行该回调函数中的内容。error
和 csvData
是两个必须传递的参数。当读取失败时,我们可以通过 console.log(error)
的方式在浏览器的控制台输出错误信息。当读取成功时,csvData 即为装载了 CSV 文件内容的变量。
请在添加的 script 标签中插入下面的代码。
d3.csv("data/result.csv", function(error, csvData) {
// TODO: add text here.
});
我们之后所要做的事情,都会在这个回调函数中完成。
为了方便起见,我们在注释中继续讲解相关的知识点。
下面的代码需要插入至回调函数 function(error, csvData)
中。
// 创建一个 Object,用于存放处理之后的绘图数据
// 可将其理解为含有 key - value 的 map 对象
var mapData = {};
// 变量 sum 用于存放准点航班总数
var sum = 0.0;
// 第一次对 csvData 进行遍历,求取准点航班总数
csvData.forEach(function(d){
// 在forEach函数中,用一个匿名函数处理每次遍历得到的数据记录 d
// OnTimeFlightsNum 为 CSV 文件中我们设置的字段名称
// 取出的值还是字符串类型,我们要将其转换为浮点型
sum += parseFloat(d.OnTimeFlightsNum);
});
// 第二次对 csvData 进行遍历,用于设置绘图数据
csvData.forEach(function(d){
// d.StateName 取出每一条记录的 StateName 字段的值,并转换为字符串,作为 map 对象的 key
var key = d.StateName.toString();
// d.OnTimeFlightsNum 取出每一条记录 OnTimeFlightsNum 字段的值,转换为浮点型
var vNumOfOnTimeFlights = parseFloat(d.OnTimeFlightsNum);
// 这里是为不同的数据设置不同程度的颜色
// 调用了d3.js的插值API:d3.interpolate()
// 参数 "#57d2f7" 和 "#726dd1" 均为 HEX 类型的16进制颜色代码,每两位分别为 RGB 通道的颜色深浅
// 用 vNumOfOnTimeFlights / sum 计算当前值占总数的比例,乘以10是为了让颜色区分更明显
var vColor = d3.interpolate("#57d2f7", "#726dd1")(vNumOfOnTimeFlights / sum * 10);
// 对于每条记录,将 StateName 字段的值作为 mapData 的键,将准点航班数量和颜色代码作为它们的值。
mapData[key] = {num:vNumOfOnTimeFlights,
color:vColor};
});
// 绘图数据准备完成后,调用 uStates 对象的 draw 函数进行绘图。
// 第一个参数为选取的绘图对象,即我们设置的 HTML 标签:statesvg
// 第二个参数为我们计算得到的绘图数据
uStates.draw("#statesvg", mapData);
完成后的代码如下图所示。
(5)实现 uStates.js 中的绘图功能
在 index.html 中计算得到了绘图所需要的数据,并且最终调用了 uStates 对象的 draw 函数。这里我们就需要完善这个绘图的逻辑。
现在编辑 js/uStates.js
文件,这是提前为大家准备好的一个 js 文件,需要我们完善部分细节。
在目前的代码中,有大量的内容是变量 uStatePaths
的定义。在这个 map 中,id
为美国各州的缩写,我们稍后会通过它来查找绘图数据 mapData 中对应的准点航班数量和颜色代码;n
为每个州对应的全称,我们会将其放在鼠标划过时显示的提示框内;d
为每个州在地图上的轮廓。
需要进一步说明的是,d
中所存放的轮廓数据来自于根据美国地图制作的 SVG 文件。制作的步骤为:首先根据地图服务商提供的 GeoJSON 数据(如 Google Map)制作投影。GeoJSON 中都是一些经纬度、海报等信息,然后利用 D3.js 的投影函数对其进行投影(d3.geo.mercator()
),在投影过程中可以进行缩放和平移等。投影之后便得到了二维数据,但它们都是一些点。地图的轮廓都是闭合的线,因此我们还需要利用 D3.js 的路径生成器(d3.geo.path()
)来链接每个区块的二维数据点,形成最终的地图轮廓。
通常我们会把轮廓数据存放在 SVG 文件中,SVG 意为可缩放矢量图形。本例中为了简化操作步骤,是直接将该数据赋予了各个州的信息中,也就是大家看到的这些描述信息。
翻到这段代码的结束部分,可以看到 uStates.draw
中的细节尚未完成,这就是接下来需要进行的工作。
请在 uStates.draw = function (id, data) { ... };
的函数定义中补充以下内容。同样地,相关讲解会以注释的形式给出。
// vData 用于装载绘图数据
var vData = data;
// 该方法用于创建鼠标划过时显示的提示框的 HTML 内容( div 元素)
function addTooltipHtml(n, d) {
// 传入的 n 为每个州的全称,用 h4 标签表示
// 传入的 d 为每个州的绘图数据,d.num为准点航班数量
return "<h4>" + n + "</h4><table>" +
"<tr><td>On Time Flights:</td><td>" + d.num + "</td></tr>"
"</table>";
}
// 当鼠标指针位于元素上方时,会发生 mouseover 事件
// 此处定义一个 mouseOver 函数作为发生该事件的回调函数
function mouseOver(d) {
// 传入的 d 为 uStatePaths 中的元素
var key = d.id;
var vData = data;
// 使用 d3.js 的 select() 函数选中 HTML 元素中 id 为 tooltip 的 div 元素
// transition() 函数用于启动转变效果,可以用于制作动画
// duration(200) 用于设置动画的持续时间为 200 毫秒
// style("opacity", .9) 用于设置 div 元素的不透明级别,不透明度为 .9 表示 90% 的不透明度
d3.select("#tooltip").transition().duration(200).style("opacity", .9);
// 同样,利用 html()函数调用 addTooltipHtml 函数来向 tooltip 元素注入 html 代码
// d.n 表示 sStatePaths 中的 n 成员,即每个州的全称
// vData[key] 表示用 key 去查询绘图数据中的每个州的准点航班数量
d3.select("#tooltip").html(addTooltipHtml(d.n, vData[key]))
.style("left", (d3.event.pageX) + "px")
.style("top", (d3.event.pageY - 28) + "px");
}
// 当鼠标指针不再位于元素上方时,会发生 mouseout 事件
// 此处定义一个 mouseOut 函数作为发生该事件的回调函数
function mouseOut() {
// 设置不透明度为 0 ,即相当于让该 div 元素消失
d3.select("#tooltip").transition().duration(500).style("opacity", 0);
}
// 此处的 id 即将会表示 svgstate 元素
// 此步骤用于设置所有州的颜色
// .data(uStatePaths).enter().append("path") 表示利用 uStatePaths 中的数据绘制路径
// .style("fill", function(d) { }) 表示按照绘图数据中每个州的颜色代码进行填充
// .on("mouseover", mouseOver).on("mouseout", mouseOut) 表示设置鼠标覆盖和鼠标移开事件的监听器及回调函数
d3.select(id).selectAll(".state")
.data(uStatePaths).enter().append("path").attr("class", "state").attr("d", function(d) {
return d.d;
})
.style("fill", function(d) {
key = d.id;
vColor = data[key].color;
return vColor;
})
.on("mouseover", mouseOver).on("mouseout", mouseOut);
完成后的代码如下图所示:
(6)设置网页元素样式
最后我们需要做一些美化工作。
回到 index.html 中,在 <head>
标签中插入以下内容。
<style>
.state{
fill: none;
stroke: #888888;
stroke-width: 1;
}
.state:hover{
fill-opacity:0.5;
}
#tooltip {
position: absolute;
text-align: center;
padding: 20px;
margin: 10px;
font: 12px sans-serif;
background: lightsteelblue;
border: 1px;
border-radius: 2px;
border:1px solid grey;
border-radius:5px;
pointer-events: none;
background:rgba(0,0,0,0.9);
font-size:14px;
width:auto;
padding:4px;
color:white;
opacity:0;
}
#tooltip h4{
margin:0;
font-size:20px;
}
#tooltip tr td:nth-child(1){
width:120px;
}
</style>
关于每个属性的作用,请查阅 W3school 。完成后的代码如下图所示。
至此,我们所有的代码编辑工作就已经完成。
最终的 index.html 页面代码如下所示:
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>US OnTime Flights Map</title>
<script src="js/d3.min.js"></script>
<script src="js/uStates.js"></script>
<style>
.state{
fill: none;
stroke: #888888;
stroke-width: 1;
}
.state:hover{
fill-opacity:0.5;
}
#tooltip {
position: absolute;
text-align: center;
padding: 20px;
margin: 10px;
font: 12px sans-serif;
background: lightsteelblue;
border: 1px;
border-radius: 2px;
border:1px solid grey;
border-radius:5px;
pointer-events: none;
background:rgba(0,0,0,0.9);
font-size:14px;
width:auto;
padding:4px;
color:white;
opacity:0;
}
#tooltip h4{
margin:0;
font-size:20px;
}
#tooltip tr td:nth-child(1){
width:120px;
}
</style>
</head>
<body>
<div id="tooltip"></div>
<svg width="960" height="800" id="statesvg"></svg>
<script>
d3.csv("data/result.csv", function(error, csvData) {
var mapData = {};
var sum = 0.0;
csvData.forEach(function(d){
sum += parseFloat(d.OnTimeFlightsNum);
});
csvData.forEach(function(d){
var key = d.StateName.toString();
var vNumOfOnTimeFlights = parseFloat(d.OnTimeFlightsNum);
var vColor = d3.interpolate("#57d2f7", "#726dd1")(vNumOfOnTimeFlights / sum * 10);
mapData[key] = {num:vNumOfOnTimeFlights,
color:vColor};
});
uStates.draw("#statesvg", mapData);
});
</script>
</body>
</html>
(7)项目预览
由于静态页面无法读取本地文件,直接打开 index.html 页面是无法查看页面内容的,我们这里简单的使用 flask 开启一个服务。
实验环境中已经安装了 flask,我们可以直接使用。
在 /home/shiyanlou 目录下新建 flask-web 文件夹,并进入该文件夹新建 demo.py 文件、static 文件夹、templates 文件夹,并将 DataVisualization 文件夹中的内容复制过来,最终形成的文件目录结构如下所示:
在 demo.py 文件中写入如下代码:
from flask import Flask
from flask import render_template
app = Flask(__name__)
@app.route('/')
def home():
return render_template("index.html")
并且修改 index.html 的路径,主要修改内容如下:
<script src="../static/js/d3.min.js"></script>
<script src="../static/js/uStates.js"></script>
...
d3.csv("/static/data/result.csv", function(error, csvData) {
...
在 flask-web 目录下执行如下命令,开启服务:
export FLASK_APP=demo.py
flask run
结果如下所示:
是否感受到数据带来的强烈冲击了呢?
可以看到,图中颜色最深的两个州为 CA 和 TX 。一定程度上表明这两个州的航空活动更加活跃。
实验总结
在本实验中,我们通过 Spark ,基于常用的 DataFrame 和 SQL 操作对航班起降的记录数据进行分析,找出了造成航班延误的原因,并利用机器学习算法,对航班延误情况进行了预测。
并利用 D3.js ,将得到的美国各州准点航班数量进行了数据可视化操作。实验中涉及了 D3.js 对于读取数据、插值、选取元素、设置属性等 API 的用法。
如果你有兴趣对其它年份的数据进行分析,你有可能会发现下面这些有趣的现象:
- 夏季由于雷雨等恶劣天气增多,航班延迟情况严重。
- 冬季由于恶劣天气较少,气候稳定,航班延迟较少。
/js/d3.min.js">
…
d3.csv("/static/data/result.csv", function(error, csvData) {
…
在 flask-web 目录下执行如下命令,开启服务:
```python
export FLASK_APP=demo.py
flask run
结果如下所示:
是否感受到数据带来的强烈冲击了呢?
可以看到,图中颜色最深的两个州为 CA 和 TX 。一定程度上表明这两个州的航空活动更加活跃。
实验总结
在本实验中,我们通过 Spark ,基于常用的 DataFrame 和 SQL 操作对航班起降的记录数据进行分析,找出了造成航班延误的原因,并利用机器学习算法,对航班延误情况进行了预测。
并利用 D3.js ,将得到的美国各州准点航班数量进行了数据可视化操作。实验中涉及了 D3.js 对于读取数据、插值、选取元素、设置属性等 API 的用法。
如果你有兴趣对其它年份的数据进行分析,你有可能会发现下面这些有趣的现象:
- 夏季由于雷雨等恶劣天气增多,航班延迟情况严重。
- 冬季由于恶劣天气较少,气候稳定,航班延迟较少。
- 9·11 恐怖袭击事件(2001 年 9 月 11 日)发生后,航班数量急剧减少。9·11 事件中,美国航空公司和联合航空公司各自损失两架飞机,整个空运停顿了 3 天。恢复飞行以后,由于受到事件的惊吓,美国航空乘客人数短期内剧烈收缩,甚至出现了一班飞机只有一位乘客的情况。