问题背景

在项目中,有一个DataFrame,名为df1,它有40亿行数据一共有100列。其中有一列,名为ids,类型为Seq[Long],最大长度为900另外每一行数据都可以由unique_id这一列唯一标识。现在有另一个DataFrame,名为df2,它可能有多行数据。其中有两列,第一列列名为flag,类型为String;第二列列名为targeting_ids,类型为Seq[Long],长度最大为4w。现在的需求是:在df1中新增一列flags,类型是Seq[String],如果ids和targeting_ids有交集,则在df1中的flags值中添加targeting_ids对应的flag。它们分别长这样:

df1:
+---------+---------+
|unique_id|ids      |
+---------+---------+
|1        |[1, 2, 3]|
|2        |[5, 6, 7]|
|3        |[2, 5]   |
|4        |[8]      |
+---------+---------+

df2:
+-----+-------------+
|flag |targeting_ids|
+-----+-------------+
|flag1|[2, 3, 4]    |
|flag2|[5]          |
+-----+-------------+

项目环境

Spark 3.0.1

解决方案一

通过从抽取df1抽取unique_id和ids字段与df2作full join,将结果按照unique_id字段分组,对flag作collect_list聚合,最后将该结果与df1做join,即可得到正确结果。代码如下:

// solution1
    val tmpDF = df1.select("unique_id", "ids").join(broadcast(df2)) // full join
      .withColumn("flag",
        when(size(array_intersect($"ids", $"targeting_ids")) > 0, $"flag")
          .otherwise(lit(""))
      ).filter("flag != ''")
      .groupBy("unique_id").agg(collect_list("flag") as "flags")
    val df3 = df1.join(broadcast(tmpDF), Seq("unique_id"), "left")
    df3.show(false)

结果如下:

+---------+---------+--------------+
|unique_id|ids      |flags         |
+---------+---------+--------------+
|1        |[1, 2, 3]|[flag1]       |
|2        |[5, 6, 7]|[flag2]       |
|3        |[2, 5]   |[flag1, flag2]|
|4        |[8]      |null          |
+---------+---------+--------------+

上述方案,由于会用到groupBy这个shuffle算子,因此对于原始40亿数据来说并不是很友好。现在尝试去掉shuffle。

解决方案二

利用自定义UDF以及广播变量,即可避免shuffle。我们可以通过将df2的信息转成一个Map(flag -> targeting_ids),并广播出去。之后在我们定义的UDF中使用该map,得到flags。代码如下:

// solution2
    val targetingIds2FlagMapBC = spark.sparkContext.broadcast(
      df2.map(
        r => {
          r.getAs[String]("flag") -> r.getAs[Seq[Long]]("targeting_ids").toSet
        }
      ).rdd.collectAsMap()
    )
    val generateFlags = udf((ids: Seq[Long]) => {
      val flags = new ListBuffer[String]()
      targetingIds2FlagMapBC.value.foreach(
        item => {
          if (ids.toSet.intersect(item._2).nonEmpty) {
            flags.append(item._1)
          }
        }
      )
      flags
    })
    val df4 = df1.withColumn("flags", generateFlags($"ids"))
    df4.show(false)

结果如下:

+---------+---------+--------------+
|unique_id|ids      |flags         |
+---------+---------+--------------+
|1        |[1, 2, 3]|[flag1]       |
|2        |[5, 6, 7]|[flag2]       |
|3        |[2, 5]   |[flag1, flag2]|
|4        |[8]      |[]            |
+---------+---------+--------------+

该方案成功地避免了shuffle,但是对于40亿大数据来说,还是不行。其原因如下:在方案一的array_intersect以及方案二的intersect都必须至少完全遍历一个对象才能返回结果,而我们实际需要的是如果有一个id既在ids中又在targeting ids中,我们完全可以提前结束遍历,也不需要返回列表。

解决方案三

从上面的思路分析中,我们可以想到我们可以使用break:当找到一个元素后,跳出当前遍历。但是Scala中不支持break关键字,我们可以使用抛出并捕获异常的方法快速进入下一片代码,跳出当前遍历。代码如下:

// solution3
    val generateFlagsForSolution3 = udf((ids: Seq[Long]) => {
      val flags = new ListBuffer[String]()
      targetingIds2FlagMapBC.value.foreach(
        item =>{
          try {
            ids.foreach(
              id => {
                if (item._2.contains(id)) {
                  throw new Exception(s"match: ${id}")
                }
              }
            )
          } catch {
            case e: Exception => {
              // println(s"get exception: ${e.getMessage}") // 2. can't use println, because multi thread will compete it(IO)
              flags.append(item._1)
            }
          }
        }
      )
      flags
    })
    val df5 = df1.withColumn("flags", generateFlagsForSolution3($"ids"))
    df5.show(false)

结果如下:

+---------+---------+--------------+
|unique_id|ids      |flags         |
+---------+---------+--------------+
|1        |[1, 2, 3]|[flag1, flag1]|
|2        |[5, 6, 7]|[flag2]       |
|3        |[2, 5]   |[flag1, flag2]|
|4        |[8]      |[]            |
+---------+---------+--------------+

需要注意的是,在代码里我解释两个注释:

  1. use ids.foreach not item._2.foreach。 这是因为item._2的size(4w)比ids的size(最大800)大,因此对较小的list进行遍历可以加快遍历速度。
  2. can’t use println, because multi thread will compete it(IO)。这是因为我们的数据是40亿条,如果每条都需要println,那各个线程很有可能去竞争printStream对象从而导致线程阻塞。而实际上println这个方法也是阻塞调用的。
// java.io.PrintStream
public void println(Object x) {
        String s = String.valueOf(x);
        synchronized (this) {
            print(s);
            newLine();
        }
    }

完整代码

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import scala.collection.mutable.ListBuffer

object OptimizeIntersect {
  def main(args: Array[String]): Unit = {
    val spark = new SparkSession.Builder().appName("OptimizeIntersect")
      .master("local[*]")
      .config("spark.driver.host", "127.0.0.1")
      .getOrCreate()
    import spark.implicits._
    val df1 = spark.sparkContext.parallelize(
      Seq(
        (1, Seq(1L, 2L, 3L)),
        (2, Seq(5L, 6L, 7L)),
        (3, Seq(2L, 5L)),
        (4, Seq(8L)),
      )
    ).toDF("unique_id", "ids")
    println("df1:")
    df1.show(false)
    val df2 = spark.sparkContext.parallelize(
      Seq(
        ("flag1", Seq(2L, 3L, 4L)),
        ("flag2", Seq(5L))
      )
    ).toDF("flag", "targeting_ids")
    println("df2:")
    df2.show(false)

    // solution1
    val tmpDF = df1.select("unique_id", "ids").join(broadcast(df2)) // full join
      .withColumn("flag",
        when(size(array_intersect($"ids", $"targeting_ids")) > 0, $"flag")
          .otherwise(lit(""))
      ).filter("flag != ''")
      .groupBy("unique_id").agg(collect_list("flag") as "flags")
    val df3 = df1.join(broadcast(tmpDF), Seq("unique_id"), "left")
//    df3.show(false)

    // solution2
    val targetingIds2FlagMapBC = spark.sparkContext.broadcast(
      df2.map(
        r => {
          r.getAs[String]("flag") -> r.getAs[Seq[Long]]("targeting_ids").toSet
        }
      ).rdd.collectAsMap()
    )
    val generateFlags = udf((ids: Seq[Long]) => {
      val flags = new ListBuffer[String]()
      targetingIds2FlagMapBC.value.foreach(
        item => {
          if (ids.toSet.intersect(item._2).nonEmpty) {
            flags.append(item._1)
          }
        }
      )
      flags
    })
    val df4 = df1.withColumn("flags", generateFlags($"ids"))
//    df4.show(false)

    // solution3
    val generateFlagsForSolution3 = udf((ids: Seq[Long]) => {
      val flags = new ListBuffer[String]()
      targetingIds2FlagMapBC.value.foreach(
        item =>{
          try {
            ids.foreach(
              id => {
                if (item._2.contains(id)) {
                  throw new Exception(s"match: ${id}")
                }
              }
            )
          } catch {
            case e: Exception => {
              // println(s"get exception: ${e.getMessage}") // 2. can't use println, because multi thread will compete it(IO)
              flags.append(item._1)
            }
          }
        }
      )
      flags
    })
    val df5 = df1.withColumn("flags", generateFlagsForSolution3($"ids"))
    df5.show(false)
  }
}

总结

  1. 有时可以使用广播变量避免shuffle。
  2. 使用抛出并且捕获异常的方式可以实现break功能。
  3. 当数据量很大时,应该避免对每天数据都做一次打印(println)或者log(logger.info)之类会阻塞线程的事情。
  4. 应该先分析清楚自己想要的是什么,再去想最好的解决方案。