自定义函数
类型
- UDF:一进一出
- UDAF:多进一出
UDF
流程
spark-sql中SQL的用法
- 1、自定义udf函数/类(类要注意需要序列化)
- 2、注册spark.udf.register(“名称”,自定义的函数/自定义的类 _)
- 3、调用查询方法
自定义udf函数并调用
import org.apache.spark.sql.SparkSession
import org.junit.Test
/**
* @ClassName: MyUDFdemo
* @Description: 将员工中id不满8位的补齐
* @Author: kele
* @Date: 2021/2/1 20:56
**/
/**
* 1、自定义udf函数/类(类要注意需要序列化)
* 2、注册spark.udf.register("名称",自定义的函数/自定义的类 _)
* 3、调用查询方法
*/
class MyUDFdemo extends Serializable{
@Test
def emp_info={
val spark = SparkSession.builder().master("local[4]").appName("UDFdemo").getOrCreate()
import spark.implicits._ //rddtoDF的隐式转换
val rdd1 = spark.sparkContext.parallelize(List(
("00123","zhangsan"),
("256","lisi"),
("0135","wangwu"),
("000368","qianqi"),
("00378","zhaoliu")
))
val df = rdd1.toDF("id","name")
/**
* 方式一:通过sql的方式查询 自定义函数
*
*/
// df.createOrReplaceTempView("user")
// spark.udf.register("fullId",fullUserId)
// spark.sql("""select fullId(id) from user """).show()
/**
* 自定义类,需要序列化
*
*/
df.createOrReplaceTempView("user")
spark.udf.register("fullId2",fullUserIdclass _)
spark.sql("""select fullId2(id) from user """).show()
/**
* 方式二:selectExpr的方式查找
*/
df.selectExpr("fullId2(id) id").show()
}
//自定义udf函数
val fullUserId = (id : String)=>{
s"${"0" *(8-id.length)}${id}"
}
//自定义udf类
def fullUserIdclass(id:String) ={
s"${"0" *(8-id.length)}${id}"
}
}
spark-sql中DataFram中的用法
在spark的DataFram的udf方法和spark sql的名字相同,但是属于不同的类,
import org.apache.spark.sql.functions._
//方法一:注册自定义函数(通过匿名函数)
val strLen = udf((str: String) => str.length())
//方法二:注册自定义函数(通过实名函数)
val udf_isAdult = udf(isAdult _)
UDAF
UDAF弱类型实现
总体流程
- 1、继承UserDefinedAggregateFunction( 没有泛型)
- 2、重写方法
- 1、指定带统计列表的类型
- 2、指定中间变量的类型
- 3、指定函数的返回类型
- 4、设置稳定性
- 5、初始化中间变量的值
- 6、求在一个task中的计算过程
- 7、求在分区间的计算过程
- 8、函数的返回值 - 3、注册spark.udf.register,为其绑定一个名字
自定义UDAF弱类型
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructType}
/**
* @ClassName: MyUDAF
* @Description:
* @Author: kele
* @Date: 2021/2/1 16:03
**/
class MyUDAF extends UserDefinedAggregateFunction{
/**
* 指定待统计的数据类型
* @return 返回值类型是StructType类型,
*/
override def inputSchema: StructType = new StructType().add("age",IntegerType)
/**
* 这里是求平均值,需要sum,和num,因此需要两个中间变量
* 指定中间变量的类型,数据进入是是一个个进
* @return
*/
override def bufferSchema: StructType = new StructType().add("sum",IntegerType)
.add("num",IntegerType)
/**
* 函数的返回类型
* @return
*/
override def dataType: DataType = DoubleType
/**
* 稳定性,同一组数据输入是否返回相同的值
* @return
*/
override def deterministic: Boolean = true
/**
* 初始化buffer的值
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0)
buffer.update(1,0)
}
/**
* 在一个task中的计算过程
* sum将age不断累加
* count+1
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0,buffer.getAs[Int](0)+input.getAs[Int](0))
buffer.update(1,buffer.getAs[Int](1)+1)
}
/**
* 分区间的计算方式
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getAs[Int](0)+buffer2.getAs[Int](0))
buffer1.update(1,buffer1.getAs[Int](1)+buffer2.getAs[Int](1))
}
/**
* 返回值
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = buffer.getAs[Int](0).toDouble/buffer.getAs[Int](1)
}
调用过程
/**
* 调用弱类型
*/
@Test
def avg_Age={
val spark = SparkSession.builder()
.master("local[4]")
.appName("avg_age")
.getOrCreate()
val rdd = spark.sparkContext.parallelize(List(
("zhangsan",20,"开发部"),
("wanwu",25,"产品部"),
("aa",26,"开发部"),
("lisi",40,"开发部"),
("bb",30,"产品部"),
("cc",28,"产品部")
))
import spark.implicits._
val df = rdd.toDF("name","age","dept")
df.createOrReplaceTempView("user")
spark.udf.register("myavg",new MyUDAF)
spark.sql(
"""
|select myavg(age) from user group by dept
""".stripMargin).show()
}
UDAF强类型实现过程
- 1、自定义class继承Aggregator[统计的列的类型、中间变量类型,输出结果类型]
- 2、重写方法
- 1、初始化中间变量
- 2、每一个task中的统计过程
- 3、分区间计算过程
- 4、计算最终结果并返回
- 5、编码中间变量的类型,个人认为是为了保证中间数据传输
注意样例类的父类是product
- 3、注册spark.udf.register(函数名,udaf(自定义udaf对象))
- import org.apache.spark.sql.functions._ //必须调用该隐式转换,否则无法导入
自定义强类型
package com.atguigu.day05
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
/**
* @ClassName: MyUDAF2
* @Description: 强类型自定义类,Aggregator可以自定义泛型[输入类型,中间变量,输出类型]
* @Author: kele
* @Date: 2021/2/1 16:05
**/
/**
* 如果需要多个中间变量,可以考虑使用样例类
*
*/
case class InterVari(var sum:Int,var count:Int)
class MyUDAFStrong extends Aggregator[Int,InterVari,Double]{
/**
* 初始化中间变量
* @return
*/
override def zero: InterVari = InterVari(0,0)
/**
* 每一个task中的统计过程
* @param b
* @param a
* @return
*/
override def reduce(b: InterVari, a: Int): InterVari = {
b.sum = b.sum+a
b.count = b.count+1
b
}
/**
* 分区间计算过程
* @param b1
* @param b2
* @return
*/
override def merge(b1: InterVari, b2: InterVari): InterVari = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
/**
* 最终结果返回
* @param reduction
* @return
*/
override def finish(reduction: InterVari): Double = reduction.sum.toDouble/reduction.count
/**
* 编码中间变量的类型,个人认为是为了保证中间数据传输
* @return 样例类的父类是product
*/
override def bufferEncoder: Encoder[InterVari] = Encoders.product
/**
* 编码结果值的类型,个人认为是为了保证中间数据传输
* @return
*/
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
调用过程
/**
* 使用弱类型
*/
@Test
def avg_Age={
val spark = SparkSession.builder()
.master("local[4]")
.appName("avg_age")
.getOrCreate()
val rdd = spark.sparkContext.parallelize(List(
("zhangsan",20,"开发部"),
("wanwu",25,"产品部"),
("aa",26,"开发部"),
("lisi",40,"开发部"),
("bb",30,"产品部"),
("cc",28,"产品部")
))
import spark.implicits._
val df = rdd.toDF("name","age","dept")
df.createOrReplaceTempView("user")
spark.udf.register("myavg",new MyUDAF)
spark.sql(
"""
|select myavg(age) from user group by dept
""".stripMargin).show()
}