【版本介绍】
本次问题所使用的代码版本是spark 2.2.0 和 elasticsearch-spark-20_2.11
【情景介绍】
今天公司的小伙伴发现了一个问题,在spark 中,使用 elasticsearch-spark 读取es的数据,"" 这种空字符串的值,在spark中会被转成null,导致计算结果异常
代码如下:
1 def getTable()(implicit spark:SparkSession)={
2 var query=
3 s"""
4 |{
5 | "query": {
6 | "bool": {
7 | "must": [
8 | {
9 | "term": {
10 | "revise_status": {
11 | "value": ""
12 | }
13 | }
14 | }
15 | ]
16 | }
17 | }
18 |}
19 """.stripMargin
20 //读取es数据
21 EsSparkSQL.esDF(spark,s"""aaa/bbb""", query)
22 }
23
24 def main(args: Array[String]): Unit = {
25 implicit val spark = SparkAndRelevantCptUtil.getSparkSession("test", "local[3]")
26 getTable().select("revise_status").show(1000, false)
27 }
显示的结果
1 +-------------+
2 |revise_status|
3 +-------------+
4 |null |
5 |null |
6 |null |
7 |null |
8 |null |
9 +-------------+
按理来说,空字符串和null是两个概念差很多的东西,elasticsearch-spark 读取出来为什么会转成null呢,第一个想法会不会是什么参数没配上?在百度谷歌都没找到答案的情况下,只能看是看源码了
【解决方法】
这个问题的原因是因为 EsSparkSQL 自己代码中,将空字符串识别为 无效数据,是一个bug,解决该问题的步骤如下:
1、将 org.elasticsearch.spark.serialization.ScalaValueReader 的内容拷贝到一个文本编辑器中
2、在自己的项目中创建一个 org.elasticsearch.spark.serialization.ScalaValueReader 的Scala类,包名必须一致
3、将刚刚拷贝的 原版 ScalaValueReader 内容,全部粘贴到这个 刚创建的 ScalaValueReader 类中
4、修改新 ScalaValueReader 类里面的 checkNull() 方法 的代码
def checkNull(converter: (String, Parser) => Any, value: String, parser: Parser) = {
if (value != null) {
//解决掉es误把空字符串弄成null的bug
//if (!StringUtils.hasText(value) && emptyAsNull) {
if (!"".equals(value) && !StringUtils.hasText(value) && emptyAsNull) {
nullValue()
}
else {
converter(value, parser).asInstanceOf[AnyRef]
}
}
else {
nullValue()
}
}
5、重新启动项目测试即可
【排查步骤】
从代码的 EsSparkSQL.esDF(spark,s"""aaa/bbb""", query) 开始研究
【stage1】
在 org.elasticsearch.spark.sql.EsSparkSQL#esDF 下
这句没啥可看,继续追踪esDF() 方法
1
【stage2】
在 org.elasticsearch.spark.sql.EsSparkSQL#esDF 下
这句没啥可看,继续追踪esDF() 方法
1 def esDF(sc: SQLContext, resource: String, query: String, cfg: Map[String, String]): DataFrame = {
2 //继续追踪里面的代码
3 esDF(sc, collection.mutable.Map(cfg.toSeq: _*) += (ES_RESOURCE_READ -> resource, ES_QUERY -> query))
4 }
【stage3】
在org.elasticsearch.spark.sql.EsSparkSQL#esDF下
1 def esDF(sc: SQLContext, cfg: Map[String, String]): DataFrame = {
2 //获取spark的配置
3 val esConf = new SparkSettingsManager().load(sc.sparkContext.getConf).copy()
4 //外部如果有传入参数,就合并到esConf里面来,将所有参数整一块
5 esConf.merge(cfg.asJava)
6
7 /**
8 * 通过format方法设置数据格式的实现类
9 * 用options方法传入配置
10 * 【重点】load方法生成DataFrame
11 */
12 sc.read.format("org.elasticsearch.spark.sql").options(esConf.asProperties.asScala.toMap).load
13 }
【stage4】
load方法明显是核心的逻辑,所以我们追踪一下load方法
在 org.apache.spark.sql.DataFrameReader#load 下
1 def load(): DataFrame = {
2 //追踪下去
3 load(Seq.empty: _*) // force invocation of `load(...varargs...)`
4 }
【stage5】
在 org.apache.spark.sql.DataFrameReader#load 下
load方法是很重要,他负责生成DataFrame,我们在看这个load方法时,里面主要两个内容重要:
1、首先看到了:sparkSession.baseRelationToDataFrame( ... DataSource实例 ... ) ,它的作用是将内部的 DataSource实例参数转化成一个DataFrame,具体做法会在后面讲解
2、这个DataSource实例是由 DataSource.apply( ... ).resolveRelation() 生成
(1)、apply() 方法是用来构造DataSource这个类
(2)、resolveRelation()方法作用 是使用反射创建出对应 DataSource 实例
1 /**
2 * load方法最重要的功能就是将baseRelation转换成DataFrame,
3 * 该功能是通过sparkSession的 def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame
4 * 接口实现的,其中的参数baseRelation通过DataSource类的resolveRelation方法提供。
5 */
6 def load(paths: String*): DataFrame = {
7 if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
8 throw new AnalysisException("Hive data source can only be used with tables, you can not " +
9 "read files of Hive data source directly.")
10 }
11
12 /**
13 * baseRelationToDataFrame() 方法 接受 baseRelation 参数返回 DataFrame,是通过 Dataset.ofRows(sparkSession,logicalPlan) 方法实现的,
14 * 其中的参数 logicPlan 是由 LogicalRelation(baseRelation) 得到。
15 */
16 sparkSession.baseRelationToDataFrame(
17 DataSource
18 //创建一个DataSource元数据信息类
19 .apply(
20 sparkSession,
21 paths = paths,
22 userSpecifiedSchema = userSpecifiedSchema,
23 className = source,
24 options = extraOptions.toMap
25 )
26
27 /**
28 * DataSource的resolveRelation() 方法中使用反射创建出对应 DataSource 实例,协同用户指定的 userSpecifiedSchema 进行匹配,匹配成功返回对应的 baseRelation:
29 * 1、如果是基于文件的,返回HadoopFsRelation实例
30 * 2、非文件的,返回如KafkaRelation或者JDBCRelation
31 */
32 .resolveRelation()
33 )
34 }
【stage6】
现在我们先关注 DataSouce 的 resolveRelation() 方法
在org.apache.spark.sql.execution.datasources.DataSource#resolveRelation下
1 def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = {
2 //【追踪点1】这里的 providingClass 是 DataSource 的类,providingClass.newInstance() 就是数据源用反射的方式创建 DataSource 实例
3 val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
4 ...
5
6 //由于 EsSparkSQL 并没有提供设置schema的方法,所以schema为空,如果有兴趣的小伙伴可以自己改造 EsSparkSQL ,给他加上设置 schema 的方法,就可以显示设置字段类型
7 case (dataSource: RelationProvider, None) =>
8 //【追踪点2】 用 org.elasticsearch.spark.sql.DefaultSource 的实例 创建 ElasticsearchRelation
9 //ElasticsearchRelation 是一种提供数据获取buildScan,数据插入更新insert等操作的数据源实际操作类实例
10 dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
11
12 ...
13 }
14
15 //返回 ElasticsearchRelation
16 relation
17 }
【stage7】
我们先看【追踪点1】,这里的 providingClass 是 DataSource 的类,那它是什么时候被赋值的呢?
在 org.apache.spark.sql.execution.datasources.DataSource#providingClass 下
1 //DataSource.lookupDataSource 用了 className 去查找 提供数据源支持的真正的那个类,这边的className就是 org.elasticsearch.spark.sql
2 lazy val providingClass: Class[_] = DataSource.lookupDataSource(sparkSession, className)
这里我们可以看到className,它的值是,那么这个值是从哪里来的呢?就是在 【stage3】的时候,在 formt( className ) 方法设置的
1 sc.read.format("org.elasticsearch.spark.sql")...
那我们继续追踪到 DataSource 的 lookupDataSource 方法中
【stage8】
在 org.apache.spark.sql.execution.datasources.DataSource#lookupDataSource 下
在这里,我们要拿到数据源的类
1 //查找DataSource的类,注意这时候的 provider 的值是 org.elasticsearch.spark.sql
2 def lookupDataSource(sparkSession: SparkSession, provider: String): Class[_] = {
3 //backwardCompatibilityMap 会保存一些过时的 数据源类,如果在这之中,就会替换成最新的 数据源类,否则还是按照用来之前的类名
4 var provider1 = backwardCompatibilityMap.getOrElse(provider, provider)
5 //如果是orc、org.apache.spark.sql.hive.orc.OrcFileFormat 这两种特殊情况,设置 数据源类为 OrcFileFormat
6 if (Seq("orc", "org.apache.spark.sql.hive.orc.OrcFileFormat").contains(provider1.toLowerCase) &&
7 sparkSession.conf.get(SQLConf.ORC_ENABLED)) {
8 logInfo(s"$provider1 is replaced with ${classOf[OrcFileFormat].getCanonicalName}")
9 provider1 = classOf[OrcFileFormat].getCanonicalName
10 }
11
12 //对 provider1 加工出完整的 数据源类
13 val provider2 = s"$provider1.DefaultSource"
14 //拿到spark当前线程上下文中的类加载器,如果没有,就用当前创建Utils类的类加载器
15 val loader = Utils.getContextOrSparkClassLoader
16 //拿到所有已注册的格式集合,比如TEXT、JSON、CSV等等
17 val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)
18
19 try {
20 //过滤出符合spark内置格式的数据
21 serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider1)).toList match {
22 // the provider format did not match any given registered aliases
23 //由于 provider1 值是:org.elasticsearch.spark.sql,不是spark提供的常规格式,所以进入到这步骤
24 case Nil =>
25 try {
26 //注意此刻的 provider1 值是:org.elasticsearch.spark.sql, provider2是 org.elasticsearch.spark.sql.DefaultSource
27 · //尝试用类加载器去加载 provider1 和 provider2类,谁能加载成功,就用谁做数据源,
28 //由于 provider1 的值是 org.elasticsearch.spark.sql,是scala 中objct类型,并不是一个类,所以无法加载成功,所以最终加载成功的是 provider2
29 Try(loader.loadClass(provider1)).orElse(Try(loader.loadClass(provider2))) match {
30 case Success(dataSource) =>
31 // Found the data source using fully qualified path
32 //返回 类型为 provider2,即 org.elasticsearch.spark.sql.DefaultSource
33 dataSource
34 ...
35 }
36 } catch {
37 ...
38 }
39 ...
40 }
41 } catch {
42 ...
43 }
44 }
【stage9】
所以回到 【stage6】 中,providingClass 的值就是 org.elasticsearch.spark.sql.DefaultSource,然后再看【追踪点2】的这段代码
1 dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
createRelation() 方法,所以接下来我们要追踪的就是 org.elasticsearch.spark.sql.DefaultSource#createRelation
【stage10】
在 org.elasticsearch.spark.sql.DefaultSource#createRelation 下
1 /**
2 * ElasticsearchRelation 是一种提供数据获取buildScan,数据插入更新insert等操作的数据源实际操作类实例
3 */
4 override def createRelation(@transient sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
5 //创建 ElasticsearchRelation
6 ElasticsearchRelation(params(parameters), sqlContext)
7 }
【stage11】
在 org.elasticsearch.spark.sql.ElasticsearchRelation 下
buildScan() 方法是本次的重点,里面包含了本次问题的关键代码
1 //【本次核心代码】创建一个读取es数据的RDD
2 def buildScan(requiredColumns: Array[String], filters: Array[Filter]) { ... }
3 //插入更新
4 def insert(data: DataFrame, overwrite: Boolean): Unit = { ... }
【stage12】
回到【stage5】,已经得出结论 DataSource.apply( ... ).resolveRelation() 生成的是 org.elasticsearch.spark.sql.ElasticsearchRelation 的实例,那我们接着看
sparkSession.baseRelationToDataFrame( ... DataSource实例 ... ) ,刚刚也知道它的作用是将内部的 DataSource实例参数转化成一个DataFrame,接下来我们重点分析 sparkSession 的 baseRelationToDataFrame( ... ) 这个方法
在 org.apache.spark.sql.SparkSession#baseRelationToDataFrame 下
1 def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
2 /**
3 * 这里做两件重要的事
4 * 1、用LogicalRelation(baseRelation) 生成计划任务
5 * 2、用 Dataset.ofRows 获取 DataFrame
6 */
7 Dataset.ofRows(self, LogicalRelation(baseRelation))
8 }
【stage13】
我们首先看 LogicalRelation(baseRelation), 它的作用事生成计划任务,所以看看它究竟在做什么
在 org.apache.spark.sql.execution.datasources.LogicalRelation#apply 下
1 def apply(relation: BaseRelation): LogicalRelation =
2 //创建一个 LogicalRelation 计划任务
3 LogicalRelation(relation, relation.schema.toAttributes, None)
【stage14】
这个 LogicalRelation 内部具体什么东西就先不看了,因为现在似乎看不出什么东西,先进入下一步
1 case class LogicalRelation(
2 relation: BaseRelation,
3 output: Seq[AttributeReference],
4 catalogTable: Option[CatalogTable])
5 extends LeafNode with MultiInstanceRelation {
6
7
8 override def equals(other: Any): Boolean
9
10 override def hashCode: Int
11
12 override def preCanonicalized: LogicalPlan
13
14 @transient override def computeStats(conf: SQLConf): Statistics
15
16 override def newInstance(): LogicalRelation
17
18 override def refresh(): Unit
19
20 override def simpleString: String
21 }
【stage15】
回到 【stage12】 中,查看代码 Dataset.ofRows( ... }
在 org.apache.spark.sql.Dataset#ofRows 下
1 def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
2 //执行逻辑计划,此处为懒加载,只新建QueryExecution实例,并不会触发实际动作。需要注意的是QueryExecution其实是包含了SQL解析执行的4个阶段计划(解析、分析、优化、执行)
3 val qe = sparkSession.sessionState.executePlan(logicalPlan)
4 //触发语法分析,得到分析计划(Analyzed Logical Plan)
5 qe.assertAnalyzed()
6 //新建一个DataSet 来 获取数据,并将Dataset返回成DataFrame
7 new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
8 }
【stage16】
我们先对 【stage15】这段代码进行分析
1 val qe = sparkSession.sessionState.executePlan(logicalPlan)
在 org.apache.spark.sql.internal.SessionState#executePlan 下
1 //执行逻辑计划,此处为懒加载,只新建QueryExecution实例,并不会触发实际动作。需要注意的是QueryExecution其实是包含了SQL解析执行的4个阶段计划(解析、分析、优化、执行)
2 def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan)
【stage17】
其中 createQueryExecution 是来源于SessionState 类的构造函数
在 org.apache.spark.sql.internal.SessionState 下
这里createQueryExecution: LogicalPlan => QueryExecution是一枚函数,将logicalplan转换为QueryExecution,它其执行整个workflow
这里顺便备注一下,planner 就是 【stage12】里的 LogicalRelation(baseRelation),可以从 【stage17】 倒看得到 【stage12】,就能看出来,这个东西在后面有用到
private[sql] class SessionState(
...
val planner: SparkPlanner,
...
createQueryExecution: LogicalPlan => QueryExecution,
...) { ... }
【stage18】
看到 createQueryExecution 这个参数是 在创建 SessionState 传进来的,所以我们要去找 SessionState 的创建代码,
所以回到 【stage15】我们看一下这段代码中的 sparkSession.sessionState 是怎么来的
1 val qe = sparkSession.sessionState.executePlan(logicalPlan)
在 org.apache.spark.sql.SparkSession#sessionState 下
1 lazy val sessionState: SessionState = {
2 parentSessionState
3 .map(_.clone(this))
4 .getOrElse {
5 //用反射的方式把它实例化出来一个Builder,然后再通过build()方法创建一个 SessionState实例,然后返回
6 val state = SparkSession.instantiateSessionState(
7 SparkSession.sessionStateClassName(sparkContext.conf),
8 self)
9 initialSessionOptions.foreach { case (k, v) => state.conf.setConfString(k, v) }
10 state
11 }
12 }
【stage19】
我们先追踪 SparkSession.sessionStateClassName(sparkContext.conf) 这句代码
在 org.apache.spark.sql.SparkSession#instantiateSessionState 下
这个是获取session状态的类名
1 private def sessionStateClassName(conf: SparkConf): String = {
2 conf.get(CATALOG_IMPLEMENTATION) match {
3 case "hive" =>
4 if (isLLAPEnabled(conf)) {
5 LLAP_SESSION_STATE_BUILDER_CLASS_NAME
6 }
7 else {
8 //【会到这步】这个静态变量的值是:"org.apache.spark.sql.hive.HiveSessionStateBuilder"
9 HIVE_SESSION_STATE_BUILDER_CLASS_NAME
10 }
11 case "in-memory" => classOf[SessionStateBuilder].getCanonicalName
12 }
13 }
【stage20】
回到【stage18】,追踪 SparkSession.instantiateSessionState( ... ) 代码
1 private def instantiateSessionState(
2 className: String,
3 sparkSession: SparkSession): SessionState = {
4 try {
5 // invoke `new [Hive]SessionStateBuilder(SparkSession, Option[SessionState])`
6 //【看这里】className的值是:org.apache.spark.sql.hive.HiveSessionStateBuilder ,用反射的方式把它实例化出来一个Builder,然后再通过build()方法创建一个 SessionState实例,然后返回
7 val clazz = Utils.classForName(className)
8 val ctor = clazz.getConstructors.head
9 ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build()
10 } catch {
11 case NonFatal(e) =>
12 throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
13 }
14 }
【stage21】
接下来我们根据上面的代码
1 ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build()
就知道了这就是调用了 org.apache.spark.sql.hive.HiveSessionStateBuilder 的 build() 方法,那我们追踪下这个代码
由于build() 方法是 HiveSessionStateBuilder 父类 BaseSessionStateBuilder 的方法,所以我们到 BaseSessionStateBuilder 下查看
在 org.apache.spark.sql.internal.BaseSessionStateBuilder#build 下
1 def build(): SessionState = {
2 new SessionState(
3 session.sharedState,
4 conf,
5 experimentalMethods,
6 functionRegistry,
7 udfRegistration,
8 catalog,
9 sqlParser,
10 analyzer,
11 optimizer,
12 planner,
13 streamingQueryManager,
14 listenerManager,
15 resourceLoader,
16
17 //这就是我们在追寻的 createQueryExecution
18 createQueryExecution,
19 createClone)
20 }
【stage22】
在 org.apache.spark.sql.internal.BaseSessionStateBuilder#createQueryExecution 下
这里会返回一个函数 LogicalPlan => QueryExecution ,这就是 【stage16】 中, def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan) 中的 createQueryExecution 这个元素的具体实现
1 protected def createQueryExecution: LogicalPlan => QueryExecution = { plan =>
2 //追踪这里
3 new QueryExecution(session, plan)
4 }
【stage23】
回到 【stage15】 ,val qe = sparkSession.sessionState.executePlan(logicalPlan) 这句代码,经过 【stage16】 到 【stage22】 的分析,已经知道 qe就是 一个 QueryExecution 的实例,相当于
1 val qe = new QueryExecution(session, logicalPlan)
【stage24】
接着回到 【stage15】 ,继续看这段代码
1 //触发语法分析,得到分析计划(Analyzed Logical Plan)
2 qe.assertAnalyzed()
我们已经知道了qe就是 QueryExecution, 所以我们要去看 QueryExecution 下的 assertAnalyzed() 方法
在 org.apache.spark.sql.execution.QueryExecution#assertAnalyzed 下
1 def assertAnalyzed(): Unit = {
2 // Analyzer is invoked outside the try block to avoid calling it again from within the
3 // catch block below.
4 //【追踪】analyzed是个懒加载的属性,执行去加载它,它的作用是对逻辑计划进行分析,得到分析后的逻辑计划
5 analyzed
6 try {
7 //检查
8 sparkSession.sessionState.analyzer.checkAnalysis(analyzed)
9 } catch {
10 case e: AnalysisException =>
11 val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed))
12 ae.setStackTrace(e.getStackTrace)
13 throw ae
14 }
15 }
【stage25】
在 org.apache.spark.sql.execution.QueryExecution#analyzed 下
对逻辑计划进行分析,得到分析后的逻辑计划,即分析计划,分析计划的生成逻辑这里就不再细追下去,有兴趣自己看即可
1 lazy val analyzed: LogicalPlan = {
2 SparkSession.setActiveSession(sparkSession)
3 //对逻辑计划进行分析,得到分析后的逻辑计划,这里就不再细追下去
4 sparkSession.sessionState.analyzer.execute(logical)
5 }
得到分析计划后,现在先暂时止步到这里,当然后面还有 逻辑计划转换成一个或多个物理执行计划 等操作,后面会讲到
【stage26】
回到 【stage15】中,我们接着看下这段代码
1 //新建一个DataSet 来 获取数据
2 new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
这个代码就是把 sparkSession、 逻辑计划、 数据行结构 作为参数,生成一个Dataset,具体也不进行解析了,自己看即可
现在上面的步骤,分析了从调用 EsSparkSQL.esDF( ... ) 这个方法开始,到最后输出 DataFrame 的每一步过程
接下来,我们会从 DataFrrame.show(false) 来讲解步骤
【stage27】
现在 EsSparkSQL.esDF(spark,s"""aaa/bbb""", query) .show(false) 这句代码,我们已经分析完 EsSparkSQL.esDF(spark,s"""aaa/bbb""", query) 部分,现在我们想分析 .show(false) 方法
在 def show(truncate: Boolean): Unit = show(20, truncate) 下
1//继续追踪show() 方法
2 def show(truncate: Boolean): Unit = show(20, truncate)
【stage28】
在 org.apache.spark.sql.Dataset#show 下
1 def show(numRows: Int, truncate: Boolean): Unit = if (truncate) {
2 println(showString(numRows, truncate = 20))
3 } else {
4 //走这步showString() 方法
5 println(showString(numRows, truncate = 0))
6 }
【stage29】
在 org.apache.spark.sql.Dataset#showString 下
1 private[sql] def showString(_numRows: Int, truncate: Int = 20): String = {
2 val numRows = _numRows.max(0)
3
4 //获取数据
5 val takeResult = toDF().take(numRows + 1)
6 ...
7 }
【stage30】
在 org.apache.spark.sql.Dataset#take 下
1 //head就是获取头几条数据
2 def take(n: Int): Array[T] = head(n)
【stage31】
在 org.apache.spark.sql.Dataset#head 下
1 //追踪withAction
2 def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan)
【stage32】
在 org.apache.spark.sql.Dataset#withAction 下
1 private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
2 try {
3 /**
4 * 【继续追踪】
5 * 1、拿到缓存的解析计划,使用遍历优化器执行解析计划,得到若干优化计划。
6 * 2、获取第一个优化计划,遍历执行前优化获得物理执行计划,这是已经可以执行的计划了。
7 */
8 qe.executedPlan.foreach { plan =>
9 plan.resetMetrics()
10 }
11 val start = System.nanoTime()
12
13 //执行物理计划,返回实际结果。至此,这条sql之旅就结束了
14 val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
15 action(qe.executedPlan)
16 }
17 val end = System.nanoTime()
18 sparkSession.listenerManager.onSuccess(name, qe, end - start)
19 result
20 } catch {
21 case e: Exception =>
22 sparkSession.listenerManager.onFailure(name, qe, e)
23 throw e
24 }
25 }
【stage33】
在 org.apache.spark.sql.execution.QueryExecution#executedPlan 下
//这里有两个地方要注意看,一个是sparkPlan,另一个是prepareForExecution() 方法
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
【stage34】
首先看 sparkPlan
在 org.apache.spark.sql.execution.QueryExecution#sparkPlan 下
这里的重点是把优化后的逻辑计划(即执行计划)转换成一个或多个物理执行计划
1 lazy val sparkPlan: SparkPlan = {
2 //设置当前活动的sparkSession
3 SparkSession.setActiveSession(sparkSession)
4 // TODO: We use next(), i.e. take the first plan returned by the planner, here for now,
5 // but we will implement to choose the best plan.
6
7 //QueryExecution获取一个sparkSession.sessionState.planner,这是一个优化器,其实现类是SparkPlanner, 该planner会把一个优化后的逻辑计划转换成一个或多个物理执行计划。
8 //注意看 optimizedPlan,先从这里继续跟踪
9 planner.plan(ReturnAnswer(optimizedPlan)).next()
10 }
【stage35】
在 org.apache.spark.sql.execution.QueryExecution#optimizedPlan 下
对分析后的逻辑计划(分析计划)进行优化,得到优化后的逻辑执行计划,即执行计划
1 //optimizedPlain是通过sparkSession.sessionState.optimizer对逻辑执行计划进行优化,得到优化后的逻辑执行计划
2 //在这里,我们关注 withCachedData
3 lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData)
【stage36】
在 org.apache.spark.sql.execution.QueryExecution#withCachedData 下
缓存分析后的逻辑计划(分析计划)
1 //对分析后的逻辑计划进行缓存,更新缓存中的计划
2 lazy val withCachedData: LogicalPlan = {
3 assertAnalyzed()
4 assertSupported()
5 //通过sparkSession.sharedState.cacheManager.useCachedData把analyzed进行缓存,或更新cacheManager中的检查后的逻辑执行计划。
6 sparkSession.sharedState.cacheManager.useCachedData(analyzed)
7 }
【stage37】
在 org.apache.spark.sql.execution.QueryExecution#analyzed 下
得到分析后的逻辑计划(分析计划)
1 //对逻辑计划进行分析,得到分析后的逻辑计划,即分析计划
2 lazy val analyzed: LogicalPlan = {
3 SparkSession.setActiveSession(sparkSession)
4 //通过sparkSession.sessionState.analyzer.executeAndCheck来检查一个逻辑执行计划,并得到一个分析和检查后的逻辑计划:analyzed。
5 sparkSession.sessionState.analyzer.execute(logical)
6 }
【核心stage38】
回到【stage34】中,查看下面代码
整体做的事就是把优化后的逻辑计划(即执行计划)转换成一个或多个物理执行计划
1 planner.plan(ReturnAnswer(optimizedPlan)).next()
我们接着分析 planner.plan( ... ) 这个方法
在 org.apache.spark.sql.catalyst.planning.QueryPlanner#plan 下
1 def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = {
2 // Obviously a lot to do here still...
3
4 // Collect physical plan candidates.
5 /**
6 * 【重点】
7 * 将strategies 进行遍历,将strategies逐个去应用在逻辑计划上,然后做flat操作,返回一个PhysicalPlan的iterator。
8 * 因为这个 strategies 是 DataSourceStrategy,Spark针对DataSource预定义了四种scan接口,
9 * 1、TableScan
10 * 2、PrunedScan
11 * 3、PrunedFilteredScan
12 * 4、CatalystScan(其中CatalystScan是unstable的,也是不常用的),
13 *
14 * 如果开发者(用户)自己实现的DataSource是实现了这四种接口之一的,在scan到执行计划的底层Relation时,就会调用来扫描文件。
15 * 这样最终得到一个Iterator[SparkPlan],每个SparkPlan就是可执行的物理操作了。
16 *
17 * strategies 的值此刻有包含 DataSourceStrategy 类,所以会执行 DataSourceStrategy 的 apply的方法
18 */
19 val candidates = strategies.iterator.flatMap(_(plan))
20
21 // The candidates may contain placeholders marked as [[planLater]],
22 // so try to replace them by their child plans.
23 //这个没仔细看,和这次的逻辑不太相关,大概是childPlan代替什么占位符
24 val plans = candidates.flatMap { candidate =>
25 val placeholders = collectPlaceholders(candidate)
26
27 if (placeholders.isEmpty) {
28 // Take the candidate as is because it does not contain placeholders.
29 Iterator(candidate)
30 } else {
31 // Plan the logical plan marked as [[planLater]] and replace the placeholders.
32 placeholders.iterator.foldLeft(Iterator(candidate)) {
33 case (candidatesWithPlaceholders, (placeholder, logicalPlan)) =>
34 // Plan the logical plan for the placeholder.
35 val childPlans = this.plan(logicalPlan)
36
37 candidatesWithPlaceholders.flatMap { candidateWithPlaceholders =>
38 childPlans.map { childPlan =>
39 // Replace the placeholder by the child plan
40 candidateWithPlaceholders.transformUp {
41 case p if p == placeholder => childPlan
42 }
43 }
44 }
45 }
46 }
47 }
48
49 val pruned = prunePlans(plans)
50
51 assert(pruned.hasNext, s"No plan for $plan")
52
53 //最终得到一个Iterator[SparkPlan],每个SparkPlan就是可执行的物理操作了。
54 pruned
55 }
【核心stage39】
在 org.apache.spark.sql.execution.datasources.DataSourceStrategy#apply 下
可以看到 读取数据的核心的逻辑了,就是 t.buildScan(a.map(_.name).toArray,这个代码会作为读取数据的主要逻辑被封装到一个rowRdd中
1 def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
2 ...
3
4 //程序会进入这段代码
5 case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) =>
6 pruneFilterProject(
7 l,
8 projects,
9 filters,
10
11 //【重点】重点看toCatalystRDD() 方法 和 t.buildScan() 方法
12 (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil
13
14 ...
15 }
【核心stage40】
在 org.apache.spark.sql.execution.datasources.DataSourceStrategy#toCatalystRDD 下
在这里可以看到rdd被转换成rowRdd的代码
1 private[this] def toCatalystRDD(
2 relation: LogicalRelation,
3 output: Seq[Attribute],
4 rdd: RDD[Row]): RDD[InternalRow] = {
5 if (relation.relation.needConversion) {
6 //将rdd,转换成RowRdd,这里的每行读取数据的核心代码,也就是外面传入的 t.buildScan()
7 execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
8 } else {
9 rdd.asInstanceOf[RDD[InternalRow]]
10 }
11 }
【核心stage41】
这里开始查看【核心stage39】 中的t.buildScan( ... ) 方法
这个t的类是 org.elasticsearch.spark.sql.ElasticsearchRelation,这里要查看 ElasticsearchRelation 下的buildScan方法
在 org.elasticsearch.spark.sql.ElasticsearchRelation#buildScan 下
1 def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = {
2 ...
3
4 //【重点】输出一个ScalaEsRowRDD,这就是读取es数据的核心RDD
5 new ScalaEsRowRDD(sqlContext.sparkContext, paramWithScan, lazySchema)
6 }
【核心stage42】
在 org.elasticsearch.spark.sql.ScalaEsRowRDDIterator 下
这时候,重点关注一下ScalaEsRowRDDIterator的 父类 AbstractEsRDDIterator
1 private[spark] class ScalaEsRowRDDIterator(
2 context: TaskContext,
3 partition: PartitionDefinition,
4 schema: SchemaUtils.Schema)
5 extends AbstractEsRDDIterator[Row](context, partition) { //【重点关注】继承了AbstractEsRDDIterator
6
7 override def getLogger() = LogFactory.getLog(classOf[ScalaEsRowRDD])
8
9 //初始化reader
10 override def initReader(settings: Settings, log: Log) = {
11 InitializationUtils.setValueReaderIfNotSet(settings, classOf[ScalaRowValueReader], log)
12
13 // parse the structure and save the order (requested by Spark) for each Row (root and nested)
14 // since the data returned from Elastic is likely to not be in the same order
15 SchemaUtils.setRowInfo(settings, schema.struct)
16 }
17
18 //输出es的value值
19 override def createValue(value: Array[Object]): Row = {
20 // drop the ID
21 value(1).asInstanceOf[ScalaEsRow]
22 }
23 }
【核心stage43】
【核心stage44】 看bug所在代码即可
在 org.elasticsearch.spark.rdd.AbstractEsRDDIterator 下,我们要关注的是hasNext() 代码
因为在读取数据的时候,会调用 AbstractEsRDDIterator 下的 hasNext() 方法,继续追踪代码
在 org.elasticsearch.spark.rdd.AbstractEsRDDIterator#hasNext 下
读取数据依靠 reader.hasNext() 来读取,这个reader就是
1 def hasNext: Boolean = {
2 if (CompatUtils.isInterrupted(context)) {
3 throw new TaskKilledException
4 }
5
6 //【重点】重点看 reader.hasNext()
7 !finished && reader.hasNext()
8 }
在 org.elasticsearch.hadoop.rest.ScrollQuery#hasNext 下
可以看到用scroll 方式去从es读取数据过来,距离我们拿到数据的代码很近了
1 public boolean hasNext() {
2 ...
3
4 if (!initialized) {
5 initialized = true;
6
7 try {
8 //【重点】这里用 scroll 方式去从es读取数据过来,query是我们传入的dsl语句,body是查询体
9 Scroll scroll = repository.scroll(query, body, reader);
10 // size is passed as a limit (since we can't pass it directly into the request) - if it's not specified (<1) just scroll the whole index
11 size = (size < 1 ? scroll.getTotalHits() : size);
12 scrollId = scroll.getScrollId();
13 batch = scroll.getHits();
14 } catch (IOException ex) {
15 throw new EsHadoopIllegalStateException(String.format("Cannot create scroll for query [%s/%s]", query, body), ex);
16 }
17
18 // no longer needed
19 body = null;
20 query = null;
21 }
22
23 ...
24
25 return true;
26 }
在 org.elasticsearch.hadoop.rest.RestRepository#scroll(java.lang.String, org.elasticsearch.hadoop.util.BytesArray, org.elasticsearch.hadoop.serialization.ScrollReader) 下
1 Scroll scroll(String query, BytesArray body, ScrollReader reader) throws IOException {
2 InputStream scroll = client.execute(POST, query, body).body();
3 try {
4 //【重点】读取数据
5 return reader.read(scroll);
6 } finally {
7 if (scroll instanceof StatsAware) {
8 stats.aggregate(((StatsAware) scroll).stats());
9 }
10 }
11 }
在 org.elasticsearch.hadoop.serialization.ScrollReader#read(java.io.InputStream) 下
1 public Scroll read(InputStream content) throws IOException {
2 Assert.notNull(content);
3
4 BytesArray copy = null;
5
6 if (log.isTraceEnabled() || returnRawJson) {
7 //copy content
8 copy = IOUtils.asBytes(content);
9 content = new FastByteArrayInputStream(copy);
10 log.trace("About to parse scroll content " + copy);
11 }
12
13 this.parser = new JacksonJsonParser(content);
14
15 try {
16 //【重点】读取数据
17 return read(copy);
18 } finally {
19 parser.close();
20 }
21 }
在 org.elasticsearch.hadoop.serialization.ScrollReader#read(org.elasticsearch.hadoop.util.BytesArray) 下
1 private Scroll read(BytesArray input) {
2 ...
3
4 for (token = parser.nextToken(); token != Token.END_ARRAY; token = parser.nextToken()) {
5 //【重点】从hit读取数据
6 results.add(readHit());
7 }
8
9 ...
10 }
在 org.elasticsearch.hadoop.serialization.ScrollReader#readHit 下
1 private Object[] readHit() {
2 Token t = parser.currentToken();
3 Assert.isTrue(t == Token.START_OBJECT, "expected object, found " + t);
4 //【重点】这里走readHitAsMap(),因为我们读取的数据是正常数据,没有特别设置返回json格式
5 return (returnRawJson ? readHitAsJson() : readHitAsMap());
6 }
在 org.elasticsearch.hadoop.serialization.ScrollReader#readHitAsMap 下
1 private Object[] readHitAsMap() {
2 Object[] result = new Object[2];
3 Object metadata = null;
4 Object id = null;
5
6 ...
7
8 //读取数据出来
9 data = read(StringUtils.EMPTY, t, null);
10
11 ...
12 }
在 org.elasticsearch.hadoop.serialization.ScrollReader#read(java.lang.String, org.elasticsearch.hadoop.serialization.Parser.Token, java.lang.String) 下
这里的 fieldMapping 就是我们要查询的字段 revise_status
1 protected Object read(String fieldName, Token t, String fieldMapping) {
2 ...
3 return map(fieldMapping);
4 ...
5 }
在 org.elasticsearch.hadoop.serialization.ScrollReader#map 下
1 protected Object map(String fieldMapping) {
2 ...
3 //reader读取数据,存入一个map,在方法的最后会输出出去,这里继续追踪read() 方法
4 reader.addToMap(map, fieldName, read(absoluteName, parser.nextToken(), nodeMapping));
5 ...
6 }
在 org.elasticsearch.hadoop.serialization.ScrollReader#read(java.lang.String, org.elasticsearch.hadoop.serialization.Parser.Token, java.lang.String) 下
1 protected Object read(String fieldName, Token t, String fieldMapping) {
2 ...
3
4 if (t.isValue()) {
5 String rawValue = parser.text();
6 try {
7 if (isArrayField(fieldMapping)) {
8 return singletonList(fieldMapping, parseValue(esType));
9 } else {
10 //【重点】按照字段类型来解析数据值,这里的esType值是keyword,对应es的mapping中的结构
11 return parseValue(esType);
12 }
13 } catch (Exception ex) {
14 throw new EsHadoopParsingException(String.format(Locale.ROOT, "Cannot parse value [%s] for field [%s]", rawValue, fieldName), ex);
15 }
16 }
17 return null;
18 }
在 org.elasticsearch.hadoop.serialization.ScrollReader#parseValue 下
private Object parseValue(FieldType esType) {
Object obj;
// special case of handing null (as text() will return "null")
if (parser.currentToken() == Token.VALUE_NULL) {
obj = null;
}
else {
//【重点】读取值
obj = reader.readValue(parser, parser.text(), esType);
}
parser.nextToken();
return obj;
}
在 org.elasticsearch.spark.sql.ScalaRowValueReader#readValue 下
1 override def readValue(parser: Parser, value: String, esType: FieldType) = {
2 sparkRowField = if (getCurrentField == null) null else getCurrentField.getFieldName
3
4 if (sparkRowField == null) {
5 sparkRowField = Utils.ROOT_LEVEL_NAME
6 }
7
8 //【重点】读取数据值
9 super.readValue(parser, value, esType)
10 }
在 org.elasticsearch.spark.serialization.ScalaValueReader#readValue 下
1 def readValue(parser: Parser, value: String, esType: FieldType) = {
2 if (esType == null || parser.currentToken() == VALUE_NULL) {
3 nullValue()
4
5 } else {
6 esType match {
7 case NULL => nullValue()
8 case STRING => textValue(value, parser)
9 case TEXT => textValue(value, parser)
10 //【重点】从这里进入
11 case KEYWORD => textValue(value, parser)
12 ...
13 }
14 }
15 }
在 org.elasticsearch.spark.serialization.ScalaValueReader#textValue 下
//【重点】追踪checkNull() 方法
def textValue(value: String, parser: Parser) = { checkNull (parseText, value, parser) }
【核心stage44】
在 org.elasticsearch.spark.serialization.ScalaValueReader#checkNull 下
这里就是bug的所在了
1 def checkNull(converter: (String, Parser) => Any, value: String, parser: Parser) = {
2 if (value != null) {
3 //【重点】当我的value是"" 空字符串的时候,StringUtils.hasText(value) 会判断为 false,即不认为空字符串是有效值
4 if (!StringUtils.hasText(value) && emptyAsNull) {
5 nullValue()
6 }
7 else {
8 converter(value, parser).asInstanceOf[AnyRef]
9 }
10 }
11 else {
12 nullValue()
13 }
14 }
修复方案就是在这句话之前判断空字符串为false即可,然后在项目中创建一个一模一样的类,在相同包名下,运行的时候就会自动覆盖原来的类,就能解决问题
1 def checkNull(converter: (String, Parser) => Any, value: String, parser: Parser) = {
2 if (value != null) {
3 //【修复】加了一个判断,!"".equals(value) ,当值为空字符串的时候,就到elsed 逻辑输出值即可
4 if (!"".equals(value) && !StringUtils.hasText(value) && emptyAsNull) {
5 nullValue()
6 }
7 else {
8 converter(value, parser).asInstanceOf[AnyRef]
9 }
10 }
11 else {
12 nullValue()
13 }
14 }
这里虽然解决了问题,但是觉得还可以继续再讲一讲后续是如何把rdd提交到Job执行
【stage45】
回到【stage33】中,回头看这句话,就知道 executedPlan 就是一个或多个物理执行计划
1 lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
再回到【stage32】中,看这段代码,执行物理计划,返回实际结果
1 //执行物理计划,返回实际结果。至此,这条sql之旅就结束了
2 val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
3 action(qe.executedPlan)
4 }
【stage46】
在 org.apache.spark.sql.execution.SQLExecution#withNewExecutionId 下
执行body函数,通过 【stage31】 可以知道,body就是 collectFromPlan() 方法,我们接着追踪 collectFromPlan() 方法
1 def withNewExecutionId[T](
2 sparkSession: SparkSession,
3 queryExecution: QueryExecution)(body: => T): T = {
4
5 ...
6 //执行body 动作,也就是外面传进来的action
7 body
8 ...
9 }
【stage47】
在 org.apache.spark.sql.Dataset#collectFromPlan 下
1 private def collectFromPlan(plan: SparkPlan): Array[T] = {
2 plan.executeCollect().map(boundEnc.fromRow)
3 }
【stage48】
在 org.apache.spark.sql.execution.CollectLimitExec#executeCollect 下
1 //child为SparkPlan,所以是调用SparkPlan.executeTake(limit)
2 override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
【stage49】
在 org.apache.spark.sql.execution.SparkPlan#executeTake 下
在这段代码中,做了很重要的两件事
1、获取前面计划任务生成的RowRdd的方法
2、提交job,获取结果
1 def executeTake(n: Int): Array[InternalRow] = {
2 if (n == 0) {
3 return new Array[InternalRow](0)
4 }
5
6 //【重点】获取前面计划任务生成的RowRdd的方法
7 val childRDD = getByteArrayRdd(n)
8
9 val buf = new ArrayBuffer[InternalRow]
10 val totalParts = childRDD.partitions.length
11 var partsScanned = 0
12 while (buf.size < n && partsScanned < totalParts) {
13 // The number of partitions to try in this iteration. It is ok for this number to be
14 // greater than totalParts because we actually cap it at totalParts in runJob.
15 var numPartsToTry = 1L
16 if (partsScanned > 0) {
17 // If we didn't find any rows after the previous iteration, quadruple and retry.
18 // Otherwise, interpolate the number of partitions we need to try, but overestimate
19 // it by 50%. We also cap the estimation in the end.
20 val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2)
21 if (buf.isEmpty) {
22 numPartsToTry = partsScanned * limitScaleUpFactor
23 } else {
24 // the left side of max is >=1 whenever partsScanned >= 2
25 numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1)
26 numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
27 }
28 }
29
30 val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
31 val sc = sqlContext.sparkContext
32
33 //【重点】提交job
34 val res = sc.runJob(childRDD,
35 (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty[Byte], p)
36
37 buf ++= res.flatMap(decodeUnsafeRows)
38
39 partsScanned += p.size
40 }
41
42 if (buf.size > n) {
43 buf.take(n).toArray
44 } else {
45 buf.toArray
46 }
47 }
至此,EsSparkSQL获取数据的源码解析基本完毕