文章目录

  • DataSource
  • Spark 对外暴漏的读写文件的入口:
  • writer.save() 方法
  • DataFrameReader.load() 方法
  • java.util.ServiceLoader
  • 扩展Spark 支持的DataSource


DataSource

DataSource 是Spark用来描述对应的数据文件格式的入口,对应的Delta也是一种数据文件格式,所以了解DataSource实现原理,是了解 Delta 的基础。

Spark 对外暴漏的读写文件的入口:

// 创建一个writer再调用writer.save()方法保存 DF
val writer = df.write.format("delta").mode("append")
writer.save(path)

// 创建一个DataFrameReader,再调用 load() 方法,返回DF
sparkSession.read.format("delta").load(path)

在上面的save方法和load方法中有一部分相同的代码,val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) 这个是根据用户指定的文件格式来找到对应的 DataSource 入口;

  • 如果writer的DataSource 是 DataSourceV2 子类,返回case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan)(WriteToDataSourceV2 已经被废弃,已经被AppendData 这样具体的实现类代替);
  • 如果writer的DataSource 是 v1 版本,是根据相关元数据信息,生成一个 DataSource对象,然后来执行数据写入;
  • 如果reader的DataSource 是 DataSourceV2 子类,会根据dataSource类和数据文件路径,还有相关元信息,创建一个DataSourceV2Relation对象,然后通过 Dataset.ofRows() 方法来读取;
  • 如果reader的DataSource 是 v1 版本,创建一个根据相关元数据信息,创建一个BaseRelation 对象,再继续封装为 LogicalRelation 对象,再解析和读取;

writer.save() 方法

def save(): Unit = {
    if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
      throw new AnalysisException("Hive data source can only be used with tables, you can not " +
        "write files of Hive data source directly.")
    }

    assertNotBucketed("save")

    val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
    if (classOf[DataSourceV2].isAssignableFrom(cls)) {
      val source = cls.newInstance().asInstanceOf[DataSourceV2]
      source match {
        case ws: WriteSupport =>
          val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
            source,
            df.sparkSession.sessionState.conf)
          val options = sessionOptions ++ extraOptions

          val writer = ws.createWriter(
            UUID.randomUUID.toString, df.logicalPlan.output.toStructType, mode,
            new DataSourceOptions(options.asJava))

          if (writer.isPresent) {
            runCommand(df.sparkSession, "save") {
              WriteToDataSourceV2(writer.get, df.logicalPlan)
            }
          }

        // Streaming also uses the data source V2 API. So it may be that the data source implements
        // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
        // as though it's a V1 source.
        case _ => saveToV1Source()
      }
    } else {
      saveToV1Source()
    }
  }

  private def saveToV1Source(): Unit = {
    //....

    // Code path for data source v1.
    runCommand(df.sparkSession, "save") {
      DataSource(
        sparkSession = df.sparkSession,
        className = source,
        partitionColumns = partitioningColumns.getOrElse(Nil),
        options = extraOptions.toMap).planForWriting(mode, df.logicalPlan)
    }
  }

DataFrameReader.load() 方法

def load(paths: String*): DataFrame = {
    if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
      throw new AnalysisException("Hive data source can only be used with tables, you can not " +
        "read files of Hive data source directly.")
    }

    val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
    if (classOf[DataSourceV2].isAssignableFrom(cls)) {
      val ds = cls.newInstance().asInstanceOf[DataSourceV2]
      if (ds.isInstanceOf[ReadSupport]) {
        val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
          ds = ds, conf = sparkSession.sessionState.conf)
        val pathsOption = {
          val objectMapper = new ObjectMapper()
          DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray)
        }
        Dataset.ofRows(sparkSession, DataSourceV2Relation.create(
          ds, sessionOptions ++ extraOptions.toMap + pathsOption,
          userSpecifiedSchema = userSpecifiedSchema))
      } else {
        loadV1Source(paths: _*)
      }
    } else {
      loadV1Source(paths: _*)
    }
  }

  private def loadV1Source(paths: String*) = {
    // Code path for data source v1.
    sparkSession.baseRelationToDataFrame(
      DataSource.apply(
        sparkSession,
        paths = paths,
        userSpecifiedSchema = userSpecifiedSchema,
        className = source,
        options = extraOptions.toMap).resolveRelation())
  }

java.util.ServiceLoader

在继续讲解之前,先学习Java的一个工具类ServiceLoader
如果我们定义了一个接口,并在我们系统内部实现了一部分子类,满足了系统的需求,可是过了一段时间,我们有了一个新的计算逻辑,或者一个新的外部扩展需要实现,这个时候,我们需要修改我们的原始代码,通过类名来反射实例化出我们需要的实现类的对象,再操作吗?如果项目很大,修改原始代码就变的太重了,而且,我们可能还有很多第三方的实现,扩展功能将在第三方变的封闭;
幸好,Java给我们提供了ServiceLoader 这个工具类,来看一下Java是如果来实现系统功能的自动扩展的把:

  • 创建Java接口或scala trait,这里使用的是 trait DataSourceRegister
  • 在项目资源文件夹下的META-INF/services/文件夹创建以类名为名字的配置文件 $delta/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister.
    在配置文件中保存当前接口的所有实现类 org.apache.spark.sql.delta.sources.DeltaDataSource
  • 代码中load当前接口的所有实现类 ServiceLoader.load(classOf[DataSourceRegister], loader)
  • 返回结果serviceLoader 实现了Iterable接口,可以进行遍历返回结果

扩展Spark 支持的DataSource

Spark在最开始设计的时候,就预备了不仅支持内部自定义的数据文件格式,还提供了扩展 DataSource来支持各种文件格式的入口;还是从上面的加载DataSource类的入口出发:val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)

现在看这个代码,逻辑就非常清晰了:

  • 通过 ServiceLoader 去加载所有的DataSource 子类,并根据类中的 shortName() 方法和外部传入的文件格式进行过滤
  • 如果找到一个,直接返回该类
  • 如果找到多个,需要返回类名以 org.apache.spark 的那一个,如果还是有多个,查找失败
  • 如果用户传入的是一种 FileFormat 子类,则尝试直接去加载该子类或该子类的DefaultSource对应类,加载成功,直接返回
// DataSource.scala
/** Given a provider name, look up the data source class definition. */
  def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
    // 通过传入的参数,找到对应的 FileFormatClass类,传入的参数也可以是orc, 或者类似 org.apache.spark.sql.parquet 这样的简写
    // orc 在native 和 hive 中使用不同的FileFormat来存取
    val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match {
      case name if name.equalsIgnoreCase("orc") &&
          conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native" =>
        classOf[OrcFileFormat].getCanonicalName
      case name if name.equalsIgnoreCase("orc") &&
          conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" =>
        "org.apache.spark.sql.hive.orc.OrcFileFormat"
      case "com.databricks.spark.avro" if conf.replaceDatabricksSparkAvroEnabled =>
        "org.apache.spark.sql.avro.AvroFileFormat"
      case name => name
    }
    val provider2 = s"$provider1.DefaultSource"
    // 使用 ServiceLoader 类来加载实现类 DeltaDataSource 类
    val loader = Utils.getContextOrSparkClassLoader
    val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)

    try {
      serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider1)).toList match {
        // the provider format did not match any given registered aliases
        case Nil =>
          try {
            Try(loader.loadClass(provider1)).orElse(Try(loader.loadClass(provider2))) match {
              case Success(dataSource) =>
                // Found the data source using fully qualified path
                dataSource
              case Failure(error) =>
                //...
            }
          } catch {
            case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal
              // NoClassDefFoundError's class name uses "/" rather than "." for packages
              val className = e.getMessage.replaceAll("/", ".")
              if (spark2RemovedClasses.contains(className)) {
                throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " +
                  "Please check if your library is compatible with Spark 2.0", e)
              } else {
                throw e
              }
          }
        case head :: Nil =>
          // there is exactly one registered alias
          head.getClass
        case sources =>
          // There are multiple registered aliases for the input. If there is single datasource
          // that has "org.apache.spark" package in the prefix, we use it considering it is an
          // internal datasource within Spark.
          val sourceNames = sources.map(_.getClass.getName)
          val internalSources = sources.filter(_.getClass.getName.startsWith("org.apache.spark"))
          if (internalSources.size == 1) {
            logWarning(s"Multiple sources found for $provider1 (${sourceNames.mkString(", ")}), " +
              s"defaulting to the internal datasource (${internalSources.head.getClass.getName}).")
            internalSources.head.getClass
          } else {
            throw new AnalysisException(s"Multiple sources found for $provider1 " +
              s"(${sourceNames.mkString(", ")}), please specify the fully qualified class name.")
          }
      }
    } catch {
      // ...
    }
  }

backwardCompatibilityMap 中的 DataSource 映射关系

0 = {Tuple2@14493} "(org.apache.spark.sql.hive.orc.DefaultSource,org.apache.spark.sql.hive.orc.OrcFileFormat)"
1 = {Tuple2@14494} "(org.apache.spark.sql.execution.datasources.json,org.apache.spark.sql.execution.datasources.json.JsonFileFormat)"
2 = {Tuple2@14495} "(org.apache.spark.sql.execution.streaming.RateSourceProvider,org.apache.spark.sql.execution.streaming.sources.RateStreamProvider)"
3 = {Tuple2@14496} "(org.apache.spark.sql.execution.datasources.json.DefaultSource,org.apache.spark.sql.execution.datasources.json.JsonFileFormat)"
4 = {Tuple2@14497} "(org.apache.spark.ml.source.libsvm.DefaultSource,org.apache.spark.ml.source.libsvm.LibSVMFileFormat)"
5 = {Tuple2@14498} "(org.apache.spark.ml.source.libsvm,org.apache.spark.ml.source.libsvm.LibSVMFileFormat)"
6 = {Tuple2@14499} "(org.apache.spark.sql.execution.datasources.orc.DefaultSource,org.apache.spark.sql.execution.datasources.orc.OrcFileFormat)"
7 = {Tuple2@14500} "(org.apache.spark.sql.jdbc.DefaultSource,org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider)"
8 = {Tuple2@14501} "(org.apache.spark.sql.json.DefaultSource,org.apache.spark.sql.execution.datasources.json.JsonFileFormat)"
9 = {Tuple2@14502} "(org.apache.spark.sql.json,org.apache.spark.sql.execution.datasources.json.JsonFileFormat)"
10 = {Tuple2@14503} "(org.apache.spark.sql.execution.datasources.jdbc,org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider)"
11 = {Tuple2@14504} "(org.apache.spark.sql.parquet,org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat)"
12 = {Tuple2@14505} "(org.apache.spark.sql.parquet.DefaultSource,org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat)"
13 = {Tuple2@14506} "(org.apache.spark.sql.execution.datasources.parquet.DefaultSource,org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat)"
14 = {Tuple2@14507} "(com.databricks.spark.csv,org.apache.spark.sql.execution.datasources.csv.CSVFileFormat)"
15 = {Tuple2@14508} "(org.apache.spark.sql.hive.orc,org.apache.spark.sql.hive.orc.OrcFileFormat)"
16 = {Tuple2@14509} "(org.apache.spark.sql.execution.datasources.orc,org.apache.spark.sql.execution.datasources.orc.OrcFileFormat)"
17 = {Tuple2@14510} "(org.apache.spark.sql.jdbc,org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider)"
18 = {Tuple2@14511} "(org.apache.spark.sql.execution.datasources.jdbc.DefaultSource,org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider)"
19 = {Tuple2@14512} "(org.apache.spark.sql.execution.streaming.TextSocketSourceProvider,org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider)"
20 = {Tuple2@14513} "(org.apache.spark.sql.execution.datasources.parquet,org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat)"