在 Spark 处理数据的过程中,虽然 DataSet 下的算子不多,但已经可以处理大多数的数据需求,但仍有少数需求需要自定义函数。UDF(User Defined Functions) 是普通的不会产生 Shuffle 不会划分新的阶段的用户自定义函数,UDAF(User Defined Aggregator Functions) 则会打乱分区,用户自定义聚合函数。

UDF

因为 UDF 不需要打乱分区,直接对 RDD 每个分区中的数据进行处理并返回当前分区,所以可以直接注册 UDF 函数,甚至可以传入匿名函数。

import org.apache.spark.sql.functions  // DSL中定义UDF需要

val rdd: RDD[User] = spark.sparkContext.makeRDD(
  List(User("Bob", 23), User("Alice", 22), User("John", 24)))
val ds: Dataset[User] = rdd.toDS
ds.createOrReplaceTempView("user")

// SQL中使用就需要注册UDF
spark.udf.register("add_name", (str: String) => { "Name: " + str })
spark.sql("select name, add_name(name) as new_name from user").show()

// 使用DSL则不用注册,定义好直接使用即可
val add_name2: UserDefinedFunction = functions.udf((str: String) => {
  "Name: " + str
})
ds.withColumn("name", add_name2($"name")).show()

UDAF

相比较 UDF 而言因为 UDAF 是聚合函数所以要打乱分区,所以也就比较复杂,并且需要重写指定的方法来定义。需要注意的是针对弱类型的 UserDefinedAggregateFunction 已经弃用,普遍使用强类型的 Aggregator ,同时若想在 Spark3.0 版本之前使用强类型 UDAF 和 Spark3.0 版本之后的定义方式略有不同。数据如下,计算每家门店的用户数量以及总付款额

store,user,payment
1,Bob,12.00
1,Alice,44.12
1,John,23.20
2,Davin,79.00
2,Lim,33.30
...
UserDefinedAggregateFunction

首先是已经弃用的 UserDefinedAggregateFunction ,以防生产环境中仍有使用老版本的 Spark 。它使用的是弱类型,所以在编写过程中你会看到使用0或者1来指定位置,这十分不方便。先是数据构建和调用部分

// 注册并调用UDAF,写在main方法中
val ds: Dataset[Record] = spark
  .sparkContext
  .makeRDD(
    List(Record(1, "Bob", 12.00), Record(1, "Alice", 44.12), Record(1, "John", 23.20),
      Record(2, "Davin", 79.00), Record(2, "Lim", 33.30))).toDS

ds.createOrReplaceTempView("record")
spark.udf.register("myudaf01", new MyUDAF01)  // 注册UDAF函数
spark.sql(
  """
    |select store, myudaf01(payment) as summary from record group by store
    |""".stripMargin).show(truncate = false)

// 样例类,写在main方法外
case class Record(store: Int, name: String, payment: Double)

下面是 UDAF 类

// UDAF部分,写在main方法外
class MyUDAF01 extends UserDefinedAggregateFunction {
  
	// 聚合函数输入参数的数据类型
  override def inputSchema: StructType = {
    StructType(Array(StructField("payment", DoubleType)))
  }
  
	// 聚合函数缓冲区中值的类型
  override def bufferSchema: StructType = {
    StructType(Array(
      StructField("total_user", IntegerType),
      StructField("total_payment", DoubleType)
    ))
  }
  
	// 函数返回的数据类型
  override def dataType: DataType = StringType
  
	// 对于相同的输入是否一直返回相同的输出
  override def deterministic: Boolean = true
  
	// 函数buffer缓冲区初始化,初始值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0  // 人数初始值
    buffer(1) = 0.00  // 总额初始值
  }
  
	// 更新缓冲区中的数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + 1
    buffer(1) = buffer.getDouble(1) + input.getDouble(0)
  }
  
	// 合并缓冲区(类似于reduce,属于两个元素的合并规则)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }
  
	// 计算最终结果
  override def evaluate(buffer: Row): Any = 
  "user: " + buffer.getInt(0) + ",payment: " + buffer.getDouble(1)
}

Aggregator(Spark3.0版本以后)

接下来是 Aggregator ,使用的是强类型编写,需要混入不同的特质,将输入、Buffer 以及输出全部定义在了泛型中,这样编写过程中就不需要使用位置来定位了,而且重写方法也简单易懂。

spark.udf.register("myudaf02", functions.udaf(new MyUDAF02))  // 注册UDAF函数
spark.sql(
  """
    |select store, myudaf02(payment) from record group by store
    |""".stripMargin).show(truncate = false)

下面是 UDAF 类

case class StoreSummary(var user: Int, var payment: Double)  // 强类型UDAF函数Buffer类型

class MyUDAF02 extends Aggregator[Double, StoreSummary, String] {
  
  // 初始化Buffer中的字段
  override def zero: StoreSummary = {
    StoreSummary(0, 0.00)
  }

  // 输入到Buffer的聚合
  override def reduce(b: StoreSummary, a: Double): StoreSummary = {
    b.user += 1
    b.payment += a
    b
  }

  // 合并Buffer
  override def merge(b1: StoreSummary, b2: StoreSummary): StoreSummary = {
    b1.user += b2.user
    b1.payment += b2.payment
    b1
  }

  // 最终的计算结果
  override def finish(reduction: StoreSummary): String = {
    "user: " + reduction.user + ",payment: " + reduction.payment
  }

  // Dataset默认编码器,用于序列化,固定写法
  override def bufferEncoder: Encoder[StoreSummary] = Encoders.product

  override def outputEncoder: Encoder[String] = Encoders.STRING
}

可以看到不管是从代码量还是调用参数相比于弱类型便捷了很多,需要注意的是注册函数时需要调用 functions 下的 UDAF 方法,还有一点就是这是 Spark3.0 以后的写法,Spark3.0 以前如果想用强类型有其他的写法。

输出使用强弱类型 UDAF 查询的结果

+-----+----------------------+----------------------+
|store|myudaf01(payment)     |myudaf02(payment)     |
+-----+----------------------+----------------------+
|1    |user: 3,payment: 79.32|user: 3,payment: 79.32|
|2    |user: 2,payment: 112.3|user: 2,payment: 112.3|
+-----+----------------------+----------------------+

Aggregator(Spark3.0版本以前)

早期版本中不能在 SQL 中使用强类型 UDAF ,但是可以在 DSL 中使用,代码编写和调用方式都有所不同,DSL 注重的是类型,所以在 UDAF 输入类型这里传入的应该是 DataSet 每一行的类型,而不是固定字段的某个类型。

val myudaf03: TypedColumn[Double, String] = (new MyUDAF03).toColumn
ds.select(myudaf03).show  // 输出的并没有按门店分组,与预想结果不同,没深究

下面是 UDAF 类

class MyUDAF03 extends Aggregator[Record, StoreSummary, String] {
    override def zero: StoreSummary = {
      StoreSummary(0, 0.00)
    }

    override def reduce(b: StoreSummary, a: Record): StoreSummary = {
      b.user += 1
      b.payment += a.payment
      b
    }

    override def merge(b1: StoreSummary, b2: StoreSummary): StoreSummary = {
      b1.user += b2.user
      b1.payment += b2.payment
      b1
    }

    override def finish(reduction: StoreSummary): String = {
      "user: " + reduction.user + ",payment: " + reduction.payment
    }

    override def bufferEncoder: Encoder[StoreSummary] = Encoders.product

    override def outputEncoder: Encoder[String] = Encoders.STRING
  }

其实就是将输入类型改成了 DataSet 的类型,代码中再调用指定的字段即可。