Java8 Stream 惰性求值实现分析

Stream 类的继承关系和职责

在上一篇博客中,我主要给大家介绍了Stream 的并行计算的实现原理。这一次,我主要讲讲Stream 惰性求值的实现方式。这次我主要介绍一下Stream 惰性求值的实现方式。下图是Stream package里面重要的类继承关系:

 

因为由于Java 的泛型不能作用于原始类型,所以为了提高性能,避免自动装箱,Java 为了int、long、double 专门构造了属于自己的Stream 类型,而对于一般的对象类型定义了Stream 和ReferencePipeline。

  • BaseStream 中规定了流的基本接口,比如返回迭代器、返回分片迭代器、判断流是否是并行还是串行、流是否是有序的(注意,不是排序的)
  • Stream 中规定了map、filter、flatMap 等经典的函数式接口,这里的接口方法是Java 用户最主要关注的
  • PipelineHelper 中主要是流水线的执行帮助方法,这里是Stream 内部实现的主要依赖方法,虽然在Java 编程中我们不需要关心这部分内容,但是在探索内部实现中,这部分是重中之重

三种Stream 类型(执行阶段)

  1. Head: Stream源,可以是集合或者自定义的Supplier 生成的有限长度或无限长度的流。一般由Stream 的工厂类StreamSupport.stream 接口,
  1. 可以由splititerator 构建有限的流
  2. 由supplier 根据规定方法构造无限的流
  1. StatelessOp/StatefulOp: 中间数据处理操作:
  1. StatelessOp: map/ flatMap/ filter 等无状态的操作
  2. StatefulOp: distinct/ slice/ sorted 等有状态的操作
  1. TerminalOp: find/ foreach/ match/ reduce 等终止操作,当遇到这种操作的时候Stream 会立刻进行求值运算

  Stream工厂类主要接口

public static <T> Stream<T> stream(Spliterator<T> spliterator, boolean parallel);
public static <T> Stream<T> stream(Supplier<? extends Spliterator<T>> supplier,
	                          int characteristics,
	                          boolean parallel);

  

我们从这里可以看出这些Stream 很类似于Scala 中的Stream 和Spark 中的RDD (当然啦,因为是Spark 也是从Scala 接口中借鉴了很多)。

其中Head 类似于Spark 中的读取文件等的源RDD,比如由sparkContext 的textfile, parallelize 生产的RDD。StatelessOp 类似于spark 中的MapPartitionsRDD ,StatefulOp 接近于spark 中的ShuffleRDD 等。而TerminalOp 和Spark 中的collect、reduce等操作效果完全相同。

那我们很自然地想问几个问题:

  1. Java 中的惰性求值是怎么实现的?
  2. Java 中是如何实现Spark 中DAGScheduler 的操作的?

1. 惰性求值的思路

对于Java 惰性求值是什么思路,其实这个跟Spark RDD 处理方式完全一样,我们可以观察到Stream 中我们执行各种函数接口调用,实际上只是包装成一个新的Stream,然后返回给调用端,直到执行collect/ reduce等操作的时候调用真正的求值evaluate 操作。我们可以如下图比较以下两者的逻辑:

Stream的实现

public final <R> Stream<R> map(Function<? super P_OUT, ? extends R> mapper) {
  Objects.requireNonNull(mapper); 
  return new StatelessOp<P_OUT, R>(this, StreamShape.REFERENCE, StreamOpFlag.NOT_SORTED | StreamOpFlag.NOT_DISTINCT) {
    @Override
    Sink<P_OUT> opWrapSink(int flags, Sink<R> sink) {
      return new Sink.ChainedReference<P_OUT, R>(sink) {
        @Override
        public void accept(P_OUT u) { downstream.accept(mapper.apply(u)); }
      };
    }
  };
}

public final Stream<P_OUT> sorted() {
  return SortedOps.makeRef(this);
}

public final <R> R collect(Supplier<R> supplier, BiConsumer<R, ? super P_OUT> accumulator,
               BiConsumer<R, R> combiner) {
  return evaluate(ReduceOps.makeRef(supplier, accumulator, combiner));
}

  其中我们看到在进行map 操作的时候,java并没有做任何的数据操作过程,而仅仅是返回了一个StatelessOp。而sorted 实际上调用了SortedOps 的makeRef,这里实际上也只是返回了一个覆盖一些具体运算方法StatefulOp。而真正会进行求值的collect 方法,实际上调用了evaluate 方法。

Stream惰性求值实现

def map[U: ClassTag](f: T => U): RDD[U] = withScope {
  val cleanF = sc.clean(f)
  new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
}

def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
    : RDD[(K, V)] = self.withScope {
  val part = new RangePartitioner(numPartitions, self, ascending)
  new ShuffledRDD[K, V, V](self, part)
    .setKeyOrdering(if (ascending) ordering else ordering.reverse)
}

def collect(): Array[T] = withScope {
  val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
  Array.concat(results: _*)
}

  

对照Spark 的相应方法的实现,我们看到map sortByKey 也都仅仅是返回相应的RDD,而没有执行操作,而collect 中调用了SparkContext 中的runJob 方法,类似于Stream 中的evaluate 方法。

2. 任务依赖关系和并行计算

在Stream 的计算中有很多运算天然支持的并行运算的,这种操作我们可以称之为一种无状态操作,map/ filter/ flatMap 等等。比如下图,输入分成任意个分块,同时执行这几组操作完全没有问题。

简单并行的操作

Arrays.asList(2,4,5,6).stream()
  .parallelize()
  .map(_+1)
  .filter(_>5)
  .flatMap(x=>List(x, x+1))

  而有些操作却必须依赖,前一步在整个数据集合全局完成了之后才能进行下去,不能等待某个分区上有序了就直接进行下一步操作。我们可以考虑下一段代码,如果在输入数据实现每个分片上并行执行会发生什么?

Arrays.asList(-1, 2, 3, 4)
        .stream()
        .map(x -> -x)
        .sorted()
        .map(x -> x > 0);

  为了解答这个问题,我们首先还是要回顾一下上一次提到过的,stream 求值的evaluate 的实现逻辑:

final <R> R evaluate(TerminalOp<E_OUT, R> terminalOp) {
    assert getOutputShape() == terminalOp.inputShape();
    if (linkedOrConsumed)
        throw new IllegalStateException(MSG_STREAM_LINKED);
    linkedOrConsumed = true;

    return isParallel()
           ? terminalOp.evaluateParallel(this, sourceSpliterator(terminalOp.getOpFlags()))
           : terminalOp.evaluateSequential(this, sourceSpliterator(terminalOp.getOpFlags()));
}

  

这里我们看到evaluate 过程中,首先会调用sourceSpliterator 方法获取源分区迭代器。如下图的代码所示:

  1. 在串行的条件下这里只是获取了Stream Source 的迭代器(一个迭代器或者一个supplier),然后标志了一下流水线的最终状态
  2. 在并行条件下就变得复杂多了,如果流水线中包含有任意一个Stateful 的操作的话,它就会依次遍历所有的stage,当遇到Stateful stage 的时候
  • 去除短路的flag,保证并行的时候能够正确获得最终结果
  • 对Stateful 的stage 进行求值,这里尽量会以惰性的方式(如上文中进行splititerator 的包装,而非直接进行计算,如skip/ limit 等slice操作),但对于sorted 这种必须全局的结果,就要直接进行求值。如66、67行所示,进行了并行collect 数据和并行排序
  • 将Stateful 的结果作为源迭代器,传递给下一步操作,进行当前的操作的求值

sourceSpliterator实现方式

// 获取Pipeline最源头的split迭代器
// 如果是并行运算的话,会把Stateful的运算尽量以惰性的方式求值
// 然后以求值的结果当做源split 迭代器
private Spliterator<?> sourceSpliterator(int terminalFlags) {
  Spliterator<?> spliterator = null;
  if (sourceStage.sourceSpliterator != null) { 
    spliterator = sourceStage.sourceSpliterator;
    sourceStage.sourceSpliterator = null;
  }
  else if (sourceStage.sourceSupplier != null) { 
    spliterator = (Spliterator<?>) sourceStage.sourceSupplier.get();
    sourceStage.sourceSupplier = null;
  }
  else { throw new IllegalStateException(MSG_CONSUMED); }

  if (isParallel() && sourceStage.sourceAnyStateful) {
    // 如果是并行计算的流,且流的stage中包含了stateful 的操作
    int depth = 1;
    for (AbstractPipeline u = sourceStage, p = sourceStage.nextStage, e = this;
       u != e; u = p, p = p.nextStage) {

      int thisOpFlags = p.sourceOrOpFlags;
      if (p.opIsStateful()) {
        // 如果是stateful 操作
        depth = 0;

        if (StreamOpFlag.SHORT_CIRCUIT.isKnown(thisOpFlags)) {
          // 如果当前的操作有短路的flag,则去除该flag
          thisOpFlags = thisOpFlags & ~StreamOpFlag.IS_SHORT_CIRCUIT;
        }

        // 尽量以惰性的方式获得p 的结果splititerator,如果
        // 遇到必须求值的情况,会及时进行并行求值
        spliterator = p.opEvaluateParallelLazy(u, spliterator);

        // 根据splititerator 标记流是有限的还是无限的
        thisOpFlags = spliterator.hasCharacteristics(Spliterator.SIZED)
          ? (thisOpFlags & ~StreamOpFlag.NOT_SIZED) | StreamOpFlag.IS_SIZED
          : (thisOpFlags & ~StreamOpFlag.IS_SIZED) | StreamOpFlag.NOT_SIZED;
      }
      p.depth = depth++;
      p.combinedFlags = StreamOpFlag.combineOpFlags(thisOpFlags, u.combinedFlags);
    }
  }

  if (terminalFlags != 0)  {
    combinedFlags = StreamOpFlag.combineOpFlags(terminalFlags, combinedFlags);
  }

  return spliterator;
}

// 惰性并行求值
<P_IN> Spliterator<E_OUT> opEvaluateParallelLazy(PipelineHelper<E_OUT> helper, Spliterator<P_IN> spliterator) {
  return opEvaluateParallel(helper, spliterator, i -> (E_OUT[]) new Object[i]).spliterator();
}

// 排序操作的并行实现方法
public <P_IN> Node<T> opEvaluateParallel(PipelineHelper<T> helper, Spliterator<P_IN> spliterator, IntFunction<T[]> generator) {
  // 如果上游的流已经是有序的,那么执行上游的evaluate方法,collect结果
  if (StreamOpFlag.SORTED.isKnown(helper.getStreamAndOpFlags()) && isNaturalSort) {
    return helper.evaluate(spliterator, false, generator);
  }
  else {
    // 如果上游的流还没有排序完成,那么就先collect数据,而后进行并行排序
    T[] flattenedData = helper.evaluate(spliterator, true, generator).asArray(generator);
    Arrays.parallelSort(flattenedData, comparator);
    return Nodes.node(flattenedData);
  }
}

  

而进入了并行运算之后,就如我们上一篇博客中的方式调用ForkJoin 框架进行了任务的拆分和调度。所以我们可以最终得出结论

  1. Java Stream 中对于所有的操作在调用的时候仅仅是返回一个流的对象,并不求值
  2. 在获取最终结果的时候会进行求值,如果当前是并行求值的话,Java 会在计算流的链条上,找出所有的Stateful 操作,完成所有的求值工作后,作为数据源传递给下一个操作。
  3. 串行的时候调度的处理方式略有不同,这个大家可以参看Sink.ChainedReference 类,来观察是如何完成调度依赖的。但总体上比并行的处理方式要简单很多。

和Spark 比较呢,其实我觉得思想上它们是高度一致的,大家看过之后应该会有类似的感受吧~~^_^

 

后话:

我为什么想讲讲这个呢,因为现在一直在做Spark 相关的工作,Spark Core 中经典的RDD 模型的一个重要特点就是基于函数式编程以及惰性求值。在执行map/ filter/ flatMap 的时候往往并不是直接进行运算,而是记录下来执行计划,在真正需要结果的时候将数据拉取、扫描、变形、过滤等操作一次性完成。这样的好处在于,我们可以在不真正触碰数据的情况下,将数据处理逻辑单元(函数)进行高层次的规划、组合。

现在Java8 的Stream 接口也提供了类似的接口,对我们处理数据带来了非常大的便利性。多理解几种这样的实现方式,多多比较它们实现的异同,对我们理解并设计自己的程序有着非常大的指导作用。