在上一篇笔记中我们一起学习了单层感知器的原理,但针对于异或问题。我们的单层神经网络就束手无策了
*异或运算: 0 0 输出为0,0 1输出为1, 1 0输出为 1, 1 1输出为 0
即,针对以下四个点(0,0)(0,1)(1,0)(1,1)对应标签0 ,1,1,0时,我们无法完成分类。
运算结果如图所示
这时候我们就要引入一个新的东西-线性神经网络
我们都知道感知器的激活函数时sign函数,即
该函数只有两个输出值,1和 -1。而线性神经网络的激活函数为 y = x, 即输出值可以为任意实数
线性神经网络采用的时LMS算法来调整网络权值和偏置
LMS算法简介
该算法的学习信号为
与感知器的非常相似,只是这里的实际输出为Wj .T*X ,相比于感知器,这里少了sign函数。其他则都差不太多。
线性神经网络结构
这里在输出的时候同时运用了线性函数和sign函数。为什么呢?
因为在运算时我们可以通过线性函数得到更好的结果,而输出的时候,因为我们的标签只有两个。假设我们不用sign函数,那输出结果可能有0.2 0.8 -0.3 -0.7,这样就无法分类了
为了完成我们的异或问题分类任务,我们先看下预备知识:
Delta学习规则:该学习规则也被称为来纳许感知器学习规则,Delta学习规则是利用梯度下降法的一般性学习规则。
代价函数(损失函数 Lost Function):
其中误差E是权值 Wj 的函数,若思想让误差E最小,Wj 则应与误差的负梯度成正比,即
有代价函数,可推导出误差梯度为
手写推导过程如下:
梯度下降法:
针对梯度下降法的讨论在这里不做解释了,不太理解的朋友可以看下这个
梯度下降(Gradient Descent)小结www.cnblogs.com
针对梯度下降法的问题也有几个,1.学习率难以选取 2.容易进入局部最优解。这篇文章中不对局部最优解问题进行讨论。大家知道这个概念就可以了
线性网络解决非线性问题:
通过对神经元添加非线性输入,从而引入非线性成分,使等效的输入维度变大。即从X1,X2扩展为 X1,X1^2,X1*X2,X2^2,X2
代码实现:
通过修改单层感知器的代码就可以做出线性网络了。
对于四个点(0,0)(0,1)(1,0)(1,1)根据其标签[0,1,1,0]进行分类
这里X的输入为[X0,X1,X2,X1^2,X1*X2,X2^2] ,对应的权值也随机了6个。
对于实际输出O进行了修改,去掉了sign函数。
样本修改为四个点。
之后我们要设定一个函数用来计算两条分割线 def calculate(x,root)
作图中我们要注意一下,这里的分类线不再是直线了
首先我们推导一下算式
(字丑见笑了)
大家都是接受过初中教育的,对于一元二次方程解法应该不陌生,通常一元二次方程有两个解(根),即root。通过计算我们可以看到a b c分别对应了哪些元素。
然后通过求根公式画出线段
最后运行代码
可以看到最后循环了一千次结束的循环(即并没有得到与期望值相同的实际输出)
源码:
import numpy as np
import matplotlib.pyplot as plt
# 输入数据
X = np.array([[1, 0, 0, 0, 0, 0], # 解决异或问题
[1, 0, 1, 0, 0, 1],
[1, 1, 0, 1, 0, 0],
[1, 1, 1, 1, 1, 1]])
# 存为标签(一一对应数据)
Y = np.array([-1, 1, 1, -1])
# 随机权值,三行一列,取值范围(-1,1)
W = (np.random.random(6) - 0.5) * 2
print('W=', W)
# 设置学习率
lr = 0.11
# 设置迭代次数
n = 0
# 设置输出值
O = 0
def update():
global X, Y, W, lr, n
n += 1
O = np.dot(X, W.T)
W_C = lr * ((Y - O.T).dot(X)) / X.shape[0] # 平均权值
W = W_C + W # 修改权值
for _ in range(1000):
update() # 更新权值
print(W)
print(n)
O = np.dot(X, W.T) # 计算神经网络输出
if (O == Y.T).all(): # 如果实际输出等于期望输出,模型收敛
print("finished")
print("epoch:", n)
break
# 正样本(标签为1)
X1 = [1, 0]
Y1 = [0, 1]
# 负样本(标签为0)
X2 = [0, 1]
Y2 = [0, 1]
def calculate(x, root):
a = W[5]
b = W[2] + x * W[4]
c = W[0] + x * W[1] + x * x * W[3]
if root == 1:
return (-b + np.sqrt(b * b - 4 * a * c)) / (2 * a)
if root == 2:
return (-b - np.sqrt(b * b - 4 * a * c)) / (2 * a)
# 作图
xdata = np.linspace(-1, 2)
plt.figure()
plt.plot(xdata, calculate(xdata, 1), 'r')
plt.plot(xdata, calculate(xdata, 2), 'r')
plt.plot(X1, Y1, 'bo')
plt.plot(X2, Y2, 'yo')
plt.show()