本篇继续进阶一点,写一下 梯度提升决策树(Gradient Boosting Decision Tree)

还是先上代码,梯度提升决策树是能够支持多种损失函数的,关于 损失函数的定义,老规矩,自己搜。既然要支持多种损失函数,因此先写个接口类,然后再来个实现,后面会用到

损失函数接口类

public interface LossFunction {

    public double loss(double y, double predict);
    
    /**
     * 负梯度
     * 详见:《The Elements of Statistical Learning .pdf》 的第10章的表 TABLE 10.2 Gradients for commonly used loss functions.
     * @param y 实际值
     * @param predict 预测值
     * @return
     */
    public double negativePartialDerivative(double y, double predict);
}

损失函数的实现 - 差平方

public class SquareErrorLoss implements LossFunction{

    @Override
    public double loss(double y, double predict) {
        double loss = y - predict;
        return loss * loss;
    }

    /**
     * 负梯度
     * 详见:《The Elements of Statistical Learning .pdf》 的第10章的表 TABLE 10.2 Gradients for commonly used loss functions.
     */
    @Override
    public double negativePartialDerivative(double y, double predict) {
        return y - predict;
    }

}

生成树的节点的代码借用上一篇提升树的代码,计算损失函数之和部分调整了一下。

/**
 * 依据给定的X和Y数据,基于最小二乘回归树生成 1 个二叉树(1个节点)
 * 选择最优的切分点
 * @param xdata
 * @param ydata
 * @return
 */
BinaryTreeNode generateRegressTreeNode(double[] xdata, double[] ydata) {
    BinaryTreeNode brn = null;
    int dataLength = xdata.length;
    double minSum = 0;
    // 遍历输入值,将xdata分为2个部分
    for (int i = 0; i < dataLength; i++) {
        // X数据的每一个值都可作为切分点
        double splitPoint = xdata[i];
            
        int[] r1Idx = new int[dataLength];
        int[] r2Idx = new int[dataLength];
            
        for (int j = 0; j < dataLength; j++) {
            r1Idx[j] = -1;
            r2Idx[j] = -1;
            if (xdata[j] > splitPoint) {
                r2Idx[j] = j;
            } else {
                r1Idx[j] = j;
            }
        }
        // 切分点左侧Y的数据均值
        double c1 = meanMatrix1DByIdx(ydata, r1Idx);
        // 切分点右侧Y的数据均值
        double c2 = meanMatrix1DByIdx(ydata, r2Idx);
            
        // 左侧和右侧值的损失函数之和
        double sumsl = sumLoss(ydata, c1, r1Idx, c2, r2Idx); // 更改的地方
            
        // 找最小的和(冒泡方式)
        if (i == 0 ||  minSum > sumsl) {
            minSum = sumsl;
            brn = new BinaryTreeNode(splitPoint, c1, c2);
            //brn.setLeftIdx(r1Idx);
            //brn.setRightIdx(r2Idx); // 索引不再需要了,因为只需要一层
        }
    }
    return brn;
}

/**
 * 偏差值使用损失函数实现
 */
double sumLoss(double[] data, double c1, int[] r1IndexArray, double c2, int[] r2IndexArray) {
    double sum = 0.0d;
    for (int idx : r1IndexArray) {
        if (idx > -1) {
            sum = sum + lossFunction.loss(data[idx], c1);
        }
    }
        
    for (int idx : r2IndexArray) {
        if (idx > -1) {
            sum = sum + lossFunction.loss(data[idx], c2);
        }
    }    
    return sum;
}

它的训练和预测的逻辑比提升树复杂一点

LossFunction lossFunction; // 损失函数
List<BinaryTreeNode> binaryTreeNodeList;
double f0; // 初始值

public GradientBoostingDecisionTree(LossFunction lossFunction){
    this.lossFunction = lossFunction;
    binaryTreeNodeList = new ArrayList<BinaryTreeNode>();
}
/**
 * 初始化时,估计使损失函数极小化的常数值,它是只有一个根节点的树,即gama是一个常数值。
 * 在此取均值
 * @param data
 * @return
 */
private double retrieveArgmin(double[] data){
    double total = 0.0d;
    for (int i = 0; i < data.length; i++){
        total = total + data[i];
    }
    
    return total / data.length;
}

/**
 * 依据梯度提升树算法进行预测计算
 * @param newXData
 * @return
 */
double[] predict(double[] newXData) {
    double[] ret = new double[newXData.length];
    for (int i = 0; i < newXData.length; i++) {
        ret[i] = f0;
        for (int j = 0; j < binaryTreeNodeList.size(); j++) {
            BinaryTreeNode btn = binaryTreeNodeList.get(j);
            if (newXData[i] > btn.getSplitPoint()) {
                ret[i] = ret[i] + btn.getRightValue();
            } else {
                ret[i] = ret[i] + btn.getLeftValue();
            } 
        }
    }
    return ret;
}
/**
 * 基于回归树,根据深度,得到多级的二叉树
 * @param xdata
 * @param ydata
 * @param level
 * @return
 */
void train(double[] xdata, double[] ydata, int level) {
    f0 = retrieveArgmin(ydata);
    double[] temp = null;
    for (int i = 0; i < level; i++) {
        if (i == 0) {
            // 计算第一次差值
            temp = calculateFirstResidual(ydata, f0);
        }
            
        BinaryTreeNode btn = generateRegressTreeNode(xdata, temp);
        binaryTreeNodeList.add(btn);
        temp = calculateByNode(xdata, temp, btn);
    }
}

// 计算每一个Y与对比值fm的负梯度值
double[] calculateFirstResidual(double[] ydata, double fm) {
    double[] ret = new double[ydata.length];
    for (int i = 0; i < ydata.length; i++) {
        ret[i] = lossFunction.negativePartialDerivative(ydata[i], fm);
    }
}
// 计算每一个Y与切分点两侧值的负梯度值
double[] calculateByNode(double[] xdata, double[] ydata, BinaryTreeNode binaryTreeNode) {
    double[] ret = new double[xdata.length];
       
    for (int i = 0; i < xdata.length; i++) {
        if (xdata[i] > binaryTreeNode.getSplitPoint()) {
            ret[i] = lossFunction.negativePartialDerivative(ydata[i], binaryTreeNode.getRightValue());
        } else {
            ret[i] = lossFunction.negativePartialDerivative(ydata[i], binaryTreeNode.getLeftValue());
        }
    }
   
    return ret;
}

最后该验证了

double[] xdata = {1,    2,    3,    4,    5,    6,    7,    8,    9,    10};
double[] ydata = {5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05};
GradientBoostingDecisionTree gbdt = new GradientBoostingDecisionTree(new SquareErrorLoss());
        
gbdt.train(xdata, ydata, 6);

double[] nxdata = {8.6, 5.5, 4.4, 3.4, 2.4};
        
double[] predictRT = gbdt.predict(nxdata);

System.out.println("nxdata:" + Arrays.toString(nxdata));
System.out.println("预测结果predict SquareErrorLoss :" + Arrays.toString(predictRT));

------------------------------------------------------------
最终结果:
nxdata: 8.6  5.5  4.4  3.4  2.4  

预测结果SquareErrorLoss: 8.9502  6.8197  6.8197  6.5516  5.8183

总结,梯度提升树(解决回归问题):上述2个都是基于最小二乘的,但是对于其它的 损失函数 就不适用了。而梯度提升就是通过近似的方式,能够支持多种 损失函数,要求损失函数能够做1阶偏导数。实现流程与提升树类似。f0 选用的是ydata的均值, 因此第一次的输出值就使用的是 ydata的每一个值和f0,通过损失函数的偏导数方法计算的结果,预测操作与提升树类似,差别是 f0 的选取。