简介

本示例说明如何使用分位数误差实现贝叶斯优化以调整回归树的随机森林的超参数。 如果计划使用模型来预测条件分位数而不是条件均值,则使用分位数误差而不是均方误差来调整模型是合适的。查找关于树复杂性和要使用贝叶斯优化在每个节点上采样的预测变量数量,实现最小,受罚的袋外分位数误差的模型。 将期望的改进加功能指定为获取功能。

加载和预处理数据

加载carsmall数据集。 假设一个模型,该模型根据加速度、汽缸数、发动机排量、马力、制造商、年和重量来预测汽车的平均燃油经济性。 将Cylinders,Mfg和Model_Year视为类别变量。

例子

clc
clear all
close all

load carsmall
% 将Cylinders、Mfg和Model_Year视为类别变量
Cylinders = categorical(Cylinders);
Mfg = categorical(cellstr(Mfg));
Model_Year = categorical(Model_Year);
X = table(Acceleration,Cylinders,Displacement,Horsepower,Mfg,...
    Model_Year,Weight,MPG);
rng('default'); % For reproducibility

% 调整参数,考虑调整:
% 森林中树木的复杂程度(深度)。 深树倾向于过度拟合,而浅树倾向于欠拟合。 因此,指定每片叶子的最少观察数为20。
% 生长树木时,在每个节点上采样的预测变量的数量。 指定从1到所有预测变量的采样。
maxMinLS = 20;
minLS = optimizableVariable('minLS',[1,maxMinLS],'Type','integer');
numPTS = optimizableVariable('numPTS',[1,size(X,2)-1],'Type','integer');
hyperparametersRF = [minLS; numPTS];

% bayesopt是实现贝叶斯优化的函数,要求您将这些规范作为optimizableVariable对象传递。
% hyperparametersRF是OptimizableVariable对象的2比1数组。
% 还应该考虑调整集合中的树数。 bayesopt倾向于选择包含许多树木的随机森林,因为会更准确。 
% 如果考虑到可用的计算资源,并且您希望使用较少的树,则可以考虑与其他参数分开调整树的数量,或者对包含许多学习者的模型进行惩罚。
% 结果是一个BayesianOptimization对象,该对象除其他外包含目标函数的最小值和优化的超参数值。
% 显示观察到的目标函数最小值和优化的超参数值。
results = bayesopt(@(params)oobErrRF(params,X),hyperparametersRF,...
    'AcquisitionFunctionName','expected-improvement-plus','Verbose',0);

bestOOBErr = results.MinObjective
bestHyperparameters = results.XAtMinObjective
% 使用优化的超参数训练模型
% 使用整个数据集和优化的超参数值训练随机森林。
% Mdl是针对中位数预测优化的TreeBagger对象。 
% 您可以通过将Mdl和新数据传递给QuantilePredict,在给定预测器数据的情况下预测平均燃油经济性。
Mdl = TreeBagger(300,X,'MPG','Method','regression',...
    'MinLeafSize',bestHyperparameters.minLS,...
    'NumPredictorstoSample',bestHyperparameters.numPTS);

% 定义目标函数
% 为贝叶斯优化算法定义目标函数以进行优化。 该功能应:
% 接受要调谐的参数作为输入。
% 使用TreeBagger训练一个随机森林。 在TreeBagger调用中,指定要调整的参数,并指定返回袋外索引。
% 根据中位数估算袋外分位数误差。
% 返回袋外分位数误差。
% oobErrRF训练随机森林并估计袋外分位数误差oobErr使用X中的预测变量数据和参数中的参数指定来训练300个回归树的随机森林,
% 然后根据中位数返回袋外分位数误差 。 X是一个表,params是一个OptimizableVariable对象的数组,对应于最小叶子大小和要在每个节点上采样的预测变量数量。
function oobErr = oobErrRF(params,X)
randomForest = TreeBagger(300,X,'MPG','Method','regression',...
    'OOBPrediction','on','MinLeafSize',params.minLS,...
    'NumPredictorstoSample',params.numPTS);
oobErr = oobQuantileError(randomForest);
end




随机森林交叉验证matlab代码_随机森林



随机森林交叉验证matlab代码_随机森林交叉验证matlab代码_02