Spark SQL操作之-自定义函数篇-下

  • 环境说明
  • 自定义函数分类
  • 用户自定义函数(UDF)
  • 用户自定义聚合函数(UDAF)

环境说明

1. JDK 1.8
2. Spark 2.1

自定义函数分类

不同的业务需要不同的处理函数,所以spark也支持用户自定义函数来做专用的处理。这里的自定义函数分两大类:用户已定义函数(UDF)和用户自定义聚合函数(UDAF)。

用户自定义函数(UDF)

用户自定义函数比较简单,写起来就是个普通的scala函数,只不过在spark中使用的时候需要单独注册一下。
直接看例子吧。

scala> val df=Range(0,10).toSeq.toDF("id")
df: org.apache.spark.sql.DataFrame = [id: int]

scala> df.show
+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
|  5|
|  6|
|  7|
|  8|
|  9|
+---+

##定义一个函数,对给定的整数列都加100
scala> def add100(value:Int):Int = { value + 100 }
add100: (value: Int)Int

##注册成自定义sql函数
scala> spark.udf.register("add100", add100(_:Int))
res1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(IntegerType)))

##调用上面写的自定义函数add100(value:Int)
scala> df.selectExpr("id", "add100(id) as new_id").show
+---+------+
| id|new_id|
+---+------+
|  0|   100|
|  1|   101|
|  2|   102|
|  3|   103|
|  4|   104|
|  5|   105|
|  6|   106|
|  7|   107|
|  8|   108|
|  9|   109|
+---+------+

要注意的是,用spark.udf.register注册的函数,不能用作dataset的函数使用。需要用udf类重新注册一下。

##直接用的话,会类型不匹配的。
scala> df.select(add100($"id")).show
<console>:28: error: type mismatch;
 found   : org.apache.spark.sql.ColumnName
 required: Int
       df.select(add100($"id")).show

##正确用法,用udf注册
scala> val add100_func=udf(add100 _)
add100_func: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(IntegerType)))

现在相当于有了一个add100_func的函数,类型是UserDefinedFunction
scala> df.select($"id", add100_func($"id").as("new_id")).show
+---+------+
| id|new_id|
+---+------+
|  0|   100|
|  1|   101|
|  2|   102|
|  3|   103|
|  4|   104|
|  5|   105|
|  6|   106|
|  7|   107|
|  8|   108|
|  9|   109|
| 10|   110|
| 11|   111|
| 12|   112|
| 13|   113|
+---+------+

好了,UDF就说到这,挺简单的。下面的UDAF比较起来,复杂多了。

用户自定义聚合函数(UDAF)

和UDF比起来,就多了一个A:聚合,Aggregation。其实聚合函数很常见,平时写SQL,求和啊,求均值啊这些都是。但是,自己写UDAF,比起写UDF可是麻烦多了。想想也是,一般的UDF,就是处理一行数据中的一列或多列,做个变换后返回。而UDAF是针对多行数据来处理的,最后只输出一行结果,操作本来就复杂些。

要实现一个UDAF功能,有两种方式:一种是从UserDefinedAggregateFunction类继承,一种是从Aggregator类继承。这两种方式基本上类似,前者是非类型安全的,但是比较灵活,不需要传入整行数据,只要传需要做聚合的列就可以了。后者是强类型,api看起来友好一些,但是,对于列很多的情况,比较麻烦。我个人比较倾向于使用UserDefinedAggregateFunction类的继承实现。

从UserDefinedAggregateFunction类继承,需要实现8个成员方法。

成员方法

释义

inputSchema: StructType

函数的输入参数的类型定义

dataType: DataType

函数的返回值类型定义

bufferSchema: StructType

内部缓存,记录临时变量等

deterministic: Boolean

这是一个确定性的指示。就是说,是否给定输入后,每次运行的结果都一致。通常都是true

initialize(buffer: MutableAggregationBuffer): Unit

初始化函数。典型的功能就是变量清零之类的

update(buffer: MutableAggregationBuffer, input: Row): Unit

更新函数。在同一个partition内的数据一行一行的被调用到该函数做更新处理

merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit

合并函数。各个partition更新完所有数据后,通过merge函数合并

evaluate(buffer: Row): Any

最终的求值函数,输出为dataType类型

基本上就是按照上述的过程来实现功能。定义输入输出,定义中间缓存的数据结构,定义初始化,更新,合并,最后求值。
来看一个实现整型求和函数的代码示例:

import java.util.ArrayList

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._

import org.apache.spark.TaskContext

class UDAF_Sum extends UserDefinedAggregateFunction {

  //1.定义输入数据的类型
  override def inputSchema = StructType(Array(
    StructField("input", LongType)
  ))

  //2.定义中间数据的类型
  override def bufferSchema = StructType(Array(
    //temp_sum很明显是保存部分和
    StructField("temp_sum", LongType),
    //ele_array这里是用来记录当前处理了哪些元素,用来帮助观察整个计算过程。
    StructField("ele_array", DataTypes.createArrayType(DataTypes.LongType))
  ))

  //3.定义返回结果的类型
  override def dataType: DataType = LongType

  //4.输出的确定性指示,一般都是true
  override def deterministic = true

  //5.定义初始化函数,就是些初始值的处理。
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化,因为是求和的,所以和的初值显然为0
    buffer(0) = 0L
    //记录当前已处理的所有输入的数
    buffer(1) = new ArrayList[Long]()
  }

  //6.定义update函数,对于一个partition来说,里面的每条数据都会经过update
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val par_id = TaskContext.getPartitionId()
    println(s"------partition $par_id update begin------")
    println(s"partition $par_id update input: $input")

    buffer(0) = buffer.getLong(0) + input.getLong(0)
    val tmpList =  new ArrayList(buffer.getList[Long](1))
    tmpList.add(input.getLong(0))
    buffer(1) = tmpList
    
    println(s"partition $par_id update output: buffer = $buffer")
    println(s"-----partition $par_id update end-----------")
  }

  //7.定义merge函数,处理所有partition的全局聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //每个分区计算的结果进行相加
    val par_id = TaskContext.getPartitionId()
    println(s"------partition $par_id merge begin------")
    println(s"partition $par_id merge input: buffer1 = $buffer1")
    println(s"partition $par_id merge input: buffer2 = $buffer2")
    
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    val tmpList = new ArrayList(buffer1.getList[Long](1))
    tmpList.addAll( buffer2.getList[Long](1))
    buffer1(1) = tmpList
    
    println(s"partition $par_id merge output: buffer1 = $buffer1")
    println(s"-----partition $par_id merge end-----------")
  }

  //8.定义evaluate函数,返回最终的结果
  override def evaluate(buffer: Row): Any = {
    println("evaluate: " + buffer)
    buffer.getLong(0)
  }
}

在spark shell里面,可以用:paste命令把整段代码一次性复制进去,我们来运行一下看看结果:

[root@ecs-930c spark-2.1.0-bin-hadoop2.7]# bin/spark-shell --master local[2]
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
19/07/21 16:55:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
19/07/21 16:55:51 WARN ObjectStore: Failed to get database global_temp, returning NoSuchObjectException
Spark context Web UI available at http://192.168.1.153:4040
Spark context available as 'sc' (master = local[2], app id = local-1563699348275).
Spark session available as 'spark'.
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 2.1.0
      /_/
         
Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_201)
Type in expressions to have them evaluated.
Type :help for more information.

scala> :paste
// Entering paste mode (ctrl-D to finish)

///
  上面的代码直接粘贴,就不重复了,粘贴后按Ctrl-D结束
///

// Exiting paste mode, now interpreting.

import java.util.ArrayList
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.TaskContext
defined class UDAF_Sum

//实例化一个函数对象出来
scala> val udaf_sum = new UDAF_Sum
udaf_sum: UDAF_Sum = UDAF_Sum@2caa9666

//生成测试数据的dataset,数字0到10,字段名"id"
scala> val df=Range(0,10).toSeq.toDF("id")
df: org.apache.spark.sql.DataFrame = [id: int]

//这里我把每个数字在哪个partition打印出来了。
//这里能看到是2个partition,partition 0里面包含了0,1,2,3,4,partition 1里面包含了5,6,7,8,9
scala> df.foreachPartition(par => par.foreach(x=>println("partition "+TaskContext.getPartitionId.toString+":"+x)))
partition 0:[0]
partition 1:[5]
partition 0:[1]
partition 1:[6]
partition 0:[2]
partition 1:[7]
partition 0:[3]
partition 1:[8]
partition 0:[4]
partition 1:[9]

//现在来调用我们创建的UDAF函数,注册的名字是udaf_sum,传入的列是id
scala> df.select(udaf_sum($"id")).show
------partition 0 update begin------
------partition 1 update begin------
partition 0 update input: [0]
partition 1 update input: [5]
partition 1 update output: buffer = [5,WrappedArray(5)]
partition 0 update output: buffer = [0,WrappedArray(0)]
-----partition 1 update end-----------
-----partition 0 update end-----------
------partition 1 update begin------
------partition 0 update begin------
partition 1 update input: [6]
partition 0 update input: [1]
partition 1 update output: buffer = [11,WrappedArray(5, 6)]
-----partition 1 update end-----------
------partition 1 update begin------
partition 1 update input: [7]
partition 0 update output: buffer = [1,WrappedArray(0, 1)]
-----partition 0 update end-----------
partition 1 update output: buffer = [18,WrappedArray(5, 6, 7)]
-----partition 1 update end-----------
------partition 0 update begin------
------partition 1 update begin------
partition 0 update input: [2]
partition 1 update input: [8]
partition 0 update output: buffer = [3,WrappedArray(0, 1, 2)]
partition 1 update output: buffer = [26,WrappedArray(5, 6, 7, 8)]
-----partition 0 update end-----------
-----partition 1 update end-----------
------partition 0 update begin------
partition 0 update input: [3]
------partition 1 update begin------
partition 1 update input: [9]
partition 0 update output: buffer = [6,WrappedArray(0, 1, 2, 3)]
-----partition 0 update end-----------
partition 1 update output: buffer = [35,WrappedArray(5, 6, 7, 8, 9)]      <-----到这里为止,partition 1更新完成,总共5条记录
------partition 0 update begin------
-----partition 1 update end-----------
partition 0 update input: [4]
partition 0 update output: buffer = [10,WrappedArray(0, 1, 2, 3, 4)]      <-----到这里为止,partition 0更新完成,总共也是5条记录
-----partition 0 update end-----------
------partition 0 merge begin------                                       <------这里开始进入merge阶段
partition 0 merge input: buffer1 = [0,WrappedArray()]
partition 0 merge input: buffer2 = [10,WrappedArray(0, 1, 2, 3, 4)]
partition 0 merge output: buffer1 = [10,WrappedArray(0, 1, 2, 3, 4)]
-----partition 0 merge end-----------
------partition 0 merge begin------
partition 0 merge input: buffer1 = [10,WrappedArray(0, 1, 2, 3, 4)]
partition 0 merge input: buffer2 = [35,WrappedArray(5, 6, 7, 8, 9)]
partition 0 merge output: buffer1 = [45,WrappedArray(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)]    <------ merge完成,总和是45,总共10个元素
-----partition 0 merge end-----------
evaluate: [45,WrappedArray(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)]
+------------+
|udaf_sum(id)|
+------------+
|          45|
+------------+

配个图看清楚一点:

sparksql取整函数 sparksql函数手册_sparksql取整函数


嗯,这个小系列拖拖拉拉的,总算是完结啦