在实际使用 spark + parquet 的时候, 遇到了两个不解的地方:
- 我们只有一个 parquet 文件(小于 hdfs block size), 但是 spark 在某个 stage 生成了4个 tasks 来处理.
- 4个 tasks 中只有一个 task 处理了所有数据, 其他几个都没有处理数据.
这两个问题牵涉到对于 parquet spark 是如何来进行切分 partitions, 以及每个 partition 要处理哪部分数据的.
先说结论, spark 中, parquet 是 splitable 的, 代码见ParquetFileFormat#isSplitable
. 那会不会把数据切碎? 答案是不会, 因为是以 spark row group 为最小单位切分 parquet 的, 这也会导致一些 partitions 会没有数据, 极端情况下, 只有一个 row group 的话, partitions 再多, 也只会一个有数据.
接下来开始我们的源码之旅:
处理流程
1. 根据 parquet 按文件大小切块生成 partitions:
在 FileSourceScanExec#createNonBucketedReadRDD
中, 如果文件是 splitable 的, 会按照 maxSplitBytes 把文件切分, 最后生成的数量, 就是 RDD partition 的数量, 这个解释了不解1, 代码如下:
val maxSplitBytes = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
s"open cost is considered as scanning $openCostInBytes bytes.")
val splitFiles = selectedPartitions.flatMap { partition =>
partition.files.flatMap { file =>
val blockLocations = getBlockLocations(file)
if (fsRelation.fileFormat.isSplitable(
fsRelation.sparkSession, fsRelation.options, file.getPath)) {
(0L until file.getLen by maxSplitBytes).map { offset =>
val remaining = file.getLen - offset
val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
val hosts = getBlockHosts(blockLocations, offset, size)
PartitionedFile(
partition.values, file.getPath.toUri.toString, offset, size, hosts)
}
} else {
val hosts = getBlockHosts(blockLocations, 0, file.getLen)
Seq(PartitionedFile(
partition.values, file.getPath.toUri.toString, 0, file.getLen, hosts))
}
}
}.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
val partitions = new ArrayBuffer[FilePartition]
val currentFiles = new ArrayBuffer[PartitionedFile]
var currentSize = 0L
/** Close the current partition and move to the next. */
def closePartition(): Unit = {
if (currentFiles.nonEmpty) {
val newPartition =
FilePartition(
partitions.size,
currentFiles.toArray.toSeq) // Copy to a new Array.
partitions += newPartition
}
currentFiles.clear()
currentSize = 0
}
// Assign files to partitions using "First Fit Decreasing" (FFD)
splitFiles.foreach { file =>
if (currentSize + file.length > maxSplitBytes) {
closePartition()
}
// Add the given file to the current partition.
currentSize += file.length + openCostInBytes
currentFiles += file
}
closePartition()
new FileScanRDD(fsRelation.sparkSession, readFile, partitions)
2. 使用 ParquetInputSplit 构造 reader:
在 ParquetFileFormat#buildReaderWithPartitionValues
实现中, 会使用 split 来初始化 reader, 并且根据配置可以把 reader 分为否是 vectorized 的:
vectorizedReader.initialize(split, hadoopAttemptContext)
reader.initialize(split, hadoopAttemptContext)
关于 步骤2 在画外中还有更详细的代码, 但与本文的主流程关系不大, 这里先不表.
3. 划分 parquet 的 row groups 到不同的 partitions 中去
在 步骤1 中根据文件大小均分了一些 partitions, 但不是所有这些 partitions 最后都会有数据.
接回 步骤2 中的 init, 在 SpecificParquetRecordReaderBase#initialize
中, 会在 readFooter
的时候传入一个 RangeMetadataFilter
, 这个 filter 的range 是根据你的 split 的边界来的, 最后会用这个 range 来划定 row group 的归属:
public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
throws IOException, InterruptedException {
...
footer = readFooter(configuration, file, range(inputSplit.getStart(), inputSplit.getEnd()));
...
}
parquet 的ParquetFileReader#readFooter
方法会用到ParquetMetadataConverter#converter.readParquetMetadata(f, filter);
, 这个readParquetMetadata
对于RangeMetadataFilter
的处理是:
@Override
public FileMetaData visit(RangeMetadataFilter filter) throws IOException {
return filterFileMetaDataByMidpoint(readFileMetaData(from), filter);
}
终于到了最关键的切分的地方, 最关键的就是这一段, 谁拥有这个 row group的中点, 谁就可以处理这个 row group.
现在假设我们有一个40m 的文件, 只有一个 row group, 10m 一分, 那么将会有4个 partitions, 但是只有一个 partition 会占有这个 row group 的中点, 所以也只有这一个 partition 会有数据.
long midPoint = startIndex + totalSize / 2;
if (filter.contains(midPoint)) {
newRowGroups.add(rowGroup);
}
完整代码如下:
static FileMetaData filterFileMetaDataByMidpoint(FileMetaData metaData, RangeMetadataFilter filter) {
List<RowGroup> rowGroups = metaData.getRow_groups();
List<RowGroup> newRowGroups = new ArrayList<RowGroup>();
for (RowGroup rowGroup : rowGroups) {
long totalSize = 0;
long startIndex = getOffset(rowGroup.getColumns().get(0));
for (ColumnChunk col : rowGroup.getColumns()) {
totalSize += col.getMeta_data().getTotal_compressed_size();
}
long midPoint = startIndex + totalSize / 2;
if (filter.contains(midPoint)) {
newRowGroups.add(rowGroup);
}
}
metaData.setRow_groups(newRowGroups);
return metaData;
}
画外:
步骤2 中的代码其实是 spark 正儿八经如何读文件的代码, 最后返回一个FileScanRDD
, 也很值得顺路看一下, 完整代码如下:
(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
val fileSplit =
new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty)
val split =
new org.apache.parquet.hadoop.ParquetInputSplit(
fileSplit.getPath,
fileSplit.getStart,
fileSplit.getStart + fileSplit.getLength,
fileSplit.getLength,
fileSplit.getLocations,
null)
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val hadoopAttemptContext =
new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId)
// Try to push down filters when filter push-down is enabled.
// Notice: This push-down is RowGroups level, not individual records.
if (pushed.isDefined) {
ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get)
}
val parquetReader = if (enableVectorizedReader) {
val vectorizedReader = new VectorizedParquetRecordReader()
vectorizedReader.initialize(split, hadoopAttemptContext)
logDebug(s"Appending $partitionSchema ${file.partitionValues}")
vectorizedReader.initBatch(partitionSchema, file.partitionValues)
if (returningBatch) {
vectorizedReader.enableReturningBatches()
}
vectorizedReader
} else {
logDebug(s"Falling back to parquet-mr")
// ParquetRecordReader returns UnsafeRow
val reader = pushed match {
case Some(filter) =>
new ParquetRecordReader[UnsafeRow](
new ParquetReadSupport,
FilterCompat.get(filter, null))
case _ =>
new ParquetRecordReader[UnsafeRow](new ParquetReadSupport)
}
reader.initialize(split, hadoopAttemptContext)
reader
}
val iter = new RecordReaderIterator(parquetReader)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
// UnsafeRowParquetRecordReader appends the columns internally to avoid another copy.
if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] &&
enableVectorizedReader) {
iter.asInstanceOf[Iterator[InternalRow]]
} else {
val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
val joinedRow = new JoinedRow()
val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
// This is a horrible erasure hack... if we type the iterator above, then it actually check
// the type in next() and we get a class cast exception. If we make that function return
// Object, then we can defer the cast until later!
if (partitionSchema.length == 0) {
// There is no partition columns
iter.asInstanceOf[Iterator[InternalRow]]
} else {
iter.asInstanceOf[Iterator[InternalRow]]
.map(d => appendPartitionColumns(joinedRow(d, file.partitionValues)))
}
}
}
这个返回的(PartitionedFile) => Iterator[InternalRow]
, 是在FileSourceScanExec#inputRDD
用的
private lazy val inputRDD: RDD[InternalRow] = {
val readFile: (PartitionedFile) => Iterator[InternalRow] =
relation.fileFormat.buildReaderWithPartitionValues(
sparkSession = relation.sparkSession,
dataSchema = relation.dataSchema,
partitionSchema = relation.partitionSchema,
requiredSchema = requiredSchema,
filters = pushedDownFilters,
options = relation.options,
hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))
relation.bucketSpec match {
case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled =>
createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation)
case _ =>
createNonBucketedReadRDD(readFile, selectedPartitions, relation)
}
}
FileScanRDD
class FileScanRDD(
@transient private val sparkSession: SparkSession,
readFunction: (PartitionedFile) => Iterator[InternalRow],
@transient val filePartitions: Seq[FilePartition])
extends RDD[InternalRow](sparkSession.sparkContext, Nil) {
override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = {
private[this] val files = split.asInstanceOf[FilePartition].files.toIterator
private[this] var currentFile: PartitionedFile = null // 根据 currentFile = files.next() 来的, 具体实现我就不贴了 有兴趣的可以自己看下.
...
readFunction(currentFile)
...
}
}
结论
提升一个 parquet 中的 row group 中的行数阈值, 籍此提示 spark 并行度.