本文目录
- 梯度下降算法
- 代码:
- 结果:
- 随机梯度下降SGD
- 代码:
- 结果:
- 二者区别
- 鞍点
- 学习资料:
- 系列文章索引
梯度下降算法
通过计算梯度就可以知道 w 的移动方向,应该让 w 向右走而不是向左走,也可以知道什么时候会到达最低点(梯度为0的地方)。此处引入一个学习率α,可以控制走的快慢,一般训练学习率α不能太大也不能太小,太小的话,可能导致迟迟走不到最低点,太大的话,会导致错过最低点!
代码:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
# 定义线性模型
def forward(x):
return x * w
# 定义所有样本的平均平方误差
def cost(xs, ys):
cost = 0
for x, y in zip(xs, ys):
y_pred = forward(x)
cost += (y_pred - y) ** 2
return cost / len(xs)
# 定义梯度函数
def gradient(xs, ys):
grad = 0
for x, y in zip(xs, ys):
grad += 2 * x * (x * w - y)
return grad / len(xs)
epoch_list = []
loss_list = []
print('Predict (before training)', 4, forward(4))
for epoch in range(100):
cost_val = cost(x_data, y_data)
grad_val = gradient(x_data, y_data)
w -= 0.01 * grad_val # 0.01是学习率
print('Epoch:', epoch, 'w=', w, 'loss=', cost_val)
epoch_list.append(epoch)
loss_list.append(cost_val)
print('Predict (after training)', 4, forward(4))
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
结果:
随机梯度下降SGD
随机梯度下降就是随机选择一个样本计算loss
代码:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
# 定义线性模型
def forward(x):
return x * w
# 定义单个样本的损失
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
# 定义梯度函数
def gradient(x, y):
return 2 * x * (x * w - y)
epoch_list = []
loss_list = []
print('Predict (before training)', 4, forward(4))
for epoch in range(100):
for x, y in zip(x_data, y_data):
grad = gradient(x, y) #单个样本梯度
w -= 0.01 * grad #权重更新
print("\tgrad:", x, y, grad)
l = loss(x, y)
print("progress", epoch, "w=", w, "loss=", l)
print("\n")
epoch_list.append(epoch)
loss_list.append(l)
print('Predict (after training)', 4, forward(4))
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
结果:
二者区别
- cost是所有样本算出来的损失,loss是随机的一个样本损失,引入了一个随机噪声,随机噪声可能会把我们向前推动,那么将来我们在更新的过程中,就有可能跨越鞍点,从而向最优值前进,而普通梯度下降可能只会停留在局部最优点停滞不前
- 随机梯度下降只需要求出一个样本均值 ,而普通梯度下降是计算所有数据
- 普通的是需要把计算所得的均值进行相加,随机梯度下降是对每一个样本来求梯度,然后进行更新
- 随机梯度算法更新的次数大于普通梯度算法
鞍点
一个不是局部极值点的驻点称为鞍点。
鞍点这词来自于不定二次型x2-y2的二维图形,像马鞍:x-轴方向往上曲,在y-轴方向往下曲。