深度学习R语言 mlr3 建模,训练,预测,评估(随机森林,Logistic Regression)
本文主要通过使用mlr3包来训练German credit数据集,实现不同的深度学习模型。
1. 加载R使用环境
# 安装官方包,一般情况下大部分常用的包都可以官方安装
# install.packages("tidyverse")
# install.packages("bruceR")
#
# # 安装Github来源的包
# # 先安装devtools包后才可以安装github来源的包
#
# install.packages("devtools")
# devtools::install_github("tidyverse")
# remotes::install_github("tidyverse")
# 加载包
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.3 ✓ purrr 0.3.4
## ✓ tibble 3.1.1 ✓ dplyr 1.0.5
## ✓ tidyr 1.1.3 ✓ stringr 1.4.0
## ✓ readr 1.4.0 ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(data.table)
##
## Attaching package: 'data.table'
## The following objects are masked from 'package:dplyr':
##
## between, first, last
## The following object is masked from 'package:purrr':
##
## transpose
library(mlr3)
library(mlr3learners)
library(mlr3viz)
library(ggplot2)
2. 数据描述
German credit data
德国信用数据,可以从rchallenge中获得,目标是使用20个解释变量来判断因变量信用风险(好/坏)
2.1 导入数据
# install.package("rchallenge)
data("german", package = "rchallenge")
#观察数据
glimpse(german) # 数据类别
## Rows: 1,000
## Columns: 21
## $ status <fct> no checking account, no checking account, ... …
## $ duration <int> 18, 9, 12, 12, 12, 10, 8, 6, 18, 24, 11, 30, 6…
## $ credit_history <fct> all credits at this bank paid back duly, all c…
## $ purpose <fct> car (used), others, retraining, others, others…
## $ amount <int> 1049, 2799, 841, 2122, 2171, 2241, 3398, 1361,…
## $ savings <fct> unknown/no savings account, unknown/no savings…
## $ employment_duration <fct> < 1 yr, 1 <= ... < 4 yrs, 4 <= ... < 7 yrs, 1 …
## $ installment_rate <ord> < 20, 25 <= ... < 35, 25 <= ... < 35, 20 <= ..…
## $ personal_status_sex <fct> female : non-single or male : single, male : m…
## $ other_debtors <fct> none, none, none, none, none, none, none, none…
## $ present_residence <ord> >= 7 yrs, 1 <= ... < 4 yrs, >= 7 yrs, 1 <= ...…
## $ property <fct> car or other, unknown / no property, unknown /…
## $ age <int> 21, 36, 23, 39, 38, 48, 39, 40, 65, 23, 36, 24…
## $ other_installment_plans <fct> none, none, none, none, bank, none, none, none…
## $ housing <fct> for free, for free, for free, for free, rent, …
## $ number_credits <ord> 1, 2-3, 1, 2-3, 2-3, 2-3, 2-3, 1, 2-3, 1, 2-3,…
## $ job <fct> skilled employee/official, skilled employee/of…
## $ people_liable <fct> 0 to 2, 3 or more, 0 to 2, 3 or more, 0 to 2, …
## $ telephone <fct> no, no, no, no, no, no, no, no, no, no, no, no…
## $ foreign_worker <fct> no, no, no, yes, yes, yes, yes, yes, no, no, n…
## $ credit_risk <fct> good, good, good, good, good, good, good, good…
dim(german) # 数据维数
## [1] 1000 21
通过观察发现数据集一共有2000个观测,21个属性(列)。想要预测的因变量是 creadit_risk (good or bad) ,自变量一共有20个,其中 duration, age, amount三个是数值变量,剩余的都是factor因子变量。
可以安装 skimr 包更细致的观察理解变量。
# install.packages("skimr")
skimr::skim(german)
Table: Data summary
Name | german |
Number of rows | 1000 |
Number of columns | 21 |
_______________________ | |
Column type frequency: | |
factor | 18 |
numeric | 3 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
status | 0 | 1 | FALSE | 4 | …: 394, no : 274, …: 269, 0<=: 63 |
credit_history | 0 | 1 | FALSE | 5 | no : 530, all: 293, exi: 88, cri: 49 |
purpose | 0 | 1 | FALSE | 10 | fur: 280, oth: 234, car: 181, car: 103 |
savings | 0 | 1 | FALSE | 5 | unk: 603, …: 183, …: 103, 100: 63 |
employment_duration | 0 | 1 | FALSE | 5 | 1 <: 339, >= : 253, 4 <: 174, < 1: 172 |
installment_rate | 0 | 1 | TRUE | 4 | < 2: 476, 25 : 231, 20 : 157, >= : 136 |
personal_status_sex | 0 | 1 | FALSE | 4 | mal: 548, fem: 310, fem: 92, mal: 50 |
other_debtors | 0 | 1 | FALSE | 3 | non: 907, gua: 52, co-: 41 |
present_residence | 0 | 1 | TRUE | 4 | >= : 413, 1 <: 308, 4 <: 149, < 1: 130 |
property | 0 | 1 | FALSE | 4 | bui: 332, unk: 282, car: 232, rea: 154 |
other_installment_plans | 0 | 1 | FALSE | 3 | non: 814, ban: 139, sto: 47 |
housing | 0 | 1 | FALSE | 3 | ren: 714, for: 179, own: 107 |
number_credits | 0 | 1 | TRUE | 4 | 1: 633, 2-3: 333, 4-5: 28, >= : 6 |
job | 0 | 1 | FALSE | 4 | ski: 630, uns: 200, man: 148, une: 22 |
people_liable | 0 | 1 | FALSE | 2 | 0 t: 845, 3 o: 155 |
telephone | 0 | 1 | FALSE | 2 | no: 596, yes: 404 |
foreign_worker | 0 | 1 | FALSE | 2 | no: 963, yes: 37 |
credit_risk | 0 | 1 | FALSE | 2 | goo: 700, bad: 300 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
duration | 0 | 1 | 20.90 | 12.06 | 4 | 12.0 | 18.0 | 24.00 | 72 | ▇▇▂▁▁ |
amount | 0 | 1 | 3271.25 | 2822.75 | 250 | 1365.5 | 2319.5 | 3972.25 | 18424 | ▇▂▁▁▁ |
age | 0 | 1 | 35.54 | 11.35 | 19 | 27.0 | 33.0 | 42.00 | 75 | ▇▆▃▁▁ |
3. 建模
通过使用mlr3包来解决信用风险分类问题。构建机器学习工作流程时出现的典型问题是:
- 我们试图解决的问题是什么?
- 什么是合适的学习算法?
- 我们如何评价“好”的表现?
在 mlr3 中更系统地,它们可以通过五个组件来表示:
- 任务定义 Task
- 学习期定义 Learner
- 模型训练 Training
- 预测 Prediction
- 通过一项或多项措施进行评估 Evaluation
3.1任务定义 Task Definition
首先,我们要确定建模的目标。大多数监督机器学习问题是回归或分类问题。在 mlr3 中,为了区分这些问题,我们定义了任务。如果我们要解决一个分类问题,我们定义一个分类任务——TaskClassif。对于回归问题,我们定义了一个回归任务——TaskRegr。
在我们的例子中,我们的目标显然是对二元因子变量 credit_risk 进行建模或预测。因此,我们定义了一个 TaskClassif:
# germancredit 是任务标签,可以自行定义, german 数据集,target是目标变量
task = TaskClassif$new("germancredit", german , target = "credit_risk")
3.2学习器定义 Learner Definition
在决定建模目标后,我们需要决定如何建模。这意味着我们需要决定哪些学习算法或 Learners 是合适的。使用先验知识(例如,知道这是一项分类任务或假设类是线性可分的)最终会得到一个或多个合适的学习器。
许多学习者可以通过 mlr3learners 包获得。此外,许多学习器是通过 GitHub 上的 mlr3extralearners 包提供的。这两种资源加起来占标准学习算法的很大一部分。
所有可用的学习器(即您从 mlr3、mlr3learners、mlr3extralearners 或自己编写的安装的所有学习器)都在字典 mlr_learners 中获得:
mlr_learners
## <DictionaryLearner> with 29 stored values
## Keys: classif.cv_glmnet, classif.debug, classif.featureless,
## classif.glmnet, classif.kknn, classif.lda, classif.log_reg,
## classif.multinom, classif.naive_bayes, classif.nnet, classif.qda,
## classif.ranger, classif.rpart, classif.svm, classif.xgboost,
## regr.cv_glmnet, regr.featureless, regr.glmnet, regr.kknn, regr.km,
## regr.lm, regr.ranger, regr.rpart, regr.svm, regr.xgboost,
## surv.cv_glmnet, surv.glmnet, surv.ranger, surv.xgboost
对于我们的问题,合适的学习器可以是以下之一:Logistic regression逻辑回归、CART、random forest随机森林等。
可以使用 lrn() 函数和学习器的名称来初始化学习器,例如 lrn(“classif.xxx”)。使用 ?mlr_learners_xxx 打开名为 xxx 的学习者的帮助页面。
例如,逻辑回归可以通过以下方式初始化(逻辑回归使用 R 的 glm() 函数,由 mlr3learners 包提供):
library("mlr3learners")
learner_logreg = lrn("classif.log_reg")
print(learner_logreg)
## <LearnerClassifLogReg:classif.log_reg>
## * Model: -
## * Parameters: list()
## * Packages: stats
## * Predict Type: response
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: twoclass, weights
3.3 训练 Training
训练是在(训练)数据上拟合模型的过程。
- 逻辑回归Logistic regression
我们从逻辑回归的例子开始。但是,您会立即看到该过程非常容易推广到任何学习者。
可以使用 $train() 对初始化的学习器进行数据训练:
learner_logreg$train(task)
通常,在机器学习中,我们不使用可用的完整数据,而是使用一个子集,即所谓的训练数据。为了有效地执行数据拆分,可以执行以下操作:
train_set = sample(task$row_ids, 0.8 * task$nrow)
test_set = setdiff(task$row_ids, train_set)
80% 的数据用于训练。剩余的 20% 用于随后进行评估。 train_set 是一个整数向量,指的是原始数据集的选定行:
head(train_set)
## [1] 410 864 543 236 958 851
在 mlr3 中,可以通过附加参数 row_ids = train_set 声明使用数据子集的训练:
learner_logreg$train(task, row_ids = train_set)
训练拟合后的模型可以通过以下命令展示:
learner_logreg$model
##
## Call: stats::glm(formula = task$formula(), family = "binomial", data = task$data(),
## model = FALSE)
##
## Coefficients:
## (Intercept)
## -0.1819216
## age
## 0.0056873
## amount
## -0.0001196
## credit_historycritical account/other credits elsewhere
## -1.0951994
## credit_historyno credits taken/all credits paid back duly
## 0.3816992
## credit_historyexisting credits paid back duly till now
## 0.9330591
## credit_historyall credits at this bank paid back duly
## 1.3556494
## duration
## -0.0271785
## employment_duration< 1 yr
## -0.0150296
## employment_duration1 <= ... < 4 yrs
## 0.2004790
## employment_duration4 <= ... < 7 yrs
## 0.9713337
## employment_duration>= 7 yrs
## 0.3789241
## foreign_workerno
## -1.2704600
## housingrent
## 0.6250064
## housingown
## 0.6444397
## installment_rate.L
## -0.5924806
## installment_rate.Q
## 0.0909648
## installment_rate.C
## 0.0636166
## jobunskilled - resident
## -0.8209089
## jobskilled employee/official
## -0.7988798
## jobmanager/self-empl./highly qualif. employee
## -0.9088915
## number_credits.L
## -0.4671141
## number_credits.Q
## 0.0976312
## number_credits.C
## 0.0062673
## other_debtorsco-applicant
## -0.9178934
## other_debtorsguarantor
## 1.3397823
## other_installment_plansstores
## 0.1427722
## other_installment_plansnone
## 0.4974245
## people_liable0 to 2
## 0.2534176
## personal_status_sexfemale : non-single or male : single
## -0.0183188
## personal_status_sexmale : married/widowed
## 0.6102816
## personal_status_sexfemale : single
## 0.0759193
## present_residence.L
## -0.1602614
## present_residence.Q
## 0.4513743
## present_residence.C
## -0.3567466
## propertycar or other
## -0.2797497
## propertybuilding soc. savings agr./life insurance
## -0.1006801
## propertyreal estate
## -0.7330205
## purposecar (new)
## 1.6559118
## purposecar (used)
## 0.8993030
## purposefurniture/equipment
## 0.8574892
## purposeradio/television
## -0.0496272
## purposedomestic appliances
## -0.0426126
## purposerepairs
## 0.0285772
## purposevacation
## 0.7196447
## purposeretraining
## 0.7088115
## purposebusiness
## 2.3256145
## savings... < 100 DM
## 0.2495854
## savings100 <= ... < 500 DM
## 0.5232586
## savings500 <= ... < 1000 DM
## 1.3157498
## savings... >= 1000 DM
## 0.9884852
## status... < 0 DM
## 0.1314611
## status0<= ... < 200 DM
## 0.8973969
## status... >= 200 DM / salary for at least 1 year
## 1.6226985
## telephoneyes (under customer name)
## 0.3142853
##
## Degrees of Freedom: 799 Total (i.e. Null); 745 Residual
## Null Deviance: 982.4
## Residual Deviance: 700.6 AIC: 810.6
可以查看Logistic regression 训练后模型的类型以及总结:
class(learner_logreg$model)
## [1] "glm" "lm"
summary(learner_logreg$model)
##
## Call:
## stats::glm(formula = task$formula(), family = "binomial", data = task$data(),
## model = FALSE)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.7481 -0.6573 0.3599 0.6823 2.0764
##
## Coefficients:
## Estimate Std. Error
## (Intercept) -1.819e-01 1.313e+00
## age 5.687e-03 1.045e-02
## amount -1.196e-04 5.297e-05
## credit_historycritical account/other credits elsewhere -1.095e+00 6.830e-01
## credit_historyno credits taken/all credits paid back duly 3.817e-01 4.971e-01
## credit_historyexisting credits paid back duly till now 9.331e-01 5.441e-01
## credit_historyall credits at this bank paid back duly 1.356e+00 4.897e-01
## duration -2.718e-02 1.083e-02
## employment_duration< 1 yr -1.503e-02 4.935e-01
## employment_duration1 <= ... < 4 yrs 2.005e-01 4.693e-01
## employment_duration4 <= ... < 7 yrs 9.713e-01 5.181e-01
## employment_duration>= 7 yrs 3.789e-01 4.733e-01
## foreign_workerno -1.270e+00 7.304e-01
## housingrent 6.250e-01 2.761e-01
## housingown 6.444e-01 5.408e-01
## installment_rate.L -5.925e-01 2.489e-01
## installment_rate.Q 9.096e-02 2.255e-01
## installment_rate.C 6.362e-02 2.311e-01
## jobunskilled - resident -8.209e-01 7.516e-01
## jobskilled employee/official -7.989e-01 7.274e-01
## jobmanager/self-empl./highly qualif. employee -9.089e-01 7.380e-01
## number_credits.L -4.671e-01 8.489e-01
## number_credits.Q 9.763e-02 6.951e-01
## number_credits.C 6.267e-03 5.218e-01
## other_debtorsco-applicant -9.179e-01 4.757e-01
## other_debtorsguarantor 1.340e+00 4.751e-01
## other_installment_plansstores 1.428e-01 5.116e-01
## other_installment_plansnone 4.974e-01 2.944e-01
## people_liable0 to 2 2.534e-01 2.831e-01
## personal_status_sexfemale : non-single or male : single -1.832e-02 4.396e-01
## personal_status_sexmale : married/widowed 6.103e-01 4.300e-01
## personal_status_sexfemale : single 7.592e-02 5.179e-01
## present_residence.L -1.603e-01 2.457e-01
## present_residence.Q 4.514e-01 2.304e-01
## present_residence.C -3.567e-01 2.293e-01
## propertycar or other -2.797e-01 2.881e-01
## propertybuilding soc. savings agr./life insurance -1.007e-01 2.790e-01
## propertyreal estate -7.330e-01 4.750e-01
## purposecar (new) 1.656e+00 4.260e-01
## purposecar (used) 8.993e-01 3.057e-01
## purposefurniture/equipment 8.575e-01 2.807e-01
## purposeradio/television -4.963e-02 9.327e-01
## purposedomestic appliances -4.261e-02 6.641e-01
## purposerepairs 2.858e-02 4.360e-01
## purposevacation 7.196e-01 1.287e+00
## purposeretraining 7.088e-01 3.815e-01
## purposebusiness 2.326e+00 9.776e-01
## savings... < 100 DM 2.496e-01 3.377e-01
## savings100 <= ... < 500 DM 5.233e-01 4.443e-01
## savings500 <= ... < 1000 DM 1.316e+00 5.692e-01
## savings... >= 1000 DM 9.885e-01 2.983e-01
## status... < 0 DM 1.315e-01 2.558e-01
## status0<= ... < 200 DM 8.974e-01 4.427e-01
## status... >= 200 DM / salary for at least 1 year 1.623e+00 2.681e-01
## telephoneyes (under customer name) 3.143e-01 2.305e-01
## z value Pr(>|z|)
## (Intercept) -0.139 0.889817
## age 0.544 0.586361
## amount -2.259 0.023910 *
## credit_historycritical account/other credits elsewhere -1.604 0.108806
## credit_historyno credits taken/all credits paid back duly 0.768 0.442612
## credit_historyexisting credits paid back duly till now 1.715 0.086353 .
## credit_historyall credits at this bank paid back duly 2.768 0.005636 **
## duration -2.511 0.012052 *
## employment_duration< 1 yr -0.030 0.975704
## employment_duration1 <= ... < 4 yrs 0.427 0.669230
## employment_duration4 <= ... < 7 yrs 1.875 0.060842 .
## employment_duration>= 7 yrs 0.801 0.423408
## foreign_workerno -1.739 0.081956 .
## housingrent 2.264 0.023571 *
## housingown 1.192 0.233383
## installment_rate.L -2.380 0.017307 *
## installment_rate.Q 0.403 0.686685
## installment_rate.C 0.275 0.783095
## jobunskilled - resident -1.092 0.274757
## jobskilled employee/official -1.098 0.272063
## jobmanager/self-empl./highly qualif. employee -1.231 0.218137
## number_credits.L -0.550 0.582157
## number_credits.Q 0.140 0.888294
## number_credits.C 0.012 0.990417
## other_debtorsco-applicant -1.930 0.053659 .
## other_debtorsguarantor 2.820 0.004806 **
## other_installment_plansstores 0.279 0.780181
## other_installment_plansnone 1.689 0.091136 .
## people_liable0 to 2 0.895 0.370704
## personal_status_sexfemale : non-single or male : single -0.042 0.966764
## personal_status_sexmale : married/widowed 1.419 0.155847
## personal_status_sexfemale : single 0.147 0.883447
## present_residence.L -0.652 0.514230
## present_residence.Q 1.959 0.050116 .
## present_residence.C -1.556 0.119724
## propertycar or other -0.971 0.331602
## propertybuilding soc. savings agr./life insurance -0.361 0.718219
## propertyreal estate -1.543 0.122757
## purposecar (new) 3.887 0.000101 ***
## purposecar (used) 2.942 0.003265 **
## purposefurniture/equipment 3.055 0.002251 **
## purposeradio/television -0.053 0.957566
## purposedomestic appliances -0.064 0.948839
## purposerepairs 0.066 0.947743
## purposevacation 0.559 0.576176
## purposeretraining 1.858 0.063142 .
## purposebusiness 2.379 0.017359 *
## savings... < 100 DM 0.739 0.459905
## savings100 <= ... < 500 DM 1.178 0.238878
## savings500 <= ... < 1000 DM 2.312 0.020794 *
## savings... >= 1000 DM 3.313 0.000922 ***
## status... < 0 DM 0.514 0.607376
## status0<= ... < 200 DM 2.027 0.042667 *
## status... >= 200 DM / salary for at least 1 year 6.052 1.43e-09 ***
## telephoneyes (under customer name) 1.363 0.172751
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 982.41 on 799 degrees of freedom
## Residual deviance: 700.57 on 745 degrees of freedom
## AIC: 810.57
##
## Number of Fisher Scoring iterations: 5
- 随机森林Random forest
就像逻辑回归一样,我们可以训练一个随机森林。我们使用 ranger包快速实现。为此,我们首先需要定义学习器,然后实际训练它。
我们现在另外提供重要性参数(importance = “permutation”)。这样做,我们覆盖默认值,让学习器根据排列特征重要性来确定特征重要性:
learner_rf = lrn("classif.ranger", importance = "permutation")
learner_rf$train(task, row_ids = train_set)
我们可以通过$importance命令来观察自变量的重要程度:
learner_rf$importance()
## status duration amount
## 0.0330947539 0.0175370797 0.0134572307
## credit_history savings age
## 0.0129659380 0.0095783381 0.0065733821
## property employment_duration purpose
## 0.0053766886 0.0053485974 0.0047822849
## other_debtors installment_rate personal_status_sex
## 0.0043989633 0.0036503334 0.0029137105
## present_residence number_credits housing
## 0.0022437675 0.0017202412 0.0013506399
## telephone people_liable job
## 0.0012456826 0.0007195306 0.0006561488
## other_installment_plans foreign_worker
## 0.0003107618 0.0001042939
为了获得重要性值的图,我们将重要性转换为 data.table格式,然后用 ggplot2 处理它:
importance = as.data.table(learner_rf$importance(), keep.rownames = TRUE)
# 修改列名称
colnames(importance) = c("Feature", "Importance")
# 用ggplot包画出重要性的图
ggplot(data=importance,
aes(x = reorder(Feature, Importance), y = Importance)) +
geom_col() + coord_flip() + xlab("")
可以看出前七个变量对于预测因变量起到了重要作用。
3.3 预测 Prediction
接下来我们要使用训练得到的模型进行预测。训练模型后,该模型可用于预测。通常,预测是机器学习模型的主要目的。
在我们的案例中,该模型可用于对新的信用申请人进行分类。它们基于特征的相关信用风险(好与坏)。通常,机器学习模型会预测数值。在回归情况下,这是很自然的。对于分类,大多数模型预测分数或概率。基于这些值,可以得出类别预测。
- 预测类别 Predict Classes
首先,我们直接预测类别:
pred_logreg = learner_logreg$predict(task, row_ids = test_set)
pred_rf = learner_rf$predict(task, row_ids = test_set)
pred_logreg
## <PredictionClassif> for 200 observations:
## row_ids truth response
## 2 good bad
## 3 good good
## 6 good good
## ---
## 986 bad good
## 998 bad good
## 1000 bad good
pred_rf
## <PredictionClassif> for 200 observations:
## row_ids truth response
## 2 good good
## 3 good good
## 6 good good
## ---
## 986 bad good
## 998 bad good
## 1000 bad good
$predict() 方法返回一个 Prediction 对象。如果想在之后使用它,可以将其转换为 data.table格式。
我们还可以显示在混淆矩阵中的预测结果:
pred_logreg$confusion
## truth
## response bad good
## bad 28 26
## good 29 117
pred_rf$confusion
## truth
## response bad good
## bad 22 15
## good 35 128
- 预测概率 Predict Probabilities
大多数学习期Learner不仅可以预测类别变量(“响应”),还可以预测他们对给定响应的“置信度”/“不确定性”程度。通常,我们通过将 Learner 的 $predict_type设置为“prob”来实现这一点。有时这需要在学习者接受培训之前完成。或者,我们可以使用此选项直接创建学习器:lrn(“classif.log_reg”, predict_type=“prob”)
learner_logreg$predict_type = "prob"
learner_logreg$predict(task, row_ids = test_set)
## <PredictionClassif> for 200 observations:
## row_ids truth response prob.bad prob.good
## 2 good bad 0.5502737 0.4497263
## 3 good good 0.2432334 0.7567666
## 6 good good 0.1617924 0.8382076
## ---
## 986 bad good 0.1088596 0.8911404
## 998 bad good 0.1524203 0.8475797
## 1000 bad good 0.3172837 0.6827163
3.4 评估Performance Evaluation
为了衡量学习者在新的数据上的表现,我们通常通过将数据分成训练集和测试集来模拟unseen数据的场景。训练集用于训练学习器,测试集仅用于预测和评估训练后的学习器的表现。许多重采样方法(交叉验证cross-validation、引导bootstrap)以不同的方式重复分割过程。
在 mlr3 中,我们需要使用 rsmp() 函数指定重采样策略resampling strategy:
resampling = rsmp("holdout", ratio = 2/3)
print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667
在这里,我们使用“holdout”,这是一个简单的训练-测试分割(只有一次迭代)。我们使用resample()函数进行重采样计算:
res = resample(task, learner = learner_logreg, resampling = resampling)
## INFO [16:08:51.897] [mlr3] Applying learner 'classif.log_reg' on task 'germancredit' (iter 1/1)
res
## <ResampleResult> of 1 iterations
## * Task: germancredit
## * Learner: classif.log_reg
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
度量的默认分数包含在 $aggregate() 中:
res$aggregate()
## classif.ce
## 0.2612613
这种情况下的默认度量是分类错误。越低越好。
我们可以运行不同的重采样策略,例如重复坚持(“二次抽样”),或交叉验证。大多数方法对不同的数据子集执行重复的训练/预测循环并聚合结果(通常作为平均值)。手动执行此操作需要我们编写循环。 mlr3 为我们完成了这项工作:
resampling = rsmp("subsampling", repeats=10)
rr = resample(task, learner = learner_logreg, resampling = resampling)
rr$aggregate()
## classif.ce
## 0.2564565
此外,我们也可以使用交叉验证
resampling = resampling = rsmp("cv", folds=10)
rr = resample(task, learner = learner_logreg, resampling = resampling)
rr$aggregate()
## classif.ce
## 0.246
mlr3 具有更多评估的分数。在这里,我们用 mlr_measures_classif.fpr 计算 false positive rate,用 mlr_measures_classif.fnr 计算 false negative rate。可以将多个度量作为度量列表提供(可以通过 msrs() 直接构造):
# false positive rate
rr$aggregate(msr("classif.fpr"))
## classif.fpr
## 0.1345898
# false positive rate and false negative
measures = msrs(c("classif.fpr", "classif.fnr"))
rr$aggregate(measures)
## classif.fpr classif.fnr
## 0.1345898 0.5068602
还有更多的重采样方法和相当多的度量(在 mlr3measures 中实现)。
mlr_resamplings
## <DictionaryResampling> with 8 stored values
## Keys: bootstrap, custom, cv, holdout, insample, loo, repeated_cv,
## subsampling
# 评估分数类型
mlr_measures
## <DictionaryMeasure> with 54 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.bbrier,
## classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
## classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
## classif.logloss, classif.mbrier, classif.mcc, classif.npv,
## classif.ppv, classif.prauc, classif.precision, classif.recall,
## classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
## classif.tp, classif.tpr, debug, oob_error, regr.bias, regr.ktau,
## regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
## regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
## regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
## selected_features, time_both, time_predict, time_train
3.5模型效果对 Performance Comparision and Benchmarks
我们可以通过手动评估每个学习期的 resample() 来比较学习器。但是, benchmark() 会自动为多个学习者和任务执行重采样评估。 benchmark_grid() 创建完全交叉的设计:比较多个任务的多个学习者 w.r.t.多次重采样。
learners = lrns(c("classif.log_reg", "classif.ranger"), predict_type = "prob")
bm_design = benchmark_grid(
tasks = task,
learners = learners,
resamplings = rsmp("cv", folds = 50)
)
bmr = benchmark(bm_design)
在基准测试中,我们可以比较不同的度量。在这里,我们看一下误分类率和 AUC:
measures = msrs(c("classif.ce", "classif.auc"))
performances = bmr$aggregate(measures)
performances[, c("learner_id", "classif.ce", "classif.auc")]
3.6超参数调优Deviating from hyperparameters defaults
之前展示的技术构建了以 mlr3 为特色的机器学习工作流程的支柱。然而,在大多数情况下,人们永远不会像我们那样进行。虽然许多 R 包都精心选择了默认设置,但它们在任何情况下都不会以最佳方式运行。通常,我们可以选择此类超参数的值。学习者的(超)参数可以通过它的 ParamSet $param_set 访问和设置:
learner_rf$param_set
## <ParamSet>
## id class lower upper nlevels default
## 1: alpha ParamDbl -Inf Inf Inf 0.5
## 2: always.split.variables ParamUty NA NA Inf <NoDefault[3]>
## 3: class.weights ParamDbl -Inf Inf Inf
## 4: holdout ParamLgl NA NA 2 FALSE
## 5: importance ParamFct NA NA 4 <NoDefault[3]>
## 6: keep.inbag ParamLgl NA NA 2 FALSE
## 7: max.depth ParamInt -Inf Inf Inf
## 8: min.node.size ParamInt 1 Inf Inf 1
## 9: min.prop ParamDbl -Inf Inf Inf 0.1
## 10: minprop ParamDbl -Inf Inf Inf 0.1
## 11: mtry ParamInt 1 Inf Inf <NoDefault[3]>
## 12: num.random.splits ParamInt 1 Inf Inf 1
## 13: num.threads ParamInt 1 Inf Inf 1
## 14: num.trees ParamInt 1 Inf Inf 500
## 15: oob.error ParamLgl NA NA 2 TRUE
## 16: regularization.factor ParamUty NA NA Inf 1
## 17: regularization.usedepth ParamLgl NA NA 2 FALSE
## 18: replace ParamLgl NA NA 2 TRUE
## 19: respect.unordered.factors ParamFct NA NA 3 ignore
## 20: sample.fraction ParamDbl 0 1 Inf <NoDefault[3]>
## 21: save.memory ParamLgl NA NA 2 FALSE
## 22: scale.permutation.importance ParamLgl NA NA 2 FALSE
## 23: se.method ParamFct NA NA 2 infjack
## 24: seed ParamInt -Inf Inf Inf
## 25: split.select.weights ParamDbl 0 1 Inf <NoDefault[3]>
## 26: splitrule ParamFct NA NA 2 gini
## 27: verbose ParamLgl NA NA 2 TRUE
## 28: write.forest ParamLgl NA NA 2 TRUE
## id class lower upper nlevels default
## parents value
## 1:
## 2:
## 3:
## 4:
## 5: permutation
## 6:
## 7:
## 8:
## 9:
## 10:
## 11:
## 12: splitrule
## 13: 1
## 14:
## 15:
## 16:
## 17:
## 18:
## 19:
## 20:
## 21:
## 22: importance
## 23:
## 24:
## 25:
## 26:
## 27:
## 28:
## parents value
learner_rf$param_set$values = list(verbose = FALSE)
我们可以通过两种不同的方式为我们的学习者选择参数。如果我们对学习器应该如何(超)参数化有先验知识,那么要走的路将是在参数集中手动输入参数。然而,在大多数情况下,我们希望调整学习器,以便它可以自己搜索“好的”模型配置。目前,我们只想比较几个模型。
要了解可以操作哪些参数,我们可以调查原始包版本的参数或查看学习器的参数集:
as.data.table(learner_rf$param_set)[,.(id, class, lower, upper)]
对于随机森林,控制模型复杂性的两个有意义的参数是 num.trees 和 mtry。 num.trees 默认为 500,mtry 为 floor(sqrt(ncol(data) - 1)),在我们的例子中是 4。
下面我们的目标是训练三个不同的学习器:
- 默认随机森林。
- 低 num.trees 和低 mtry 的随机森林。
- 具有高 num.trees 和高 mtry 的随机森林。
我们将在德国信用数据集上对他们的表现进行基准测试。为此,我们构建了三个学习器并相应地设置参数:
rf_med = lrn("classif.ranger", id = "med", predict_type = "prob")
rf_low = lrn("classif.ranger", id = "low", predict_type = "prob",
num.trees = 5, mtry = 2)
rf_high = lrn("classif.ranger", id = "high", predict_type = "prob",
num.trees = 1000, mtry = 11)
一旦定义了学习器,我们就可以对它们进行基准测试:
learners = list(rf_low, rf_med, rf_high)
bm_design = benchmark_grid(
tasks = task,
learners = learners,
resamplings = rsmp("cv", folds = 10)
)
bmr = benchmark(bm_design)
print(bmr)
## <BenchmarkResult> of 30 rows with 3 resampling runs
## nr task_id learner_id resampling_id iters warnings errors
## 1 germancredit low cv 10 0 0
## 2 germancredit med cv 10 0 0
## 3 germancredit high cv 10 0 0
我们比较不同学习器的误分类率和 AUC:
measures = msrs(c("classif.ce", "classif.auc"))
performances = bmr$aggregate(measures)
performances[, .(learner_id, classif.ce, classif.auc)]
autoplot(bmr)
“低”设置似乎有点不适合,“高”设置的标准差比默认设置“中”的大。所以对比三个参数调优模型,本文中还是默认参数的模型更优。
Session info
## R version 4.0.3 (2020-10-10)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.7
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] mlr3viz_0.5.3 mlr3learners_0.4.5 mlr3_0.11.0 data.table_1.14.0
## [5] forcats_0.5.1 stringr_1.4.0 dplyr_1.0.5 purrr_0.3.4
## [9] readr_1.4.0 tidyr_1.1.3 tibble_3.1.1 ggplot2_3.3.3
## [13] tidyverse_1.3.0
##
## loaded via a namespace (and not attached):
## [1] httr_1.4.2 sass_0.3.1 jsonlite_1.7.2
## [4] modelr_0.1.8 bslib_0.2.4 assertthat_0.2.1
## [7] lgr_0.4.2 highr_0.9 cellranger_1.1.0
## [10] yaml_2.2.1 mlr3misc_0.8.0 globals_0.14.0
## [13] pillar_1.6.0 backports_1.2.1 lattice_0.20-41
## [16] glue_1.4.2 uuid_0.1-4 digest_0.6.27
## [19] checkmate_2.0.0 rvest_1.0.0 colorspace_2.0-0
## [22] htmltools_0.5.1.1 Matrix_1.2-18 pkgconfig_2.0.3
## [25] mlr3measures_0.3.1 broom_0.7.6.9001 listenv_0.8.0
## [28] haven_2.3.1 scales_1.1.1 ranger_0.12.1
## [31] farver_2.0.3 generics_0.1.0 ellipsis_0.3.1
## [34] withr_2.4.1 repr_1.1.3 skimr_2.1.3
## [37] cli_2.4.0 magrittr_2.0.1 crayon_1.4.1
## [40] readxl_1.3.1 paradox_0.7.1 evaluate_0.14
## [43] future_1.21.0 fs_1.5.0 fansi_0.4.2
## [46] parallelly_1.24.0 xml2_1.3.2 palmerpenguins_0.1.0
## [49] tools_4.0.3 hms_1.0.0 lifecycle_1.0.0
## [52] munsell_0.5.0 reprex_2.0.0 compiler_4.0.3
## [55] jquerylib_0.1.3 rlang_0.4.10 grid_4.0.3
## [58] rstudioapi_0.13 base64enc_0.1-3 labeling_0.4.2
## [61] rmarkdown_2.7 codetools_0.2-16 gtable_0.3.0
## [64] DBI_1.1.1 R6_2.5.0 lubridate_1.7.9.2
## [67] knitr_1.33 future.apply_1.7.0 utf8_1.2.1
## [70] stringi_1.5.3 parallel_4.0.3 Rcpp_1.0.6
## [73] vctrs_0.3.7 dbplyr_2.1.0 tidyselect_1.1.0
## [76] xfun_0.22
Reference
Lovelace, Robin, Jakub Nowosad, and Jannes Muenchow. 2019. Geocomputation with r. CRC Press.
Lang, Michel. 2017. “checkmate: Fast Argument Checks for Defensive R Programming.” The R Journal 9 (1): 437–45. https://doi.org/10.32614/RJ-2017-028.
Funk, et al. (2020, July 27). mlr3gallery: Bike Sharing Demand - Use Case. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-07-27-bikesharing-demand/
Binder & Pfisterer (2020, March 11). mlr3gallery: mlr3tuning Tutorial - German Credit. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-03-11-mlr3tuning-tutorial-german-credit/
Pfisterer (2020, April 27). mlr3gallery: A Pipeline for the Titanic Data Set - Advanced. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-04-27-mlr3pipelines-Imputation-titanic/
Li, Lisha, Kevin G. Jamieson, Giulia DeSalvo, Afshin Rostamizadeh, and Ameet Talwalkar. 2016. “Efficient Hyperparameter Optimization and Infinitely Many Armed Bandits.” CoRR abs/1603.06560. http://arxiv.org/abs/1603.06560.
Schratz, Patrick, Jannes Muenchow, Eugenia Iturritxa, Jakob Richter, and Alexander Brenning. 2019. “Hyperparameter Tuning and Performance Assessment of Statistical and Machine-Learning Algorithms Using Spatial Data.” Ecological Modelling 406 (August): 109–20. https://doi.org/10.1016/j.ecolmodel.2019.06.002.