文章目录


UDF

UDF 接受一个参数返回一个结果

    spark.udf.register("toUppperCaseUdf",(cloumn:String) => cloumn.toUpperCase)
    spark.sql("select toUppperCaseUdf(name) from  t_user")

UDAF

多进一出,比如系统函数sum

无泛型约束的UDAF

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}object AverageUDAF extends UserDefinedAggregateFunction {
  /**
   * 通过inputSchema指定调用自定义函数传入的参数类型
   *numInput 为类型名称,可以任意指定,
   * StructField("numInput", DoubleType, nullable = true) :: Nil 等同
   *  List("numInput", DoubleType, nullable = true) :: Nil
   * @return
   */
  override def inputSchema: StructType = {
    StructType(
      StructField("numInput", DoubleType, nullable = true) :: Nil    )
  }

  /**
   * 缓冲数据
   * 对于求平均数而言,不断累加的是年龄总人数以及年龄总和
   * @return
   */
  override def bufferSchema: StructType = {
    StructType(
    StructField("buff1", DoubleType, nullable = true) :: StructField("buff2", LongType, nullable = true) :: Nil    )
  }

  /**
   * 自定义UDAF函数返回的数据类型
   * @return
   */
  override def dataType: DataType = DoubleType  /**
   * 判断UDAF函数与返回的函数类型是否一致
   * @return
   */
  override def deterministic: Boolean = ???

  /**
   * 初始化值
   * @param buffer
   *               等价 buffer(0) = 0.0
   *               buffer(1) = 0L
   */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0,0.0)
    buffer.update(1,0L)
  }

  /**
   * 控制具体的聚合逻辑,在同一个分区中,每次只取一行数据,将原表中每一行参与运算列累加到聚合缓冲区
   * @param buffer 缓冲中数据ROW
   * @param input 表中的ROW,0代表存放累加的年龄,1代表当前参数累加的年龄
   */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0,buffer.getDouble(0)+input.getDouble(0))
    buffer.update(1,buffer.getLong(1)+1)
  }

  /**
   * 每一个分区都有自己的缓冲区,通过merge将聚合缓冲区中数据合并到一个聚合缓冲区中
   * @param buffer1
   * @param buffer2
   */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0,buffer1.getDouble(0)+buffer2.getDouble(0))
    buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
  }

  /**
   * 对最终聚合缓冲区中数据进行最后一次运算
   * @param buffer
   * @return
   */
  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(0) / buffer.getLong(1)
  }}

import org.apache.spark.sql.SparkSession

object TestUDF {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("test")
      .getOrCreate()
    spark.udf.register("AverageUDAF",(cloumn:Any) => cloumn.toString.toDouble)
    spark.udf.register("AverageUDAF",AverageUDAF) 
    spark.sql("select AverageUDAF(age) from  t_user group by sex")
  }}

有泛型约束的UDAF

原理一致,但是调用该UDAF时允许添加泛型,保障函数更加安全.但是这种UDAF不可直接在SQL中被调用运算

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, Row}/**
 * 泛型IN,BUF,OUT 对应类型部分为Row,自定义样例类,数据类型
 * IN: 在聚合前每一个待聚合数据的类型
 * BUF: 每个分区的聚合缓冲区的类型,在本例中为自定义样例类Buffer
 * OUT: 聚合后最终返回的结果类型
 */object AverageFemaleUDAF extends Aggregator[Row, Buffer, Double] {
  /**
   * 初始化聚合缓冲区的初始值
   * @return
   */
  override def zero: Buffer = Buffer(0.0, 0L)

  /**
   * 用于聚合当前分区中每一行的值到聚合缓冲区中Buffer中,在buffer中,age属性用于累加年龄,count用于累加人数
   * @param b 缓冲区
   * @param a 表中数据
   * @return
   */
  override def reduce(b: Buffer, a: Row): Buffer = {
    if (a.getString(2) == "Female") {
      b.age += a.getInt(1)
      b.count += 1
    }
    b  }

  /**
   * 合并多个聚合缓冲区中的值
   * @param b1
   * @param b2
   * @return
   */
  override def merge(b1: Buffer, b2: Buffer): Buffer = {
    b1.age += b2.age
    b1.count += b2.count
    b1  }

  /**
   * 对于最终的聚合缓冲区中的数据进行最后一次运算,得到UDAF的最终结果
   * @param reduction
   * @return
   */
  override def finish(reduction: Buffer): Double = reduction.age / reduction.count  /**
   * 聚合缓冲区类型解码器
   * @return
   */
  override def bufferEncoder: Encoder[Buffer] = Encoders.product  /**
   * 最终结果的数据类型解码器
   * @return
   */
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble}/**
 * 通过样例类充当聚合缓冲区
 *
 * @param age
 * @param count
 */case class Buffer(var age: Double, var count: Long)

调用

   val femalAvg = ("avg")
    df.select(femalAvg)

UDTF

一进多出去,一行数据中某一列数据展开比如flatMap

import java.util

import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession}import org.apache.spark.sql.types.{StringType, StructField, StructType}import scala.collection.mutable.ListBuffer

object TestUDTF {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[*]")
      .appName("test")
      .getOrCreate()
    val schema = StructType(List(
      StructField("movie", StringType, nullable = false),
      StructField("category", StringType, nullable = false)
    ))

    val rows = new util.ArrayList[Row]()
    rows.add(Row("", "战争,历史"))
    rows.add(Row("", "科幻,丧尸"))
    val df1 = spark.createDataFrame(rows, schema)
    df1.show()

    implicit val flatMapEncoder: Encoder[(String, String)] = Encoders.kryo[(String, String)]
    val tableArray = df1.flatMap(row => {
      val tableArray = new ListBuffer[(String, String)]()
      val categoryArray = row.getString(1).split(",")
      for (c <- categoryArray) {
        tableArray.append((row.getString(0), c))
      }
      tableArray    }).collect()
    val df2 = spark.createDataFrame(tableArray).toDF("movie", "category")
    df2.show()
    spark.stop()
  }}

结果

±-----±-------+
| movie|category|
±-----±-------+
|