背景

PySpark Performance Enhancements: [SPARK-22216][SPARK-21187] Significant improvements in python performance and interoperability by fast data serialization and vectorized execution.

SPARK-22216:主要实现矢量化pandas udf处理,并解决涉及pandas/arrow相关的问题;
SPARK-21187: 知道现在为止还没有解决的一个ISSUE,当前Arrow类型依然不支持BinaryType, MapType, ArrayType of TimestampType, and nested StructType

这个ISSUE是一个比较复杂的问题,涉及pyspark、arrow、pandas以及spark sql等组件,每个组件都有值得研究的地方。本篇主要梳理下pyspark的运行逻辑,后续逐渐分析该ISSUE采用哪些优化手段提升性能的。

综述

Driver端

  • 用户在PySpark中实例化一个Python的SparkContext对象,最终会在JVM中实例化Scala的SparkContext对象;

Executor端

  • 不需要借助Py4j,因为Executor端运行的Task逻辑是由Driver发过来的,那是序列化后的字节码。
  • 虽然里面可能包含有用户定义的Python函数或Lambda表达式,Py4j并不能实现在Java里调用Python的方法,为了能在Executor端运行用户定义的Python函数或Lambda表达式,则需要为每个Task单独启一个Python进程,通过socket通信方式将Python函数或Lambda表达式发给Python进程执行。

python程序通过SparkSubmit提交后,可以查看其最终运行的主函数如下:

// If we're running a python app, set the main class to our specific python runner
    if (args.isPython && deployMode == CLIENT) {
      if (args.primaryResource == PYSPARK_SHELL) {
        args.mainClass = "org.apache.spark.api.python.PythonGatewayServer"
      } else {
        // If a python file is provided, add it to the child arguments and list of files to deploy.
        // Usage: PythonAppRunner <main python file> <extra python files> [app arguments]
        args.mainClass = "org.apache.spark.deploy.PythonRunner"
        args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs
        if (clusterManager != YARN) {
          // The YARN backend distributes the primary file differently, so don't merge it.
          args.files = mergeFileLists(args.files, args.primaryResource)
        }
      }
      if (clusterManager != YARN) {
        // The YARN backend handles python files differently, so don't merge the lists.
        args.files = mergeFileLists(args.files, args.pyFiles)
      }
      if (localPyFiles != null) {
        sparkConf.set("spark.submit.pyFiles", localPyFiles)
      }
    }

Driver端

用户Python脚本起来后,首先会实例化Python版的SparkContext对象,在实例化过程中会做两件事:

  • 实例化Py4j GatewayClient,连接JVM中的Py4j GatewayServer,后续在Python中调用Java的方法都是借助这个Py4j Gateway
  • 通过Py4j Gateway在JVM中实例化SparkContext对象

经过上面两步后,SparkContext对象初始化完毕,Driver已经起来了,开始申请Executor资源,同时开始调度任务。用户Python脚本中定义的一系列处理逻辑最终遇到action方法后会触发Job的提交,提交Job时是直接通过Py4j调用Java的PythonRDD.runJob方法完成,映射到JVM中,会转给sparkContext.runJob方法,Job运行完成后,JVM中会开启一个本地Socket等待Python进程拉取,对应地,Python进程在调用PythonRDD.runJob后就会通过Socket去拉取结果。

PythonRunner源码分析

PythonRunner入口main函数里主要做两件事:

  • 开启Py4j GatewayServer
  • 通过Java Process方式运行用户上传的Python脚本

具体源码分析如下:

object PythonRunner {
  def main(args: Array[String]) {
    val pythonFile = args(0)
    val pyFiles = args(1)
    val otherArgs = args.slice(2, args.length)
    val sparkConf = new SparkConf()
    val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON)
      .orElse(sparkConf.get(PYSPARK_PYTHON))
      .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON"))
      .orElse(sys.env.get("PYSPARK_PYTHON"))
      .getOrElse("python")

    // Format python file paths before adding them to the PYTHONPATH
    val formattedPythonFile = formatPath(pythonFile)
    val formattedPyFiles = formatPaths(pyFiles)

    // 开启py4j gateway服务,用于同executor通信;设置为daemon的方式,另起一个线程;
    // Launch a Py4J gateway server for the process to connect to; this will let it see our
    // Java system properties and such
    val gatewayServer = new py4j.GatewayServer(null, 0)
    val thread = new Thread(new Runnable() {
      override def run(): Unit = Utils.logUncaughtExceptions {
        gatewayServer.start()
      }
    })
    thread.setName("py4j-gateway-init")
    thread.setDaemon(true)
    thread.start()

    // Wait until the gateway server has started, so that we know which port is it bound to.
    // `gatewayServer.start()` will start a new thread and run the server code there, after
    // initializing the socket, so the thread started above will end as soon as the server is
    // ready to serve connections.
    // 此处注意:需要等待gateway服务器开启完毕,所以可以知道哪一个port绑定。
    thread.join()

    // Build up a PYTHONPATH that includes the Spark assembly (where this class is), the
    // python directories in SPARK_HOME (if set), and any files in the pyFiles argument
    val pathElements = new ArrayBuffer[String]
    pathElements ++= formattedPyFiles
    pathElements += PythonUtils.sparkPythonPath
    pathElements += sys.env.getOrElse("PYTHONPATH", "")
    val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*)

    // Launch Python process
    // 此处初始化一个执行python 命令的进程,执行用户提交的python文件。
    val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
    val env = builder.environment()
    env.put("PYTHONPATH", pythonPath)
    // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
    env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
    env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
    // pass conf spark.pyspark.python to python process, the only way to pass info to
    // python process is through environment variable.
    sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
    sys.env.get("PYTHONHASHSEED").foreach(env.put("PYTHONHASHSEED", _))
    builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
    try {
      val process = builder.start()

      new RedirectThread(process.getInputStream, System.out, "redirect output").start()

      val exitCode = process.waitFor()
      if (exitCode != 0) {
        throw new SparkUserAppException(exitCode)
      }
    } finally {
      gatewayServer.shutdown()
    }
  }

Executor端

具体可分析spark源码下pyspark逻辑,其实现逻辑基本上:通过pyspark提供的python API编写的这个程序,在创建SparkContext(python)时,会初始化_gateway变量(JavaGateway对象)和_jvm变量(JVMView对象),来实现对spark 算子的封装。

同时需要注意的是在gateway中,引入的一下spark包,其是可以在pyspark中直接使用的:

# Import the classes used by PySpark
    java_import(gateway.jvm, "org.apache.spark.SparkConf")
    java_import(gateway.jvm, "org.apache.spark.api.java.*")
    java_import(gateway.jvm, "org.apache.spark.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
    # TODO(davies): move into sql
    java_import(gateway.jvm, "org.apache.spark.sql.*")
    java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
    java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
    java_import(gateway.jvm, "scala.Tuple2")

使用注意事项

如果想要使用arrow和pandas,默认情况下spark是没有带这些安装包的(py4j依赖包spark有提供),需要自行安装pyarrow和pandas。其对版本要求如下:

extras_require={
            'ml': ['numpy>=1.7'],
            'mllib': ['numpy>=1.7'],
            'sql': [
                'pandas>=%s' % _minimum_pandas_version,
                'pyarrow>=%s' % _minimum_pyarrow_version,
            ]
        },

_minimum_pandas_version = "0.19.2"
_minimum_pyarrow_version = "0.8.0"

在调用df.toPandas()函数时,可以参考https://issues.apache.org/jira/browse/SPARK-13534,提升读取文件性能。

val ARROW_EXECUTION_ENABLE =
    buildConf("spark.sql.execution.arrow.enabled")
      .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " +
        "for use with pyspark.sql.DataFrame.toPandas, and " +
        "pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " +
        "The following data types are unsupported: " +
        "BinaryType, MapType, ArrayType of TimestampType, and nested StructType.")
      .booleanConf
      .createWithDefault(false)