SortShuffleWriter

概述

SortShuffleWriter它主要是判断在Map端是否需要本地进行combine操作。如果需要聚合,则使用PartitionedAppendOnlyMap;如果不进行combine操作,则使用PartitionedPairBuffer添加数据存放于内存中。然后无论哪一种情况都需要判断内存是否足够,如果内存不够而且又申请不到内存,则需要进行本地磁盘溢写操作,把相关的数据写入溢写到临时文件。最后把内存里的数据和磁盘溢写的临时文件的数据进行合并,如果需要则进行一次归并排序,如果没有发生溢写则是不需要归并排序,因为都在内存里。最后生成合并后的data文件和index文件。

write方法

该方法实现如下:

1、创建外部排序器ExternalSorter, 只是根据是否需要本地combine与否从而决定是否传入aggregator和keyOrdering参数;

2、调用ExternalSorter实例的insertAll方法,插入record;

      如果ExternalSorter实例中用以保存record的in-memory collection的大小达到阈值,会将record按顺序溢写到磁盘文件。

3、 构造最终的输出文件实例,其中文件名为(reduceId为0): "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId;

4、在输出文件名后加上uuid用于标识文件正在写入,结束后重命名;

5、调用ExternalSorter实例的writePartitionedFile方法,将插入到该sorter的record进行排序并写入输出文件;

     插入到sorter的record可以是在in-memory collection或者在溢写文件。

6、将每个partition的offset写入index文件方便reduce端fetch数据;

7、 把部分信息封装到MapStatus返回;

/** Write a bunch of records to this task's output */
  override def write(records: Iterator[Product2[K, V]]): Unit = {
    sorter = if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      new ExternalSorter[K, V, V](
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }
    sorter.insertAll(records)

    // Don't bother including the time to open the merged output file in the shuffle write time,
    // because it just opens a single file, so is typically too fast to measure accurately
    // (see SPARK-3570).
   /*构造最终的输出文件实例,其中文件名为(reduceId为0): "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId;
   */
    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
   //在输出文件名后加上uuid用于标识文件正在写入,结束后重命名
    val tmp = Utils.tempFileWith(output)
    try {
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      //将排序后的record写入输出文件
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      //将每个partition的offset写入index文件方便reduce端fetch数据
      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }
  }

ExternalSorter

概述

对大量的(k, v)键值对进行排序,并且可能合并,从而产生(k, c)类型的key-combiner对。使用一个partitioner将key分组划分到partition里,然后使用自定义comparator对每个partition里的key进行排序。最后,将每个partition中不同字节范围的(k, v)键值对写入到一个输出文件,以便shuffle fetch。

如果禁用了combining,则类型C必须等于V - 我们将在最后转换对象类型。

注意:虽然ExternalSorter是一个相当通用的分类器,但它的一些配置是绑定到基于sort的shuffle的使用当中。例如:block compression使用的是"spark.shuffle.compress"。如果是在非shuffle上下文使用ExternalSorter,也许我们应该重新审视这个类,使用不同配置设置。

该类几个重要的构造函数参数如下:

@param aggregator 可选,aggregator 具有用于合并数据的组合函数
@param partitioner 可选; 如果给定,则按partitionID排序,然后按key
@param ordering 可选;对每个partition内的key进行排序时的顺序,是一个总的顺序
@param serializer 当溢出到磁盘时使用的serializer

请注意,如果给定了ordering,我们将始终使用它进行排序,所以只有在你确实想要输出的key被排序时才提供这个参数。在没有map端聚合的map task中,你可能想传递None作为ordering参数来避免意外排序。另一方面,如果你真的想做combining,有一个ordering参数的效率是比没有的要高的。

使用者应该使用以下方式与这个类交互:

  1. 初始化一个ExternalSorter实例;
  2. 调用ExternalSorter实例的insertAll方法,插入一批record;
  3. 调用iterator()方法,使用迭代器迭代已经排序完成或者聚合完成的record;或者调用writePartitionedFile()方法,在sort shuffle中将已经排序完成或者聚合完成的的record写入输出文件;

这个类的内部工作原理如下:

我们将内存上的数据反复填充到PartitionedAppendOnlyMap(需要按照key合并时),或者PartitionedPairBuffer(不需要按照key合并时),将它们作为buffer。在这些buffer中,我们会按照PartitionId,以及可能按照key,对元素进行排序。为了避免每个key都调用partitioner多次,我们在每个record上存储partitionId。

当每个buffer到达我们的内存限制时,我们会将其溢出到文件中。这个文件首先按照partitionId进行排序,然后按照key或者key的哈希值进行排序,如果我们想要做聚合的话。对于每个文件,我们都会追踪内存中的每个partition的对象的数量,所以我们不需要为每个元素写上partitionId。

当用户请求使用迭代器或者文件输出时,溢出的文件会被合并,同时包括内存上剩余的数据。合并时使用的是上面定义的排序顺序(除非sorting和aggregation都同时被禁用了)。如果我们需要按照key来聚合,我们要么使用来自ordering参数的总的排序顺序,要么按照相同哈希值读取key值,并且互相比较以合并value值。

期望用户在最后调用stop方法来删除所有中间文件。

ExternalSort的父类

Spillable是ExternalSort的父类。同时,Spillable也是MemoryConsumer的子类。

Spillable类用于当内存超过阈值时,溢出in-memory collection的内容到磁盘上。
in-memory collection指的是PartitionedAppendOnlyMap或者PartitionPairBuffer数据结构。

 

成员变量

  • serializerBatchSize:从serializer读取对象,或将对象写入serializer时,对象的批处理数量。当对象以批处理方式写入时,每一批都使用它们自己的serialization stream。这在解序列化一个流时,能减少refrence-tracking map的初始化大小。注意,将这个值设置得过小,会导致在序列化时频繁复制,因为有些serializer在每次对象数量翻倍时,增长内部数据结构是靠growing + copying。
  • PartitionedAppendOnlyMap和partitionedPairBuffer:in-memory collection,在spill之前在内存上存储record的数据结构。根据是否需要聚合来决定将对象放到AppendOnlyMap还是PartitionedPairBuffer中。如果需要map端的聚合,使用PartitionedOnlyMap,否则使用partitionPairBuffer。
  • keyComparator:key值的比较器,用以将一个partition内的key进行排序,从而允许聚合或者排序。如果ordering参数没有提供这个comparator,可以使用默认的comparator通过hashcode进行部分排序。部分排序意味着相等的key具有comparator.compare(k,k)= 0,但有些不相等的key也有这个,所以我们需要做一个稍后的传递来找到真正相等的key。ps:equals()方法相等的key,它的hashCode()方法一定相等;hashCode()方法相等的key,equals()方法不一定相等。所以通过比较hashCode只能实现部分排序。
  • spills:当in-memory collection的大小达到阈值,会将collection上的record按顺序溢出到磁盘文件。用该ArrayBuffer[SpilledFile]实例保存溢写文件的相关信息。
// Size of object batches when reading/writing from serializers.
  //
  // Objects are written in batches, with each batch using its own serialization stream. This
  // cuts down on the size of reference-tracking maps constructed when deserializing a stream.
  //
  // NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
  // grow internal data structures by growing + copying every time the number of objects doubles.
  private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)

  // Data structures to store in-memory objects before we spill. Depending on whether we have an
  // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
  // store them in an array buffer.
  @volatile private var map = new PartitionedAppendOnlyMap[K, C]
  @volatile private var buffer = new PartitionedPairBuffer[K, C]

 // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
  // Can be a partial ordering by hash code if a total ordering is not provided through by the
  // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
  // non-equal keys also have this, so we need to do a later pass to find truly equal keys).
  // Note that we ignore this if no aggregator and no ordering are given.
  private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
    override def compare(a: K, b: K): Int = {
      val h1 = if (a == null) 0 else a.hashCode()
      val h2 = if (b == null) 0 else b.hashCode()
      if (h1 < h2) -1 else if (h1 == h2) 0 else 1
    }
  })

 // Information about a spilled file. Includes sizes in bytes of "batches" written by the
  // serializer as we periodically reset its stream, as well as number of elements in each
  // partition, used to efficiently keep track of partitions when merging.
  private[this] case class SpilledFile(
    file: File,
    blockId: BlockId,
    serializerBatchSizes: Array[Long],
    elementsPerPartition: Array[Long])

  private val spills = new ArrayBuffer[SpilledFile]

注意,如果aggregator和ordering参数都没有给定,则我们忽略keyComparator。 

comparator方法返回ordering参数指定的comparator——也就是成员变量keyComparator;如果没有定义ordering参数,comparator方法返回None

private def comparator: Option[Comparator[K]] = {
    if (ordering.isDefined || aggregator.isDefined) {
      Some(keyComparator)
    } else {
      None
    }
  }

ExternlSorter插入record

 insertAll方法

该方法实现如下:

1、如果需要map端的聚合:

      获取aggregator的mergeValue函数和createCombiner函数,并以此创建update函数。update函数的作用是,如果有值进行mergeValue,如果没有则createCombiner。

      迭代record,计算record的分区,并调用PartitionedAppendOnlyMap#changeValue方法,执行update函数。

      最后,调用maybeSpillCollection方法判断需要溢出数据到磁盘。

2、如果不需要map端的聚合:

     迭代record,计算record的分区,并调用PartitionedPairBuffer#insert方法插入buffer。

     最后,调用maybeSpillCollection方法判断需要溢出数据到磁盘。

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined

    if (shouldCombine) {
      // Combine values in-memory first using our AppendOnlyMap
      // 使用AppendOnlyMap优先在内存中进行combine
      // 获取aggregator的mergeValue函数,用于merge新的值到聚合记录
      val mergeValue = aggregator.get.mergeValue
      // 获取aggregator的createCombiner函数,用于创建聚合的初始值
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      //创建update函数,如果有值进行mergeValue,如果没有则createCombiner
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      while (records.hasNext) {
        //处理一个元素,就更新一次结果
        addElementsRead()
        //取出一个(key,value)
        kv = records.next()
        // 对key计算分区,然后开始进行merge
        map.changeValue((getPartition(kv._1), kv._1), update)
        // 如果需要溢写内存数据到磁盘
        maybeSpillCollection(usingMap = true)
      }
    } else { // 不需要进行本地combine
      // Stick values into our buffer
      while (records.hasNext) {
        //处理一个元素,就更新一次结果
        addElementsRead()
        // 取出一个(key,value)
        val kv = records.next()
        // 往PartitionedPairBuffer添加数据
        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
        // 如果需要溢写内存数据到磁盘
        maybeSpillCollection(usingMap = false)
      }
    }
  }

maybeSpillCollection方法

该方法实现如下:

1、如果需要map端的聚合:

   估计map的大小,根据预估的map大小决定是否需要进行spill。如果需要spill,在spill之后,初始化一个新的PartitionedAppendOnlyMap。

2、如果不需要map端的聚合:

      估计buffer的大小,根据预估的buffer大小决定是否需要进行spill。如果需要spill,spill之后,初始化一个新的PartitionedPairBuffer。

/**
   * Spill the current in-memory collection to disk if needed.
   *
   * @param usingMap whether we're using a map or buffer as our current in-memory collection
   */
  private def maybeSpillCollection(usingMap: Boolean): Unit = {
    var estimatedSize = 0L
    if (usingMap) {  //如果使用PartitionedAppendOnlyMap
      //估计map的大小
      estimatedSize = map.estimateSize()
      //根据预估的map大小决定是否需要进行spill
      if (maybeSpill(map, estimatedSize)) {
        //spill之后,初始化一个新的PartitionedAppendOnlyMap
        map = new PartitionedAppendOnlyMap[K, C]
      }
    } else { //如果使用PartitionedPairBuffer
      //估计buffer的大小
      estimatedSize = buffer.estimateSize()
      //调用父类Spillable的maybeSpill方法,根据预估的buffer大小决定是否需要进行spill
      if (maybeSpill(buffer, estimatedSize)) {
        //spill之后,初始化一个新的PartitionedPairBuffer
        buffer = new PartitionedPairBuffer[K, C]
      }
    }

    if (estimatedSize > _peakMemoryUsedBytes) {
      _peakMemoryUsedBytes = estimatedSize
    }
  }

maybeSpill方法

maybeSpillCollection方法会调用父类Spillable的maybeSpill方法。

该方法根据预估的buffer大小决定是否需要进行spill,如果需要spill则调用spill方法进行spill。

该方法实现如下:

如果读取的数据是32的倍数,而且当前内存大于内存阀值,默认是5M
会先尝试向TaskMemoryManager申请(2 * currentMemory - myMemoryThreshold)大小的内存
如果能够申请到,则不进行Spill操作,而是继续向Buffer中存储数据,
否则就会调用spill()方法将Buffer中数据输出到磁盘文件

/**
   * Spills the current in-memory collection to disk if needed. Attempts to acquire more
   * memory before spilling.
   *
   * @param collection collection to spill to disk
   * @param currentMemory estimated size of the collection in bytes
   * @return true if `collection` was spilled to disk; false otherwise
   */
  protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    var shouldSpill = false
    if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      //底层调用TaskMemoryManager的acquireExecutionMemory方法分配内存
      val granted = acquireMemory(amountToRequest)
     // 更新现在内存阀值
      myMemoryThreshold += granted
      // If we were granted too little memory to grow further (either tryToAcquire returned 0,
      // or we already had more memory than myMemoryThreshold), spill the current collection
      //再次判断当前内存是否大于阀值,如果还是大于阀值则spill
      shouldSpill = currentMemory >= myMemoryThreshold
    }
    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
    // Actually spill
    if (shouldSpill) {
      _spillCount += 1
      logSpillage(currentMemory)
      //开始spill
      spill(collection)
      _elementsRead = 0
      _memoryBytesSpilled += currentMemory
      releaseMemory()
    }
    shouldSpill
  }

spill方法

 该方法将in-memory collection上的内容按照比较器的顺序溢出到磁盘文件。当在in-memory collection的大小达到阈值时被调用。

该方法实现如下:

1、获取比较器。comparator方法返回ordering参数指定的comparator——也就是成员变量keyComparator;如果没有定义ordering参数,comparator方法返回None。

2、获取根据比较器排序后的的in-memory collection的迭代器。

3、溢写in-memory collection的数据到磁盘一个临时文件。

4、 更新溢写的临时磁盘文件。

/**
   * Spill our in-memory collection to a sorted file that we can merge later.
   * We add this file into `spilledFiles` to find it later.
   *
   * @param collection whichever collection we're using (map or buffer)
   */
  override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
    //返回一个根据指定的比较器排序的迭代器
   //comparator方法返回ordering参数指定的comparator——也就是成员变量keyComparator,
   //如果没有定义ordering参数,comparator方法返回null.
    val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
   // 溢写in-memory collection的数据到磁盘一个临时文件
    val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
   // 更新溢写的临时磁盘文件
    spills += spillFile
  }

 spillMemoryIteratorToDisk方法

1、创建临时文件

2、创建一个DiskBlockObjectWriter用于写临时文件

3、迭代in-memory collection的 inMemoryIterator,用DiskBlockObjectWriter写入当前迭代的record。如果写入record的数量到达阈值,将disk writer的缓冲区内容flush到磁盘。

/**
   * Spill contents of in-memory iterator to a temporary file on disk.
   */
  private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
      : SpilledFile = {
    // Because these files may be read during shuffle, their compression must be controlled by
    // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
    // createTempShuffleBlock here; see SPARK-3426 for more context.
   //创建临时文件
    val (blockId, file) = diskBlockManager.createTempShuffleBlock()

    // These variables are reset after each flush
    var objectsWritten: Long = 0
    val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
    //创建一个DiskBlockObjectWriter用于写临时文件
    val writer: DiskBlockObjectWriter =
      blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)

    // List of batch sizes (bytes) in the order they are written to disk
    val batchSizes = new ArrayBuffer[Long]

    // How many elements we have in each partition
    val elementsPerPartition = new Array[Long](numPartitions)

    // Flush the disk writer's contents to disk, and update relevant variables.
    // The writer is committed at the end of this process.
   //将disk writr的缓冲区内容flush到磁盘,并更新相关变量
    def flush(): Unit = {
      val segment = writer.commitAndGet()
      batchSizes += segment.length
      _diskBytesSpilled += segment.length
      objectsWritten = 0
    }

    var success = false
    try {
      //迭代in-memory collection的排序且可写入分区的Iterator(WritablePartitionedIterator)
      while (inMemoryIterator.hasNext) {
       // 获取partitionId
        val partitionId = inMemoryIterator.nextPartition()
        require(partitionId >= 0 && partitionId < numPartitions,
          s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
       //用DiskBlockObjectWriter写入当前迭代的record
        inMemoryIterator.writeNext(writer)
       //当前迭代的partitionId的record数量加一
        elementsPerPartition(partitionId) += 1
       //记录写入record的数量加一
        objectsWritten += 1

       //如果写入record的数量到达阈值,将disk writer的缓冲区内容flush到磁盘
        if (objectsWritten == serializerBatchSize) {
          flush()
        }
      }
     //迭代完成之后,如果存在record写入到disk writer的缓冲区,同样需要flush到磁盘
      if (objectsWritten > 0) {
        flush()
      } else {
        writer.revertPartialWritesAndClose()
      }
      success = true
    } finally {
      if (success) {
        writer.close()
      } else {
        // This code path only happens if an exception was thrown above before we set success;
        // close our stuff and let the exception be thrown further
        writer.revertPartialWritesAndClose()
        if (file.exists()) {
          if (!file.delete()) {
            logWarning(s"Error deleting ${file}")
          }
        }
      }
    }
   
   //创建SpilledFile然后返回
    SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
  }

ExternalSorter将插入到该sorter的record进行排序并写入到一个磁盘文件

writePartitionedFile方法

该方法将插入到ExternalSorter的record写入到一个磁盘文件。插入到sorter的record可以是在in-memory collection或者在溢写文件。

# 溢写文件为空,则内存足够,不需要溢写结果到磁盘, 返回一个对结果排序的迭代器, 遍历数据写入data临时文件;再将数据刷到磁盘文件,返回FileSegment对象;构造一个分区文件长度的数组

# 溢写文件不为空,则需要将溢写的文件和内存数据合并,合并之后则需要进行归并排序(merge-sort);数据写入data临时文件,再将数据刷到磁盘文件,返回FileSegment对象;构造一个分区文件长度的数组

# 返回分区文件长度的数组

/**
   * Write all the data added into this ExternalSorter into a file in the disk store. This is
   * called by the SortShuffleWriter.
   *
   * @param blockId block ID to write to. The index file will be blockId.name + ".index".
   * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
   */
  def writePartitionedFile(
      blockId: BlockId,
      outputFile: File): Array[Long] = {

    // Track location of each range in the output file
    val lengths = new Array[Long](numPartitions)
    val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
      context.taskMetrics().shuffleWriteMetrics)

    //如果溢写文件信息的数组为空
    if (spills.isEmpty) {
      // Case where we only have in-memory data
      //则属于只有in-memory data的情况
      //根据是否定义map端的聚合获取相应的in-memory collection
      val collection = if (aggregator.isDefined) map else buffer
      //获取collection的排序且可写入分区的iterator
      //iterator迭代的元素类型为((Int, K), V)
      val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
      //迭代元素
      while (it.hasNext) {
        //获取当前迭代的partitionId
        val partitionId = it.nextPartition()
        //二次迭代,迭代当前的partitionId的所有record
        while (it.hasNext && it.nextPartition() == partitionId) {
          it.writeNext(writer)
        }
        //将同一个partitionId的所有record数据提交,作为一个block
        val segment = writer.commitAndGet()
        lengths(partitionId) = segment.length
      }
    } else {
      // We must perform merge-sort; get an iterator by partition and write everything directly.
      for ((id, elements) <- this.partitionedIterator) {
        if (elements.hasNext) {
          for (elem <- elements) {
            writer.write(elem._1, elem._2)
          }
          val segment = writer.commitAndGet()
          lengths(id) = segment.length
        }
      }
    }

    writer.close()
    context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
    context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
    context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

    lengths
  }

partitionedIterator方法

该方法返回一个能迭代所有插入到ExternalSorter的record的迭代器。这些record已经经过partitionId进行分区,并经过aggregator的函数聚合。

该迭代器的泛型类型为Iterator[(Int, Iterator[Product2[K, C]])]。Int类型代表的是partitionId,每个partition都有一个与之对应的iterator迭代器,用以迭代该partition上的record。partition的迭代器之间是按顺序访问的,你不能在未迭代完当前的partition就跳过迭代一个新的partition。

/**
   * Return an iterator over all the data written to this object, grouped by partition and
   * aggregated by the requested aggregator. For each partition we then have an iterator over its
   * contents, and these are expected to be accessed in order (you can't "skip ahead" to one
   * partition without reading the previous one). Guaranteed to return a key-value pair for each
   * partition, in order of partition ID.
   *
   * For now, we just merge all the spilled files in once pass, but this can be modified to
   * support hierarchical merging.
   * Exposed for testing.
   */
  def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
   // 是否需要本地combine
    val usingMap = aggregator.isDefined
   // 根据是否需要本地combine获取相应的in-memory collection
    val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
   //如果没有发生磁盘溢写
    if (spills.isEmpty) {
      // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
      // we don't even need to sort by anything other than partition ID
      // 而且不需要排序
      if (!ordering.isDefined) {
        // The user hasn't requested sorted keys, so only sort by partition ID, not key
        //数据只是按照partitionId排序,并不会对key进行排序
        groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
      } else { //如果需要排序
        // We do need to sort by both partition ID and key
       //先按照partitionId排序,然后分区内部对key进行排序
        groupByPartition(destructiveIterator(
          collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
      }
    } else {
      // Merge spilled and in-memory data
      // 如果发生了溢写操作,则需要将磁盘上溢写文件和in-memory collection的数据进行合并
      merge(spills, destructiveIterator(
        collection.partitionedDestructiveSortedIterator(comparator)))
    }
  }

 merge方法

合并磁盘上溢写文件的数据和in-memory collection的数据。

当存在溢写文件时,会调用到此方法。该方法返回一个泛型类型为Iterator[(Int, Iterator[Product2[K, C]])]的迭代器,用以迭代所有partition,再迭代每个partition上的所有record。

/**
   * Merge a sequence of sorted files, giving an iterator over partitions and then over elements
   * inside each partition. This can be used to either write out a new file or return data to
   * the user.
   *
   * Returns an iterator over all the data written to this object, grouped by partition. For each
   * partition we then have an iterator over its contents, and these are expected to be accessed
   * in order (you can't "skip ahead" to one partition without reading the previous one).
   * Guaranteed to return a key-value pair for each partition, in order of partition ID.
   */
  private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
      : Iterator[(Int, Iterator[Product2[K, C]])] = {
    val readers = spills.map(new SpillReader(_))
  //调用buffered方法返回一个BufferedIterator
    val inMemBuffered = inMemory.buffered
   //迭代partitionId,对每个partitionId使用映射函数映射出新值,并返回这些新值的迭代器
    (0 until numPartitions).iterator.map { p =>
      //返回给定partitionId相应的所有record的迭代器
      val inMemIterator = new IteratorForPartition(p, inMemBuffered)
      val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
      if (aggregator.isDefined) {
        // Perform partial aggregation across partitions
        (p, mergeWithAggregation(
          iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
      } else if (ordering.isDefined) {
        // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
        // sort the elements without trying to merge them
        (p, mergeSort(iterators, ordering.get))
      } else {
        (p, iterators.iterator.flatten)
      }
    }
  }

 mergeWithAggregation方法

/**
   * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each
   * iterator is sorted by key with a given comparator. If the comparator is not a total ordering
   * (e.g. when we sort objects by hash code and different keys may compare as equal although
   * they're not), we still merge them by doing equality tests for all keys that compare as equal.
   */
  private def mergeWithAggregation(
      iterators: Seq[Iterator[Product2[K, C]]],
      mergeCombiners: (C, C) => C,
      comparator: Comparator[K],
      totalOrder: Boolean)
      : Iterator[Product2[K, C]] =
  {
    if (!totalOrder) {
      // We only have a partial ordering, e.g. comparing the keys by hash code, which means that
      // multiple distinct keys might be treated as equal by the ordering. To deal with this, we
      // need to read all keys considered equal by the ordering at once and compare them.
      new Iterator[Iterator[Product2[K, C]]] {
        val sorted = mergeSort(iterators, comparator).buffered

        // Buffers reused across elements to decrease memory allocation
        val keys = new ArrayBuffer[K]
        val combiners = new ArrayBuffer[C]

        override def hasNext: Boolean = sorted.hasNext

        override def next(): Iterator[Product2[K, C]] = {
          if (!hasNext) {
            throw new NoSuchElementException
          }
          keys.clear()
          combiners.clear()
          val firstPair = sorted.next()
          keys += firstPair._1
          combiners += firstPair._2
          val key = firstPair._1
          while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
            val pair = sorted.next()
            var i = 0
            var foundKey = false
            while (i < keys.size && !foundKey) {
              if (keys(i) == pair._1) {
                combiners(i) = mergeCombiners(combiners(i), pair._2)
                foundKey = true
              }
              i += 1
            }
            if (!foundKey) {
              keys += pair._1
              combiners += pair._2
            }
          }

          // Note that we return an iterator of elements since we could've had many keys marked
          // equal by the partial order; we flatten this below to get a flat iterator of (K, C).
          keys.iterator.zip(combiners.iterator)
        }
      }.flatMap(i => i)
    } else {
      // We have a total ordering, so the objects with the same key are sequential.
      new Iterator[Product2[K, C]] {
        val sorted = mergeSort(iterators, comparator).buffered

        override def hasNext: Boolean = sorted.hasNext

        override def next(): Product2[K, C] = {
          if (!hasNext) {
            throw new NoSuchElementException
          }
          val elem = sorted.next()
          val k = elem._1
          var c = elem._2
          while (sorted.hasNext && sorted.head._1 == k) {
            val pair = sorted.next()
            c = mergeCombiners(c, pair._2)
          }
          (k, c)
        }
      }
    }
  }

 SpillReader

/**
   * An internal class for reading a spilled file partition by partition. Expects all the
   * partitions to be requested in order.
   */
  private[this] class SpillReader(spill: SpilledFile) {
    // Serializer batch offsets; size will be batchSize.length + 1
    val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _)

    // Track which partition and which batch stream we're in. These will be the indices of
    // the next element we will read. We'll also store the last partition read so that
    // readNextPartition() can figure out what partition that was from.
    var partitionId = 0
    var indexInPartition = 0L
    var batchId = 0
    var indexInBatch = 0
    var lastPartitionId = 0

    skipToNextPartition()

    // Intermediate file and deserializer streams that read from exactly one batch
    // This guards against pre-fetching and other arbitrary behavior of higher level streams
    var fileStream: FileInputStream = null
    var deserializeStream = nextBatchStream()  // Also sets fileStream

    var nextItem: (K, C) = null
    var finished = false

    /** Construct a stream that only reads from the next batch */
    def nextBatchStream(): DeserializationStream = {
      // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
      // we're still in a valid batch.
      if (batchId < batchOffsets.length - 1) {
        if (deserializeStream != null) {
          deserializeStream.close()
          fileStream.close()
          deserializeStream = null
          fileStream = null
        }

        val start = batchOffsets(batchId)
        fileStream = new FileInputStream(spill.file)
        fileStream.getChannel.position(start)
        batchId += 1

        val end = batchOffsets(batchId)

        assert(end >= start, "start = " + start + ", end = " + end +
          ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

        val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))

        val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream)
        serInstance.deserializeStream(wrappedStream)
      } else {
        // No more batches left
        cleanup()
        null
      }
    }

    /**
     * Update partitionId if we have reached the end of our current partition, possibly skipping
     * empty partitions on the way.
     */
    private def skipToNextPartition() {
      while (partitionId < numPartitions &&
          indexInPartition == spill.elementsPerPartition(partitionId)) {
        partitionId += 1
        indexInPartition = 0L
      }
    }

    /**
     * Return the next (K, C) pair from the deserialization stream and update partitionId,
     * indexInPartition, indexInBatch and such to match its location.
     *
     * If the current batch is drained, construct a stream for the next batch and read from it.
     * If no more pairs are left, return null.
     */
    private def readNextItem(): (K, C) = {
      if (finished || deserializeStream == null) {
        return null
      }
      val k = deserializeStream.readKey().asInstanceOf[K]
      val c = deserializeStream.readValue().asInstanceOf[C]
      lastPartitionId = partitionId
      // Start reading the next batch if we're done with this one
      indexInBatch += 1
      if (indexInBatch == serializerBatchSize) {
        indexInBatch = 0
        deserializeStream = nextBatchStream()
      }
      // Update the partition location of the element we're reading
      indexInPartition += 1
      skipToNextPartition()
      // If we've finished reading the last partition, remember that we're done
      if (partitionId == numPartitions) {
        finished = true
        if (deserializeStream != null) {
          deserializeStream.close()
        }
      }
      (k, c)
    }

    var nextPartitionToRead = 0

    def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] {
      val myPartition = nextPartitionToRead
      nextPartitionToRead += 1

      override def hasNext: Boolean = {
        if (nextItem == null) {
          nextItem = readNextItem()
          if (nextItem == null) {
            return false
          }
        }
        assert(lastPartitionId >= myPartition)
        // Check that we're still in the right partition; note that readNextItem will have returned
        // null at EOF above so we would've returned false there
        lastPartitionId == myPartition
      }

      override def next(): Product2[K, C] = {
        if (!hasNext) {
          throw new NoSuchElementException
        }
        val item = nextItem
        nextItem = null
        item
      }
    }

    // Clean up our open streams and put us in a state where we can't read any more data
    def cleanup() {
      batchId = batchOffsets.length  // Prevent reading any other batch
      val ds = deserializeStream
      deserializeStream = null
      fileStream = null
      if (ds != null) {
        ds.close()
      }
      // NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop().
      // This should also be fixed in ExternalAppendOnlyMap.
    }
  }

参考:Spark源码分析之Sort-Based Shuffle读写流程