1.聚合函数概念

聚合函数:将一个表的一个或多个行并且具有一个或多个属性聚合为标量值。

聚合函数理解:假设一个关于饮料的表。表里面有三个字段,分别是 id、name、price,表里有 5 行数据。假设你需要找到所有饮料里最贵的饮料的价格,即执行一个 max() 聚合。你需要遍历所有 5 行数据,而结果就只有一个数值。

2.聚合函数实现

聚合函数主要通过扩展AggregateFunction类实现。

AggregateFunction工作原理:

  1. 定义一个累加器,它是保存聚合的中间结果的数据结构。
  2. 利用AggregateFunction的createAccumulator()方法创建一个空的累加器。
  3. 对每个输入行调用函数的accumulate()方法来更新累加器。
  4. 调用函数的getValue()方法来计算并返回最终结果。

对应必须实现以下三个方法:

  • createAccumulator():创建累加器
  • accumulate():更新累加器
  • getValue():返回聚合结果

除了上面的方法,还有几个方法可以选择实现。这些方法有些可以让查询更加高效,而有些是在某些特定场景下必须要实现的。

  • retract():在有界OVER窗口上聚合是必需的
  • merge() :许多批处理聚合和会话窗口聚合都需要
  • resetAccumulator():在许多批式聚合中是必须实现的

注意:

  1. AggregateFunction 的所有方法都必须是 public 的,不能是 static 的,而且名字必须跟上面写的一样。
  2. createAccumulator、getValue、getResultType 以及 getAccumulatorType 这几个函数是在抽象类 AggregateFunction 中定义的,而其他函数都是约定的方法。
  3. 如果要定义一个聚合函数,需要扩展 org.apache.flink.table.functions.AggregateFunction,并且实现一个(或者多个)accumulate 方法。
  4. accumulate 方法可以重载,每个方法的参数类型不同,并且支持变长参数。

具体可参考以下两个具体实现。

3.GroupConcatList实现

需求:group_concat_list(list, separator),separator不填则默认为逗号,类似于group_concat,以指定间隔符拼接字符串

具体实现:

import org.apache.commons.collections.CollectionUtils;
import org.apache.flink.table.functions.AggregateFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;

/**
 * @description: group_concat_list(list, separator),separator不填则默认为逗号,类似于group_concat,以指定间隔符拼接字符串
 * AggregateFunction<T, ACC>, T表示聚合输出的结果类型,ACC表示聚合的中间状态类型
 */
public class GroupConcatList extends AggregateFunction<String, GroupConcatList.AggregateList> {

    private static final Logger LOG = LoggerFactory.getLogger(GroupConcatList.class);

    public static class AggregateList {
        public List<String> columnList;
        public String delimiter;     // 间隔符
    }

    /**
     * 返回聚合结果
     * @param acc ACC类型的累加器
     * @return
     */
    @Override
    public String getValue(GroupConcatList.AggregateList acc) {
        if (CollectionUtils.isEmpty(acc.columnList)) {
            return "";
        }
        return String.join(acc.delimiter, acc.columnList);
    }

    /**
     * 创建累加器
     * @return 累加器类型ACC
     */
    @Override
    public GroupConcatList.AggregateList createAccumulator() {
        GroupConcatList.AggregateList acc = new GroupConcatList.AggregateList();
        acc.columnList = new ArrayList<>();
        return acc;
    }

    /**
     * 更新累加器
     * @param acc 当前累加器,类型为ACC
     * @param param 可变字符串,第一个字符串为值,第二个字符串为间隔符,若无间隔符则默认为逗号
     */
    public void accumulate(GroupConcatList.AggregateList acc, String... param) {
        if (param.length == 1) {
            acc.columnList.add(param[0]);
            acc.delimiter = String.valueOf(',');    //默认为逗号
        } else if (param.length == 2) {
            acc.columnList.add(param[0]);
            acc.delimiter = param[1];
        } else {
            LOG.error("param number error, not support");
        }
    }

    /**
     * 回撤相关操作
     * @param acc
     * @param param
     */
    public void retract(GroupConcatList.AggregateList acc, String... param) {
        acc.columnList.remove(param[0]);
    }

    public void resetAccumulator(GroupConcatList.AggregateList acc) {
        acc.columnList.clear();
    }
}

4.GroupConcatSet实现

需求:group_concat_set(list, separator),separator不填则默认为逗号,类似于group_concat并去重,以指定间隔符拼接字符串。

具体实现:

import org.apache.commons.collections.CollectionUtils;
import org.apache.flink.table.functions.AggregateFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * @description: group_concat_set(list, separator),separator不填则默认为逗号,类似于group_concat并去重,以指定间隔符拼接字符串
 * AggregateFunction<T, ACC>, T表示聚合输出的结果类型,ACC表示聚合的中间状态类型
 */
public class GroupConcatSet extends AggregateFunction<String, GroupConcatSet.AggregateList> {

    private static final Logger LOG = LoggerFactory.getLogger(GroupConcatSet.class);

    public static class AggregateList {
        public List<String> columnList;
        public String delimiter;     // 间隔符
    }

    /**
     * 返回聚合结果
     * @param acc ACC类型的累加器
     * @return
     */
    @Override
    public String getValue(GroupConcatSet.AggregateList acc) {
        if (CollectionUtils.isEmpty(acc.columnList)) {
            return "";
        }
        Set set = new HashSet(acc.columnList);
        return String.join(acc.delimiter, set);
    }

    /**
     * 创建累加器
     * @return 累加器类型ACC
     */
    @Override
    public GroupConcatSet.AggregateList createAccumulator() {
        GroupConcatSet.AggregateList acc = new GroupConcatSet.AggregateList();
        acc.columnList = new ArrayList<>();
        return acc;
    }

    /**
     * 更新累加器
     * @param acc 当前累加器,类型为ACC
     * @param param 可变字符串,第一个字符串为值,第二个字符串为间隔符,若无间隔符则默认为逗号
     */
    public void accumulate(GroupConcatSet.AggregateList acc, String... param) {
        if (param.length == 1) {
            acc.columnList.add(param[0]);
            acc.delimiter = String.valueOf(',');    //默认为逗号
        } else if (param.length == 2) {
            acc.columnList.add(param[0]);
            acc.delimiter = param[1];
        } else {
            LOG.error("param number error, not support");
        }
    }

    /**
     * 回撤相关操作
     * @param acc
     * @param param
     */
    public void retract(GroupConcatSet.AggregateList acc, String... param) {
        acc.columnList.remove(param[0]);

    }

    public void resetAccumulator(GroupConcatSet.AggregateList acc) {
        acc.columnList.clear();
    }
}