SortShuffleWriter
概述
SortShuffleWriter它主要是判断在Map端是否需要本地进行combine操作。如果需要聚合,则使用PartitionedAppendOnlyMap;如果不进行combine操作,则使用PartitionedPairBuffer添加数据存放于内存中。然后无论哪一种情况都需要判断内存是否足够,如果内存不够而且又申请不到内存,则需要进行本地磁盘溢写操作,把相关的数据写入溢写到临时文件。最后把内存里的数据和磁盘溢写的临时文件的数据进行合并,如果需要则进行一次归并排序,如果没有发生溢写则是不需要归并排序,因为都在内存里。最后生成合并后的data文件和index文件。
write方法
该方法实现如下:
1、创建外部排序器ExternalSorter, 只是根据是否需要本地combine与否从而决定是否传入aggregator和keyOrdering参数;
2、调用ExternalSorter实例的insertAll方法,插入record;
如果ExternalSorter实例中用以保存record的in-memory collection的大小达到阈值,会将record按顺序溢写到磁盘文件。
3、 构造最终的输出文件实例,其中文件名为(reduceId为0): "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId;
4、在输出文件名后加上uuid用于标识文件正在写入,结束后重命名;
5、调用ExternalSorter实例的writePartitionedFile方法,将插入到该sorter的record进行排序并写入输出文件;
插入到sorter的record可以是在in-memory collection或者在溢写文件。
6、将每个partition的offset写入index文件方便reduce端fetch数据;
7、 把部分信息封装到MapStatus返回;
/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
sorter.insertAll(records)
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
/*构造最终的输出文件实例,其中文件名为(reduceId为0): "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId;
*/
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
//在输出文件名后加上uuid用于标识文件正在写入,结束后重命名
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
//将排序后的record写入输出文件
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
//将每个partition的offset写入index文件方便reduce端fetch数据
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
}
}
}
ExternalSorter
概述
对大量的(k, v)键值对进行排序,并且可能合并,从而产生(k, c)类型的key-combiner对。使用一个partitioner将key分组划分到partition里,然后使用自定义comparator对每个partition里的key进行排序。最后,将每个partition中不同字节范围的(k, v)键值对写入到一个输出文件,以便shuffle fetch。
如果禁用了combining,则类型C必须等于V - 我们将在最后转换对象类型。
注意:虽然ExternalSorter是一个相当通用的分类器,但它的一些配置是绑定到基于sort的shuffle的使用当中。例如:block compression使用的是"spark.shuffle.compress"。如果是在非shuffle上下文使用ExternalSorter,也许我们应该重新审视这个类,使用不同配置设置。
该类几个重要的构造函数参数如下:
@param aggregator 可选,aggregator 具有用于合并数据的组合函数
@param partitioner 可选; 如果给定,则按partitionID排序,然后按key
@param ordering 可选;对每个partition内的key进行排序时的顺序,是一个总的顺序
@param serializer 当溢出到磁盘时使用的serializer
请注意,如果给定了ordering,我们将始终使用它进行排序,所以只有在你确实想要输出的key被排序时才提供这个参数。在没有map端聚合的map task中,你可能想传递None作为ordering参数来避免意外排序。另一方面,如果你真的想做combining,有一个ordering参数的效率是比没有的要高的。
使用者应该使用以下方式与这个类交互:
- 初始化一个ExternalSorter实例;
- 调用ExternalSorter实例的insertAll方法,插入一批record;
- 调用iterator()方法,使用迭代器迭代已经排序完成或者聚合完成的record;或者调用writePartitionedFile()方法,在sort shuffle中将已经排序完成或者聚合完成的的record写入输出文件;
这个类的内部工作原理如下:
我们将内存上的数据反复填充到PartitionedAppendOnlyMap(需要按照key合并时),或者PartitionedPairBuffer(不需要按照key合并时),将它们作为buffer。在这些buffer中,我们会按照PartitionId,以及可能按照key,对元素进行排序。为了避免每个key都调用partitioner多次,我们在每个record上存储partitionId。
当每个buffer到达我们的内存限制时,我们会将其溢出到文件中。这个文件首先按照partitionId进行排序,然后按照key或者key的哈希值进行排序,如果我们想要做聚合的话。对于每个文件,我们都会追踪内存中的每个partition的对象的数量,所以我们不需要为每个元素写上partitionId。
当用户请求使用迭代器或者文件输出时,溢出的文件会被合并,同时包括内存上剩余的数据。合并时使用的是上面定义的排序顺序(除非sorting和aggregation都同时被禁用了)。如果我们需要按照key来聚合,我们要么使用来自ordering参数的总的排序顺序,要么按照相同哈希值读取key值,并且互相比较以合并value值。
期望用户在最后调用stop方法来删除所有中间文件。
ExternalSort的父类
Spillable是ExternalSort的父类。同时,Spillable也是MemoryConsumer的子类。
Spillable类用于当内存超过阈值时,溢出in-memory collection的内容到磁盘上。
in-memory collection指的是PartitionedAppendOnlyMap或者PartitionPairBuffer数据结构。
成员变量
- serializerBatchSize:从serializer读取对象,或将对象写入serializer时,对象的批处理数量。当对象以批处理方式写入时,每一批都使用它们自己的serialization stream。这在解序列化一个流时,能减少refrence-tracking map的初始化大小。注意,将这个值设置得过小,会导致在序列化时频繁复制,因为有些serializer在每次对象数量翻倍时,增长内部数据结构是靠growing + copying。
- PartitionedAppendOnlyMap和partitionedPairBuffer:in-memory collection,在spill之前在内存上存储record的数据结构。根据是否需要聚合来决定将对象放到AppendOnlyMap还是PartitionedPairBuffer中。如果需要map端的聚合,使用PartitionedOnlyMap,否则使用partitionPairBuffer。
- keyComparator:key值的比较器,用以将一个partition内的key进行排序,从而允许聚合或者排序。如果ordering参数没有提供这个comparator,可以使用默认的comparator通过hashcode进行部分排序。部分排序意味着相等的key具有comparator.compare(k,k)= 0,但有些不相等的key也有这个,所以我们需要做一个稍后的传递来找到真正相等的key。ps:equals()方法相等的key,它的hashCode()方法一定相等;hashCode()方法相等的key,equals()方法不一定相等。所以通过比较hashCode只能实现部分排序。
- spills:当in-memory collection的大小达到阈值,会将collection上的record按顺序溢出到磁盘文件。用该ArrayBuffer[SpilledFile]实例保存溢写文件的相关信息。
// Size of object batches when reading/writing from serializers.
//
// Objects are written in batches, with each batch using its own serialization stream. This
// cuts down on the size of reference-tracking maps constructed when deserializing a stream.
//
// NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
// grow internal data structures by growing + copying every time the number of objects doubles.
private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
@volatile private var map = new PartitionedAppendOnlyMap[K, C]
@volatile private var buffer = new PartitionedPairBuffer[K, C]
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
// non-equal keys also have this, so we need to do a later pass to find truly equal keys).
// Note that we ignore this if no aggregator and no ordering are given.
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
override def compare(a: K, b: K): Int = {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
if (h1 < h2) -1 else if (h1 == h2) 0 else 1
}
})
// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
private[this] case class SpilledFile(
file: File,
blockId: BlockId,
serializerBatchSizes: Array[Long],
elementsPerPartition: Array[Long])
private val spills = new ArrayBuffer[SpilledFile]
注意,如果aggregator和ordering参数都没有给定,则我们忽略keyComparator。
comparator方法返回ordering参数指定的comparator——也就是成员变量keyComparator;如果没有定义ordering参数,comparator方法返回None
private def comparator: Option[Comparator[K]] = {
if (ordering.isDefined || aggregator.isDefined) {
Some(keyComparator)
} else {
None
}
}
ExternlSorter插入record
insertAll方法
该方法实现如下:
1、如果需要map端的聚合:
获取aggregator的mergeValue函数和createCombiner函数,并以此创建update函数。update函数的作用是,如果有值进行mergeValue,如果没有则createCombiner。
迭代record,计算record的分区,并调用PartitionedAppendOnlyMap#changeValue方法,执行update函数。
最后,调用maybeSpillCollection方法判断需要溢出数据到磁盘。
2、如果不需要map端的聚合:
迭代record,计算record的分区,并调用PartitionedPairBuffer#insert方法插入buffer。
最后,调用maybeSpillCollection方法判断需要溢出数据到磁盘。
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined
if (shouldCombine) {
// Combine values in-memory first using our AppendOnlyMap
// 使用AppendOnlyMap优先在内存中进行combine
// 获取aggregator的mergeValue函数,用于merge新的值到聚合记录
val mergeValue = aggregator.get.mergeValue
// 获取aggregator的createCombiner函数,用于创建聚合的初始值
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
//创建update函数,如果有值进行mergeValue,如果没有则createCombiner
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (records.hasNext) {
//处理一个元素,就更新一次结果
addElementsRead()
//取出一个(key,value)
kv = records.next()
// 对key计算分区,然后开始进行merge
map.changeValue((getPartition(kv._1), kv._1), update)
// 如果需要溢写内存数据到磁盘
maybeSpillCollection(usingMap = true)
}
} else { // 不需要进行本地combine
// Stick values into our buffer
while (records.hasNext) {
//处理一个元素,就更新一次结果
addElementsRead()
// 取出一个(key,value)
val kv = records.next()
// 往PartitionedPairBuffer添加数据
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
// 如果需要溢写内存数据到磁盘
maybeSpillCollection(usingMap = false)
}
}
}
maybeSpillCollection方法
该方法实现如下:
1、如果需要map端的聚合:
估计map的大小,根据预估的map大小决定是否需要进行spill。如果需要spill,在spill之后,初始化一个新的PartitionedAppendOnlyMap。
2、如果不需要map端的聚合:
估计buffer的大小,根据预估的buffer大小决定是否需要进行spill。如果需要spill,spill之后,初始化一个新的PartitionedPairBuffer。
/**
* Spill the current in-memory collection to disk if needed.
*
* @param usingMap whether we're using a map or buffer as our current in-memory collection
*/
private def maybeSpillCollection(usingMap: Boolean): Unit = {
var estimatedSize = 0L
if (usingMap) { //如果使用PartitionedAppendOnlyMap
//估计map的大小
estimatedSize = map.estimateSize()
//根据预估的map大小决定是否需要进行spill
if (maybeSpill(map, estimatedSize)) {
//spill之后,初始化一个新的PartitionedAppendOnlyMap
map = new PartitionedAppendOnlyMap[K, C]
}
} else { //如果使用PartitionedPairBuffer
//估计buffer的大小
estimatedSize = buffer.estimateSize()
//调用父类Spillable的maybeSpill方法,根据预估的buffer大小决定是否需要进行spill
if (maybeSpill(buffer, estimatedSize)) {
//spill之后,初始化一个新的PartitionedPairBuffer
buffer = new PartitionedPairBuffer[K, C]
}
}
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
}
maybeSpill方法
maybeSpillCollection方法会调用父类Spillable的maybeSpill方法。
该方法根据预估的buffer大小决定是否需要进行spill,如果需要spill则调用spill方法进行spill。
该方法实现如下:
如果读取的数据是32的倍数,而且当前内存大于内存阀值,默认是5M
会先尝试向TaskMemoryManager申请(2 * currentMemory - myMemoryThreshold)大小的内存
如果能够申请到,则不进行Spill操作,而是继续向Buffer中存储数据,
否则就会调用spill()方法将Buffer中数据输出到磁盘文件
/**
* Spills the current in-memory collection to disk if needed. Attempts to acquire more
* memory before spilling.
*
* @param collection collection to spill to disk
* @param currentMemory estimated size of the collection in bytes
* @return true if `collection` was spilled to disk; false otherwise
*/
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
//底层调用TaskMemoryManager的acquireExecutionMemory方法分配内存
val granted = acquireMemory(amountToRequest)
// 更新现在内存阀值
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
//再次判断当前内存是否大于阀值,如果还是大于阀值则spill
shouldSpill = currentMemory >= myMemoryThreshold
}
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
// Actually spill
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
//开始spill
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemory()
}
shouldSpill
}
spill方法
该方法将in-memory collection上的内容按照比较器的顺序溢出到磁盘文件。当在in-memory collection的大小达到阈值时被调用。
该方法实现如下:
1、获取比较器。comparator方法返回ordering参数指定的comparator——也就是成员变量keyComparator;如果没有定义ordering参数,comparator方法返回None。
2、获取根据比较器排序后的的in-memory collection的迭代器。
3、溢写in-memory collection的数据到磁盘一个临时文件。
4、 更新溢写的临时磁盘文件。
/**
* Spill our in-memory collection to a sorted file that we can merge later.
* We add this file into `spilledFiles` to find it later.
*
* @param collection whichever collection we're using (map or buffer)
*/
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
//返回一个根据指定的比较器排序的迭代器
//comparator方法返回ordering参数指定的comparator——也就是成员变量keyComparator,
//如果没有定义ordering参数,comparator方法返回null.
val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
// 溢写in-memory collection的数据到磁盘一个临时文件
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
// 更新溢写的临时磁盘文件
spills += spillFile
}
spillMemoryIteratorToDisk方法
1、创建临时文件
2、创建一个DiskBlockObjectWriter用于写临时文件
3、迭代in-memory collection的 inMemoryIterator,用DiskBlockObjectWriter写入当前迭代的record。如果写入record的数量到达阈值,将disk writer的缓冲区内容flush到磁盘。
/**
* Spill contents of in-memory iterator to a temporary file on disk.
*/
private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
: SpilledFile = {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
//创建临时文件
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
// These variables are reset after each flush
var objectsWritten: Long = 0
val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
//创建一个DiskBlockObjectWriter用于写临时文件
val writer: DiskBlockObjectWriter =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
// List of batch sizes (bytes) in the order they are written to disk
val batchSizes = new ArrayBuffer[Long]
// How many elements we have in each partition
val elementsPerPartition = new Array[Long](numPartitions)
// Flush the disk writer's contents to disk, and update relevant variables.
// The writer is committed at the end of this process.
//将disk writr的缓冲区内容flush到磁盘,并更新相关变量
def flush(): Unit = {
val segment = writer.commitAndGet()
batchSizes += segment.length
_diskBytesSpilled += segment.length
objectsWritten = 0
}
var success = false
try {
//迭代in-memory collection的排序且可写入分区的Iterator(WritablePartitionedIterator)
while (inMemoryIterator.hasNext) {
// 获取partitionId
val partitionId = inMemoryIterator.nextPartition()
require(partitionId >= 0 && partitionId < numPartitions,
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
//用DiskBlockObjectWriter写入当前迭代的record
inMemoryIterator.writeNext(writer)
//当前迭代的partitionId的record数量加一
elementsPerPartition(partitionId) += 1
//记录写入record的数量加一
objectsWritten += 1
//如果写入record的数量到达阈值,将disk writer的缓冲区内容flush到磁盘
if (objectsWritten == serializerBatchSize) {
flush()
}
}
//迭代完成之后,如果存在record写入到disk writer的缓冲区,同样需要flush到磁盘
if (objectsWritten > 0) {
flush()
} else {
writer.revertPartialWritesAndClose()
}
success = true
} finally {
if (success) {
writer.close()
} else {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
writer.revertPartialWritesAndClose()
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file}")
}
}
}
}
//创建SpilledFile然后返回
SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
}
ExternalSorter将插入到该sorter的record进行排序并写入到一个磁盘文件
writePartitionedFile方法
该方法将插入到ExternalSorter的record写入到一个磁盘文件。插入到sorter的record可以是在in-memory collection或者在溢写文件。
# 溢写文件为空,则内存足够,不需要溢写结果到磁盘, 返回一个对结果排序的迭代器, 遍历数据写入data临时文件;再将数据刷到磁盘文件,返回FileSegment对象;构造一个分区文件长度的数组
# 溢写文件不为空,则需要将溢写的文件和内存数据合并,合并之后则需要进行归并排序(merge-sort);数据写入data临时文件,再将数据刷到磁盘文件,返回FileSegment对象;构造一个分区文件长度的数组
# 返回分区文件长度的数组
/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
* called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
def writePartitionedFile(
blockId: BlockId,
outputFile: File): Array[Long] = {
// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics().shuffleWriteMetrics)
//如果溢写文件信息的数组为空
if (spills.isEmpty) {
// Case where we only have in-memory data
//则属于只有in-memory data的情况
//根据是否定义map端的聚合获取相应的in-memory collection
val collection = if (aggregator.isDefined) map else buffer
//获取collection的排序且可写入分区的iterator
//iterator迭代的元素类型为((Int, K), V)
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
//迭代元素
while (it.hasNext) {
//获取当前迭代的partitionId
val partitionId = it.nextPartition()
//二次迭代,迭代当前的partitionId的所有record
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
//将同一个partitionId的所有record数据提交,作为一个block
val segment = writer.commitAndGet()
lengths(partitionId) = segment.length
}
} else {
// We must perform merge-sort; get an iterator by partition and write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
for (elem <- elements) {
writer.write(elem._1, elem._2)
}
val segment = writer.commitAndGet()
lengths(id) = segment.length
}
}
}
writer.close()
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
lengths
}
partitionedIterator方法
该方法返回一个能迭代所有插入到ExternalSorter的record的迭代器。这些record已经经过partitionId进行分区,并经过aggregator的函数聚合。
该迭代器的泛型类型为Iterator[(Int, Iterator[Product2[K, C]])]。Int类型代表的是partitionId,每个partition都有一个与之对应的iterator迭代器,用以迭代该partition上的record。partition的迭代器之间是按顺序访问的,你不能在未迭代完当前的partition就跳过迭代一个新的partition。
/**
* Return an iterator over all the data written to this object, grouped by partition and
* aggregated by the requested aggregator. For each partition we then have an iterator over its
* contents, and these are expected to be accessed in order (you can't "skip ahead" to one
* partition without reading the previous one). Guaranteed to return a key-value pair for each
* partition, in order of partition ID.
*
* For now, we just merge all the spilled files in once pass, but this can be modified to
* support hierarchical merging.
* Exposed for testing.
*/
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
// 是否需要本地combine
val usingMap = aggregator.isDefined
// 根据是否需要本地combine获取相应的in-memory collection
val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
//如果没有发生磁盘溢写
if (spills.isEmpty) {
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
// we don't even need to sort by anything other than partition ID
// 而且不需要排序
if (!ordering.isDefined) {
// The user hasn't requested sorted keys, so only sort by partition ID, not key
//数据只是按照partitionId排序,并不会对key进行排序
groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
} else { //如果需要排序
// We do need to sort by both partition ID and key
//先按照partitionId排序,然后分区内部对key进行排序
groupByPartition(destructiveIterator(
collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
}
} else {
// Merge spilled and in-memory data
// 如果发生了溢写操作,则需要将磁盘上溢写文件和in-memory collection的数据进行合并
merge(spills, destructiveIterator(
collection.partitionedDestructiveSortedIterator(comparator)))
}
}
merge方法
合并磁盘上溢写文件的数据和in-memory collection的数据。
当存在溢写文件时,会调用到此方法。该方法返回一个泛型类型为Iterator[(Int, Iterator[Product2[K, C]])]的迭代器,用以迭代所有partition,再迭代每个partition上的所有record。
/**
* Merge a sequence of sorted files, giving an iterator over partitions and then over elements
* inside each partition. This can be used to either write out a new file or return data to
* the user.
*
* Returns an iterator over all the data written to this object, grouped by partition. For each
* partition we then have an iterator over its contents, and these are expected to be accessed
* in order (you can't "skip ahead" to one partition without reading the previous one).
* Guaranteed to return a key-value pair for each partition, in order of partition ID.
*/
private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
val readers = spills.map(new SpillReader(_))
//调用buffered方法返回一个BufferedIterator
val inMemBuffered = inMemory.buffered
//迭代partitionId,对每个partitionId使用映射函数映射出新值,并返回这些新值的迭代器
(0 until numPartitions).iterator.map { p =>
//返回给定partitionId相应的所有record的迭代器
val inMemIterator = new IteratorForPartition(p, inMemBuffered)
val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
if (aggregator.isDefined) {
// Perform partial aggregation across partitions
(p, mergeWithAggregation(
iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
} else if (ordering.isDefined) {
// No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
// sort the elements without trying to merge them
(p, mergeSort(iterators, ordering.get))
} else {
(p, iterators.iterator.flatten)
}
}
}
mergeWithAggregation方法
/**
* Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each
* iterator is sorted by key with a given comparator. If the comparator is not a total ordering
* (e.g. when we sort objects by hash code and different keys may compare as equal although
* they're not), we still merge them by doing equality tests for all keys that compare as equal.
*/
private def mergeWithAggregation(
iterators: Seq[Iterator[Product2[K, C]]],
mergeCombiners: (C, C) => C,
comparator: Comparator[K],
totalOrder: Boolean)
: Iterator[Product2[K, C]] =
{
if (!totalOrder) {
// We only have a partial ordering, e.g. comparing the keys by hash code, which means that
// multiple distinct keys might be treated as equal by the ordering. To deal with this, we
// need to read all keys considered equal by the ordering at once and compare them.
new Iterator[Iterator[Product2[K, C]]] {
val sorted = mergeSort(iterators, comparator).buffered
// Buffers reused across elements to decrease memory allocation
val keys = new ArrayBuffer[K]
val combiners = new ArrayBuffer[C]
override def hasNext: Boolean = sorted.hasNext
override def next(): Iterator[Product2[K, C]] = {
if (!hasNext) {
throw new NoSuchElementException
}
keys.clear()
combiners.clear()
val firstPair = sorted.next()
keys += firstPair._1
combiners += firstPair._2
val key = firstPair._1
while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
val pair = sorted.next()
var i = 0
var foundKey = false
while (i < keys.size && !foundKey) {
if (keys(i) == pair._1) {
combiners(i) = mergeCombiners(combiners(i), pair._2)
foundKey = true
}
i += 1
}
if (!foundKey) {
keys += pair._1
combiners += pair._2
}
}
// Note that we return an iterator of elements since we could've had many keys marked
// equal by the partial order; we flatten this below to get a flat iterator of (K, C).
keys.iterator.zip(combiners.iterator)
}
}.flatMap(i => i)
} else {
// We have a total ordering, so the objects with the same key are sequential.
new Iterator[Product2[K, C]] {
val sorted = mergeSort(iterators, comparator).buffered
override def hasNext: Boolean = sorted.hasNext
override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val elem = sorted.next()
val k = elem._1
var c = elem._2
while (sorted.hasNext && sorted.head._1 == k) {
val pair = sorted.next()
c = mergeCombiners(c, pair._2)
}
(k, c)
}
}
}
}
SpillReader
/**
* An internal class for reading a spilled file partition by partition. Expects all the
* partitions to be requested in order.
*/
private[this] class SpillReader(spill: SpilledFile) {
// Serializer batch offsets; size will be batchSize.length + 1
val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _)
// Track which partition and which batch stream we're in. These will be the indices of
// the next element we will read. We'll also store the last partition read so that
// readNextPartition() can figure out what partition that was from.
var partitionId = 0
var indexInPartition = 0L
var batchId = 0
var indexInBatch = 0
var lastPartitionId = 0
skipToNextPartition()
// Intermediate file and deserializer streams that read from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
var fileStream: FileInputStream = null
var deserializeStream = nextBatchStream() // Also sets fileStream
var nextItem: (K, C) = null
var finished = false
/** Construct a stream that only reads from the next batch */
def nextBatchStream(): DeserializationStream = {
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
// we're still in a valid batch.
if (batchId < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
deserializeStream = null
fileStream = null
}
val start = batchOffsets(batchId)
fileStream = new FileInputStream(spill.file)
fileStream.getChannel.position(start)
batchId += 1
val end = batchOffsets(batchId)
assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream)
serInstance.deserializeStream(wrappedStream)
} else {
// No more batches left
cleanup()
null
}
}
/**
* Update partitionId if we have reached the end of our current partition, possibly skipping
* empty partitions on the way.
*/
private def skipToNextPartition() {
while (partitionId < numPartitions &&
indexInPartition == spill.elementsPerPartition(partitionId)) {
partitionId += 1
indexInPartition = 0L
}
}
/**
* Return the next (K, C) pair from the deserialization stream and update partitionId,
* indexInPartition, indexInBatch and such to match its location.
*
* If the current batch is drained, construct a stream for the next batch and read from it.
* If no more pairs are left, return null.
*/
private def readNextItem(): (K, C) = {
if (finished || deserializeStream == null) {
return null
}
val k = deserializeStream.readKey().asInstanceOf[K]
val c = deserializeStream.readValue().asInstanceOf[C]
lastPartitionId = partitionId
// Start reading the next batch if we're done with this one
indexInBatch += 1
if (indexInBatch == serializerBatchSize) {
indexInBatch = 0
deserializeStream = nextBatchStream()
}
// Update the partition location of the element we're reading
indexInPartition += 1
skipToNextPartition()
// If we've finished reading the last partition, remember that we're done
if (partitionId == numPartitions) {
finished = true
if (deserializeStream != null) {
deserializeStream.close()
}
}
(k, c)
}
var nextPartitionToRead = 0
def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] {
val myPartition = nextPartitionToRead
nextPartitionToRead += 1
override def hasNext: Boolean = {
if (nextItem == null) {
nextItem = readNextItem()
if (nextItem == null) {
return false
}
}
assert(lastPartitionId >= myPartition)
// Check that we're still in the right partition; note that readNextItem will have returned
// null at EOF above so we would've returned false there
lastPartitionId == myPartition
}
override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val item = nextItem
nextItem = null
item
}
}
// Clean up our open streams and put us in a state where we can't read any more data
def cleanup() {
batchId = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
deserializeStream = null
fileStream = null
if (ds != null) {
ds.close()
}
// NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop().
// This should also be fixed in ExternalAppendOnlyMap.
}
}
参考:Spark源码分析之Sort-Based Shuffle读写流程