task.run.runTask->ShuffleMapTask.runTask->writer.write
writer 有 HashShuffleWriter和SortShuffleWriter
本章分析 HashShuffleWriter
Shuffle Write
/**
* Write a bunch of records to this task's output
* 将每个shuffleMapTask计算出来的新的RDD的partition数据写入本地磁盘
*/
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
/**
* 首先判断,是否需要在map端本地聚合
* 如果reduceByKey这种操作,它的dep.aggregator.isDegined就是true
* 那么就会进行map端的本地聚合
*/
val iter = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// 本地聚合 如:(hello,1) (hello,1) ---> (hello,2)
dep.aggregator.get.combineValuesByKey(records, context)
} else {
records
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
records
}
/**
* 如果要本地聚合,那么先本地聚合
* 然后遍历数据
* 对每个数据,调用partitioner,默认是HashPartitioner生成bucketId
* 也就是决定了,每一份数据,要写入那个bucket中
*/
for (elem <- iter) {
val bucketId = dep.partitioner.getPartition(elem._1)
/**
* 获取到bucketId后,会调用shuffleBlockManager.forMapTask()方法,生成bucketId对应的writer,
* 然后用writer将数据写入buket
*/
shuffle.writers(bucketId).write(elem)
}
}
=> shuffle.writers -> FileShuffleBlockManager.forMapTask.writers
/**
* 给每个map task获取一个shufflewritegroup
*/
def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer,
writeMetrics: ShuffleWriteMetrics) = {
new ShuffleWriterGroup {
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
private val shuffleState = shuffleStates(shuffleId)
private var fileGroup: ShuffleFileGroup = null
/**
* shuffle的两种模式:
* 1)开启consolication机制:consolidateShuffleFiles=true,不会给每个bucket都获取一个独立的文件
* 而是为这个bucket获取一个ShuffleGroup的writer
* 2) 未开启consolication机制 consolidateShuffleFiles=false
*/
val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
fileGroup = getUnusedFileGroup()
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
/**
* 首先用shuffleId,mapId,bucketId(reduceId)生成一个唯一的ShuffleBlockId
* 然后用bucketId,来调用shufflefileGroup的apply函数,为bucket获取一个shufflefilegroup
*/
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
/**
* 针对ShuffleFileGroup获取一个writer
* 如果开启了consolidation机制,对于每一个bucket,都会获取一个针对ShuffleFileGroup的writer
* 而不是一个独立的ShuffleBlockFile的writer
* 这样就实现了多个shuffleMapTask输出数据的合并
*/
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
writeMetrics)
}
} else {
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
// 获取一个代表了要写入的本地磁盘文件的blockfile
val blockFile = blockManager.diskBlockManager.getFile(blockId)
// Because of previous failures, the shuffle file may already exist on this machine.
// If so, remove it.
if (blockFile.exists) {
if (blockFile.delete()) {
logInfo(s"Removed existing shuffle file $blockFile")
} else {
logWarning(s"Failed to remove existing shuffle file $blockFile")
}
}
blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
}
}
==>blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,writeMetrics)
bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 // 默认32kb
--> BlockManager.getDiskWriter
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites,
writeMetrics)
DiskBlockObjectWriter.write-> open
override def open(): BlockObjectWriter = {
if (hasBeenClosed) {
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
// java 文件输出流
fos = new FileOutputStream(file, true)
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
/**
* java 缓冲流 ,中传入 bufferSize,缓冲大小,当内存中数据达到这个值时就会异步写入磁盘
* 至此,spark shufflewrite 最终调用 BufferedOutputStream 实现write
*/
bs = compressStream(new BufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
initialized = true
this
}
ShuffleReader
ShuffledRDD.compute 方法 调用 ShuffleReader
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
/**
* ResultTask或者ShuffleMapTask,在执行到ShuffledRdd时,肯定会调用ShuffledRDD的compute方法
* 来计算当前这个RDD的partition的数据
* 在这里会调用shufflemanager的getReader方法,获取一个HashShuffleReader
* 然后调用他的read方法,拉取该resultTask/shuffleMapTask需要聚合的数据
*/
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
// TODO HashShuffleReader.read
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
=> HashShuffleReader.read
override def read(): Iterator[Product2[K, C]] = {
val ser = Serializer.getSerializer(dep.serializer)
/**
* TODO fetch
* resultTask在拉取数据时,其实会用BlockStoreShuffleFetcher来从DAGScheduler的MapOutputTrackerMaster
* 中获取自己想要的数据的信息,然后底层再通过blockManager从对应的位置拉取需要的数据
*/
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
==>BlockStoreShuffleFetcher.fetch
def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
/**
* 重点
* 拿到了全局的MapOutTrackerMaster的引用
* 然后调用getServerStatuses方法,传入 shuffleId和reduceId
* shuffleId 可以代表当前这个stage的上一个stage,shuffle分为两个stage:
* shuffle write 发生在上一个stage中
* shuffle read发生在当前的stage中
*
* 理解:
* 首先通过shuffleId可以限制上上一个stage的所有shuffleMapTask的输出的MapStatus
* 接着,通过reduceId(bucketId)来限制从每个MapStatus中获取当前这个ResultTask需要
* 获取的每个ShuffleMapTask的输出文件的信息
*
* 这个getServerStatuses一定走远程网络通信的,因为要联系driver上的DAGScheduler的MapOutputTrackerMaster
*
* // TODO getServerStatuses
*/
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Success(block) => {
block.asInstanceOf[Iterator[T]]
}
case Failure(e) => {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block", e)
}
}
}
}
/**
* ShuffleBlockFetcherIterator构造后,在其内部就直接根据拉取到的地理位置信息,
* 通过blockManager去远程的shuffleMapTask所在的节点的blockManager去拉取数据
*
* TODO ShuffleBlockFetcherIterator.initialize
*/
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)
val itr = blockFetcherItr.flatMap(unpackBlock)
// 最后,将拉取到的数据进行一些转换和封装 返回
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
context.taskMetrics.updateShuffleReadMetrics()
})
===>SparkEnv.get.mapOutputTracker.getServerStatuses -> MapOutputTracker.getServerStatuses
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
// Someone else is fetching it; wait for them to be done
// 不断去拉取shuffleId对应的数据,只要还没拉到,死循环,等待
while (fetching.contains(shuffleId)) {
try {
fetching.wait()
} catch {
case e: InterruptedException =>
}
}
// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
fetchedStatuses = mapStatuses.get(shuffleId).orNull
if (fetchedStatuses == null) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
} else {
logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
} else {
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
}
=> 再回到 new ShuffleBlockFetcherIterator -> ShuffleBlockFetcherIterator.initialize
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener(_ => cleanup())
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
/**
* Send out initial requests for blocks, up to our maxBytesInFlight
*
* 循环,发现还有数据没有拉取完,就发送请求到远程去拉取
* 调优参数: max.bytes.in.flight 最多能拉取多少数据到本地就要开始进行reduce操作
*/
while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
// 拉取完了远程数据之后,拉取本地的数据(数据本地化)
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}