一般情况下,Spark 算子每个节点之间函数中用到的变量是独立拷贝的,互不影响,即使更改之后也不会被拉回到 Driver 端,支持跨 task 之间共享变量通常是低效的, 但是 Spark 对共享变量也提供了两种支持:

  • 广播变量
  • 累加器

1. 广播变量

广播变量只会在每个节点上保存一份只读变量的缓存,而不是每个 taskcopy 一份(每个节点可能有多个 task,可以理解为之前已每个线程有一份 copy,现在是每个进程缓存一份,多个线程之间进行共享)

广播变量是只读的,一般用来广播大变量(最好不好超过 1 G),小变量直接传递即可

import org.apache.spark.sql.SparkSession

object Broadcast {
  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder.appName("create_rdd").master("local[2]").getOrCreate()
    val sc = session.sparkContext

    val list1 = 1 to 100000 toArray
    val bd = sc.broadcast(list1)
    val list2 = List(1, 2, 3, 5, 6, 9999)

    val rdd = sc.parallelize(list2).filter(x => bd.value.contains(x))

    rdd.collect().foreach(println)
    sc.stop()
  }
}

2. 累加器

累加器用来对信息进行聚合,广播变量是可读的,如果想在每个 task 上影响这个变量,就只能使用累加器, Spark 内置了一些简单的累加器,如:add,用户也可以自定义复杂的累加器。

2.1 内置累加器

rdd 中,每个元素都加 1:

import org.apache.spark.sql.SparkSession

object AddCalc {
  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder.appName("create_rdd").master("local[2]").getOrCreate()
    val sc = session.sparkContext

    // 定义一个 Long 类型的累加器
    val acc = sc.longAccumulator("oneAcc")

    val rdd = sc.parallelize(List(1, 2, 3, 4, 5, 6), 3).map(x => {
      acc.add(1)
      x
    })

    rdd.collect
    println(acc.value) // 获取累加器的值

    sc.stop()
    
  }
}

累加器的更新操作最好放在 action 中, Spark 可以保证每个 task 只执行一次. 如果放在 transformations 操作中则不能保证只更新一次,有可能会被重复执行:

// 将累加操作放在 action 中
val rdd = sc.parallelize(List(1, 2, 3, 4, 5, 6), 3)
rdd.foreach(x => {
    acc.add(1)
    x
})
println(acc.value)

2.2 自定义累加器

可以通过继承 AccumulatorV2 来自定义累加器:

import org.apache.spark.sql.SparkSession
import org.apache.spark.util.AccumulatorV2

object AccDefined {
  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder.appName("create_rdd").master("local[2]").getOrCreate()
    val sc = session.sparkContext

    // 先注册自定义的累加器
    val acc = new MyAcc
    sc.register(acc, "first_acc")

    // 将累加操作放在 action 中
    val rdd = sc.parallelize(List(1, 2, 3, 4, 5, 6), 3)
    rdd.foreach(x => {
      acc.add(1)
      x
    })
    println(acc.value)
    sc.stop()
  }
}

class MyAcc extends AccumulatorV2[Int, Int] {
  private var sum = 0

  // 判"零", 对缓冲区值进行判"零"
  override def isZero: Boolean = sum == 0

  // 把当前的累加赋值为一个新的累加器
  override def copy(): AccumulatorV2[Int, Int] = {
    val acc = new MyAcc
    acc.sum = sum
    acc
  }

  // 重置累加器(就是把缓冲区的值重置为"零")
  override def reset(): Unit = sum = 0

  // 真正的累加方法(分区内的累加)
  override def add(v: Int): Unit = sum += v

  // 分区间的合并  把other的sum合并到this的sum中
  override def merge(other: AccumulatorV2[Int, Int]): Unit = other match {
    case acc: MyAcc => this.sum += acc.sum
    case _ => this.sum += 0
  }

  // 返回累加后的最终值
  override def value: Int = sum
}

需求:

  • 返回一个 map,该累加器同时包含 sum, count, avg
  • 输出结果类似:Map(sum -> 370.0, count -> 10, avg -> 37.0)
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.AccumulatorV2

object MapAcc {
  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder.appName("create_rdd").master("local[2]").getOrCreate()
    val sc = session.sparkContext
    val list1 = List(30, 50, 70, 60, 10, 20, 10, 30, 40, 50)
    val rdd = sc.parallelize(list1, 2)

    val acc = new MapAcc
    sc.register(acc) // 注册自定义累加器

    rdd.foreach(
      x => acc.add(x)
    )

    println(acc.value)

    sc.stop()
  }
}

/*
输入一个 Double,返回一个 map,该累加器同时包含 sum, count, avg
输出:Map(sum -> 370.0, count -> 10, avg -> 37.0)
 */
class MapAcc extends AccumulatorV2[Double, Map[String, Any]] {
  private var map = Map[String, Any]()

  override def isZero: Boolean = map.isEmpty

  override def copy(): AccumulatorV2[Double, Map[String, Any]] = {
    val acc = new MapAcc
    acc.map = map
    acc
  }

  override def reset(): Unit = map = {
    Map[String, Any]()
  }

  override def add(v: Double): Unit = {
    // sum、count 累加,avg 最后的 value 函数计算
    // asInstanceOf[Double] 将其转换为 Double 类型
    // map += key -> value   Map 添加新 key
    println("add v====>" + v)
    map += "sum" -> (map.getOrElse("sum", 0D).asInstanceOf[Double] + v)
    map += "count" -> (map.getOrElse("count", 0L).asInstanceOf[Long] + 1L)
  }

  override def merge(other: AccumulatorV2[Double, Map[String, Any]]): Unit = {
    // 合并两个map
    other match {
      case o: MapAcc =>
        map +=
          "sum" -> (map.getOrElse("sum", 0D).asInstanceOf[Double] + o.map.getOrElse("sum", 0D).asInstanceOf[Double])
        map +=
          "count" -> (map.getOrElse("count", 0L).asInstanceOf[Long] + o.map.getOrElse("count", 0L).asInstanceOf[Long])
      case _ => throw new UnsupportedOperationException
    }
  }

  override def value: Map[String, Any] = {
    map += "avg" -> (map.getOrElse("sum", 0D).asInstanceOf[Double] / map.getOrElse("count", 0L).asInstanceOf[Long])
    map
  }
}