一般情况下,Spark
算子每个节点之间函数中用到的变量是独立拷贝的,互不影响,即使更改之后也不会被拉回到 Driver
端,支持跨 task
之间共享变量通常是低效的, 但是 Spark
对共享变量也提供了两种支持:
- 广播变量
- 累加器
1. 广播变量
广播变量只会在每个节点上保存一份只读变量的缓存,而不是每个 task
都 copy
一份(每个节点可能有多个 task
,可以理解为之前已每个线程有一份 copy
,现在是每个进程缓存一份,多个线程之间进行共享)
广播变量是只读的,一般用来广播大变量(最好不好超过 1 G),小变量直接传递即可
import org.apache.spark.sql.SparkSession
object Broadcast {
def main(args: Array[String]): Unit = {
val session = SparkSession.builder.appName("create_rdd").master("local[2]").getOrCreate()
val sc = session.sparkContext
val list1 = 1 to 100000 toArray
val bd = sc.broadcast(list1)
val list2 = List(1, 2, 3, 5, 6, 9999)
val rdd = sc.parallelize(list2).filter(x => bd.value.contains(x))
rdd.collect().foreach(println)
sc.stop()
}
}
2. 累加器
累加器用来对信息进行聚合,广播变量是可读的,如果想在每个 task
上影响这个变量,就只能使用累加器, Spark
内置了一些简单的累加器,如:add
,用户也可以自定义复杂的累加器。
2.1 内置累加器
给 rdd
中,每个元素都加 1:
import org.apache.spark.sql.SparkSession
object AddCalc {
def main(args: Array[String]): Unit = {
val session = SparkSession.builder.appName("create_rdd").master("local[2]").getOrCreate()
val sc = session.sparkContext
// 定义一个 Long 类型的累加器
val acc = sc.longAccumulator("oneAcc")
val rdd = sc.parallelize(List(1, 2, 3, 4, 5, 6), 3).map(x => {
acc.add(1)
x
})
rdd.collect
println(acc.value) // 获取累加器的值
sc.stop()
}
}
累加器的更新操作最好放在 action
中, Spark
可以保证每个 task
只执行一次. 如果放在 transformations
操作中则不能保证只更新一次,有可能会被重复执行:
// 将累加操作放在 action 中
val rdd = sc.parallelize(List(1, 2, 3, 4, 5, 6), 3)
rdd.foreach(x => {
acc.add(1)
x
})
println(acc.value)
2.2 自定义累加器
可以通过继承 AccumulatorV2
来自定义累加器:
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.AccumulatorV2
object AccDefined {
def main(args: Array[String]): Unit = {
val session = SparkSession.builder.appName("create_rdd").master("local[2]").getOrCreate()
val sc = session.sparkContext
// 先注册自定义的累加器
val acc = new MyAcc
sc.register(acc, "first_acc")
// 将累加操作放在 action 中
val rdd = sc.parallelize(List(1, 2, 3, 4, 5, 6), 3)
rdd.foreach(x => {
acc.add(1)
x
})
println(acc.value)
sc.stop()
}
}
class MyAcc extends AccumulatorV2[Int, Int] {
private var sum = 0
// 判"零", 对缓冲区值进行判"零"
override def isZero: Boolean = sum == 0
// 把当前的累加赋值为一个新的累加器
override def copy(): AccumulatorV2[Int, Int] = {
val acc = new MyAcc
acc.sum = sum
acc
}
// 重置累加器(就是把缓冲区的值重置为"零")
override def reset(): Unit = sum = 0
// 真正的累加方法(分区内的累加)
override def add(v: Int): Unit = sum += v
// 分区间的合并 把other的sum合并到this的sum中
override def merge(other: AccumulatorV2[Int, Int]): Unit = other match {
case acc: MyAcc => this.sum += acc.sum
case _ => this.sum += 0
}
// 返回累加后的最终值
override def value: Int = sum
}
需求:
- 返回一个 map,该累加器同时包含
sum, count, avg
- 输出结果类似:
Map(sum -> 370.0, count -> 10, avg -> 37.0)
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.AccumulatorV2
object MapAcc {
def main(args: Array[String]): Unit = {
val session = SparkSession.builder.appName("create_rdd").master("local[2]").getOrCreate()
val sc = session.sparkContext
val list1 = List(30, 50, 70, 60, 10, 20, 10, 30, 40, 50)
val rdd = sc.parallelize(list1, 2)
val acc = new MapAcc
sc.register(acc) // 注册自定义累加器
rdd.foreach(
x => acc.add(x)
)
println(acc.value)
sc.stop()
}
}
/*
输入一个 Double,返回一个 map,该累加器同时包含 sum, count, avg
输出:Map(sum -> 370.0, count -> 10, avg -> 37.0)
*/
class MapAcc extends AccumulatorV2[Double, Map[String, Any]] {
private var map = Map[String, Any]()
override def isZero: Boolean = map.isEmpty
override def copy(): AccumulatorV2[Double, Map[String, Any]] = {
val acc = new MapAcc
acc.map = map
acc
}
override def reset(): Unit = map = {
Map[String, Any]()
}
override def add(v: Double): Unit = {
// sum、count 累加,avg 最后的 value 函数计算
// asInstanceOf[Double] 将其转换为 Double 类型
// map += key -> value Map 添加新 key
println("add v====>" + v)
map += "sum" -> (map.getOrElse("sum", 0D).asInstanceOf[Double] + v)
map += "count" -> (map.getOrElse("count", 0L).asInstanceOf[Long] + 1L)
}
override def merge(other: AccumulatorV2[Double, Map[String, Any]]): Unit = {
// 合并两个map
other match {
case o: MapAcc =>
map +=
"sum" -> (map.getOrElse("sum", 0D).asInstanceOf[Double] + o.map.getOrElse("sum", 0D).asInstanceOf[Double])
map +=
"count" -> (map.getOrElse("count", 0L).asInstanceOf[Long] + o.map.getOrElse("count", 0L).asInstanceOf[Long])
case _ => throw new UnsupportedOperationException
}
}
override def value: Map[String, Any] = {
map += "avg" -> (map.getOrElse("sum", 0D).asInstanceOf[Double] / map.getOrElse("count", 0L).asInstanceOf[Long])
map
}
}