Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文和将带领大家来分析Alink中 回归评估 的实现。
Alink漫谈(二十一) :回归评估之源码分析0x00 摘要
Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文和将带领大家来分析Alink中 回归评估 的实现。
这是剖析Alink以来,最轻松的一次了。因为这里的概念和实现逻辑都非常清晰。
0x01 背景概念
1.1 功能介绍
回归评估是对回归算法的预测结果进行效果评估,支持下列评估指标。这些指标基本都是统计领域概念。
1.2 具体指标
Alink 提供如下指标:
count 行数
SST 总平方和(Sum of Squared for Total),度量了Y在样本中的分散程度。
\[SST=\sum_{i=1}^{N}(y_i-\bar{y})^2 \]
SSE 误差平方和(Sum of Squares for Error),度量了总样本变异。
\[SSE=\sum_{i=1}^{N}(y_i-f_i)^2" \]
SSR 回归平方和(Sum of Squares for Regression),度量了残差的样本变异。
\[SSR=\sum_{i=1}^{N}(f_i-\bar{y})^2 \]
R^2 判定系数(Coefficient of Determination),用于估计回归方程是否很好的拟合了样本的数据,判定系数为估计的回归方程提供了一个拟合优度的度量。
\[R^2=1-\dfrac{SSE}{SST} \]
R 多重相关系数(Multiple Correlation Coeffient),指一个随机变量与某一组随机变量间线性相依性的度量。
\[R=\sqrt{R^2} \]
MSE 均方误差(Mean Squared Error),均方差(标准差)、方差都是用来描述数据集的离散程度。
均方误差是衡量“平均误差”的一种较方便的方法,可以评价数据的变化程度。从类别来看属于预测评价与预测组合;从字面上看来,“均”指的是平均,即求其平均值,“方差”即是在概率论中用来衡量随机变量和其估计值(其平均值)之间的偏离程度的度量值,“误”可以理解为测定值与真实值之间的误差。
\[MSE=\dfrac{1}{N}\sum_{i=1}^{N}(f_i-y_i)^2 \]
RMSE 均方根误差(Root Mean Squared Error)
\[RMSE=\sqrt{MSE} \]
SAE/SAD 绝对误差(Sum of Absolute Error/Difference)
\[SAE=\sum_{i=1}^{N}|f_i-y_i| \]
MAE/MAD 平均绝对误差(Mean Absolute Error/Difference)
\[MAE=\dfrac{1}{N}\sum_{i=1}^{N}|f_i-y_i| \]
MAPE 平均绝对百分误差(Mean Absolute Percentage Error)
\[MAPE=\dfrac{100}{N}\sum_{i=1}^{N}|\dfrac{f_i-y_i}{y_i}| \]
explained variance 解释方差
\[explained Variance=\dfrac{SSR}{N} \]
0x02 示例代码
直接拿出来Alink的示例代码。
public class EvalRegressionBatchOpExp { public static void main(String[] args) throws Exception { Row[] data = new Row[] { Row.of(0.4, 0.5), Row.of(0.3, 0.5), Row.of(0.2, 0.6), Row.of(0.6, 0.7), Row.of(0.1, 0.5) }; MemSourceBatchOp input = new MemSourceBatchOp(data, new String[] {"label", "pred"}); RegressionMetrics metrics = new EvalRegressionBatchOp() .setLabelCol("label") .setPredictionCol("pred") .linkFrom(input) .collectMetrics(); System.out.println(metrics.getRmse()); System.out.println(metrics.getR2()); System.out.println(metrics.getSse()); System.out.println(metrics.getMape()); System.out.println(metrics.getMae()); System.out.println(metrics.getSsr()); System.out.println(metrics.getSst()); } }
输出为:
0.27568097504180444 -1.5675675675675653 0.38 141.66666666666669 0.24 0.31999999999999973 0.14800000000000013
0x03 总体逻辑
总体逻辑是:
- 调用 CalcLocal 进行分区计算各种统计数值;
- reduce 调用 ReduceBaseMetrics 进行归并各种统计数值;
- 调用 SaveDataAsParams 存储;
getLabelCol 就是 y,getPredictionCol 就是 y_hat。
public EvalRegressionBatchOp linkFrom(BatchOperator... inputs) { BatchOperator in = checkAndGetFirst(inputs); // 这里就是找到y, y_hat TableUtil.findColIndexWithAssertAndHint(in.getColNames(), this.getLabelCol()); TableUtil.findColIndexWithAssertAndHint(in.getColNames(), this.getPredictionCol()); // 利用y, y_hat来构建Metrics TableUtil.assertNumericalCols(in.getSchema(), this.getLabelCol(), this.getPredictionCol()); DataSetout = in.select(new String[] {this.getLabelCol(), this.getPredictionCol()}) .getDataSet() .rebalance() .mapPartition(new CalcLocal()) .reduce(new EvaluationUtil.ReduceBaseMetrics()) .flatMap(new EvaluationUtil.SaveDataAsParams()); this.setOutputTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), out, new TableSchema(new String[] {"regression_eval_result"}, new TypeInformation[] {Types.STRING}) )); return this; }
0x04 分区计算统计数值
调用 CalcLocal 进行分区计算各种统计数值,间接调用getRegressionStatistics。
/** * Get the label sum, predResult sum, SSE, MAE, MAPE of one partition. */ public static class CalcLocal implements MapPartitionFunction{ @Override public void mapPartition(Iterablerows, Collectorcollector) throws Exception { collector.collect(getRegressionStatistics(rows)); } }
getRegressionStatistics作用是遍历输入数据,在本Partition内部计算各种累积数值,为后续做准备。
/** * Calculate the RegressionMetrics from local data. * * @param rows Input rows, the first field is label value, the second field is prediction value. * @return RegressionMetricsSummary. */ public static RegressionMetricsSummary getRegressionStatistics(Iterablerows) { RegressionMetricsSummary regressionSummary = new RegressionMetricsSummary(); for (Row row : rows) { if (checkRowFieldNotNull(row)) { double yVal = ((Number)row.getField(0)).doubleValue(); double predictVal = ((Number)row.getField(1)).doubleValue(); double diff = Math.abs(yVal - predictVal); regressionSummary.ySumLocal += yVal; regressionSummary.ySum2Local += yVal * yVal; regressionSummary.predSumLocal += predictVal; regressionSummary.predSum2Local += predictVal * predictVal; regressionSummary.maeLocal += diff; regressionSummary.sseLocal += diff * diff; regressionSummary.mapeLocal += Math.abs(diff / yVal); regressionSummary.total++; } } return regressionSummary.total == 0 ? null : regressionSummary; }
0x05 归并统计数值
reduce 调用 ReduceBaseMetrics 进行归并各种统计数值:
/** * Merge the BaseMetrics calculated locally. */ public static class ReduceBaseMetrics implements ReduceFunction{ @Override public BaseMetricsSummary reduce(BaseMetricsSummary t1, BaseMetricsSummary t2) throws Exception { return null == t1 ? t2 : t1.merge(t2); } }
0x06 存储模型
这里调用SaveDataAsParams来存储模型。
/** * After merging all the BaseMetrics, we get the total BaseMetrics. Calculate the indexes and save them into params. */ public static class SaveDataAsParams implements FlatMapFunction{ @Override public void flatMap(BaseMetricsSummary t, Collectorcollector) throws Exception { collector.collect(t.toMetrics().serialize()); } }
0x07 toMetrics
最后呈现出统计指标。
public RegressionMetrics toMetrics() { Params params = new Params(); params.set(RegressionMetrics.SST, ySum2Local - ySumLocal * ySumLocal / total); params.set(RegressionMetrics.SSE, sseLocal); params.set(RegressionMetrics.SSR, predSum2Local - 2 * ySumLocal * predSumLocal / total + ySumLocal * ySumLocal / total); params.set(RegressionMetrics.R2, 1 - params.get(RegressionMetrics.SSE) / params.get(RegressionMetrics.SST)); params.set(RegressionMetrics.R, Math.sqrt(params.get(RegressionMetrics.R2))); params.set(RegressionMetrics.MSE, params.get(RegressionMetrics.SSE) / total); params.set(RegressionMetrics.RMSE, Math.sqrt(params.get(RegressionMetrics.MSE))); params.set(RegressionMetrics.SAE, maeLocal); params.set(RegressionMetrics.MAE, params.get(RegressionMetrics.SAE) / total); params.set(RegressionMetrics.COUNT, (double)total); params.set(RegressionMetrics.MAPE, mapeLocal * 100 / total); params.set(RegressionMetrics.Y_MEAN, ySumLocal / total); params.set(RegressionMetrics.PREDICTION_MEAN, predSumLocal / total); params.set(RegressionMetrics.EXPLAINED_VARIANCE, params.get(RegressionMetrics.SSR) / total); return new RegressionMetrics(params); }
最后得到结果
params = {Params@9098} "Params {R2=-1.5675675675675693, predictionMean=0.5599999999999999, SSE=0.38, count=5.0, MAPE=141.66666666666666, RMSE=0.27568097504180444, MAE=0.24, R=NaN, SSR=0.3200000000000002, yMean=0.32, SST=0.1479999999999999, SAE=1.2, Explained Variance=0.06400000000000003, MSE=0.076}" params = {HashMap@9101} size = 14 "R2" -> "-1.5675675675675693" "predictionMean" -> "0.5599999999999999" "SSE" -> "0.38" "count" -> "5.0" "MAPE" -> "141.66666666666666" "RMSE" -> "0.27568097504180444" "MAE" -> "0.24" "R" -> "NaN" "SSR" -> "0.3200000000000002" "yMean" -> "0.32" "SST" -> "0.1479999999999999" "SAE" -> "1.2" "Explained Variance" -> "0.06400000000000003" "MSE" -> "0.076"