背景
我根据算子输入输出之间的关系来理解算子分类:
UDF——输入一行,输出一行
UDAF——输入多行,输出一行
UDTF——输入一行,输出多行
本文主要是整理这三种自定义算子的具体实现方式
使用的数据集——用户行为日志user_log.csv,csv中自带首行列头信息,字段定义如下:
1. user_id | 买家id
2. item_id | 商品id
3. cat_id | 商品类别id
4. merchant_id | 卖家id
5. brand_id | 品牌id
6. month | 交易时间:月
7. day | 交易事件:日
8. action | 行为
9. age_range | 买家年龄分段
10. gender | 性别
11. province| 收获地址省份
新手上路,有任何搞错的地方,或者走了弯路,还请大家不吝指出,帮我进步。
SparkSQL算子分类
- 1. UDF
- 2. UDAF
- 3. UDTF
- ► 小结
1. UDF
通过匿名函数的方式注册自定义算子
object UserAnalysis {
def main(args:Array[String]): Unit ={
//测试数据所在的本地路径
val userDataPath = "file:///home/hadoop/data_format/small_user_log.csv"
//创建sparksession
val sparkSession = SparkSession
.builder
.master("local")
.appName("UserAnalysis")
.enableHiveSupport() //启用hive
.getOrCreate()
//sparksession直接读取csv,可设置分隔符delimitor.
val userDF = sparkSession.read
.option("header","true")
.csv(userDataPath)
//将DataFrame注册成视图,然后即可使用hql访问
userDF.createOrReplaceTempView("userDF")
//通过匿名函数的方式注册自定义算子:将0和1分别转换成female和male
sparkSession.udf.register("getGender",(gender:Integer)=>{
var result="unknown"
if (gender==0){
result="female"
}else if(gender==1){
result="male"
}
result
})
val genderDF = sparkSession.sql("select getGender(gender) as A from userDF")
//显示DataFrame内容
genderDF.show(10)
}
}
通过实名函数的方式注册自定义算子
object UserAnalysis {
def main(args:Array[String]): Unit ={
//测试数据所在的本地路径
val userDataPath = "file:///home/hadoop/data_format/small_user_log.csv"
//创建sparksession
val sparkSession = SparkSession
.builder
.master("local")
.appName("UserAnalysis")
.enableHiveSupport() //启用hive
.getOrCreate()
//sparksession直接读取csv,可设置分隔符delimitor.
val userDF = sparkSession.read
.option("header","true")
.csv(userDataPath)
//将DataFrame注册成视图,然后即可使用hql访问
userDF.createOrReplaceTempView("userDF")
/*
通过实名函数的方式注册自定义算子
Scala中方法和函数是两个不同的概念,方法无法作为参数进行传递,
也无法赋值给变量,但是函数是可以的。在Scala中,利用下划线可以将方法转换成函数:
*/
sparkSession.udf.register("getGender",getGender _)
val genderDF = sparkSession.sql("select getGender(gender) as A from userDF")
//显示DataFrame内容
genderDF.show(10)
}
//将0和1分别转换成female和male
def getGender(gender:Integer): String ={
var result="unknown"
if (gender==0){
result="female"
}else if(gender==1){
result="male"
}
result
}
}
通过以上两种方式实现相同算子,得到相同的结果:
2. UDAF
通过实现抽象类org.apache.spark.sql.expressions.UserDefinedAggregateFunction来自定义UDAF算子
class UserDefinedMax extends UserDefinedAggregateFunction{
//定义输入数据的类型,两种写法都可以
//override def inputSchema: StructType = StructType(Array(StructField("input", IntegerType, true)))
override def inputSchema: StructType = StructType(StructField("input", IntegerType) :: Nil)
//定义聚合过程中所处理的数据类型
override def bufferSchema: StructType = StructType(Array(StructField("cache", IntegerType, true)))
//定义输入数据的类型
override def dataType: DataType = IntegerType
//规定一致性
override def deterministic: Boolean = true
//在聚合之前,每组数据的初始化操作
override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0}
//每组数据中,当新的值进来的时候,如何进行聚合值的计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if(input.getInt(0)> buffer.getInt(0))
buffer(0)=input.getInt(0)
}
//合并各个分组的结果
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if(buffer2.getInt(0)> buffer1.getInt(0)){
buffer1(0)=buffer2.getInt(0)
}
}
//返回最终结果
override def evaluate(buffer: Row): Any = {buffer.getInt(0)}
}
测试代码
object UserAnalysis {
def main(args:Array[String]): Unit ={
//测试数据所在的本地路径
val userDataPath = "file:///home/hadoop/data_format/small_user_log.csv"
//创建sparksession
val sparkSession = SparkSession
.builder
.master("local")
.appName("UserAnalysis")
.enableHiveSupport() //启用hive
.getOrCreate()
//sparksession直接读取csv,可设置分隔符delimitor.
var userDF = sparkSession.read
.option("header","true")
.csv(userDataPath)
//转换dataframe字段类型或字段名
import org.apache.spark.sql.functions._
userDF = userDF .withColumn("item_id", col("item_id").cast(IntegerType))
//将DataFrame注册成视图,然后即可使用hql访问
userDF.createOrReplaceTempView("userDF")
//注册算子,如果UserDefinedMax是object,不用new
sparkSession.udf.register("UserDefinedMax", new UserDefinedMax)
//测试sparksql内嵌max算子结果
val MaxDF = sparkSession.sql("select max(item_id) from userDF")
MaxDF.show
//测试用户自定义max算子结果
val UserDefinedMaxDF = sparkSession.sql("select UserDefinedMax(item_id) from userDF")
UserDefinedMaxDF.show
}
}
可以看到两个max算子的输出相同:
3. UDTF
通过实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF来自定义UDTF算子
class UserDefinedUDTF extends GenericUDTF{
//这个方法的作用:1.输入参数校验 2. 输出列定义,可以多于1列,相当于可以生成多行多列数据
override def initialize(args:Array[ObjectInspector]): StructObjectInspector = {
if (args.length != 1) {
throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
}
if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
}
val fieldNames = new util.ArrayList[String]
val fieldOIs = new util.ArrayList[ObjectInspector]
//这里定义的是输出列默认字段名称
fieldNames.add("col1")
//这里定义的是输出列字段类型
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
}
//这是处理数据的方法,入参数组里只有1行数据,即每次调用process方法只处理一行数据
override def process(args: Array[AnyRef]): Unit = {
//将字符串切分成单个字符的数组
val strLst = args(0).toString.split("")
for(i <- strLst){
var tmp:Array[String] = new Array[String](1)
tmp(0) = i
//调用forward方法,必须传字符串数组,即使只有一个元素
forward(tmp)
}
}
override def close(): Unit = {}
}
测试代码
object UserAnalysis {
def main(args:Array[String]): Unit ={
//测试数据所在的本地路径
val userDataPath = "file:///home/hadoop/data_format/zxc/small1.csv"
//创建sparksession
val sparkSession = SparkSession
.builder
.master("local")
.appName("UserAnalysis")
.enableHiveSupport() //启用hive
.getOrCreate()
//sparksession直接读取csv,可设置分隔符delimitor.
var userDF = sparkSession.read
.option("header","true")
.csv(userDataPath)
//将DataFrame注册成视图,然后即可使用hql访问
userDF.createOrReplaceTempView("userDF")
//注册utdf算子,这里无法使用sparkSession.udf.register()
sparkSession.sql("CREATE TEMPORARY FUNCTION UserDefinedUDTF as 'com.zxc.sparkAppTest.udtf.UserDefinedUDTF'")
//使用UDTF算子处理原表userDF
val UserDefinedUDTFDF = sparkSession.sql(
"select " +
"user_id," +
"item_id," +
"cat_id," +
"merchant_id," +
"brand_id," +
"month," +
"day," +
"action," +
"age_range," +
"gender," +
"UserDefinedUDTF(province) " +
"from " +
"userDF"
)
UserDefinedUDTFDF.show
}
}
对比原表和经UDTF算子处理之后的结果表:
► 小结
- 关于UDF
简单粗暴的理解,它就是输入一行输出一行的自定义算子
我们可以通过实名函数或匿名函数的方式来实现,并使用sparkSession.udf.register()注册
需要注意,截至目前(spark2.4)最多只支持22个输入参数的UDF
另外还有一种实现方案(基于spark1.5,spark2.4待测试):
继承org.apache.hadoop.hive.ql.exec.UDF - 关于UDAF
简单粗暴的理解,它就是输入多行输出一行的自定义算子,比UDF的功能强大一些
通过实现抽象类org.apache.spark.sql.expressions.UserDefinedAggregateFunction来实现UDAF算子,并使用sparkSession.udf.register()注册
另外还有一种实现方案(基于spark1.5,spark2.4待测试):
先继承org.apache.hadoop.hive.ql.exec.UDAF
内部静态类实现org.apache.hadoop.hive.ql.exec.UDAFEvaluator - 关于UDTF
简单粗暴的理解,它就是输入一行输出多行的自定义算子,可输出多行多列,又被称为 “表生成函数”
通过实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF来实现UDTF算子,但是似乎无法使用sparkSession.udf.register()注册。注册方法如下:
sparkSession.sql("CREATE TEMPORARY FUNCTION 自定义算子名称 as '算子实现类全限定名称'")
实现UDTFf还需要注意(基于spark1.5,可能已过时):
udtf,process方法中对参数需要使用toString,String强转没用
sparksql子查询必须要有别名
算子内部使用竖线切分字符串时,需要转义
udtf调用forward方法,必须传字符串数组,即使只有一个元素