1.聚合函数概念
聚合函数:将一个表的一个或多个行并且具有一个或多个属性聚合为标量值。
聚合函数理解:假设一个关于饮料的表。表里面有三个字段,分别是 id、name、price,表里有 5 行数据。假设你需要找到所有饮料里最贵的饮料的价格,即执行一个 max() 聚合。你需要遍历所有 5 行数据,而结果就只有一个数值。
2.聚合函数实现
聚合函数主要通过扩展AggregateFunction类实现。
AggregateFunction工作原理:
- 定义一个累加器,它是保存聚合的中间结果的数据结构。
- 利用AggregateFunction的createAccumulator()方法创建一个空的累加器。
- 对每个输入行调用函数的accumulate()方法来更新累加器。
- 调用函数的getValue()方法来计算并返回最终结果。
对应必须实现以下三个方法:
- createAccumulator():创建累加器
- accumulate():更新累加器
- getValue():返回聚合结果
除了上面的方法,还有几个方法可以选择实现。这些方法有些可以让查询更加高效,而有些是在某些特定场景下必须要实现的。
- retract():在有界OVER窗口上聚合是必需的
- merge() :许多批处理聚合和会话窗口聚合都需要
- resetAccumulator():在许多批式聚合中是必须实现的
注意:
- AggregateFunction 的所有方法都必须是 public 的,不能是 static 的,而且名字必须跟上面写的一样。
- createAccumulator、getValue、getResultType 以及 getAccumulatorType 这几个函数是在抽象类 AggregateFunction 中定义的,而其他函数都是约定的方法。
- 如果要定义一个聚合函数,需要扩展 org.apache.flink.table.functions.AggregateFunction,并且实现一个(或者多个)accumulate 方法。
- accumulate 方法可以重载,每个方法的参数类型不同,并且支持变长参数。
具体可参考以下两个具体实现。
3.GroupConcatList实现
需求:group_concat_list(list, separator),separator不填则默认为逗号,类似于group_concat,以指定间隔符拼接字符串
具体实现:
import org.apache.commons.collections.CollectionUtils;
import org.apache.flink.table.functions.AggregateFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
/**
* @description: group_concat_list(list, separator),separator不填则默认为逗号,类似于group_concat,以指定间隔符拼接字符串
* AggregateFunction<T, ACC>, T表示聚合输出的结果类型,ACC表示聚合的中间状态类型
*/
public class GroupConcatList extends AggregateFunction<String, GroupConcatList.AggregateList> {
private static final Logger LOG = LoggerFactory.getLogger(GroupConcatList.class);
public static class AggregateList {
public List<String> columnList;
public String delimiter; // 间隔符
}
/**
* 返回聚合结果
* @param acc ACC类型的累加器
* @return
*/
@Override
public String getValue(GroupConcatList.AggregateList acc) {
if (CollectionUtils.isEmpty(acc.columnList)) {
return "";
}
return String.join(acc.delimiter, acc.columnList);
}
/**
* 创建累加器
* @return 累加器类型ACC
*/
@Override
public GroupConcatList.AggregateList createAccumulator() {
GroupConcatList.AggregateList acc = new GroupConcatList.AggregateList();
acc.columnList = new ArrayList<>();
return acc;
}
/**
* 更新累加器
* @param acc 当前累加器,类型为ACC
* @param param 可变字符串,第一个字符串为值,第二个字符串为间隔符,若无间隔符则默认为逗号
*/
public void accumulate(GroupConcatList.AggregateList acc, String... param) {
if (param.length == 1) {
acc.columnList.add(param[0]);
acc.delimiter = String.valueOf(','); //默认为逗号
} else if (param.length == 2) {
acc.columnList.add(param[0]);
acc.delimiter = param[1];
} else {
LOG.error("param number error, not support");
}
}
/**
* 回撤相关操作
* @param acc
* @param param
*/
public void retract(GroupConcatList.AggregateList acc, String... param) {
acc.columnList.remove(param[0]);
}
public void resetAccumulator(GroupConcatList.AggregateList acc) {
acc.columnList.clear();
}
}
4.GroupConcatSet实现
需求:group_concat_set(list, separator),separator不填则默认为逗号,类似于group_concat并去重,以指定间隔符拼接字符串。
具体实现:
import org.apache.commons.collections.CollectionUtils;
import org.apache.flink.table.functions.AggregateFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* @description: group_concat_set(list, separator),separator不填则默认为逗号,类似于group_concat并去重,以指定间隔符拼接字符串
* AggregateFunction<T, ACC>, T表示聚合输出的结果类型,ACC表示聚合的中间状态类型
*/
public class GroupConcatSet extends AggregateFunction<String, GroupConcatSet.AggregateList> {
private static final Logger LOG = LoggerFactory.getLogger(GroupConcatSet.class);
public static class AggregateList {
public List<String> columnList;
public String delimiter; // 间隔符
}
/**
* 返回聚合结果
* @param acc ACC类型的累加器
* @return
*/
@Override
public String getValue(GroupConcatSet.AggregateList acc) {
if (CollectionUtils.isEmpty(acc.columnList)) {
return "";
}
Set set = new HashSet(acc.columnList);
return String.join(acc.delimiter, set);
}
/**
* 创建累加器
* @return 累加器类型ACC
*/
@Override
public GroupConcatSet.AggregateList createAccumulator() {
GroupConcatSet.AggregateList acc = new GroupConcatSet.AggregateList();
acc.columnList = new ArrayList<>();
return acc;
}
/**
* 更新累加器
* @param acc 当前累加器,类型为ACC
* @param param 可变字符串,第一个字符串为值,第二个字符串为间隔符,若无间隔符则默认为逗号
*/
public void accumulate(GroupConcatSet.AggregateList acc, String... param) {
if (param.length == 1) {
acc.columnList.add(param[0]);
acc.delimiter = String.valueOf(','); //默认为逗号
} else if (param.length == 2) {
acc.columnList.add(param[0]);
acc.delimiter = param[1];
} else {
LOG.error("param number error, not support");
}
}
/**
* 回撤相关操作
* @param acc
* @param param
*/
public void retract(GroupConcatSet.AggregateList acc, String... param) {
acc.columnList.remove(param[0]);
}
public void resetAccumulator(GroupConcatSet.AggregateList acc) {
acc.columnList.clear();
}
}