目录
- 一.数据源
- 二.自定义 UDF 函数
- 三.用户自定义聚合函数
- sum()聚合
- avg()聚合
- 四.自定义强类型聚合函数(了解)
一.数据源
{"name":"lisi","age":20}
{"name":"ww","age":10}
{"name":"zl","age":15}
{"name":"zy","age":30}
二.自定义 UDF 函数
import org.apache.spark.sql.SparkSession
object UDFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.master("local[*]")
.appName("UDFDemo")
.getOrCreate()
val df = spark.read.json("D:\\idea\\spark-sql\\input\\user.json")
//toUpperCase将字符串转换成大写
// 注册一个 udf 函数: toUpper是函数名, 第二个参数是函数的具体实现
spark.udf.register("toUpper",(s: String) => s.toUpperCase)
df.createOrReplaceTempView("user")
spark.sql("select toUpper(name),age from user").show()
spark.close()
}
}
结果
三.用户自定义聚合函数
继承UserDefinedAggregateFunction
sum()聚合
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}
import scala.collection.immutable.Nil
object UDAFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[*]")
.appName("RDD2DF")
.getOrCreate()
import spark.implicits._
val df = spark.read.json("D:\\idea\\spark-sql\\input\\user.json")
df.createOrReplaceTempView("user")
// 注册聚合函数
spark.udf.register("mySum",new MySum)
spark.sql("select mySum(age) from user").show
spark.close()
}
}
class MySum extends UserDefinedAggregateFunction {
//用来定义输入的数据类型 10.1 12.2
override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil)
//缓冲区的类型
override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::Nil)
//最终聚合结果的类型
override def dataType: DataType = DoubleType
//相同的输入是否返回相同的输出
override def deterministic: Boolean = true
//对缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//在缓冲集合中初始化和
buffer(0) = 0D //等价于 buffer.update(0,0D)
}
//分区内聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// input是指的使用聚合函数的时候,缓过来的参数封装到Row中
if (!input.isNullAt(0)){ //考虑到传字段可能是null
val v = input.getAs[Double](0) //等价于 getDouble(0)
buffer(0) = buffer.getDouble(0) + v
}
}
//分区间的聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 把buffer1 和 buffer2的缓冲聚合在一起,再把值写回到buffer1中
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
}
//返回最终的输出值
override def evaluate(buffer: Row): Any = buffer.getDouble(0)
}
avg()聚合
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import scala.collection.immutable.Nil
object UDAFDemo1 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[*]")
.appName("RDD2DF")
.getOrCreate()
import spark.implicits._
val df = spark.read.json("D:\\idea\\spark-sql\\input\\user.json")
df.createOrReplaceTempView("user")
// 注册聚合函数
spark.udf.register("myAvg",new MyAvg)
spark.sql("select myAvg(age) from user").show
spark.close()
}
}
class MyAvg extends UserDefinedAggregateFunction {
//用来定义输入的数据类型 10.1 12.2
override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil)
//缓冲区的类型
override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::(StructField("count",LongType)
::Nil))
//最终聚合结果的类型
override def dataType: DataType = DoubleType
//相同的输入是否返回相同的输出
override def deterministic: Boolean = true
//对缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//在缓冲集合中初始化和
buffer(0) = 0D //等价于 buffer.update(0,0D)
buffer(1) = 0L
}
//分区内聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// input是指的使用聚合函数的时候,缓过来的参数封装到Row中
if (!input.isNullAt(0)){ //考虑到传字段可能是null
val v = input.getAs[Double](0) //等价于 getDouble(0)
buffer(0) = buffer.getDouble(0) + v
buffer(1) = buffer.getLong(1) + 1L
}
}
//分区间的聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 把buffer1 和 buffer2的缓冲聚合在一起,再把值写回到buffer1中
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//返回最终的输出值
override def evaluate(buffer: Row): Any = buffer.getDouble(0) /buffer.getLong(1)
}
四.自定义强类型聚合函数(了解)
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
case class Dog(name:String,age:Int)
case class AgeAvg(sum:Int,count:Int){
def avg =sum.toDouble/count
}
object UDAFDemo2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[*]")
.appName("RDD2DF")
.getOrCreate()
import spark.implicits._
val ds =List(Dog("大黄",6),Dog("小黄",2),Dog("中黄",4)).toDS()
//强类型的使用方式
val avg = new MyAvg2().toColumn.name("avg")
val result = ds.select(avg)
result.show()
spark.close()
}
}
class MyAvg2 extends Aggregator[Dog,AgeAvg,Double] {
//对缓冲区进行初始化
override def zero: AgeAvg = AgeAvg(0,0)
//聚合(分区内聚合)
override def reduce(b: AgeAvg, a: Dog): AgeAvg = a match {
//如果是dog对象,则把年龄相加,个数加1
case Dog(name,age) =>AgeAvg(b.sum +age,b.count + 1)
//如果是null,则原封不动返回
case _ => b
}
//分区间的聚合
override def merge(b1: AgeAvg, b2: AgeAvg): AgeAvg = {
AgeAvg(b1.sum+b2.sum,b1.count+b2.count)
}
//返回最终的值
override def finish(reduction: AgeAvg): Double = reduction.avg
//对缓冲区进行编码
override def bufferEncoder: Encoder[AgeAvg] = Encoders.product //如果是样例类,就直接返回这个编码器
//对返回值进行编码
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}