问题背景
在项目中,有一个DataFrame,名为df1,它有40亿行数据,一共有100列。其中有一列,名为ids,类型为Seq[Long],最大长度为900。另外每一行数据都可以由unique_id这一列唯一标识。现在有另一个DataFrame,名为df2,它可能有多行数据。其中有两列,第一列列名为flag,类型为String;第二列列名为targeting_ids,类型为Seq[Long],长度最大为4w。现在的需求是:在df1中新增一列flags,类型是Seq[String],如果ids和targeting_ids有交集,则在df1中的flags值中添加targeting_ids对应的flag。它们分别长这样:
df1:
+---------+---------+
|unique_id|ids |
+---------+---------+
|1 |[1, 2, 3]|
|2 |[5, 6, 7]|
|3 |[2, 5] |
|4 |[8] |
+---------+---------+
df2:
+-----+-------------+
|flag |targeting_ids|
+-----+-------------+
|flag1|[2, 3, 4] |
|flag2|[5] |
+-----+-------------+
项目环境
Spark 3.0.1
解决方案一
通过从抽取df1抽取unique_id和ids字段与df2作full join,将结果按照unique_id字段分组,对flag作collect_list聚合,最后将该结果与df1做join,即可得到正确结果。代码如下:
// solution1
val tmpDF = df1.select("unique_id", "ids").join(broadcast(df2)) // full join
.withColumn("flag",
when(size(array_intersect($"ids", $"targeting_ids")) > 0, $"flag")
.otherwise(lit(""))
).filter("flag != ''")
.groupBy("unique_id").agg(collect_list("flag") as "flags")
val df3 = df1.join(broadcast(tmpDF), Seq("unique_id"), "left")
df3.show(false)
结果如下:
+---------+---------+--------------+
|unique_id|ids |flags |
+---------+---------+--------------+
|1 |[1, 2, 3]|[flag1] |
|2 |[5, 6, 7]|[flag2] |
|3 |[2, 5] |[flag1, flag2]|
|4 |[8] |null |
+---------+---------+--------------+
上述方案,由于会用到groupBy这个shuffle算子,因此对于原始40亿数据来说并不是很友好。现在尝试去掉shuffle。
解决方案二
利用自定义UDF以及广播变量,即可避免shuffle。我们可以通过将df2的信息转成一个Map(flag -> targeting_ids),并广播出去。之后在我们定义的UDF中使用该map,得到flags。代码如下:
// solution2
val targetingIds2FlagMapBC = spark.sparkContext.broadcast(
df2.map(
r => {
r.getAs[String]("flag") -> r.getAs[Seq[Long]]("targeting_ids").toSet
}
).rdd.collectAsMap()
)
val generateFlags = udf((ids: Seq[Long]) => {
val flags = new ListBuffer[String]()
targetingIds2FlagMapBC.value.foreach(
item => {
if (ids.toSet.intersect(item._2).nonEmpty) {
flags.append(item._1)
}
}
)
flags
})
val df4 = df1.withColumn("flags", generateFlags($"ids"))
df4.show(false)
结果如下:
+---------+---------+--------------+
|unique_id|ids |flags |
+---------+---------+--------------+
|1 |[1, 2, 3]|[flag1] |
|2 |[5, 6, 7]|[flag2] |
|3 |[2, 5] |[flag1, flag2]|
|4 |[8] |[] |
+---------+---------+--------------+
该方案成功地避免了shuffle,但是对于40亿大数据来说,还是不行。其原因如下:在方案一的array_intersect以及方案二的intersect都必须至少完全遍历一个对象才能返回结果,而我们实际需要的是如果有一个id既在ids中又在targeting ids中,我们完全可以提前结束遍历,也不需要返回列表。
解决方案三
从上面的思路分析中,我们可以想到我们可以使用break:当找到一个元素后,跳出当前遍历。但是Scala中不支持break关键字,我们可以使用抛出并捕获异常的方法快速进入下一片代码,跳出当前遍历。代码如下:
// solution3
val generateFlagsForSolution3 = udf((ids: Seq[Long]) => {
val flags = new ListBuffer[String]()
targetingIds2FlagMapBC.value.foreach(
item =>{
try {
ids.foreach(
id => {
if (item._2.contains(id)) {
throw new Exception(s"match: ${id}")
}
}
)
} catch {
case e: Exception => {
// println(s"get exception: ${e.getMessage}") // 2. can't use println, because multi thread will compete it(IO)
flags.append(item._1)
}
}
}
)
flags
})
val df5 = df1.withColumn("flags", generateFlagsForSolution3($"ids"))
df5.show(false)
结果如下:
+---------+---------+--------------+
|unique_id|ids |flags |
+---------+---------+--------------+
|1 |[1, 2, 3]|[flag1, flag1]|
|2 |[5, 6, 7]|[flag2] |
|3 |[2, 5] |[flag1, flag2]|
|4 |[8] |[] |
+---------+---------+--------------+
需要注意的是,在代码里我解释两个注释:
- use ids.foreach not item._2.foreach。 这是因为item._2的size(4w)比ids的size(最大800)大,因此对较小的list进行遍历可以加快遍历速度。
- can’t use println, because multi thread will compete it(IO)。这是因为我们的数据是40亿条,如果每条都需要println,那各个线程很有可能去竞争printStream对象从而导致线程阻塞。而实际上println这个方法也是阻塞调用的。
// java.io.PrintStream
public void println(Object x) {
String s = String.valueOf(x);
synchronized (this) {
print(s);
newLine();
}
}
完整代码
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import scala.collection.mutable.ListBuffer
object OptimizeIntersect {
def main(args: Array[String]): Unit = {
val spark = new SparkSession.Builder().appName("OptimizeIntersect")
.master("local[*]")
.config("spark.driver.host", "127.0.0.1")
.getOrCreate()
import spark.implicits._
val df1 = spark.sparkContext.parallelize(
Seq(
(1, Seq(1L, 2L, 3L)),
(2, Seq(5L, 6L, 7L)),
(3, Seq(2L, 5L)),
(4, Seq(8L)),
)
).toDF("unique_id", "ids")
println("df1:")
df1.show(false)
val df2 = spark.sparkContext.parallelize(
Seq(
("flag1", Seq(2L, 3L, 4L)),
("flag2", Seq(5L))
)
).toDF("flag", "targeting_ids")
println("df2:")
df2.show(false)
// solution1
val tmpDF = df1.select("unique_id", "ids").join(broadcast(df2)) // full join
.withColumn("flag",
when(size(array_intersect($"ids", $"targeting_ids")) > 0, $"flag")
.otherwise(lit(""))
).filter("flag != ''")
.groupBy("unique_id").agg(collect_list("flag") as "flags")
val df3 = df1.join(broadcast(tmpDF), Seq("unique_id"), "left")
// df3.show(false)
// solution2
val targetingIds2FlagMapBC = spark.sparkContext.broadcast(
df2.map(
r => {
r.getAs[String]("flag") -> r.getAs[Seq[Long]]("targeting_ids").toSet
}
).rdd.collectAsMap()
)
val generateFlags = udf((ids: Seq[Long]) => {
val flags = new ListBuffer[String]()
targetingIds2FlagMapBC.value.foreach(
item => {
if (ids.toSet.intersect(item._2).nonEmpty) {
flags.append(item._1)
}
}
)
flags
})
val df4 = df1.withColumn("flags", generateFlags($"ids"))
// df4.show(false)
// solution3
val generateFlagsForSolution3 = udf((ids: Seq[Long]) => {
val flags = new ListBuffer[String]()
targetingIds2FlagMapBC.value.foreach(
item =>{
try {
ids.foreach(
id => {
if (item._2.contains(id)) {
throw new Exception(s"match: ${id}")
}
}
)
} catch {
case e: Exception => {
// println(s"get exception: ${e.getMessage}") // 2. can't use println, because multi thread will compete it(IO)
flags.append(item._1)
}
}
}
)
flags
})
val df5 = df1.withColumn("flags", generateFlagsForSolution3($"ids"))
df5.show(false)
}
}
总结
- 有时可以使用广播变量避免shuffle。
- 使用抛出并且捕获异常的方式可以实现break功能。
- 当数据量很大时,应该避免对每天数据都做一次打印(println)或者log(logger.info)之类会阻塞线程的事情。
- 应该先分析清楚自己想要的是什么,再去想最好的解决方案。