- 我们此篇使用的树都是User.json这个,具体如下图
自定义UDF{“username”: “zhangsan”,“age”: 20}
{“username”: “lisi”,“age”: 21}
{“username”: “wangwu”,“age”: 19}
UDF的简介
UDF: 输入一行, 返回一个结果. 一对一关系,放入函数一个值, 就返回一个值, 而不会返回多个值 。如下面的例子就可以看出:
(x: String) => "Name=" + x
这个函数, 入参为一个, 返回也是一个, 而不会返回多个值
具体实现
object UDF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[*]")
.appName("UTF")
.getOrCreate()
val df = spark.read
.json("data/user.json")
df.createOrReplaceTempView("user")
//注册udf
spark.udf.register("prefixName", (name: String) => {
"Name:" + name
})
spark.sql("select age,prefixName(username) from user").show()
spark.close()
}
}
结果展示
解释
- UDF在使用之前,需要先注册
spark.udf.register
自定义UDAF
UDAF的简介
UDAF主要可以分为强类型和弱类型
- 强弱类型的主要区别就是强类型要注意数据的类型
强类型的 Dataset
和弱类型的 DataFrame
都提供了相关的聚合函数, 如 count()
,countDistinct()
,avg()
,max()
,min()
。除此之外,用户可以设定自己的自定义聚合函数。通过继承 UserDefinedAggregateFunction
来实现用户自定义弱类型聚合函数。如今UserDefinedAggregateFunction
已经不推荐使用了。可以统一采用强类型聚合函数Aggregator
弱类型的UDAF
自定义UDAF
class MyAvgUDAF extends UserDefinedAggregateFunction {
/**
* 输入数据的结构,我们这里是求年龄的平均值,所以输入的数据是年龄
* 由于是聚合函数,肯定时输入一个数组的数据,最后返回一个数据也就是平均值
* 所以输入的是一个数组,数据的类别名叫age,数据的类型是longType
*/
override def inputSchema: StructType = {
StructType(
Array(
StructField("age", LongType)
)
)
}
/**
* 缓冲区
* 缓冲区是用来暂时存储数据,数据会在这里进行暂时的存储、运算然后才输出数据
* 例如求平均值:数据在缓冲区进行求和和计算数量,求出平均值后输出
*
* @return
*/
override def bufferSchema: StructType = {
StructType(
Array(
StructField("total", LongType),
StructField("count", LongType)
)
)
}
/**
* 函数输出的数据类型就是是计算结果的数据类型
*
* @return
*/
override def dataType: DataType = LongType
/**
* 函数的稳定性
*
* @return
*/
override def deterministic: Boolean = true
/**
* 缓冲区的初始换
*
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//这里就是如何该初始哈缓冲区的数据(也就是归零),这里有两个方法来归零
//方法一
//buffer(0) = 0l
//buffer(1) = 0l
//方法二
buffer.update(0, 0l)
buffer.update(1, 0l)
}
/**
* 根据输入的数据来更新缓冲区的数据,也就是缓冲区的计算规则
*
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//第一个数据就是求和,缓冲区里的数据加上输入的数据
buffer.update(0, buffer.getLong(0) + input.getLong(0))
//第二个数据就是计算总数,每次加一即可
buffer.update(1, buffer.getLong(1) + 1)
}
/**
* 缓冲区的数据合并
* 保留1
*
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}
/**
* 计算平均值
*
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = (buffer.getLong(0) / buffer.getLong(1))
}
主要步骤:
- 继承
UserDefinedAggregateFunction
类 - 实现他的方法
方法的含义各是什么?
inputSchema
:输入数据的结构。由于是聚合,输入数据肯定是一个数组bufferSchema
:缓冲区数据的结构,缓冲区就是编写计算规则的,如选哟计算平均值,那么就需要在缓冲区中计算出总数和总和dataType
:输出的数据结构,即输出结果的数据结构deterministic
:函数的稳定性,确保一致性, 一般用trueinitialize
:缓冲区的初始化即归零update
:根据输入的数据来更新缓冲区的数据,也就是缓冲区的计算规则merge
:缓冲区的合并evaluate
:计算平均值
注册并且使用
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[*]")
.appName("UDAF")
.getOrCreate()
val df = spark.read
.json("data/user.json")
df.createOrReplaceTempView("user")
//注册函数
spark.udf.register("ageAvg",new MyAvgUDAF())
spark.sql("select ageAvg(age) from user").show()
spark.close()
}
运行结果
强类型的UDAF
自定义两个样例类
//存储缓冲区的数据
case class Buff(var total: Long, var count: Long)
//存储输入数据
case class User(var username: String, var age: Long)
自定义强类型UDAF类
class MyAvgAgeUDAF extends Aggregator[User, Buff, Long] {
/**
* 初始值或者是零值
* 缓冲区的初始化
*
* @return
*/
override def zero: Buff = {
Buff(0l, 0l)
}
/**
* 根据输入的数据来更新缓冲区的数据
*
* @param b
* @param a
* @return
*/
override def reduce(b: Buff, a: User): Buff = {
b.total += a.age
b.count += 1
b
}
/**
* 合并缓冲区
*
* @param b1
* @param b2
* @return
*/
override def merge(b1: Buff, b2: Buff): Buff = {
b1.total += b2.total
b1.count += b2.count
b1
}
/**
* 计算结果
*
* @param reduction
* @return
*/
override def finish(reduction: Buff): Long = (reduction.total / reduction.count)
/**
* 这是固定的写法,若是自定义的类那么就是:product
* 缓冲区的编码操作
*
* @return
*/
override def bufferEncoder: Encoder[Buff] = Encoders.product
/**
* 这也是固定的写法,若是scala存在的类(如long,int,string……)就是选择对应的即可
* 输出的编码操作
*
* @return
*/
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
解释
- 继承
Aggregator
类 - 实现方法
- 与弱类型相比,此时这里需要定义输入、缓冲区和输出数据的泛型
方法的简绍
zero
:缓冲区的初始化reduce
:根据输入的数据来更新缓冲区的数据,也就是计算总数据数和数据和merge
:合并缓冲区数据finsh
:计算结果bufferEncoder
和·outputEncoder
:这两个分别是缓冲区和输出的编码格式,其实是由固定格式的,若再次阶段输出的数据是自定义的那么就是Encoders.product
,若输出的数据是scala
自带的那么就是Encoders.scalaLong
后面的long
根据自己输出的数据类型而定
注册并且使用
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[*]")
.appName("UDAF")
.getOrCreate()
import spark.implicits._
val df = spark.read
.json("data/user.json")
df.createOrReplaceTempView("user")
val ds = df.as[User]
//将UDAF变成查询的列对象
val udafCol = new MyAvgAgeUDAF().toColumn
ds.select(udafCol).show()
spark.close()
}
结果展示