本教程的知识点为:机器学习算法定位、 K-近邻算法 1.4 k值的选择 1 K值选择说明 1.6 案例:鸢尾花种类预测--数据集介绍 1 案例:鸢尾花种类预测 1.8 案例:鸢尾花种类预测—流程实现 1 再识K-近邻算法API 1.11 案例2:预测facebook签到位置 1 项目描述 线性回归 2.3 数学:求导 1 常见函数的导数 线性回归 2.5 梯度下降方法介绍 1 详解梯度下降算法 线性回归 2.6 线性回归api再介绍 小结 线性回归 2.9 正则化线性模型 1 Ridge Regression (岭回归,又名 Tikhonov regularization) 逻辑回归 3.3 案例:癌症分类预测-良/恶性乳腺癌肿瘤预测 1 背景介绍 决策树算法 4.2 决策树分类原理 1 熵 决策树算法 4.3 cart剪枝 1 为什么要剪枝 决策树算法 4.4 特征工程-特征提取 1 特征提取 决策树算法 4.5 决策树算法api 4.6 案例:泰坦尼克号乘客生存预测 集成学习基础 5.1 集成学习算法简介 1 什么是集成学习 2 复习:机器学习的两个核心任务 集成学习基础 5.3 otto案例介绍 -- Otto Group Product Classification Challenge 1.背景介绍 2.数据集介绍 3.评分标准 集成学习基础 5.5 GBDT介绍 1 Decision Tree:CART回归树 1.1 回归树生成算法(复习) 聚类算法 6.1 聚类算法简介 1 认识聚类算法 聚类算法 6.5 算法优化 1 Canopy算法配合初始聚类 聚类算法 6.7 案例:探究用户对物品类别的喜好细分 1 需求 第一章知识补充:再议数据分割 1 留出法 2 交叉验证法 KFold和StratifiedKFold 3 自助法 正规方程的另一种推导方式 1.损失表示方式 2.另一种推导方式 梯度下降法算法比较和进一步优化 1 算法比较 2 梯度下降优化算法 第二章知识补充: 多项式回归 1 多项式回归的一般形式 维灾难 1 什么是维灾难 2 维数灾难与过拟合 第三章补充内容:分类中解决类别不平衡问题 1 类别不平衡数据集基本介绍 向量与矩阵的范数 1.向量的范数 2.矩阵的范数 如何理解无偏估计?无偏估计有什么用? 1.如何理解无偏估计
完整笔记资料代码->:https://gitee.com/yinuo112/AI/tree/master/机器学习/嘿马机器学习(算法篇)/note.md
感兴趣的小伙伴可以自取哦~
全套教程部分目录:
部分文件图片:
线性回归
学习目标
- 掌握线性回归的实现过程
- 应用LinearRegression或SGDRegressor实现回归预测
- 知道回归算法的评估标准及其公式
- 知道过拟合与欠拟合的原因以及解决方法
- 知道岭回归的原理及与线性回归的不同之处
- 应用Ridge实现回归预测
- 应用joblib实现模型的保存与加载
2.5 梯度下降方法介绍
学习目标
- 掌握梯度下降法的推导过程
- 知道全梯度下降算法的原理
- 知道随机梯度下降算法的原理
- 知道随机平均梯度下降算法的原理
- 知道小批量梯度下降算法的原理
上一节中给大家介绍了最基本的梯度下降法实现流程,本节我们将进一步介绍梯度下降法的详细过算法推导过程和常见的梯度下降算法。
1 详解梯度下降算法
1.1梯度下降的相关概念复习
在详细了解梯度下降的算法之前,我们先复习相关的一些概念。
-
步长(Learning rate):
- **步长决定了在梯度下降迭代的过程中,每一步沿梯度负方向前进的长度。**用前面下山的例子,步长就是在当前这一步所在位置沿着最陡峭最易下山的位置走的那一步的长度。
-
特征(feature):
- 指的是样本中输入部分,比如2个单特征的样本<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mo>(</mo><msup><mi>x</mi><mrow><mo>(</mo><mn>0</mn><mo>)</mo></mrow></msup><mo separator="true">,</mo><msup><mi>y</mi><mrow><mo>(</mo><mn>0</mn><mo>)</mo></mrow></msup><mo>)</mo><mo separator="true">,</mo><mo>(</mo><msup><mi>x</mi><mrow><mo>(</mo><mn>1</mn><mo>)</mo></mrow></msup><mo separator="true">,</mo><msup><mi>y</mi><mrow><mo>(</mo><mn>1</mn><mo>)</mo></mrow></msup><mo>)</mo></mrow><annotation encoding="application/x-tex">(x^{(0)},y^{(0)}),(x^{(1)},y^{(1)})</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.8879999999999999em;"></span><span class="strut bottom" style="height:1.138em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mopen">(</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:-0.363em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathrm mtight">0</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mpunct">,</span><span class="mord"><span class="mord mathit" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist"><span style="top:-0.363em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathrm mtight">0</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mopen">(</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:-0.363em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathrm mtight">1</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mpunct">,</span><span class="mord"><span class="mord mathit" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist"><span style="top:-0.363em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathrm mtight">1</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mclose">)</span></span></span></span>,则第一个样本特征为<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msup><mi>x</mi><mrow><mo>(</mo><mn>0</mn><mo>)</mo></mrow></msup></mrow><annotation encoding="application/x-tex">x^{(0)}</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.8879999999999999em;"></span><span class="strut bottom" style="height:0.8879999999999999em;vertical-align:0em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:-0.363em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathrm mtight">0</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>,第一个样本输出为<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msup><mi>y</mi><mrow><mo>(</mo><mn>0</mn><mo>)</mo></mrow></msup></mrow><annotation encoding="application/x-tex">y^{(0)}</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.8879999999999999em;"></span><span class="strut bottom" style="height:1.0824399999999998em;vertical-align:-0.19444em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist"><span style="top:-0.363em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathrm mtight">0</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>。
-
假设函数(hypothesis function):
- **在监督学习中,为了拟合输入样本,而使用的假设函数,记为<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>h</mi><mi>θ</mi></msub><mo>(</mo><mi>x</mi><mo>)</mo></mrow><annotation encoding="application/x-tex">h_\theta (x)</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.75em;"></span><span class="strut bottom" style="height:1em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">h</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight" style="margin-right:0.02778em;">θ</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathit">x</span><span class="mclose">)</span></span></span></span>。**比如对于单个特征的m个样本<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mo>(</mo><msup><mi>x</mi><mrow><mo>(</mo><mi>i</mi><mo>)</mo></mrow></msup><mo separator="true">,</mo><msup><mi>y</mi><mrow><mo>(</mo><mi>i</mi><mo>)</mo></mrow></msup><mo>)</mo><mo>(</mo><mi>i</mi><mo>=</mo><mn>1</mn><mo separator="true">,</mo><mn>2</mn><mo separator="true">,</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi>m</mi><mo>)</mo></mrow><annotation encoding="application/x-tex">(x^{(i)},y^{(i)})(i=1,2,...m)</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.8879999999999999em;"></span><span class="strut bottom" style="height:1.138em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mopen">(</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:-0.363em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathit mtight">i</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mpunct">,</span><span class="mord"><span class="mord mathit" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist"><span style="top:-0.363em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathit mtight">i</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mclose">)</span><span class="mopen">(</span><span class="mord mathit">i</span><span class="mrel">=</span><span class="mord mathrm">1</span><span class="mpunct">,</span><span class="mord mathrm">2</span><span class="mpunct">,</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathit">m</span><span class="mclose">)</span></span></span></span>,可以采用拟合函数如下:<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>h</mi><mi>θ</mi></msub><mo>(</mo><mi>x</mi><mo>)</mo><mo>=</mo><msub><mi>θ</mi><mn>0</mn></msub><mo>+</mo><msub><mi>θ</mi><mn>1</mn></msub><mi>x</mi></mrow><annotation encoding="application/x-tex">h_\theta (x)=\theta _0+\theta _1x</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.75em;"></span><span class="strut bottom" style="height:1em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">h</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight" style="margin-right:0.02778em;">θ</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathit">x</span><span class="mclose">)</span><span class="mrel">=</span><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">0</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mbin">+</span><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">1</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mord mathit">x</span></span></span></span>。
-
损失函数(loss function):
- 为了评估模型拟合的好坏,**通常用损失函数来度量拟合的程度。**损失函数极小化,意味着拟合程度最好,对应的模型参数即为最优参数。
- 在线性回归中,损失函数通常为样本输出和假设函数的差取平方。比如对于m个样本<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo separator="true">,</mo><msub><mi>y</mi><mi>i</mi></msub><mo>)</mo><mo>(</mo><mi>i</mi><mo>=</mo><mn>1</mn><mo separator="true">,</mo><mn>2</mn><mo separator="true">,</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi>m</mi><mo>)</mo></mrow><annotation encoding="application/x-tex">(x_i,y_i)(i=1,2,...m)</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.75em;"></span><span class="strut bottom" style="height:1em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mopen">(</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mpunct">,</span><span class="mord"><span class="mord mathit" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.03588em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mclose">)</span><span class="mopen">(</span><span class="mord mathit">i</span><span class="mrel">=</span><span class="mord mathrm">1</span><span class="mpunct">,</span><span class="mord mathrm">2</span><span class="mpunct">,</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathit">m</span><span class="mclose">)</span></span></span></span>,采用线性回归,损失函数为:
其中<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>x</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">x_i</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.43056em;"></span><span class="strut bottom" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>表示第i个样本特征,<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>y</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">y_i</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.43056em;"></span><span class="strut bottom" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.03588em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>表示第i个样本对应的输出,<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>h</mi><mi>θ</mi></msub><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>)</mo></mrow><annotation encoding="application/x-tex">h_\theta (x_i)</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.75em;"></span><span class="strut bottom" style="height:1em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">h</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight" style="margin-right:0.02778em;">θ</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mclose">)</span></span></span></span>为假设函数。
1.2 梯度下降法的推导流程
1) 先决条件: 确认优化模型的假设函数和损失函数。
比如对于线性回归,假设函数表示为<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>h</mi><mi>θ</mi></msub><mo>(</mo><msub><mi>x</mi><mn>1</mn></msub><mo separator="true">,</mo><msub><mi>x</mi><mn>2</mn></msub><mo separator="true">,</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mo separator="true">,</mo><msub><mi>x</mi><mi>n</mi></msub><mo>)</mo><mo>=</mo><msub><mi>θ</mi><mn>0</mn></msub><mo>+</mo><msub><mi>θ</mi><mn>1</mn></msub><msub><mi>x</mi><mn>1</mn></msub><mo>+</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mo>+</mo><msub><mi>θ</mi><mi>n</mi></msub><msub><mi>x</mi><mi>n</mi></msub></mrow><annotation encoding="application/x-tex">h_\theta (x_1,x_2,...,x_n)=\theta _0+\theta _1x_1+...+\theta _nx_n</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.75em;"></span><span class="strut bottom" style="height:1em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">h</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight" style="margin-right:0.02778em;">θ</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">1</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mpunct">,</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">2</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mpunct">,</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mpunct">,</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">n</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mclose">)</span><span class="mrel">=</span><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">0</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mbin">+</span><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">1</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">1</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mbin">+</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mbin">+</span><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">n</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">n</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>, 其中<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>θ</mi><mi>i</mi></msub><mo>(</mo><mi>i</mi><mo>=</mo><mn>0</mn><mo separator="true">,</mo><mn>1</mn><mo separator="true">,</mo><mn>2</mn><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi>n</mi><mo>)</mo></mrow><annotation encoding="application/x-tex">\theta _i (i = 0,1,2... n)</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.75em;"></span><span class="strut bottom" style="height:1em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathit">i</span><span class="mrel">=</span><span class="mord mathrm">0</span><span class="mpunct">,</span><span class="mord mathrm">1</span><span class="mpunct">,</span><span class="mord mathrm">2</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathit">n</span><span class="mclose">)</span></span></span></span>为模型参数,<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>x</mi><mi>i</mi></msub><mo>(</mo><mi>i</mi><mo>=</mo><mn>0</mn><mo separator="true">,</mo><mn>1</mn><mo separator="true">,</mo><mn>2</mn><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi>n</mi><mo>)</mo></mrow><annotation encoding="application/x-tex">x_i (i = 0,1,2... n)</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.75em;"></span><span class="strut bottom" style="height:1em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathit">i</span><span class="mrel">=</span><span class="mord mathrm">0</span><span class="mpunct">,</span><span class="mord mathrm">1</span><span class="mpunct">,</span><span class="mord mathrm">2</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathit">n</span><span class="mclose">)</span></span></span></span>为每个样本的n个特征值。这个表示可以简化,我们增加一个特征<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>x</mi><mn>0</mn></msub><mo>=</mo><mn>1</mn></mrow><annotation encoding="application/x-tex">x_0=1</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.64444em;"></span><span class="strut bottom" style="height:0.79444em;vertical-align:-0.15em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">0</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mrel">=</span><span class="mord mathrm">1</span></span></span></span>,这样
同样是线性回归,对应于上面的假设函数,损失函数为:
2) 算法相关参数初始化,
主要是初始化<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>θ</mi><mn>0</mn></msub><mo separator="true">,</mo><msub><mi>θ</mi><mn>1</mn></msub><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mo separator="true">,</mo><msub><mi>θ</mi><mi>n</mi></msub></mrow><annotation encoding="application/x-tex">\theta _0,\theta _1...,\theta _n</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.69444em;"></span><span class="strut bottom" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">0</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mpunct">,</span><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">1</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mpunct">,</span><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">n</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>,算法终止距离ε以及步长<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>α</mi></mrow><annotation encoding="application/x-tex">\alpha</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.43056em;"></span><span class="strut bottom" style="height:0.43056em;vertical-align:0em;"></span><span class="base textstyle uncramped"><span class="mord mathit" style="margin-right:0.0037em;">α</span></span></span></span>。在没有任何先验知识的时候,我喜欢将所有的<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>θ</mi></mrow><annotation encoding="application/x-tex">\theta</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.69444em;"></span><span class="strut bottom" style="height:0.69444em;vertical-align:0em;"></span><span class="base textstyle uncramped"><span class="mord mathit" style="margin-right:0.02778em;">θ</span></span></span></span>初始化为0, 将步长初始化为1。在调优的时候再 优化。
3) 算法过程:
3.1) 确定当前位置的损失函数的梯度,对于<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>θ</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">\theta _i</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.69444em;"></span><span class="strut bottom" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>,其梯度表达式如下:
3.2) 用步长乘以损失函数的梯度,得到当前位置下降的距离,即
对应于前面登山例子中的某一步。
3.3) 确定是否所有的<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>θ</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">\theta _i</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.69444em;"></span><span class="strut bottom" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>,梯度下降的距离都小于ε,如果小于ε则算法终止,当前所有的<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>θ</mi><mi>i</mi></msub><mo>(</mo><mi>i</mi><mo>=</mo><mn>0</mn><mo separator="true">,</mo><mn>1</mn><mo separator="true">,</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi>n</mi><mo>)</mo></mrow><annotation encoding="application/x-tex">\theta _i(i=0,1,...n)</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.75em;"></span><span class="strut bottom" style="height:1em;vertical-align:-0.25em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathit">i</span><span class="mrel">=</span><span class="mord mathrm">0</span><span class="mpunct">,</span><span class="mord mathrm">1</span><span class="mpunct">,</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathit">n</span><span class="mclose">)</span></span></span></span>即为最终结果。否则进入步骤4.
4)更新所有的<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>θ</mi></mrow><annotation encoding="application/x-tex">\theta</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.69444em;"></span><span class="strut bottom" style="height:0.69444em;vertical-align:0em;"></span><span class="base textstyle uncramped"><span class="mord mathit" style="margin-right:0.02778em;">θ</span></span></span></span>,对于<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>θ</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">\theta _i</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.69444em;"></span><span class="strut bottom" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>,其更新表达式如下。更新完毕后继续转入步骤1.
下面用线性回归的例子来具体描述梯度下降。假设我们的样本是:
损失函数如前面先决条件所述:
则在算法过程步骤1中对于<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>θ</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">\theta _i</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.69444em;"></span><span class="strut bottom" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>的偏导数计算如下:
由于样本中没有<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>x</mi><mn>0</mn></msub></mrow><annotation encoding="application/x-tex">x_0</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.43056em;"></span><span class="strut bottom" style="height:0.58056em;vertical-align:-0.15em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">0</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>上式中令所有的<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msubsup><mi>x</mi><mn>0</mn><mi>j</mi></msubsup></mrow><annotation encoding="application/x-tex">x_0^j</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.942572em;"></span><span class="strut bottom" style="height:1.20888em;vertical-align:-0.266308em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.266308em;margin-left:0em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">0</span></span></span><span style="top:-0.480908em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord mathit mtight" style="margin-right:0.05724em;">j</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>为1.
步骤4中<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>θ</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">\theta _i</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.69444em;"></span><span class="strut bottom" style="height:0.84444em;vertical-align:-0.15em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>的更新表达式如下:
从这个例子可以看出当前点的梯度方向是由所有的样本决定的,加<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mfrac><mrow><mn>1</mn></mrow><mrow><mi>m</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">\frac{1}{m}</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.845108em;"></span><span class="strut bottom" style="height:1.190108em;vertical-align:-0.345em;"></span><span class="base textstyle uncramped"><span class="mord reset-textstyle textstyle uncramped"><span class="mopen sizing reset-size5 size5 reset-textstyle textstyle uncramped nulldelimiter"></span><span class="mfrac"><span class="vlist"><span style="top:0.345em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord scriptstyle cramped mtight"><span class="mord mathit mtight">m</span></span></span></span><span style="top:-0.22999999999999998em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle textstyle uncramped frac-line"></span></span><span style="top:-0.394em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mord mathrm mtight">1</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span><span class="mclose sizing reset-size5 size5 reset-textstyle textstyle uncramped nulldelimiter"></span></span></span></span></span>是为了好理解。由于步长也为常数,他们的乘积也为常数,所以这里<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>α</mi></mrow><annotation encoding="application/x-tex">\alpha</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.43056em;"></span><span class="strut bottom" style="height:0.43056em;vertical-align:0em;"></span><span class="base textstyle uncramped"><span class="mord mathit" style="margin-right:0.0037em;">α</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mfrac><mrow><mn>1</mn></mrow><mrow><mi>m</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">\frac{1}{m}</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.845108em;"></span><span class="strut bottom" style="height:1.190108em;vertical-align:-0.345em;"></span><span class="base textstyle uncramped"><span class="mord reset-textstyle textstyle uncramped"><span class="mopen sizing reset-size5 size5 reset-textstyle textstyle uncramped nulldelimiter"></span><span class="mfrac"><span class="vlist"><span style="top:0.345em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord scriptstyle cramped mtight"><span class="mord mathit mtight">m</span></span></span></span><span style="top:-0.22999999999999998em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle textstyle uncramped frac-line"></span></span><span style="top:-0.394em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mord mathrm mtight">1</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span><span class="mclose sizing reset-size5 size5 reset-textstyle textstyle uncramped nulldelimiter"></span></span></span></span></span>可以用一个常数表示。
在下面一节中,咱们会详细讲到梯度下降法的变种,他们主要的区别就是对样本的采用方法不同。这里我们采用的是用所有样本。
2 梯度下降法大家族
首先,我们来看一下,常见的梯度下降算法有:
- 全梯度下降算法(Full gradient descent),
- 随机梯度下降算法(Stochastic gradient descent),
- 小批量梯度下降算法(Mini-batch gradient descent),
- 随机平均梯度下降算法(Stochastic average gradient descent)
它们都是为了正确地调节权重向量,通过为每个权重计算一个梯度,从而更新权值,使目标函数尽可能最小化。其差别在于样本的使用方式不同。
2.1 全梯度下降算法(FG)
批量梯度下降法,是梯度下降法最常用的形式,具体做法也就是在更新参数时使用所有的样本来进行更新。
计算训练集所有样本误差,对其求和再取平均值作为目标函数。
权重向量沿其梯度相反的方向移动,从而使当前目标函数减少得最多。
其是在整个训练数据集上计算损失函数关于参数<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>θ</mi></mrow><annotation encoding="application/x-tex">\theta</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:0.69444em;"></span><span class="strut bottom" style="height:0.69444em;vertical-align:0em;"></span><span class="base textstyle uncramped"><span class="mord mathit" style="margin-right:0.02778em;">θ</span></span></span></span>的梯度:
由于我们有m个样本,这里求梯度的时候就用了所有m个样本的梯度数据。
注意:
-
因为在执行每次更新时,我们需要在整个数据集上计算所有的梯度,所以批梯度下降法的速度会很慢,同时,批梯度下降法无法处理超出内存容量限制的数据集。
-
批梯度下降法同样也不能在线更新模型,即在运行的过程中,不能增加新的样本。
2.2 随机梯度下降算法(SG)
由于FG每迭代更新一次权重都需要计算所有样本误差,而实际问题中经常有上亿的训练样本,故效率偏低,且容易陷入局部最优解,因此提出了随机梯度下降算法。
其每轮计算的目标函数不再是全体样本误差,而仅是单个样本误差,即每次只代入计算一个样本目标函数的梯度来更新权重,再取下一个样本重复此过程,直到损失函数值停止下降或损失函数值小于某个可以容忍的阈值。
此过程简单,高效,通常可以较好地避免更新迭代收敛到局部最优解。其迭代形式为
但是由于,SG每次只使用一个样本迭代,若遇上噪声则容易陷入局部最优解。
2.3 小批量梯度下降算法(mini-batch)
小批量梯度下降算法是FG和SG的折中方案,在一定程度上兼顾了以上两种方法的优点。
每次从训练样本集上随机抽取一个小样本集,在抽出来的小样本集上采用FG迭代更新权重。
被抽出的小样本集所含样本点的个数称为batch_size,通常设置为2的幂次方,更有利于GPU加速处理。
特别的,若batch_size=1,则变成了SG;若batch_size=n,则变成了FG.其迭代形式为
上式中,也就是我们从m个样本中,选择x个样本进行迭代(1<x<m),
2.4 随机平均梯度下降算法(SAG)
在SG方法中,虽然避开了运算成本大的问题,但对于大数据训练而言,SG效果常不尽如人意,因为每一轮梯度更新都完全与上一轮的数据和梯度无关。
随机平均梯度算法克服了这个问题,在内存中为每一个样本都维护一个旧的梯度,随机选择第i个样本来更新此样本的梯度,其他样本的梯度保持不变,然后求得所有梯度的平均值,进而更新了参数。
如此,每一轮更新仅需计算一个样本的梯度,计算成本等同于SG,但收敛速度快得多。
其迭代形式为:
<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>θ</mi><mi>i</mi></msub><mo>=</mo><msub><mi>θ</mi><mi>i</mi></msub><mo>−</mo><mfrac><mrow><mi>α</mi></mrow><mrow><mi>n</mi></mrow></mfrac><mo>(</mo><msub><mi>h</mi><mi>θ</mi></msub><mo>(</mo><msubsup><mi>x</mi><mn>0</mn><mrow><mo>(</mo><mi>j</mi><mo>)</mo></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>x</mi><mn>1</mn><mrow><mo>(</mo><mi>j</mi><mo>)</mo></mrow></msubsup><mo separator="true">,</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><msubsup><mi>x</mi><mi>n</mi><mrow><mo>(</mo><mi>j</mi><mo>)</mo></mrow></msubsup><mo>)</mo><mo>−</mo><msub><mi>y</mi><mi>j</mi></msub><mo>)</mo><msubsup><mi>x</mi><mi>i</mi><mrow><mo>(</mo><mi>j</mi><mo>)</mo></mrow></msubsup></mrow><annotation encoding="application/x-tex">\theta _i=\theta i-\frac{\alpha }{n}(h\theta (x^{(j)}_0,x^{(j)}_1,...x^{(j)}_n)-y_j)x_i^{(j)}</annotation></semantics></math></span><span aria-hidden="true" class="katex-html"><span class="strut" style="height:1.0448em;"></span><span class="strut bottom" style="height:1.3898em;vertical-align:-0.345em;"></span><span class="base textstyle uncramped"><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mrel">=</span><span class="mord"><span class="mord mathit" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.02778em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mbin">−</span><span class="mord reset-textstyle textstyle uncramped"><span class="mopen sizing reset-size5 size5 reset-textstyle textstyle uncramped nulldelimiter"></span><span class="mfrac"><span class="vlist"><span style="top:0.345em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord scriptstyle cramped mtight"><span class="mord mathit mtight">n</span></span></span></span><span style="top:-0.22999999999999998em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle textstyle uncramped frac-line"></span></span><span style="top:-0.394em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mord mathit mtight" style="margin-right:0.0037em;">α</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span><span class="mclose sizing reset-size5 size5 reset-textstyle textstyle uncramped nulldelimiter"></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathit">h</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:0em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight" style="margin-right:0.02778em;">θ</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.26630799999999993em;margin-left:0em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">0</span></span></span><span style="top:-0.5198em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathit mtight" style="margin-right:0.05724em;">j</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mpunct">,</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.26630799999999993em;margin-left:0em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathrm mtight">1</span></span></span><span style="top:-0.5198em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathit mtight" style="margin-right:0.05724em;">j</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mpunct">,</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord mathrm">.</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.11659199999999997em;margin-left:0em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">n</span></span></span><span style="top:-0.5198em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathit mtight" style="margin-right:0.05724em;">j</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mclose">)</span><span class="mbin">−</span><span class="mord"><span class="mord mathit" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist"><span style="top:0.15em;margin-right:0.05em;margin-left:-0.03588em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight" style="margin-right:0.05724em;">j</span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span><span class="mclose">)</span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist"><span style="top:0.27686399999999994em;margin-left:0em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle cramped mtight"><span class="mord mathit mtight">i</span></span></span><span style="top:-0.5198em;margin-right:0.05em;"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span><span class="reset-textstyle scriptstyle uncramped mtight"><span class="mord scriptstyle uncramped mtight"><span class="mopen mtight">(</span><span class="mord mathit mtight" style="margin-right:0.05724em;">j</span><span class="mclose mtight">)</span></span></span></span><span class="baseline-fix"><span class="fontsize-ensurer reset-size5 size5"><span style="font-size:0em;"></span></span></span></span></span></span></span></span></span>
-
我们知道sgd是当前权重减去步长乘以梯度,得到新的权重。sag中的a,就是平均的意思,具体说,就是在第k步迭代的时候,我考虑的这一步和前面n-1个梯度的平均值,当前权重减去步长乘以最近n个梯度的平均值。
-
n是自己设置的,当n=1的时候,就是普通的sgd。
-
这个想法非常的简单,在随机中又增加了确定性,类似于mini-batch sgd的作用,但不同的是,sag又没有去计算更多的样本,只是利用了之前计算出来的梯度,所以每次迭代的计算成本远小于mini-batch sgd,和sgd相当。效果而言,sag相对于sgd,收敛速度快了很多。这一点下面的论文中有具体的描述和证明。
-
SAG论文链接:[
-
拓展阅读:
3 小结
-
全梯度下降算法(FG)【知道】
- 在进行计算的时候,计算所有样本的误差平均值,作为我的目标函数
-
随机梯度下降算法(SG)【知道】
- 每次只选择一个样本进行考核
-
小批量梯度下降算法(mini-batch)【知道】
- 选择一部分样本进行考核
-
随机平均梯度下降算法(SAG)【知道】
- 会给每个样本都维持一个平均值,后期计算的时候,参考这个平均值