Java8 Stream 惰性求值实现分析
Stream 类的继承关系和职责
在上一篇博客中,我主要给大家介绍了Stream 的并行计算的实现原理。这一次,我主要讲讲Stream 惰性求值的实现方式。这次我主要介绍一下Stream 惰性求值的实现方式。下图是Stream package里面重要的类继承关系:
因为由于Java 的泛型不能作用于原始类型,所以为了提高性能,避免自动装箱,Java 为了int、long、double 专门构造了属于自己的Stream 类型,而对于一般的对象类型定义了Stream 和ReferencePipeline。
- BaseStream 中规定了流的基本接口,比如返回迭代器、返回分片迭代器、判断流是否是并行还是串行、流是否是有序的(注意,不是排序的)
- Stream 中规定了map、filter、flatMap 等经典的函数式接口,这里的接口方法是Java 用户最主要关注的
- PipelineHelper 中主要是流水线的执行帮助方法,这里是Stream 内部实现的主要依赖方法,虽然在Java 编程中我们不需要关心这部分内容,但是在探索内部实现中,这部分是重中之重
三种Stream 类型(执行阶段)
- Head: Stream源,可以是集合或者自定义的Supplier 生成的有限长度或无限长度的流。一般由Stream 的工厂类StreamSupport.stream 接口,
- 可以由splititerator 构建有限的流
- 由supplier 根据规定方法构造无限的流
- StatelessOp/StatefulOp: 中间数据处理操作:
- StatelessOp: map/ flatMap/ filter 等无状态的操作
- StatefulOp: distinct/ slice/ sorted 等有状态的操作
- TerminalOp: find/ foreach/ match/ reduce 等终止操作,当遇到这种操作的时候Stream 会立刻进行求值运算
Stream工厂类主要接口
public static <T> Stream<T> stream(Spliterator<T> spliterator, boolean parallel);
public static <T> Stream<T> stream(Supplier<? extends Spliterator<T>> supplier,
int characteristics,
boolean parallel);
我们从这里可以看出这些Stream 很类似于Scala 中的Stream 和Spark 中的RDD (当然啦,因为是Spark 也是从Scala 接口中借鉴了很多)。
其中Head 类似于Spark 中的读取文件等的源RDD,比如由sparkContext 的textfile, parallelize 生产的RDD。StatelessOp 类似于spark 中的MapPartitionsRDD ,StatefulOp 接近于spark 中的ShuffleRDD 等。而TerminalOp 和Spark 中的collect、reduce等操作效果完全相同。
那我们很自然地想问几个问题:
- Java 中的惰性求值是怎么实现的?
- Java 中是如何实现Spark 中DAGScheduler 的操作的?
1. 惰性求值的思路
对于Java 惰性求值是什么思路,其实这个跟Spark RDD 处理方式完全一样,我们可以观察到Stream 中我们执行各种函数接口调用,实际上只是包装成一个新的Stream,然后返回给调用端,直到执行collect/ reduce等操作的时候调用真正的求值evaluate 操作。我们可以如下图比较以下两者的逻辑:
Stream的实现
public final <R> Stream<R> map(Function<? super P_OUT, ? extends R> mapper) {
Objects.requireNonNull(mapper);
return new StatelessOp<P_OUT, R>(this, StreamShape.REFERENCE, StreamOpFlag.NOT_SORTED | StreamOpFlag.NOT_DISTINCT) {
@Override
Sink<P_OUT> opWrapSink(int flags, Sink<R> sink) {
return new Sink.ChainedReference<P_OUT, R>(sink) {
@Override
public void accept(P_OUT u) { downstream.accept(mapper.apply(u)); }
};
}
};
}
public final Stream<P_OUT> sorted() {
return SortedOps.makeRef(this);
}
public final <R> R collect(Supplier<R> supplier, BiConsumer<R, ? super P_OUT> accumulator,
BiConsumer<R, R> combiner) {
return evaluate(ReduceOps.makeRef(supplier, accumulator, combiner));
}
其中我们看到在进行map 操作的时候,java并没有做任何的数据操作过程,而仅仅是返回了一个StatelessOp。而sorted 实际上调用了SortedOps 的makeRef,这里实际上也只是返回了一个覆盖一些具体运算方法StatefulOp。而真正会进行求值的collect 方法,实际上调用了evaluate 方法。
Stream惰性求值实现
def map[U: ClassTag](f: T => U): RDD[U] = withScope {
val cleanF = sc.clean(f)
new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
}
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
: RDD[(K, V)] = self.withScope {
val part = new RangePartitioner(numPartitions, self, ascending)
new ShuffledRDD[K, V, V](self, part)
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
}
def collect(): Array[T] = withScope {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
Array.concat(results: _*)
}
对照Spark 的相应方法的实现,我们看到map sortByKey 也都仅仅是返回相应的RDD,而没有执行操作,而collect 中调用了SparkContext 中的runJob 方法,类似于Stream 中的evaluate 方法。
2. 任务依赖关系和并行计算
在Stream 的计算中有很多运算天然支持的并行运算的,这种操作我们可以称之为一种无状态操作,map/ filter/ flatMap 等等。比如下图,输入分成任意个分块,同时执行这几组操作完全没有问题。
简单并行的操作
Arrays.asList(2,4,5,6).stream()
.parallelize()
.map(_+1)
.filter(_>5)
.flatMap(x=>List(x, x+1))
而有些操作却必须依赖,前一步在整个数据集合全局完成了之后才能进行下去,不能等待某个分区上有序了就直接进行下一步操作。我们可以考虑下一段代码,如果在输入数据实现每个分片上并行执行会发生什么?
Arrays.asList(-1, 2, 3, 4)
.stream()
.map(x -> -x)
.sorted()
.map(x -> x > 0);
为了解答这个问题,我们首先还是要回顾一下上一次提到过的,stream 求值的evaluate 的实现逻辑:
final <R> R evaluate(TerminalOp<E_OUT, R> terminalOp) {
assert getOutputShape() == terminalOp.inputShape();
if (linkedOrConsumed)
throw new IllegalStateException(MSG_STREAM_LINKED);
linkedOrConsumed = true;
return isParallel()
? terminalOp.evaluateParallel(this, sourceSpliterator(terminalOp.getOpFlags()))
: terminalOp.evaluateSequential(this, sourceSpliterator(terminalOp.getOpFlags()));
}
这里我们看到evaluate 过程中,首先会调用sourceSpliterator 方法获取源分区迭代器。如下图的代码所示:
- 在串行的条件下这里只是获取了Stream Source 的迭代器(一个迭代器或者一个supplier),然后标志了一下流水线的最终状态
- 在并行条件下就变得复杂多了,如果流水线中包含有任意一个Stateful 的操作的话,它就会依次遍历所有的stage,当遇到Stateful stage 的时候
- 去除短路的flag,保证并行的时候能够正确获得最终结果
- 对Stateful 的stage 进行求值,这里尽量会以惰性的方式(如上文中进行splititerator 的包装,而非直接进行计算,如skip/ limit 等slice操作),但对于sorted 这种必须全局的结果,就要直接进行求值。如66、67行所示,进行了并行collect 数据和并行排序
- 将Stateful 的结果作为源迭代器,传递给下一步操作,进行当前的操作的求值
sourceSpliterator实现方式
// 获取Pipeline最源头的split迭代器
// 如果是并行运算的话,会把Stateful的运算尽量以惰性的方式求值
// 然后以求值的结果当做源split 迭代器
private Spliterator<?> sourceSpliterator(int terminalFlags) {
Spliterator<?> spliterator = null;
if (sourceStage.sourceSpliterator != null) {
spliterator = sourceStage.sourceSpliterator;
sourceStage.sourceSpliterator = null;
}
else if (sourceStage.sourceSupplier != null) {
spliterator = (Spliterator<?>) sourceStage.sourceSupplier.get();
sourceStage.sourceSupplier = null;
}
else { throw new IllegalStateException(MSG_CONSUMED); }
if (isParallel() && sourceStage.sourceAnyStateful) {
// 如果是并行计算的流,且流的stage中包含了stateful 的操作
int depth = 1;
for (AbstractPipeline u = sourceStage, p = sourceStage.nextStage, e = this;
u != e; u = p, p = p.nextStage) {
int thisOpFlags = p.sourceOrOpFlags;
if (p.opIsStateful()) {
// 如果是stateful 操作
depth = 0;
if (StreamOpFlag.SHORT_CIRCUIT.isKnown(thisOpFlags)) {
// 如果当前的操作有短路的flag,则去除该flag
thisOpFlags = thisOpFlags & ~StreamOpFlag.IS_SHORT_CIRCUIT;
}
// 尽量以惰性的方式获得p 的结果splititerator,如果
// 遇到必须求值的情况,会及时进行并行求值
spliterator = p.opEvaluateParallelLazy(u, spliterator);
// 根据splititerator 标记流是有限的还是无限的
thisOpFlags = spliterator.hasCharacteristics(Spliterator.SIZED)
? (thisOpFlags & ~StreamOpFlag.NOT_SIZED) | StreamOpFlag.IS_SIZED
: (thisOpFlags & ~StreamOpFlag.IS_SIZED) | StreamOpFlag.NOT_SIZED;
}
p.depth = depth++;
p.combinedFlags = StreamOpFlag.combineOpFlags(thisOpFlags, u.combinedFlags);
}
}
if (terminalFlags != 0) {
combinedFlags = StreamOpFlag.combineOpFlags(terminalFlags, combinedFlags);
}
return spliterator;
}
// 惰性并行求值
<P_IN> Spliterator<E_OUT> opEvaluateParallelLazy(PipelineHelper<E_OUT> helper, Spliterator<P_IN> spliterator) {
return opEvaluateParallel(helper, spliterator, i -> (E_OUT[]) new Object[i]).spliterator();
}
// 排序操作的并行实现方法
public <P_IN> Node<T> opEvaluateParallel(PipelineHelper<T> helper, Spliterator<P_IN> spliterator, IntFunction<T[]> generator) {
// 如果上游的流已经是有序的,那么执行上游的evaluate方法,collect结果
if (StreamOpFlag.SORTED.isKnown(helper.getStreamAndOpFlags()) && isNaturalSort) {
return helper.evaluate(spliterator, false, generator);
}
else {
// 如果上游的流还没有排序完成,那么就先collect数据,而后进行并行排序
T[] flattenedData = helper.evaluate(spliterator, true, generator).asArray(generator);
Arrays.parallelSort(flattenedData, comparator);
return Nodes.node(flattenedData);
}
}
而进入了并行运算之后,就如我们上一篇博客中的方式调用ForkJoin 框架进行了任务的拆分和调度。所以我们可以最终得出结论
- Java Stream 中对于所有的操作在调用的时候仅仅是返回一个流的对象,并不求值
- 在获取最终结果的时候会进行求值,如果当前是并行求值的话,Java 会在计算流的链条上,找出所有的Stateful 操作,完成所有的求值工作后,作为数据源传递给下一个操作。
- 串行的时候调度的处理方式略有不同,这个大家可以参看Sink.ChainedReference 类,来观察是如何完成调度依赖的。但总体上比并行的处理方式要简单很多。
和Spark 比较呢,其实我觉得思想上它们是高度一致的,大家看过之后应该会有类似的感受吧~~^_^
后话:
我为什么想讲讲这个呢,因为现在一直在做Spark 相关的工作,Spark Core 中经典的RDD 模型的一个重要特点就是基于函数式编程以及惰性求值。在执行map/ filter/ flatMap 的时候往往并不是直接进行运算,而是记录下来执行计划,在真正需要结果的时候将数据拉取、扫描、变形、过滤等操作一次性完成。这样的好处在于,我们可以在不真正触碰数据的情况下,将数据处理逻辑单元(函数)进行高层次的规划、组合。
现在Java8 的Stream 接口也提供了类似的接口,对我们处理数据带来了非常大的便利性。多理解几种这样的实现方式,多多比较它们实现的异同,对我们理解并设计自己的程序有着非常大的指导作用。