文章目录

  • 简介:
  • 使用场景
  • UDF
  • spark UDF
  • 源码:
  • 语法:
  • 实现方法:
  • 案例
  • Hive UDF
  • 实现步骤
  • 案例:
  • UDAF
  • Spark UDAF(User Defined Aggregate Function)
  • Spark UDAF 实现方法:
  • Spark UDAF 实现步骤:
  • 案例:
  • 继承`UserDefinedAggregateFunction`:
  • 继承`Aggregator`
  • Hive UDAF(User Defined Aggregate Function)
  • Hive UDAF 实现步骤:
  • Hive UDAF实现方式一:
  • Hive UDAF 方式一示例:
  • Hive UDAF实现方式二:
  • 相关抽象类介绍
  • GenericUDAFEvaluator的方法
  • Hive UDAF 方式二实现步骤
  • Hive UDAF 方式二实现实例
  • Hive UDAF 使用
  • UDTF
  • 自定义UDTF实现方法:
  • UDTF的基本语法
  • UDTF的使用
  • UDTF函数实现实例:
  • 总结:


简介:

做数仓的小伙伴应该深有体会,我们在做复杂业务时经常遇到一些比较复杂的逻辑或者复杂的数据结构,它们无法使用hive或者spark天然提供的函数进行解析,在这个时候我们就会想到如果可以自定义一个像hivespark自身提供的函数对数据对处理那就方便多了,为此hivespark使用UDF、UDAF、UDTF几种接口,供我们自定义函数解决此类问题,下面笔者以自身实践为基础对UDF、UDAF、UDTF函数进行简单的介绍。

使用场景

我们先来看下它们各自的定义:

  • UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date
  • UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等
  • UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap

其实虽然我们将这三种函数放到一起,但是在Hive和Spark中对他们的支持是不一样的, Hive中是全部支持是这三种函数,但是Spark中其实是仅支持UDF和UDAF,而UDTF在spark中其实是完全使用了Hive的UDTF函数。那么我们来回顾下Hive中自定义函数的三种类型:

  • 第一种:UDF(User-Defined-Function) 函数
    一对一的关系,输入一个值经过函数以后输出一个值;
    Hive中继承UDF类,方法名称为evaluate,返回值不能为void,其实就是实现一个方法;
  • 第二种:UDAF(User-Defined Aggregation Function) 聚合函数
    多对一的关系,输入多个值输出一个值,通常与groupBy联合使用;
  • 第三种:UDTF(User-Defined Table-Generating Functions) 函数
    一对多的关系,输入一个值输出多个值(一行变为多行);
    用户自定义生成函数,有点像flatMap;

既然说到两个框架对三个自定义函数的支持,那么我们就来简单了解spark几个版本对函数的支持变化:

Spark版本

Spark SQL UDF(Python,Java,Scala)

Spark SQL UDAF(Java、Scala)

Spark SQL UDF(R)

Hive UDF、UDAF、UDTF

1.1 - 1.4

✔️

✔️

1.5

✔️

experimental

✔️

1.6

✔️

✔️

✔️

2.0

✔️

✔️

✔️

✔️

在SparkSQL中,目前仅仅支持UDF函数和UDAF函数:

  • UDF函数:一对一关系;
  • UDAF函数:聚合函数,通常与group by 分组函数连用,多对一关系

由于SparkSQL数据分析有两种方式:DSL编程和SQL编程,所以定义UDF函数也有两种方式,不同方式可以在不同分析中使用。

UDF

spark UDF

源码:

我们从源码知道spark提供了23个UDF相关的接口。如下图所示:

spark sql udf spark sql udf hive udf_UDTF

其实它们之间区别就与接口中定义参数的多少,这些udf能支持的传入的参数的个数从[0,22]分别对应每个UDF函数
为了方便大家了解,我们从源码中引入部分代码:

UDF0函数源码:

/**
 * A Spark SQL UDF that has 0 arguments.
 */
@InterfaceStability.Stable
public interface UDF0<R> extends Serializable {
    R call() throws Exception;
}

UDF1函数源码:

/**
 * A Spark SQL UDF that has 1 arguments.
 */
@InterfaceStability.Stable
public interface UDF1<T1, R> extends Serializable {
  R call(T1 t1) throws Exception;
}

UDF2函数源码:

/**
 * A Spark SQL UDF that has 2 arguments.
 */
@InterfaceStability.Stable
public interface UDF2<T1, T2, R> extends Serializable {
  R call(T1 t1, T2 t2) throws Exception;
}

UDF22函数源码:

/**
 * A Spark SQL UDF that has 22 arguments.
 */
@InterfaceStability.Stable
public interface UDF22<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21, T22, R> extends Serializable {
  R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception;
}

由源码我们可以看到我们在使用时需要明确好两点来选择不同UDF:
1、需要传入参数的个数。
2、需要返回值类型。

语法:

使用SparkSessionudf方法定义和注册函数,在SQL中使用,使用如下方式定义:

sparkSession.udf.register(
		"udfName", //自定义函数的名称
		(UDF1<Long, Double>) (parameter) -> parameter*0.2, // 匿名函数
		DataTypes.DoubleType //返回值的类型
						);

其中register()该注册函数的参数解释如下:

  • 第一个参数udfName就是你的udf的名字。
  • 第二个参数中的parameter就是传入的UDF的参数。
  • 第三个参数就是处理完的返回的数据类型。

注:特别说明的UDF1<Long,Double> 中的Long表示传入的参数的数据类型Double表示返回的参数的数据类型,这个必须与上面提到的注册函数的第三个参数保持一致。

实现方法:

自定义udf的方式有两种:

  • SQLContext.udf.register(),SQL方式。
  • 创建UserDefinedFunction,DSL方式。

案例

将小写数据转换为大写:

//TODO:通过sparkSession进行UDF的注册,将我们的小写转换成大写
    //1.SQL方式:
    sparkSession.udf.register("smallToBigger", new UDF1[String,String]() {
      @throws[Exception]
      override def call(t1: String): String = {
        t1.toUpperCase()
      }
    }, DataTypes.StringType)
  //2.DSL方式:
   val smallToBigger: UserDefinedFunction = udf((str: String) => t1.toUpperCase())
    //使用UDF函数
    sparkSession.sql("select line, smallToBigger(line) as biggerLine from small_table").show()

Hive UDF

实现步骤

1.继承org.apache.hadoop.hive.ql.exec.UDF 2.实现evaluate方法 注意:该方法必须有返回值,可以为null
3.将jar包上传到hdfs /xx/xx/udf/hive-udf.jar将jar包重命名为hive-udf.jar
4.在集群的客户端,打开hive shell执行命令:

CREATE [TEMPORARY] FUNCTION json_array_string AS 
'org.myfunctions.udf.JsonArrayUDF' 
USING JAR 'hdfs:///xx/xx/udf/hive-udf.jar';

案例:

这里仅提供代码,详细步骤参照上面即可。

package com.aiyunxiao.bigdata.udf;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.slf4j.Logger;

public class JsonArrayUDF extends UDF {
    private Logger LOG = org.slf4j.LoggerFactory.getLogger(JsonArrayUDF.class);
    /**
     * 1. Implement one or more methods named "evaluate" which will be called by Hive.
     *
     * 2. "evaluate" should never be a void method. However it can return "null" if needed.
     */
    // CREATE FUNCTION udf_name as 'org.myfunctions...' USING JAR 'hdfs:///xx/xx/udf/hive-udf.jar';
    public String evaluate(String str){
        try {
            JSONArray jsonArray = JSON.parseArray(str);
            StringBuilder sb = new StringBuilder();
            for (Object o : jsonArray) {
                sb.append(JSON.toJSONString(o)).append("||");
            }
            return sb.substring(0,sb.length()-2);
        }catch (Exception e){
            //LOG.error(e.getMessage());
            return null;
        }
    }
}

UDAF

  • UDAF简介
    UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。
  • 关于UDAF的一个误区
    我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:
select min(a) from table group by b;

表示根据b字段分组,然后求每个分组的最小值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:

select min(a) from table;

这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最小值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量

Spark UDAF(User Defined Aggregate Function)

Spark UDAF 实现方法:

Spark实现UDAF有两个办法,如下:
1.继承UserDefinedAggregateFunction 2.继承Aggregator

Spark UDAF 实现步骤:

使用继承UserDefinedAggregateFunction实现UDAF的步骤:

  1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现
  2. spark中注册UDAF,为其绑定一个名字
  3. 然后就可以在sql语句中使用上面绑定的名字调用

案例:

我们写一个计算学生平均分值的UDAF例子方便大家理解

继承UserDefinedAggregateFunction
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/** *
 *
 * @author Saodiseng
 * @date 2021/5/22 6:25 下午 周六
 * @jdk jdk1.8.0
 * @version 1.0
 * **/

object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction {

   // 聚合函数的输入数据结构
  override def inputSchema: StructType = StructType(StructField("input", DoubleType) :: Nil)

  // 缓存区数据结构
  override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil)

  // 聚合函数返回值数据结构
  override def dataType: DataType = DoubleType

  // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  override def deterministic: Boolean = true

  // 初始化缓冲区
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.0
    buffer(1) = 0L
  }

  // 给聚合函数传入一条新数据进行处理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (input.isNullAt(0)) return
    buffer(0) = buffer.getDouble(0) + input.getDouble(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  // 合并聚合函数缓冲区
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // 计算最终结果
  override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)
  
}

然后注册并使用它:

import org.apache.spark.sql.SparkSession
 
object SparkSqlUDAFDemo_001 {

 def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local[*]").appName("StudentScoreAvg").getOrCreate()

    spark.read.json("metastore_db/tmp/score").createOrReplaceTempView("student_score")

    spark.udf.register("s_avg", AverageUserDefinedAggregateFunction)

    // 将整张表看做是一个分组对求所有人的平均年龄
    spark.sql("select count(1) as count, s_avg(age) as avg_score from student_score").show()

    // 按照性别分组求平均
    spark.sql("select class, count(1) as count, s_avg(score) as avg_score from student_score group by class").show()

    spark.close()
  }
}

使用到的数据集:

{"student_id": 1001, "student_name": "xiaoming", "class": "2", "sex": "woman", "score": 56.5}
{"student_id": 1002, "student_name": "xiaoqiang", "class": "2", "sex": "woman", "score": 59.5}
{"student_id": 1003, "student_name": "qunqun", "class": "1", "sex": "man", "score": 100}
{"student_id": 1004, "student_name": "lulu", "class": "1", "sex": "man", "score": 99}
{"student_id": 1005, "student_name": "xiaolong", "class": "3", "sex": "man", "score": 99}
{"student_id": 1006, "student_name": "luting", "class": "3", "sex": "man", "score": 98}

运营结果:

spark sql udf spark sql udf hive udf_spark sql udf_02

spark sql udf spark sql udf hive udf_大数据_03

继承Aggregator

还有另一种方式就是继承Aggregator这个类,优点是可以带类型:

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}
 
/**
  * 计算平均值
  */
object AverageAggregator extends Aggregator[User, Average, Double] {
 
  // 初始化buffer
  override def zero: Average = Average(0.0, 0L)
 
  // 处理一条新的记录
  override def reduce(b: Average, a: User): Average = {
    b.sum += a.score
    b.count += 1L
    b
  }
 
  // 合并聚合buffer
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }
 
  // 减少中间数据传输
  override def finish(reduction: Average): Double = reduction.sum / reduction.count
 
  override def bufferEncoder: Encoder[Average] = Encoders.product
 
  // 最终输出结果的类型
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
 
}
 
/**
  * 计算平均值过程中使用的Buffer
  *
  * @param sum
  * @param count
  */
case class Average(var sum: Double, var count: Long) {
}
 
case class User(student_id: Long, student_name: String, sex: String, score: Double) {
}

调用:

import org.apache.spark.sql.SparkSession
 
object AverageAggregatorDemo_001 {
 
  def main(args: Array[String]): Unit = {
 
    val spark = SparkSession.builder().master("local[*]").appName("StudentScoreAvg").getOrCreate()
    import spark.implicits._
    
    val user = spark.read.json("metastore_db/tmp/score").as[User]
    
    user.select(AverageAggregator.toColumn.name("avg")).show()
  }
}

运行结果:

spark sql udf spark sql udf hive udf_UDTF_04

Hive UDAF(User Defined Aggregate Function)

Hive UDAF 实现步骤:

Hive UDAF实现方式一:

UDAF其实已经过时,但这里我们也进行简单介绍:

想要实现自定义UDAF需要使用以下两类:

  • import org.apache.hadoop.hive.ql.exec.UDAF
  • import org.apache.hadoop.hive.ql.exec.UDAFEvaluator

步骤:

  • 1、函数类需要继承UDAF类,计算类Evaluator实现UDAFEvaluator接口
  • 2、Evaluator需要实现UDAFEvaluatorinititerateterminatePartialmergeterminate这几个函数。
    a)init函数实现接口UDAFEvaluatorinit函数。
    b)iterate接收传入的参数,并进行内部的迭代。其返回类型为boolean
    c)terminatePartial无参数,其为iterate函数遍历结束后,返回遍历得到的数据,terminatePartial类似于 hadoopCombiner
    d)merge接收terminatePartial的返回结果,进行数据merge操作,其返回类型为boolean
    e)terminate返回最终的聚集函数结果。
Hive UDAF 方式一示例:
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;

/***
 * @author Saodiseng
 * @date 2021/5/24 11:00 上午 周一
 * @jdk jdk1.8.0
 * @version 1.0
 ***/
public class Avg extends UDAF {

    /**
     * 定义静态内部类AvgState
     */
    public static class AvgState {
        private long mCount;
        private double mSum;
    }

    public static class AvgEvaluator implements UDAFEvaluator {

        //初始化AvgState对象
        AvgState state;

        //创建AvgEvaluator无参构造函数
        public AvgEvaluator(){

            super();

            state = new AvgState();

            init();
        }

        /**
         * init函数类似于构造函数,用于UDAF的初始化
         */
        @Override
        public void init() {

            //设置mCount初始值
            state.mCount = 0;

            //设置mSum初始值
            state.mSum = 0;
        }

        /**
         * iterate接收传入的参数,并进行内部的轮转。其返回类型为boolean
         * @param o
         * @return
         */
        public boolean iterate(Double o) {
            if (o != null) {
                state.mSum += o;
                state.mCount++;
            }
            return true;
        }

        /**
         * terminatePartial无参数,其为iterate函数遍历结束后,返回轮转数据,
         * terminatePartial类似于hadoop的Combiner
         * @return
         */
        public AvgState terminatePartial() {
            // combiner
            return state.mCount == 0 ? null : state;
        }

        /**
         * merge接收terminatePartial的返回结果,进行数据merge操作,其返回类型为boolean
         * @param
         * @return
         */
        public boolean merge(AvgState avgState) {
            if (avgState != null) {
                state.mCount += avgState.mCount;
                state.mSum += avgState.mSum;
            }
            return true;
        }

        /**
         * terminate返回最终的聚集函数结果
         * @return
         */
        public Double terminate() {
            return state.mCount == 0 ? null : Double.valueOf(state.mSum / state.mCount);
        }
    }
}
Hive UDAF实现方式二:

Hive自定义聚类函数-GenericUDAFUDAF开发主要涉及到以下两个抽象类,创建一个GenericUDAF必须先了解以下两个抽象类:

  • org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver
  • org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
相关抽象类介绍

为了更好理解上述抽象类的API,要记住hive只是mapreduce函数,只不过hive已经帮助我们写好并隐藏mapreduce,向上提供简洁的sql函数,所以我们要结合Mapper、Combiner与Reducer来帮助我们理解这个函数。要记住在hadoop集群中有若干台机器,在不同的机器上Mapper与Reducer任务独立运行。所以大体上来说,这个UDAF函数读取数据(mapper),聚集一堆mapper输出到部分聚集结果(combiner),并且最终创建一个最终的聚集结果(reducer)。因为我们跨域多个combiner进行聚集,所以我们需要保存部分聚集结果。

AbstractGenericUDAFResolver:
Resolver很简单,要覆盖实现下面方法,该方法会根据sql传人的参数数据格式指定调用哪个Evaluator进行处理。

public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException;

GenericUDAFEvaluator:
UDAF逻辑处理主要发生在Evaluator中,要实现该抽象类的几个方法。在理解Evaluator之前,必须先理解objectInspector接口与GenericUDAFEvaluator中的内部类Model
ObjectInspector:
作用主要是解耦数据使用与数据格式,使得数据流在输入输出端切换不同的输入输出格式,不同的Operator上使用不同的格式。简单来说,ObjectInspector接口使得Hive可以不拘泥于一种特定数据格式,使得数据流 1在输入端和输出端切换不同的输入/输出格式 2。
Model
Model代表了UDAFmapreduce的各个阶段。

public static enum Mode {  
    /** 
     * PARTIAL1: 这个是mapreduce的map阶段:从原始数据到部分数据聚合 
     * 将会调用iterate()和terminatePartial() 
     */  
    PARTIAL1,  
        /** 
     * PARTIAL2: 这个是mapreduce的map端的Combiner阶段,负责在map端合并map的数据::从部分数据聚合到部分数据聚合: 
     * 将会调用merge() 和 terminatePartial() 
     */  
    PARTIAL2,  
        /** 
     * FINAL: mapreduce的reduce阶段:从部分数据的聚合到完全聚合  
     * 将会调用merge()和terminate() 
     */  
    FINAL,  
   /** 
     * COMPLETE: 如果出现了这个阶段,表示mapreduce只有map,没有reduce,所以map端就直接出结果了:从原始数据直接到完全聚合 
      * 将会调用 iterate()和terminate() 
     */  
    COMPLETE  
  };

一般情况下,完整的UDAF逻辑是一个mapreduce过程,

  • 如果有mapperreducer,就会经历PARTIAL1(mapper),FINAL(reducer)
  • 如果还有combiner,那就会经历PARTIAL1(mapper),PARTIAL2(combiner),FINAL(reducer)
  • 而有一些情况下的mapreduce,只有mapper,而没有reducer,所以就会只有COMPLETE阶段,这个阶段直接输入原始数据,出结果。

代码执行过程:

  1. PARTIAL1(阶段1:map):init() >> iterate() >> terminatePartial()
  2. PARTIAL2(阶段2:combine):init() >> merge() >> terminatePartial()
  3. FINAL(最终阶段:reduce):init() >> merge() >> terminate()
  4. COMPLETE(直接输出阶段:只有map):init() >> iterate() >> terminate()

注:每个阶段都会执行init()初始化操作。

GenericUDAFEvaluator的方法
// 确定各个阶段输入输出参数的数据格式ObjectInspectors  
public  ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException;  
  
// 保存数据聚集结果的类  
abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;  
  
// 重置聚集结果  
public void reset(AggregationBuffer agg) throws HiveException;  
  
// map阶段,迭代处理输入sql传过来的列数据  
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;  
  
// map与combiner结束返回结果,得到部分数据聚集结果  
public Object terminatePartial(AggregationBuffer agg) throws HiveException;  
  
// combiner合并map返回的结果,还有reducer合并mapper或combiner返回的结果。  
public void merge(AggregationBuffer agg, Object partial) throws HiveException;  
  
// reducer阶段,输出最终结果  
public Object terminate(AggregationBuffer agg) throws HiveException;
Hive UDAF 方式二实现步骤

1)需继承AbstractGenericUDAFResolver抽象类,重写方法getEvaluator(TypeInfo[] parameters)
2)内部静态类需继承GenericUDAFEvaluator抽象类,重写方法init(),实现方法getNewAggregationBuffer(),reset(),iterate(),terminatePartial(),merge(),terminate()

Hive UDAF 方式二实现实例
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.util.StringUtils;

import java.util.ArrayList;

/***
 * @author Saodiseng
 * @date 2021/5/24 2:54 下午 周一
 * @jdk jdk1.8.0
 * @version 1.0
 ***/

@Description(name = "myavg", value = "_FUNC_(x) - Returns the mean of a set of numbers")
public class GenericUDAFAverage extends AbstractGenericUDAFResolver {

    public static final Log LOG = LogFactory.getLog(GenericUDAFAverage.class.getName());

    /**
     * 读入参数类型校验,满足条件时返回聚合函数数据处理对象
     * @param info
     * @return
     * @throws SemanticException
     */
    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {

        if (info.length != 1) {
            throw new  UDFArgumentTypeException(info.length - 1,
                    "Exactly one argument is expected.");
        }

        if (info[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0,
                    "Only primitive type arguments are accepted but "
                            + info[0].getTypeName() + " is passed.");
        }

        switch (((PrimitiveTypeInfo) info[0]).getPrimitiveCategory()) {
            case BYTE:
            case SHORT:
            case INT:
            case LONG:
            case FLOAT:
            case DOUBLE:
            case STRING:
            case TIMESTAMP:
                return new GenericUDAFAverageEvaluator();
            case BOOLEAN:
            default:
                throw new UDFArgumentTypeException(0,
                        "Only numeric or string type arguments are accepted but "
                                + info[0].getTypeName() + " is passed.");
        }
    }

    /**
     * GenericUDAFAverageEvaluator.
     * 自定义静态内部类:数据处理类,继承GenericUDAFEvaluator抽象类
     */
    public static class GenericUDAFAverageEvaluator extends GenericUDAFEvaluator {

        //1.1.定义全局输入输出数据的类型OI实例,用于解析输入输出数据
        // input For PARTIAL1 and COMPLETE
        PrimitiveObjectInspector inputOI;

        // input For PARTIAL2 and FINAL
        // output For PARTIAL1 and PARTIAL2
        StructObjectInspector soi;
        StructField countField;
        StructField sumField;
        LongObjectInspector countFieldOI;
        DoubleObjectInspector sumFieldOI;

        //1.2.定义全局输出数据的类型,用于存储实际数据
        // output For PARTIAL1 and PARTIAL2
        Object[] partialResult;

        // output For FINAL and COMPLETE
        DoubleWritable result;

        /*
         * 初始化:对各个模式处理过程,提取输入数据类型OI,返回输出数据类型OI
         * .每个模式(Mode)都会执行初始化
         * 1.输入参数parameters:
         * .1.1.对于PARTIAL1 和COMPLETE模式来说,是原始数据(单值)
         *    .设定了iterate()方法的输入参数的类型OI为:
         *    .		 PrimitiveObjectInspector 的实现类 WritableDoubleObjectInspector 的实例
         *    .		 通过输入OI实例解析输入参数值
         * .1.2.对于PARTIAL2 和FINAL模式来说,是模式聚合数据(双值)
         *    .设定了merge()方法的输入参数的类型OI为:
         *    .		 StructObjectInspector 的实现类 StandardStructObjectInspector 的实例
         *    .		 通过输入OI实例解析输入参数值
         * 2.返回值OI:
         * .2.1.对于PARTIAL1 和PARTIAL2模式来说,是设定了方法terminatePartial()返回值的OI实例
         *    .输出OI为 StructObjectInspector 的实现类 StandardStructObjectInspector 的实例
         * .2.2.对于FINAL 和COMPLETE模式来说,是设定了方法terminate()返回值的OI实例
         *    .输出OI为 PrimitiveObjectInspector 的实现类 WritableDoubleObjectInspector 的实例
         */
        @Override
        public ObjectInspector init(Mode mode, ObjectInspector[] parameters)
                throws HiveException {
            assert (parameters.length == 1);
            super.init(mode, parameters);

            // init input
            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
                inputOI = (PrimitiveObjectInspector) parameters[0];
            } else {
                //部分数据作为输入参数时,用到的struct的OI实例,指定输入数据类型,用于解析数据
                soi = (StructObjectInspector) parameters[0];
                countField = soi.getStructFieldRef("count");
                sumField = soi.getStructFieldRef("sum");
                //数组中的每个数据,需要其各自的基本类型OI实例解析
                countFieldOI = (LongObjectInspector) countField.getFieldObjectInspector();
                sumFieldOI = (DoubleObjectInspector) sumField.getFieldObjectInspector();
            }

            // init output
            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
                // The output of a partial aggregation is a struct containing
                // a "long" count and a "double" sum.
                //部分聚合结果是一个数组
                partialResult = new Object[2];
                partialResult[0] = new LongWritable(0);
                partialResult[1] = new DoubleWritable(0);
                /*
                 * .构造Struct的OI实例,用于设定聚合结果数组的类型
                 * .需要字段名List和字段类型List作为参数来构造
                 */
                ArrayList<String> fname = new ArrayList<String>();
                fname.add("count");
                fname.add("sum");
                ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
                //注:此处的两个OI类型 描述的是 partialResult[] 的两个类型,故需一致
                foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
                foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
            } else {
                //FINAL 最终聚合结果为一个数值,并用基本类型OI设定其类型
                result = new DoubleWritable(0);
                return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            }
        }

        /*
         * .聚合数据缓存存储结构
         */
        static class AverageAgg implements AggregationBuffer {
            long count;
            double sum;
        };

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            AverageAgg result = new AverageAgg();
            reset(result);
            return result;
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            AverageAgg myagg = (AverageAgg) agg;
            myagg.count = 0;
            myagg.sum = 0;
        }

        boolean warned = false;

        /*
         * .遍历原始数据
         */
        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters)
                throws HiveException {
            assert (parameters.length == 1);
            Object p = parameters[0];
            if (p != null) {
                AverageAgg myagg = (AverageAgg) agg;
                try {
                    //通过基本数据类型OI解析Object p的值
                    double v = PrimitiveObjectInspectorUtils.getDouble(p, inputOI);
                    myagg.count++;
                    myagg.sum += v;
                } catch (NumberFormatException e) {
                    if (!warned) {
                        warned = true;
                        LOG.warn(getClass().getSimpleName() + " "
                                + StringUtils.stringifyException(e));
                        LOG.warn(getClass().getSimpleName()
                                + " ignoring similar exceptions.");
                    }
                }
            }
        }

        /*
         * .得出部分聚合结果
         */
        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            AverageAgg myagg = (AverageAgg) agg;
            ((LongWritable) partialResult[0]).set(myagg.count);
            ((DoubleWritable) partialResult[1]).set(myagg.sum);
            return partialResult;
        }

        /*
         * .合并部分聚合结果
         * .注:Object[] 是 Object 的子类,此处 partial 为 Object[]数组
         */
        @Override
        public void merge(AggregationBuffer agg, Object partial)
                throws HiveException {
            if (partial != null) {
                AverageAgg myagg = (AverageAgg) agg;
                //通过StandardStructObjectInspector实例,分解出 partial 数组元素值
                Object partialCount = soi.getStructFieldData(partial, countField);
                Object partialSum = soi.getStructFieldData(partial, sumField);
                //通过基本数据类型的OI实例解析Object的值
                myagg.count += countFieldOI.get(partialCount);
                myagg.sum += sumFieldOI.get(partialSum);
            }
        }

        /*
         * .得出最终聚合结果
         */
        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            AverageAgg myagg = (AverageAgg) agg;
            if (myagg.count == 0) {
                return null;
            } else {
                result.set(myagg.sum / myagg.count);
                return result;
            }
        }
    }
}

Hive UDAF 使用

  1. java文件编译成udaf_avg.jar
  2. 进入hive客户端添加jar包
hive>add jar /home/hadoop/udaf_avg.jar
  1. 创建函数
hive>create function udaf_avg AS 'hive.udaf.Avg'
  1. 查询语句
hive>select udaf_avg(age) from user
  1. 销毁函数
hive>drop function udaf_avg

UDTF

其实Hive官方为我们提供了很多的UDTF函数如:explode、json_tuple、get_splits等。实际上UDTF就是把一列数据转换为多列多行
在Spark的源码中其实是没有定义UDTF相关接口的,这让我知道Spark的UDTF函数其实来自于Hive

自定义UDTF实现方法:

直接继承org.apache.hadoop.hive.ql.udf.generic.GenericUDTF就行。

UDTF的基本语法

//创建class类继承GenericUDTF,重写initialize、process、close
class UDTF类名 extends GenericUDTF {}

UDTF的使用

spark中UDTF是没有办法使用register函数进行注册的。

sparkSession.udf.register()

所以我们要特别说明下这个函数的使用,下面我们说明其使用方法

//在获取SparkSession实例时需要加上.enableHiveSupport(),否则无法使用
val spark = SparkSession.builder().appName("UDTF").master("local[*]").enableHiveSupport().getOrCreate()

//注册UDTF
spark.sql("CREATE [TEMPORARY] FUNCTION 自定义UDTF别名 AS 'UDTF类名'")

UDTF函数实现实例:

需求:将ls的Hadoop scala kafka hive hbase Oozie生成如下形式:

//      type           --(表头)
  //      Hadoop
  //      scala
  //      kafaka
  //       hive
  //      hbase
  //      Oozie

数据:

01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop

代码:

import java.util

import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, PrimitiveObjectInspector, StructObjectInspector}
import org.apache.spark.sql.SparkSession

object SparkUDTFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[1]")
      .enableHiveSupport()		//需要hive支持
      .appName("SparkUDTFDemo")
      .getOrCreate()
    val sc = spark.sparkContext

    import spark.implicits._

    val lines = sc.textFile("D:\\test\\t\\udtf.txt")
    val stuDF = lines.map(_.split("//")).filter(x => x(1).equals("ls"))
      .map(x => (x(0), x(1), x(2))).toDF("id", "name", "class")
    //stuDF.printSchema()
    //stuDF.show()

    stuDF.createTempView("student")
    
    spark.sql("CREATE TEMPORARY FUNCTION myUDTF AS 'kb09.sql.myUDTF'")
    //注意AS后面的类如果在包里一定要加包名!!!
    val resultDF = spark.sql("select myUDTF(class) from student")

    resultDF.show()
  }
}

class myUDTF extends GenericUDTF{

  override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
    if (argOIs.length!=1){
      throw new UDFArgumentException("有且只能有一个参数传入")
    }
    if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
      throw new UDFArgumentException("参数类型不匹配")
    }
    val fieldNames =new util.ArrayList[String]
    val fieldOIs =new util.ArrayList[ObjectInspector]

    fieldNames.add("type")
    //这里定义的是输出字段的类型
    fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)


  }
	//传入 Hadoop scala kafaka hive hbase Oozie
  override def process(objects: Array[AnyRef]): Unit ={
    //将字符串切分成单个字符的数组
    val strings = objects(0).toString.split(" ")
    println(strings)
    for (str<- strings){
      val tmp = new Array[String](1)
      tmp(0)=str
      forward(tmp)
    }
  }
  override def close(): Unit = {}
}
+------+
|  type|
+------+
|Hadoop|
| scala|
| kafka|
|  hive|
| hbase|
| Oozie|
+------+

总结:

实现UDTFf还需要注意(基于spark1.5,可能已过时):

  • udtf,process方法中对参数需要使用toString,String强转没用
  • sparksql子查询必须要有别名
  • 算子内部使用竖线切分字符串时,需要转义
  • udtf调用forward方法,必须传字符串数组,即使只有一个元素