1.简介

Hive中编写udf(User-defined function)需要继承UDF类或者GenericUDF类,至于UDF和GenericUDF类的区别,我们可以从GenericUDF类注释上找到答案

/**
   * A Generic User-defined function (GenericUDF) for the use with Hive.
   *
   * New GenericUDF classes need to inherit from this GenericUDF class.
   *
   * The GenericUDF are superior to normal UDFs in the following ways: 1. It can
   * accept arguments of complex types, and return complex types. 2. It can accept
   * variable length of arguments. 3. It can accept an infinite number of function
   * signature - for example, it's easy to write a GenericUDF that accepts
   * array<int>, array<array<int>> and so on (arbitrary levels of nesting). 4. It
   * can do short-circuit evaluations using DeferedObject.
   */

区别主要分四点:

  1. 可以接收和返回复杂类型参数(arrays,maps,structs,union)
  2. 可以接受可变长度参数
  3. 可以接受无限长度的参数
  4. 可以通过DeferedObject来缩短计算

其实1~3点可以总结为可以处理复杂数据类型、可变参数,第4点通过DeferedObject可以使计算变简短

接下来,通过两个两个例子解开UDF和GenericUDF的神秘面纱

编写udf只需要实现evaluate方法即可。注意,此处不是Override(重写)evaluate方法。忍不住要问,为什么是实现evaluate方法,而不是实现cat或dog方法。于是机智的我们打开源码看到如下注释

* Implement one or more methods named {@code evaluate} which will be called by Hive (the exact
   * way in which Hive resolves the method to call can be configured by setting a custom {@link
   * UDFMethodResolver}). The following are some examples:

从注释可以看出evaluate方法是在UDFMethodResolver中进行的配置,UDFMethodResolver接口的默认实现类为DefaultUDFMethodResolver,从该类中可以看到具体实现方法,原来是在此处进行的注册,疑问得到解答。

@Override
    public Method getEvalMethod(List<TypeInfo> argClasses) throws UDFArgumentException {
      return FunctionRegistry.getMethodInternal(udfClass, "evaluate", false,
          argClasses);
    }

2.实战

2.1UDF

在工作中我们经常遇到得到层次的情况,需要实现函数如下

eg:get_levels('A/B/C/D','/',1,3) -> 'A/B/C'

get_levels('A/B/C/D','/',2) -> 'A/B/C/D'

/**
   * eg:get_levels('A/B/C/D','/',1,3) -> 'A/B/C'
   *    get_levels('A/B/C/D','/',2)   -> 'A/B/C/D'
   */
  public class UDFGetLevels extends UDF{
    
      /**
       * source,为源数据,sep为分割字符串,start为起始位置,从1开始 end为结束位置
       *
       * @param source
       * @param sep
       * @param start
       * @param end
       * @return
       */
      public String evaluate(String source,String sep,int start,int end) {
          if(source == null || sep == null) {
              return null;
          }
  
          String[] arr = source.split(sep);
          return StringUtils.join(ArrayUtils.subarray(arr,start - 1,end),sep);
      }
  
      /**
       * 结束位置为最末尾
       *
       * @param source
       * @param sep
       * @param start
       * @return
       */
      public String evaluate(String source,String sep,int start) {
          if(source == null || sep == null) {
              return null;
          }
  
          String[] arr = source.split(sep);
          return StringUtils.join(ArrayUtils.subarray(arr,start - 1,arr.length),sep);
      }
  
      /**
       * 默认以"/" 为分割
       *
       * @param source
       * @param start
       * @param end
       * @return
       */
      public String evaluate(String source,int start,int end) {
          return evaluate(source,"/",start,end);
      }
  }

执行效果如下

hive> add jar /Users/liufeifei/hive/jar/hive.jar;
  Added [/Users/liufeifei/hive/jar/hive.jar] to class path
  Added resources: [/Users/liufeifei/hive/jar/hive.jar]
  hive> create temporary function get_levels as 'com.practice.hive.udf.UDFGetLevels';
  OK
  Time taken: 0.002 seconds
  hive> select get_levels('A/B/C/D','/',1,3);
  OK
  A/B/C
  Time taken: 0.041 seconds, Fetched: 1 row(s)
  hive> select get_levels('A/B/C/D','/',2);
  OK
  B/C/D
  Time taken: 0.034 seconds, Fetched: 1 row(s)

2.2 GenericUDF

编写GenericUDF主要实现initialize、evaluate、getDisplayString三个方法,

initialize在初始化的时候调用一次,用来检查输入参数且可以将输入参数(DeferredObject)进行转换, evaluate 用来进行逻辑处理, getDisplayString 用来在 explain 的时候进行显示,注释如下

/**
     * Initialize this GenericUDF. This will be called once and only once per
     * GenericUDF instance.
     *
     * @param arguments
     *          The ObjectInspector for the arguments
     * @throws UDFArgumentException
     *           Thrown when arguments have wrong types, wrong length, etc.
     * @return The ObjectInspector for the return value
     */
    public abstract ObjectInspector initialize(ObjectInspector[] arguments)
        throws UDFArgumentException;
        
    /**
     * Evaluate the GenericUDF with the arguments.
     *
     * @param arguments
     *          The arguments as DeferedObject, use DeferedObject.get() to get the
     *          actual argument Object. The Objects can be inspected by the
     *          ObjectInspectors passed in the initialize call.
     * @return The
     */
    public abstract Object evaluate(DeferredObject[] arguments)
        throws HiveException;
  
  
    /**
     * Get the String to be displayed in explain.
     */
    public abstract String getDisplayString(String[] children);

对ObjectInspector我们比较感兴趣,查看源码注释

/**
   * ObjectInspector helps us to look into the internal structure of a complex
   * object.
   *
   * A (probably configured) ObjectInspector instance stands for a specific type
   * and a specific way to store the data of that type in the memory.
   *
   * For native java Object, we can directly access the internal structure through
   * member fields and methods. ObjectInspector is a way to delegate that
   * functionality away from the Object, so that we have more control on the
   * behavior of those actions.
   *
   * An efficient implementation of ObjectInspector should rely on factory, so
   * that we can make sure the same ObjectInspector only has one instance. That
   * also makes sure hashCode() and equals() methods of java.lang.Object directly
   * works for ObjectInspector as well.
   */
  public interface ObjectInspector extends Cloneable

可以看到 ObjectInspector 主要是用来检查数据格式,并且格式化数据,通过工厂可以得到单例

实战

2.2.1

eg:select array_to_map(array('zhangsan','18','90')); -> {"score":"90","name":"zhangsan","age":"18"}

传入的为数组,长度为3,分别对应用户的姓名、年龄、分数,返回相应的map信息

代码如下:

/**
   * eg:select array_to_map(array('zhangsan','18','90')); -> {"score":"90","name":"zhangsan","age":"18"}
   */
  public class GenericUDFArrayToMap extends GenericUDF{
  
      private ListObjectInspector listInspector;
  
      private MapObjectInspector mapInspector;
  
      /**
       * 输入参数判断
       *
       * @param arguments
       * @return
       * @throws UDFArgumentException
       */
      @Override
      public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
          if(arguments.length != 1) {
              throw new UDFArgumentException("Must have one parameter");
          //    检查传入参数是否为array
          } else if(! (arguments[0] instanceof StandardListObjectInspector)) {
              throw new UDFArgumentException("Must be array");
          }
  
          // ListObjectInspector  Array<String>
          listInspector = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
  
          // MapObjectInspector Map<String,String>
          mapInspector = ObjectInspectorFactory.getStandardMapObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector,PrimitiveObjectInspectorFactory.javaStringObjectInspector);
  
          // 返回类型为 Map<String,String>
          return mapInspector;
      }
  
      /**
       * 逻辑计算
       *
       * @param arguments
       * @return
       * @throws HiveException
       */
      @Override
      public Object evaluate(DeferredObject[] arguments) throws HiveException {
          Map<String ,String> resMap = Maps.newHashMap();
  
          // listInspector 解析数据
          List<String> arr = (List<String>)listInspector.getList(arguments[0].get());
  
          resMap.put("name",arr.get(0));
          resMap.put("age",arr.get(1));
          resMap.put("score",arr.get(2));
          return resMap;
      }
  
      /**
       * 显示字符串
       *
       * @param children
       * @return
       */
      @Override
      public String getDisplayString(String[] children) {
          return "function ArrayToMap";
      }
  }

打包在hive上执行

hive> add jar /Users/liufeifei/hive/jar/hive.jar;
  Added [/Users/liufeifei/hive/jar/hive.jar] to class path
  Added resources: [/Users/liufeifei/hive/jar/hive.jar]
  hive> create temporary function array_to_map as 'com.practice.hive.udf.generic.GenericUDFArrayToMap';
  OK
  Time taken: 0.002 seconds
  hive> select array_to_map(array('zhangsan','18','90'));
  OK
  {"score":"90","name":"zhangsan","age":"18"}
  Time taken: 0.039 seconds, Fetched: 1 row(s)
hive> explain select array_to_map(array('zhangsan','18','90'));
  OK
  STAGE DEPENDENCIES:
    Stage-0 is a root stage
  
  STAGE PLANS:
    Stage: Stage-0
      Fetch Operator
        limit: -1
        Processor Tree:
          TableScan
            alias: _dummy_table
            Row Limit Per Split: 1
            Statistics: Num rows: 1 Data size: 1 Basic stats: COMPLETE Column stats: COMPLETE
            Select Operator
              expressions: function ArrayToMap (type: map<string,string>)
              outputColumnNames: _col0
              Statistics: Num rows: 1 Data size: 772 Basic stats: COMPLETE Column stats: COMPLETE
              ListSink
  
  Time taken: 0.044 seconds, Fetched: 18 row(s)

2.2.2

再来一个例子

输入为 Array<Struct>,需要根据其中指定的key过滤,并返回Array<Bigint>

如 : [{attr_id=100162, attr_val_id=1447, formatted_value=Floral, status=1},{attr_id=100162, attr_val_id=1463, formatted_value=Plain, status=1},{attr_id=100163, attr_val_id=1464, formatted_value=Plain, status=1}]
需要过滤  attr_id = 100162 的数据,并且得到 attr_val_id 的列表,即
 
filter_array_struct(data,"attr_id",100162,"attr_val_id") ,返回 [1447,1463]


代码如下

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.LongWritable;

import java.util.ArrayList;
import java.util.List;

/**
 * @author lff
 * @datetime 2023/3/24 18:00
 * <p>
 * 传入参数为 (arr,"inKey",value,"outKey")
 * <p>
 * 注册函数 filter_array_struct
 */
public class ArrayStructFilter extends GenericUDF {

    private List<String> nameList;

    @Override
    public ObjectInspector initialize(ObjectInspector[] objectInspectors) throws UDFArgumentException {
        if (objectInspectors.length != 4) {
            throw new UDFArgumentException("Must have three parameter");
            //    检查传入参数是否为array
        } else if (!(objectInspectors[0] instanceof StandardListObjectInspector)) {
            throw new UDFArgumentException("Must be array");
        }

        StandardListObjectInspector listObjectInspector = (StandardListObjectInspector) objectInspectors[0];
        StructObjectInspector structObjectInspector = (StructObjectInspector) listObjectInspector.getListElementObjectInspector();
        List<? extends StructField> allStructFieldRefs = structObjectInspector.getAllStructFieldRefs();
        nameList = new ArrayList<>();

        // 得到传进来 struct的 schema
        for (StructField allStructFieldRef : allStructFieldRefs) {
            String fieldName = allStructFieldRef.getFieldName();
            nameList.add(fieldName);
        }

        // 返回类型为 Map<String,String>
        return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaLongObjectInspector);
    }

    @Override
    public Object evaluate(DeferredObject[] deferredObjects) throws HiveException {
        List<List> inList = (List<List>) (deferredObjects[0].get());
        if (inList == null) {
            return null;
        }

        String inKey = deferredObjects[1].get().toString();
        long value = ((LongWritable) deferredObjects[2].get()).get();
        String outKey = deferredObjects[3].get().toString();

        int inKeyIndex = nameList.indexOf(inKey);
        int outKeyIndex = nameList.indexOf(outKey);
        if (inKeyIndex < 0 || outKeyIndex < 0) {
            return null;
        }

        ArrayList<Long> resList = new ArrayList<>();
        for (List dataList : inList) {
            long dataValue = (Long) dataList.get(inKeyIndex);
            if (dataValue == value) {
                resList.add((Long) (dataList.get(outKeyIndex)));
            }
        }

        return resList.size() == 0 ? null : resList;
    }

    @Override
    public String getDisplayString(String[] strings) {
        return "filter_array_struct usage : filter_array_struct(data,inKey,value,outKey) ";
    }

}

结语

网上可以搜到很多udf教程,很多没有全面介绍和详细说明,这里结合个人学习经验及实战对udf进行分析。文章拷贝了很多英文注释,一方面是可以看到很多知识在文档上都有说明,另一方面是直接翻译过来意思有失偏颇。