UDF

用户定义函数(User-defined functions, UDFs)是大多数 SQL 环境的关键特性,用于扩展系统的内置功能。 UDF允许开发人员通过抽象其低级语言实现来在更高级语言(如SQL)中启用新功能。 Apache Spark 也不例外,并且提供了用于将 UDF 与 Spark SQL工作流集成的各种选项。

object UDF {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("UDF").master("local").getOrCreate()

    val names = Array("Leo", "Marry", "Jack", "Tom")

    val namesRDD = spark.sparkContext.parallelize(names, 5)

    val namesRowRDD = namesRDD.map { name => Row(name) }
    val structType = StructType(Array(StructField("name", StringType, true)))
    val namesDF = spark.createDataFrame(namesRowRDD, structType)

    namesDF.createOrReplaceTempView("names")

    //定义和注册自定义函数
    spark.udf.register("strLen", (str: String) => str.length())
    
    
    spark.sql("select name,strLen(name) from names").collect().foreach(println)

  }

}

UDAF

同时处理多行,并且返回一个结果,通常结合使用 GROUP BY 语句(例如 COUNT 或 SUM)

count

class StringCount extends UserDefinedAggregateFunction {

  //指的是,输入数据的类型
  def inputSchema: StructType = {
    StructType(Array(StructField("str", StringType, true)))
  }

  //指的是,中间进行聚合时,所处理的数据的类型
  def bufferSchema: StructType = {
    StructType(Array(StructField("count", IntegerType, true)))
  }

  //指的是,函数返回值的类型
  def dataType: DataType = {
    IntegerType
  }

  def deterministic: Boolean = {
    true
  }

  // 为每个分组的数据执行初始化操作
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
  }

  // 指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }

  // 由于Spark是分布式的,所以一个分组的数据,可能会在不同的节点上进行局部聚合,就是update
  // 但是,最后一个分组,在各个节点上的聚合值,要进行merge,也就是合并
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
  }

  // 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)
  }

}
def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("UDF").master("local").getOrCreate()

    val names = Array("Leo", "Marry", "Jack", "Tom", "Tom", "Tom", "Leo")

    val namesRDD = spark.sparkContext.parallelize(names, 5)

    val namesRowRDD = namesRDD.map { name => Row(name) }
    val structType = StructType(Array(StructField("name", StringType, true)))
    val namesDF = spark.createDataFrame(namesRowRDD, structType)

    namesDF.createOrReplaceTempView("names")

    spark.udf.register("strCount", new StringCount)

    // 使用自定义函数
    spark.sql("select name,strCount(name) from names group by name")
      .collect()
      .foreach(println)

  }

max

class CustomerSum extends UserDefinedAggregateFunction {

  //聚合函数的输入参数数据类型
  def inputSchema: StructType = {
    StructType(Array(StructField("inputColumn", LongType, true)))
  }

  //指的是,中间进行聚合时,所处理的数据的类型
  def bufferSchema: StructType = {
    StructType(Array(StructField("sum", LongType, true)))
  }

  //指的是,函数返回值的类型
  def dataType: DataType = {
    LongType
  }

  //初始值,要是DataSet没有数据,就返回该值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
  }

  def deterministic: Boolean = {
    true
  }

  // 指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0) + input.getLong(0)
  }

  //相当于把每个分区的数据进行汇总
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
  }

  //最后返回一个最终的聚合值
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)
  }

}
def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("UDF").master("local").getOrCreate()

    val rdd = spark.sparkContext.parallelize(Array(3000, 4500, 3500, 4000), 5)

    val salaryRowRDD = rdd.map { salary => Row(salary) }
    val structType = StructType(Array(StructField("salary", IntegerType, true)))
    val dF = spark.createDataFrame(salaryRowRDD, structType)

    dF.createOrReplaceTempView("employees")

    spark.udf.register("customerSum", new CustomerSum)

    // 使用自定义函数
    spark.sql("select customerSum(salary) from employees")
      .collect()
      .foreach(println)

  }

average

class MyAverage extends UserDefinedAggregateFunction {

  //聚合函数的输入参数数据类型
  def inputSchema: StructType = {
    StructType(Array(StructField("inputColumn", LongType, true)))
  }

  //指的是,中间进行聚合时,所处理的数据的类型
  def bufferSchema: StructType = {
    StructType(Array(StructField("average", LongType, true), StructField("count", LongType, true)))
  }

  //指的是,函数返回值的类型
  def dataType: DataType = {
    DoubleType
  }

  //初始值,要是DataSet没有数据,就返回该值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }

  def deterministic: Boolean = {
    true
  }

  // 指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    //总数
    buffer(1) = buffer.getLong(1) + 1
  }

  //相当于把每个分区的数据进行汇总
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //salary
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) // count
  }

  //最后返回一个最终的聚合值
  def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble / buffer.getLong(1)
  }

group by max

/**
 * 按性别分组统计收入最高是多少
 */
class CustomerMax extends UserDefinedAggregateFunction {

  //聚合函数的输入参数数据类型
  def inputSchema: StructType = {
    StructType(Array(StructField("inputColumn", LongType, true)))
  }

  //指的是,中间进行聚合时,所处理的数据的类型
  def bufferSchema: StructType = {
    StructType(Array(StructField("max", LongType, true)))
  }

  //指的是,函数返回值的类型
  def dataType: DataType = {
    LongType
  }

  //初始值,要是DataSet没有数据,就返回该值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
  }

  def deterministic: Boolean = {
    true
  }

  // 指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)) {
      if (input.getLong(0) > buffer.getLong(0)) {
        buffer(0) = input.getLong(0)
      }
    }
  }

  //相当于把每个分区的数据进行汇总
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    if (buffer2.getLong(0) > buffer1.getLong(0)) buffer1(0) = buffer2.getLong(0)
  }

  //最后返回一个最终的聚合值
  def evaluate(buffer: Row): Any = {
    buffer.getLong(0)
  }

}
def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("UDF").master("local").getOrCreate()

    import spark.implicits._

    val employees = Array(("男", "小王", 30000), ("女", "小丽", 50000), ("男", "小军", 80000), ("女", "小李", 90000)).toSeq

    val df = spark.sparkContext.parallelize(employees, 5).toDF("gender", "name", "salary")

    df.createOrReplaceTempView("employees")

    df.show()
    
    spark.udf.register("customerMax", new CustomerMax)

    // 使用自定义函数
    spark.sql("select gender,customerMax(salary) from employees group by gender")
      .collect()
      .foreach(println)

  }