SparkSQL专门为读取HDFS上的文件开的外部数据源接口,spark-parquet、csv、json等都是这种方式。

DefaultSource

入口类,用来建立外部数据源连接,SparkSQL默认会找这个名字,不要改类名。基本所有接口都在这个类里

private[tsfile] class DefaultSource extends FileFormat with DataSourceRegister {

  class TSFileDataSourceException(message: String, cause: Throwable)
    extends Exception(message, cause) {
    def this(message: String) = this(message, null)
  }

  override def equals(other: Any): Boolean = other match {
    case _: DefaultSource => true
    case _ => false
  }
	

inferSchema接口

返回文件对应的表结构

参数说明:

files是连接的文件,返回SparkSQL的一张表结构。

用xxx代表文件后缀,不同用法对应不同的参数:

  • 通配符指定路径
read.xxx("hdfs:///data/*/*.xxx")

这时files中包括所有匹配到的文件

  • 指定文件夹
read.xxx("hdfs:///data/")

这时files中会包括/data文件夹下的所有文件

  • 指定文件
read.xxx("hdfs:///data/a.xxx")

这时files中就只有一个a.xxx文件

	//返回表结构,StrucType里包含StructField的list,每个StructField是一列
  override def inferSchema(
                            spark: SparkSession,
                            options: Map[String, String],
                            files: Seq[FileStatus]): Option[StructType] = {
    val conf = spark.sparkContext.hadoopConfiguration

    //check if the path is given
    options.getOrElse(DefaultSource.path, throw new TSFileDataSourceException(s"${DefaultSource.path} must be specified for cn.edu.thu.tsfile DataSource"))

	val fields = new ListBuffer[StructField]()
    fields += StructField(SQLConstant.RESERVED_TIME, LongType, nullable = false)
    
    SchemaType(StructType(fields.toList), nullable = false).dataType match {
      case t: StructType => Some(t)
      case _ =>throw new RuntimeException(
        s"""TSFile schema cannot be converted to a Spark SQL StructType:
           |${tsfileSchema.toString}
           |""".stripMargin)
    }

  }


isSplitable接口

是否可按HDFS的block size切分文件,如果true,下面拿到的PartitionedFile大小就是一个block size

  override def isSplitable(
                            sparkSession: SparkSession,
                            options: Map[String, String],
                            path: Path): Boolean = {
    true
  }

buildReader接口

此方法返回一个方法:读取一个 PartitionedFile ,以Iterator的方式返回数据

该方法返回一个方法,返回的方法参数为PartitionedFile,即一个文件块,包括起始和终止。处理这个分片,将数据构造成Internal的Iterator返回。主要实现hasnext和next方法,将数据按表结构一行一行交给sparkSQL

参数中的一些注意的地方:

  • requiredSchema
spark.read.XXX(path).createOrReplaceTempView("table_name")
spark.sql(select * from table_name).show()

这时requiredSchema是表中所有列

spark.read.XXX(path).createOrReplaceTempView("table_name")
spark.sql(select * from table_name where c1 = 1).count()

这时requiredSchema是where后出现的列

spark.read.XXX(path).createOrReplaceTempView("table_name")
spark.sql(select * from table_name).count()

这时requiredSchema是空,其实spark最后再计数时也没有拿数据,只是调用了count值遍的hasNext和next。可以考虑不真正取数据,只构造count个空的InternalRow,这样会快很多。

  • partitionSchema

当目录结构如下时:其中baseFoder和a=1,a=2都是文件夹名
baseFolder/a=1/file1,baseFolder/a=2/file2

将baseFolder路径当做path传入时,partitionSchema中就是a,IntegerType类型。在表中会自动添加一列a,值是1或者2。

  override def buildReader(
                            sparkSession: SparkSession,
                            dataSchema: StructType,
                            partitionSchema: StructType,
                            requiredSchema: StructType,
                            filters: Seq[Filter],
                            options: Map[String, String],
                            hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
    val broadcastedConf =
      sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

    (file: PartitionedFile) => {
      val log = LoggerFactory.getLogger(classOf[DefaultSource])
      log.info(file.toString())

      val conf = broadcastedConf.value.value
      val in = new HDFSInputStream(new Path(new URI(file.filePath)), conf)

		//打印task信息和文件块信息
      Option(TaskContext.get()).foreach { taskContext => {
        taskContext.addTaskCompletionListener { _ => in.close() }
        log.info("task Id: " + taskContext.taskAttemptId() + " partition Id: " + taskContext.partitionId())
      }
      }


     
      var curRecord: Record = null

		//构造Iterator,重写next和hasNext方法
      new Iterator[InternalRow] {
        private val rowBuffer = Array.fill[Any](requiredSchema.length)(null)

        private val safeDataRow = new GenericRow(rowBuffer)

        // 用来将 `Row` 转化为 `InternalRow`
        private val encoderForDataColumns = RowEncoder(requiredSchema)

        private var deltaObjectId = "null"

        override def hasNext: Boolean = {
          false
        }

        override def next(): InternalRow = {
        	//填充 rowBuffer 的每一个字段
          rowBuffer(index) = columns(columnIndex)    
          //将包装 rowbuffer 的 safeDataRow 转化为 Row      
          encoderForDataColumns.toRow(safeDataRow)
        }
      }
    }
  }

别名

  override def shortName(): String = "tsfile"


序列化Configuration


private[tsfile] object DefaultSource {
  //这个参数必须为path,对应调用读文件接口传入的文件路径,不能自定义名字
  val path = "path"
  val columnNames = new ArrayBuffer[String]()

  //这个类不需要改,用来序列化Configuration
  class SerializableConfiguration(@transient var value: Configuration) extends Serializable {
    private def writeObject(out: ObjectOutputStream): Unit = {
      out.defaultWriteObject()
      value.write(out)
    }

    private def readObject(in: ObjectInputStream): Unit = {
      value = new Configuration(false)
      value.readFields(in)
    }
  }

}
  • 注册隐式方法

package object tsfile {

  /**
    * add a method 'tsfile' to DataFrameReader to read tsfile
    */
  implicit class TsFileDataFrameReader(reader: DataFrameReader) {
    def tsfile: String => DataFrame = reader.format("cn.edu.tsinghua.tsfile").load
  }

  /**
    * add a method 'tsfile' to DataFrameWriter to write tsfile
    */
  implicit class TsFileDataFrameWriter[T](writer: DataFrameWriter[T]) {
    def tsfile: String => Unit = writer.format("cn.edu.tsinghua.tsfile").save
  }
}

调用示例:

```
 val df = spark.read.tsfile(“hdfs:”)
```

完整代码:

https://github.com/thulab/tsfile-spark-connector