本文简单介绍Spark 的数据存储原理,是《图解Spark核心技术与案例实战》一书的读书笔记。
组件
spark 存储模型是主从模型,其中Driver是Master,Executor是Slave。Driver负责数据的元信息管理,Slave 负责存储数据,执行Driver传递过来的数据操作命令。
Driver
应用启动时,SparkContext会在Driver端创建SparkEnv,在SparkEnv中创建BlockManager和BlockManagerMaster,在BlockManagerMaster里面创建BlockManagerMasterEndPoint来进行通信。
Executor
Executor启动的时候也会创建SparkEnv,Executor的SparkEnv创建了BlockManager和BlockTransferService,在BlockManager初始化的过程中,会加入BlockManagerMasterEndpoint的引用和创建BlockManagerSlaveEndPoint,并会将BlockManagerSlaveEndPoint的引用注册到Driver,使得Driver和Executor可以互相通信。
元数据
BlockManagerMasterEndPoint保存了数据块的元数据,包括三个field
1 blockManagerInfo:private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]
是一个HashMap,键是blockManagerId,值是BlockManagerInfo,BlockManagerInfo保存了Executor内存使用情况,数据块使用情况,已经被使用的数据块和Executor通信的终端点的引用。
2 blockManagerIdByExecutor:private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
也是一个HashMap,存放了ExecutorId和BlockManagerId对应的列表
3 blockLocationsprivate val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
block id和持有这个block的block manager id的set的对应,这个可以在查询block位置的时候使用
存储级别
persist() 方法
cache()方法和persist()方法可以用来显式地将数据保存到内存或者磁盘中,其中cache方法是persist()在参数为MEMORY_ONLY
时的封装。
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): this.type = persist()
persist()函数的实现如下:
private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = {
// TODO: Handle changes of StorageLevel
// RDD的存储级别一旦设置了之后就不能更改
if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) {
throw new UnsupportedOperationException(
"Cannot change storage level of an RDD after it was already assigned a level")
}
// If this is the first time this RDD is marked for persisting, register it
// with the SparkContext for cleanups and accounting. Do this only once.
if (storageLevel == StorageLevel.NONE) {
sc.cleaner.foreach(_.registerRDDForCleanup(this))
sc.persistRDD(this)
}
storageLevel = newLevel
this
}
RDD第一次被计算的时候调用persist函数会根据存储级别参数采取特定的缓存策略,之后就不能修改这个缓存策略。
所有的存储级别如下:
存储调用
对数据的存取都是在job执行的过程中才发生,真正的入口在job运行的方法iterator()方法里面的getOrCompute()里面:
/**
* Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached.
*/
private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
// 这个方法由Executor调用
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
// 如果数据不在内存,那么就尝试读取检查点结果迭代计算
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
// 读取成功了
case Left(blockResult) =>
if (readCachedBlock) {
val existingMetrics = context.taskMetrics().inputMetrics
existingMetrics.incBytesRead(blockResult.bytes)
new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
override def next(): T = {
existingMetrics.incRecordsRead(1)
delegate.next()
}
}
} else {
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
}
case Right(iter) =>
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
}
}
读写的逻辑是在getOrElseUpdate里面
/**
* Retrieve the given block if it exists, otherwise call the provided `makeIterator` method
* to compute the block, persist it, and return its values.
*
* @return either a BlockResult if the block was successfully cached, or an iterator if the block
* could not be cached.
*/
def getOrElseUpdate[T](
blockId: BlockId,
level: StorageLevel,
classTag: ClassTag[T],
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
// Attempt to read the block from local or remote storage. If it's present, then we don't need
// to go through the local-get-or-put path.
// 读取数据
get(blockId) match {
case Some(block) =>
return Left(block)
case _ =>
// Need to compute the block.
}
// Initially we hold no locks on this block.
// 写数据入口
doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
case None =>
// doPut() didn't hand work back to us, so the block already existed or was successfully
// stored. Therefore, we now hold a read lock on the block.
val blockResult = getLocalValues(blockId).getOrElse {
// Since we held a read lock between the doPut() and get() calls, the block should not
// have been evicted, so get() not returning the block indicates some internal error.
releaseLock(blockId)
throw new SparkException(s"get() failed for block $blockId even though we held a lock")
}
// We already hold a read lock on the block from the doPut() call and getLocalValues()
// acquires the lock again, so we need to call releaseLock() here so that the net number
// of lock acquisitions is 1 (since the caller will only call release() once).
releaseLock(blockId)
Left(blockResult)
case Some(iter) =>
// The put failed, likely because the data was too large to fit in memory and could not be
// dropped to disk. Therefore, we need to pass the input iterator back to the caller so
// that they can decide what to do with the values (e.g. process them without caching).
Right(iter)
}
}
读数据
调用的逻辑如下:
BlockManage的get方法是读取数据的入口点,在里面会判断数据是否在本地而选择是直接从本地读取还是通过BlockTransferService读取远程数据:
/**
* Get a block from the block manager (either local or remote).
*
* This acquires a read lock on the block if the block was stored locally and does not acquire
* any locks if the block was fetched from a remote block manager. The read lock will
* automatically be freed once the result's `data` iterator is fully consumed.
*/
def get(blockId: BlockId): Option[BlockResult] = {
val local = getLocalValues(blockId)
if (local.isDefined) {
logInfo(s"Found block $blockId locally")
return local
}
val remote = getRemoteValues(blockId)
if (remote.isDefined) {
logInfo(s"Found block $blockId remotely")
return remote
}
None
}
具体读取逻辑在 getLocalValues和getRemoteValues这两个函数里面,接下来依次看看其实现:
def getLocalValues(blockId: BlockId): Option[BlockResult] = {
logDebug(s"Getting local block $blockId")
// 加读锁
blockInfoManager.lockForReading(blockId) match {
case None =>
logDebug(s"Block $blockId was not found")
None
case Some(info) =>
val level = info.level
logDebug(s"Level for block $blockId is $level")
// 从内存读取数据
if (level.useMemory && memoryStore.contains(blockId)) {
// 如果序列化了,那么说明是对象数据,使用getValues
val iter: Iterator[Any] = if (level.deserialized) {
memoryStore.getValues(blockId).get
} else {
// 没序列化,那么是数据流,使用getBytes()
serializerManager.dataDeserializeStream(
blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
}
val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
// 返回结果
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
} else if (level.useDisk && diskStore.contains(blockId)) {
// 存储级别是磁盘,从磁盘读取
val iterToReturn: Iterator[Any] = {
// 先读取数据
val diskBytes = diskStore.getBytes(blockId)
if (level.deserialized) {
val diskValues = serializerManager.dataDeserializeStream(
blockId,
diskBytes.toInputStream(dispose = true))(info.classTag)
// 先序列化,然后将数据放入内存,
maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
} else {
// 先将数据放入内存
val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)
.map {_.toInputStream(dispose = false)}
.getOrElse { diskBytes.toInputStream(dispose = true) }
// 序列化返回的值
serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
}
}
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId))
Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
} else {
handleLocalReadFailure(blockId)
}
}
}
内存读取
先查看内存读取的,继续查看memoryStore的getValues和getBytes方法:
def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
val entry = entries.synchronized { entries.get(blockId) }
entry match {
case null => None
case e: DeserializedMemoryEntry[_] =>
throw new IllegalArgumentException("should only call getBytes on serialized blocks")
case SerializedMemoryEntry(bytes, _, _) => Some(bytes)
}
}
def getValues(blockId: BlockId): Option[Iterator[_]] = {
val entry = entries.synchronized { entries.get(blockId) }
entry match {
case null => None
case e: SerializedMemoryEntry[_] =>
throw new IllegalArgumentException("should only call getValues on deserialized blocks")
case DeserializedMemoryEntry(values, _, _) =>
val x = Some(values)
x.map(_.iterator)
}
}
发现都是从一个entries的LinkedHashMap里面按照blockId读取数据,可以看出Spark底层是在使用一个LinkedHashMap保存数据。使用LinkedHashMap可以保存键值对的插入顺序,这样在内存不够时,先插入的数据会先溢出到磁盘,实现了FIFO序。
磁盘读取
Spark通过spark.local.dir设置文件存储的目录,默认情况下设置一个一级目录,在这个一级目录下最多创建64个二级目录,目录的名称是00-63,目录中文件的名称是blockId.name这个字段,唯一标识一个块
def getBytes(blockId: BlockId): ChunkedByteBuffer = {
val file = diskManager.getFile(blockId.name)
// 使用NIO的channel来访问文件
val channel = new RandomAccessFile(file, "r").getChannel
Utils.tryWithSafeFinally {
// For small files, directly read rather than memory map
// 小文件直接读取到buffer里面
if (file.length < minMemoryMapBytes) {
val buf = ByteBuffer.allocate(file.length.toInt)
channel.position(0)
while (buf.remaining() != 0) {
if (channel.read(buf) == -1) {
throw new IOException("Reached EOF before filling buffer\n" +
s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
}
}
buf.flip()
new ChunkedByteBuffer(buf)
} else {
// 大文件将channel 映射到内存
new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length))
}
} {
channel.close()
}
}
在读取数据之前,首先使用了getFile来获取保存数据的文件
/** Looks up a file by hashing it into one of our local subdirectories. */
// This method should be kept in sync with
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getFile().
def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
// filename 传进来的是blockId
val hash = Utils.nonNegativeHash(filename)
val dirId = hash % localDirs.length
// 计算出当前的block数据存在哪个二级目录里面
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
// Create the subdirectory if it doesn't already exist
// 要是这个二级目录不存在,那么需要创建这个二级目录
val subDir = subDirs(dirId).synchronized {
val old = subDirs(dirId)(subDirId)
if (old != null) {
old
} else {
val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
if (!newDir.exists() && !newDir.mkdir()) {
throw new IOException(s"Failed to create local dir in $newDir.")
}
subDirs(dirId)(subDirId) = newDir
newDir
}
}
new File(subDir, filename)
}
远程读取
在getRemoteValues里面先调用了getRemoteBytes获取到数据。
/**
* Get block from remote block managers.
*
* This does not acquire a lock on this block in this JVM.
*/
private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
getRemoteBytes(blockId).map { data =>
val values =
serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))
new BlockResult(values, DataReadMethod.Network, data.size)
}
}
getRemoteBytes首先调用了getLocations获取数据保存的具体位置。
调用getLocations获取数据所在的位置
/**
* Return a list of locations for the given block, prioritizing the local machine since
* multiple block managers can share the same host.
*/
private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
// 调用master的getLocations()
val locs = Random.shuffle(master.getLocations(blockId))
val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host }
preferredLocs ++ otherLocs
}
/** Get locations of the blockId from the driver */
def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
// 给driver发送消息获取数据保存的位置
driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId))
}
// 这个是BlockManagerMasterEndPoint收到了GetLocations这样的消息之后调用的方法,是
// 从blockLocations这个HashMap里面查找,返回了持有这个block的所有block manager id
private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
}
接下来看getRemoteBytes()
/**
* Get block from remote block managers as serialized bytes.
*/
def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
logDebug(s"Getting remote block $blockId")
require(blockId != null, "BlockId is null")
var runningFailureCount = 0
var totalFailureCount = 0
// 首先要查询这个数据的具体位置,获取到了所有持有这个block 的block manager id
val locations = getLocations(blockId)
val maxFetchFailures = locations.size
var locationIterator = locations.iterator
while (locationIterator.hasNext) {
val loc = locationIterator.next()
logDebug(s"Getting remote block $blockId from $loc")
val data = try {
// 尝试读取数据
blockTransferService.fetchBlockSync(
loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
} catch {
case NonFatal(e) =>
runningFailureCount += 1
totalFailureCount += 1
if (totalFailureCount >= maxFetchFailures) {
// Give up trying anymore locations. Either we've tried all of the original locations,
// or we've refreshed the list of locations from the master, and have still
// hit failures after trying locations from the refreshed list.
throw new BlockFetchException(s"Failed to fetch block after" +
s" ${totalFailureCount} fetch failures. Most recent failure cause:", e)
}
logWarning(s"Failed to fetch remote block $blockId " +
s"from $loc (failed attempt $runningFailureCount)", e)
// If there is a large number of executors then locations list can contain a
// large number of stale entries causing a large number of retries that may
// take a significant amount of time. To get rid of these stale entries
// we refresh the block locations after a certain number of fetch failures
if (runningFailureCount >= maxFailuresBeforeLocationRefresh) {
locationIterator = getLocations(blockId).iterator
logDebug(s"Refreshed locations from the driver " +
s"after ${runningFailureCount} fetch failures.")
runningFailureCount = 0
}
// This location failed, so we retry fetch from a different one by returning null here
null
}
if (data != null) {
return Some(new ChunkedByteBuffer(data))
}
logDebug(s"The value of block $blockId is null")
}
logDebug(s"Block $blockId not found")
None
}
fetchBlockSync()方法负责根据BlockManagerId读取数据,需要说明的是BlockManagerId不是个字段,是个class,有host,port,executor id等字段,这个类名容易误解
/**
* A special case of [[fetchBlocks]], as it fetches only one block and is blocking.
*
* It is also only available after [[init]] is invoked.
*/
def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = {
// A monitor for the thread to wait on.
val result = Promise[ManagedBuffer]()
fetchBlocks(host, port, execId, Array(blockId),
new BlockFetchingListener {
override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
result.failure(exception)
}
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
val ret = ByteBuffer.allocate(data.size.toInt)
ret.put(data.nioByteBuffer())
ret.flip()
result.success(new NioManagedBuffer(ret))
}
})
ThreadUtils.awaitResult(result.future, Duration.Inf)
}
一部下载数据,使用了listener来回调,具体逻辑在fetchBlocks里面,这是个抽象方法,实际上调用了NettyBlockTransferService里面的实现
override def fetchBlocks(
host: String,
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
// 创建通信客户端
val client = clientFactory.createClient(host, port)
// 一对一读取数据
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
}
}
val maxRetries = transportConf.maxIORetries()
if (maxRetries > 0) {
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
// a bug in this code. We should remove the if statement once we're sure of the stability.
new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
} else {
blockFetchStarter.createAndStart(blockIds, listener)
}
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
blockIds.foreach(listener.onBlockFetchFailure(_, e))
}
}
fetchBlocks里的start方法是用rpc向持有block的executor发送消息
public void start() {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
…………
}
@Override
public void onFailure(Throwable e) {
…………
}
});
}
消息将由对应Executor的NettyBlockRpcServer中的receive收到,并调用getBlockData方法来读取数据
override def receive(
client: TransportClient,
rpcMessage: ByteBuffer,
responseContext: RpcResponseCallback): Unit = {
val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
logTrace(s"Received request: $message")
message match {
case openBlocks: OpenBlocks =>
val blocks: Seq[ManagedBuffer] =
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)
case uploadBlock: UploadBlock =>
……
}
}
getBlockData最后的调用将会判断时候是shuffle之后获取shuffle 输出的数据,如果不是就会转入和本地读取数据的getLocalValues一样的调用,如果是就会根据shuffle的类型选择不同的读取方式,hash 或者排序。
写入数据
写入数据的类方法调用图:
getOrElseUpdate()方法里面的putIterator()是写入数据的入口:
/**
* Put the given block according to the given level in one of the block stores, replicating
* the values if necessary.
*
* If the block already exists, this method will not overwrite it.
*
* @param keepReadLock if true, this method will hold the read lock when it returns (even if the
* block already exists). If false, this method will hold no locks when it
* returns.
* @return None if the block was already present or if the put succeeded, or Some(iterator)
* if the put failed.
*/
private def doPutIterator[T](
blockId: BlockId,
iterator: () => Iterator[T],
level: StorageLevel,
classTag: ClassTag[T],
tellMaster: Boolean = true,
keepReadLock: Boolean = false): Option[PartiallyUnrolledIterator[T]] = {
doPut(blockId, level, classTag, tellMaster = tellMaster, keepReadLock = keepReadLock) { info =>
val startTimeMs = System.currentTimeMillis
var iteratorFromFailedMemoryStorePut: Option[PartiallyUnrolledIterator[T]] = None
// Size of the block in bytes
var size = 0L
if (level.useMemory) {
// Put it in memory first, even if it also has useDisk set to true;
// We will drop it to disk later if the memory store can't hold it.
if (level.deserialized) {
memoryStore.putIteratorAsValues(blockId, iterator(), classTag) match {
case Right(s) =>
size = s
case Left(iter) =>
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
diskStore.put(blockId) { fileOutputStream =>
serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
}
size = diskStore.getSize(blockId)
} else {
iteratorFromFailedMemoryStorePut = Some(iter)
}
}
} else { // !level.deserialized
memoryStore.putIteratorAsBytes(blockId, iterator(), classTag, level.memoryMode) match {
case Right(s) =>
size = s
case Left(partiallySerializedValues) =>
// Not enough space to unroll this block; drop to disk if applicable
if (level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
diskStore.put(blockId) { fileOutputStream =>
partiallySerializedValues.finishWritingToStream(fileOutputStream)
}
size = diskStore.getSize(blockId)
} else {
iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator)
}
}
}
} else if (level.useDisk) {
diskStore.put(blockId) { fileOutputStream =>
serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
}
size = diskStore.getSize(blockId)
}
val putBlockStatus = getCurrentBlockStatus(blockId, info)
val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid
if (blockWasSuccessfullyStored) {
// Now that the block is in either the memory, externalBlockStore, or disk store,
// tell the master about it.
info.size = size
if (tellMaster) {
reportBlockStatus(blockId, info, putBlockStatus)
}
Option(TaskContext.get()).foreach { c =>
c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus)
}
logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
val bytesToReplicate = doGetLocalBytes(blockId, info)
try {
replicate(blockId, bytesToReplicate, level, classTag)
} finally {
bytesToReplicate.dispose()
}
logDebug("Put block %s remotely took %s"
.format(blockId, Utils.getUsedTimeMs(remoteStartTime)))
}
}
assert(blockWasSuccessfullyStored == iteratorFromFailedMemoryStorePut.isEmpty)
iteratorFromFailedMemoryStorePut
}
}
doPutIterator里面的逻辑如下:
if 存储级别是内存
if 需要序列化
memoryStore.putIteratorAsValues
if 空间不够展开块且允许写入磁盘
diskStore.put(blockId)
else
返回错误提示
else if 不用序列化
memoryStore.putIteratorAsBytes
if 空间不够展开块且允许写入磁盘
diskStore.put(blockId)
else
返回错误提示
else if 存储级别是磁盘
diskStore.put(blockId)
更新数据块信息
if 备份数目>1
replicate(blockId, bytesToReplicate, level, classTag)
可以看出,上面的doPutIterator方法根据不同的配置,选择了不同的方法写入数据,然后更新了数据块的状态,然后做了备份的更新。
写入内存
上面提到了展开块,什么是展开块呢?展开块是在写入数据到内存是不先写入,而是多次写入,每次写入之前首先检查剩余的内存是否足够存放块,不够的话就尝试将内存中已有的数据写入到磁盘,释放空间来放新的数据。接下来结合putIteratorAsValues详细看下:
/**
* Attempt to put the given block in memory store as values.
* 尝试将当前块作为value保存在内存中
*
* It's possible that the iterator is too large to materialize and store in memory. To avoid
* OOM exceptions, this method will gradually unroll the iterator while periodically checking
* whether there is enough free memory. If the block is successfully materialized, then the
* temporary unroll memory used during the materialization is "transferred" to storage memory,
* so we won't acquire more memory than is actually needed to store the block.
* 有可能iterator太大以至于不能保存到内存中,为了避免OOM,这个方法会逐渐展开iterator并间歇性检查是否有足够的空余内存
* 如果这个块成功地保存到了内存中,那么这些在保存过程中暂时展开的的内存就成了存储内存,因此我们不会获取多余的内存。
* @return in case of success, the estimated the estimated size of the stored data. In case of
* failure, return an iterator containing the values of the block. The returned iterator
* will be backed by the combination of the partially-unrolled block and the remaining
* elements of the original input iterator. The caller must either fully consume this
* iterator or call `close()` on it in order to free the storage memory consumed by the
* partially-unrolled block.
*/
private[storage] def putIteratorAsValues[T](
blockId: BlockId,
values: Iterator[T],
classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {
require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
// Number of elements unrolled so far
// 当前已经展开的数据块
var elementsUnrolled = 0
// Whether there is still enough memory for us to continue unrolling this block
// 是否仍有足够的内存来展开数据块
var keepUnrolling = true
// Initial per-task memory to request for unrolling blocks (bytes).
// 每个展开线程初始的内存大小
val initialMemoryThreshold = unrollMemoryThreshold
// How often to check whether we need to request more memory
// 每隔几次检查是否有足够的空余空间
val memoryCheckPeriod = 16
// Memory currently reserved by this task for this particular unrolling operation
// 当前线程保留用来做展开块工作的内存大小
var memoryThreshold = initialMemoryThreshold
// Memory to request as a multiple of current vector size
// 内存增长因子,每次请求的内存大小为(memoryGrowthFactor * vector .size())-memoryThreshold
val memoryGrowthFactor = 1.5
// Keep track of unroll memory used by this particular block / putIterator() operation
// 展开这个块使用的内存大小
var unrollMemoryUsedByThisBlock = 0L
// Underlying vector for unrolling the block
// 用于追踪数据块展开所使用的内存的大小
var vector = new SizeTrackingVector[T]()(classTag)
// Request enough memory to begin unrolling
// 请求足够的内存做unrolling
keepUnrolling =
reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP)
if (!keepUnrolling) {
logWarning(s"Failed to reserve initial memory threshold of " +
s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
} else {
unrollMemoryUsedByThisBlock += initialMemoryThreshold
}
// Unroll this block safely, checking whether we have exceeded our threshold periodically
// 安全地展开这个数据库,定期检查剩余内存是否足够
while (values.hasNext && keepUnrolling) {
vector += values.next()
// 每16次检查一次是否超出了分配的内存的大小
if (elementsUnrolled % memoryCheckPeriod == 0) {
// If our vector's size has exceeded the threshold, request more memory
val currentSize = vector.estimateSize()
// 如果不够
if (currentSize >= memoryThreshold) {
val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong
// 申请那日村
keepUnrolling =
reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP)
if (keepUnrolling) {
unrollMemoryUsedByThisBlock += amountToRequest
}
// New threshold is currentSize * memoryGrowthFactor
memoryThreshold += amountToRequest
}
}
elementsUnrolled += 1
}
// 如果成功展开了这个块,估计该块在内存中占的空间的大小
if (keepUnrolling) {
// We successfully unrolled the entirety of this block
val arrayValues = vector.toArray
vector = null
val entry =
new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag)
val size = entry.size
def transferUnrollToStorage(amount: Long): Unit = {
// Synchronize so that transfer is atomic
// 将展开所用的内存转为存储的内存,释放掉展开的空间,然后获取内存用于存放block
memoryManager.synchronized {
releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount)
val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP)
assert(success, "transferring unroll memory to storage memory failed")
}
}
// Acquire storage memory if necessary to store this block in memory.
// 判断内存是否足够保存这个数据块
val enoughStorageMemory = {
// 展开所用的内存小于数据块的大小
if (unrollMemoryUsedByThisBlock <= size) {
// 获取额外的空间
val acquiredExtra =
memoryManager.acquireStorageMemory(
blockId, size - unrollMemoryUsedByThisBlock, MemoryMode.ON_HEAP)
// 申请成功就用transferUnrollToStorage分配内存写入block
if (acquiredExtra) {
transferUnrollToStorage(unrollMemoryUsedByThisBlock)
}
acquiredExtra
} else { // unrollMemoryUsedByThisBlock > size
// If this task attempt already owns more unroll memory than is necessary to store the
// block, then release the extra memory that will not be used.
// 如果展开使用的内存大于块需要的size,那么先释放多余的内存,然后使用transferUnrollToStorage处理
val excessUnrollMemory = unrollMemoryUsedByThisBlock - size
releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory)
transferUnrollToStorage(size)
true
}
}
// 如果内存足够,那么写入数据到entry
if (enoughStorageMemory) {
entries.synchronized {
entries.put(blockId, entry)
}
logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(
blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed)))
Right(size)
} else {
assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask,
"released too much unroll memory")
Left(new PartiallyUnrolledIterator(
this,
unrollMemoryUsedByThisBlock,
unrolled = arrayValues.toIterator,
rest = Iterator.empty))
}
} else {
// We ran out of space while unrolling the values for this block
logUnrollFailureMessage(blockId, vector.estimateSize())
Left(new PartiallyUnrolledIterator(
this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
}
}
写入磁盘
写入磁盘的方法是distStore.put()方法
def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = {
if (contains(blockId)) {
throw new IllegalStateException(s"Block $blockId is already present in the disk store")
}
logDebug(s"Attempting to put block $blockId")
val startTime = System.currentTimeMillis
val file = diskManager.getFile(blockId)
val fileOutputStream = new FileOutputStream(file)
var threwException: Boolean = true
try {
writeFunc(fileOutputStream)
threwException = false
} finally {
try {
Closeables.close(fileOutputStream, threwException)
} finally {
if (threwException) {
remove(blockId)
}
}
}
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
file.getName,
Utils.bytesToString(file.length()),
finishTime - startTime))
}
和读取文件的流程类似,但是是回调了参数中的writerFunc,将数据写入到文件,看一个writerFunc
/**
* Finish writing this block to the given output stream by first writing the serialized values
* and then serializing the values from the original input iterator.
*/
def finishWritingToStream(os: OutputStream): Unit = {
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
redirectableOutputStream.setOutputStream(os)
while (rest.hasNext) {
serializationStream.writeObject(rest.next())(classTag)
}
serializationStream.close()
}
调用了writeObject方法来将数据写入到流中