SparkSQL专门为读取HDFS上的文件开的外部数据源接口,spark-parquet、csv、json等都是这种方式。
DefaultSource
入口类,用来建立外部数据源连接,SparkSQL默认会找这个名字,不要改类名。基本所有接口都在这个类里
private[tsfile] class DefaultSource extends FileFormat with DataSourceRegister {
class TSFileDataSourceException(message: String, cause: Throwable)
extends Exception(message, cause) {
def this(message: String) = this(message, null)
}
override def equals(other: Any): Boolean = other match {
case _: DefaultSource => true
case _ => false
}
inferSchema接口
返回文件对应的表结构
参数说明:
files是连接的文件,返回SparkSQL的一张表结构。
用xxx代表文件后缀,不同用法对应不同的参数:
- 通配符指定路径
read.xxx("hdfs:///data/*/*.xxx")
这时files中包括所有匹配到的文件
- 指定文件夹
read.xxx("hdfs:///data/")
这时files中会包括/data文件夹下的所有文件
- 指定文件
read.xxx("hdfs:///data/a.xxx")
这时files中就只有一个a.xxx文件
//返回表结构,StrucType里包含StructField的list,每个StructField是一列
override def inferSchema(
spark: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
val conf = spark.sparkContext.hadoopConfiguration
//check if the path is given
options.getOrElse(DefaultSource.path, throw new TSFileDataSourceException(s"${DefaultSource.path} must be specified for cn.edu.thu.tsfile DataSource"))
val fields = new ListBuffer[StructField]()
fields += StructField(SQLConstant.RESERVED_TIME, LongType, nullable = false)
SchemaType(StructType(fields.toList), nullable = false).dataType match {
case t: StructType => Some(t)
case _ =>throw new RuntimeException(
s"""TSFile schema cannot be converted to a Spark SQL StructType:
|${tsfileSchema.toString}
|""".stripMargin)
}
}
isSplitable接口
是否可按HDFS的block size切分文件,如果true,下面拿到的PartitionedFile大小就是一个block size
override def isSplitable(
sparkSession: SparkSession,
options: Map[String, String],
path: Path): Boolean = {
true
}
buildReader接口
此方法返回一个方法:读取一个 PartitionedFile ,以Iterator的方式返回数据
该方法返回一个方法,返回的方法参数为PartitionedFile,即一个文件块,包括起始和终止。处理这个分片,将数据构造成Internal的Iterator返回。主要实现hasnext和next方法,将数据按表结构一行一行交给sparkSQL
参数中的一些注意的地方:
- requiredSchema
spark.read.XXX(path).createOrReplaceTempView("table_name")
spark.sql(select * from table_name).show()
这时requiredSchema是表中所有列
spark.read.XXX(path).createOrReplaceTempView("table_name")
spark.sql(select * from table_name where c1 = 1).count()
这时requiredSchema是where后出现的列
spark.read.XXX(path).createOrReplaceTempView("table_name")
spark.sql(select * from table_name).count()
这时requiredSchema是空,其实spark最后再计数时也没有拿数据,只是调用了count值遍的hasNext和next。可以考虑不真正取数据,只构造count个空的InternalRow,这样会快很多。
- partitionSchema
当目录结构如下时:其中baseFoder和a=1,a=2都是文件夹名
baseFolder/a=1/file1,baseFolder/a=2/file2
将baseFolder路径当做path传入时,partitionSchema中就是a,IntegerType类型。在表中会自动添加一列a,值是1或者2。
override def buildReader(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
val broadcastedConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
(file: PartitionedFile) => {
val log = LoggerFactory.getLogger(classOf[DefaultSource])
log.info(file.toString())
val conf = broadcastedConf.value.value
val in = new HDFSInputStream(new Path(new URI(file.filePath)), conf)
//打印task信息和文件块信息
Option(TaskContext.get()).foreach { taskContext => {
taskContext.addTaskCompletionListener { _ => in.close() }
log.info("task Id: " + taskContext.taskAttemptId() + " partition Id: " + taskContext.partitionId())
}
}
var curRecord: Record = null
//构造Iterator,重写next和hasNext方法
new Iterator[InternalRow] {
private val rowBuffer = Array.fill[Any](requiredSchema.length)(null)
private val safeDataRow = new GenericRow(rowBuffer)
// 用来将 `Row` 转化为 `InternalRow`
private val encoderForDataColumns = RowEncoder(requiredSchema)
private var deltaObjectId = "null"
override def hasNext: Boolean = {
false
}
override def next(): InternalRow = {
//填充 rowBuffer 的每一个字段
rowBuffer(index) = columns(columnIndex)
//将包装 rowbuffer 的 safeDataRow 转化为 Row
encoderForDataColumns.toRow(safeDataRow)
}
}
}
}
别名
override def shortName(): String = "tsfile"
序列化Configuration
private[tsfile] object DefaultSource {
//这个参数必须为path,对应调用读文件接口传入的文件路径,不能自定义名字
val path = "path"
val columnNames = new ArrayBuffer[String]()
//这个类不需要改,用来序列化Configuration
class SerializableConfiguration(@transient var value: Configuration) extends Serializable {
private def writeObject(out: ObjectOutputStream): Unit = {
out.defaultWriteObject()
value.write(out)
}
private def readObject(in: ObjectInputStream): Unit = {
value = new Configuration(false)
value.readFields(in)
}
}
}
- 注册隐式方法
package object tsfile {
/**
* add a method 'tsfile' to DataFrameReader to read tsfile
*/
implicit class TsFileDataFrameReader(reader: DataFrameReader) {
def tsfile: String => DataFrame = reader.format("cn.edu.tsinghua.tsfile").load
}
/**
* add a method 'tsfile' to DataFrameWriter to write tsfile
*/
implicit class TsFileDataFrameWriter[T](writer: DataFrameWriter[T]) {
def tsfile: String => Unit = writer.format("cn.edu.tsinghua.tsfile").save
}
}
调用示例:
```
val df = spark.read.tsfile(“hdfs:”)
```
完整代码: