神经网络算法是很早就提出的一种方法,挺起来很高大上,其实说白了就是通过各种数据变换,把输入数据和输出数据建立一种联系。如果是函数的拟合,那么这个联系可以理解为一种复杂的函数关系,如果是分类可以理解通过维度变换以及函数变换找到一个能区分两类数据的超平面(多维)。神经网络方法就是通过建立训练数据与目标的联系对新的输入参数进行预测的过程。

BP神经网络是一种按误差反向传播(简称误差反传)训练的多层前馈网络,其算法称为BP算法,它的基本思想是梯度下降法,利用梯度搜索技术,以期使网络的实际输出值和期望输出值的误差均方差为最小。

        整个过程为,首先进行一次正向传播:输入层---->隐含层-->输出层,根据对比输出结果和目标的误差,反向计算层传递权值的误差,调整权值,再进行一次正向传播,反复迭代,实现目标拟合。

以2个参数的三层网络为例子, 通过手动计算一遍整BP神经网络的参数传递过程,来理解整神经网络原理。

神经网络 bp参数设置 bp神经网络算法详解_数据分析


三层BP神经网络结构模型


为了简化计算,将模型进行简化,输入层只有 2 个数据点,隐含层和输出层均为两个节点,结构图变为。


神经网络 bp参数设置 bp神经网络算法详解_数据分析_02

输入数据为x1=0.10,x2=0.88,

目标值为:y1=0.55,y2=1

相互之间的权值全部初始为w1-w8=1

正向传递的公式为: y和x是向量,将式子简化令a=1,b=1,正向传播简化为:y=wx

正向传播:


神经网络 bp参数设置 bp神经网络算法详解_神经网络_03




隐含层部分(包含输出部分以及激活函数的输出):


神经网络 bp参数设置 bp神经网络算法详解_神经网络_04



输出层部分(包含输出部分以及激活函数的输出):

神经网络 bp参数设置 bp神经网络算法详解_数据分析_05


正向传播完成,计算损失函数:


反向传播:


计算 Lost 对 w5 , w6 , w7 , w8 (隐含层 —> 输出层)的偏导数,这几个权重对最终误差产生的影响 , (注意 s ( x )的导数)


神经网络 bp参数设置 bp神经网络算法详解_神经网络 bp参数设置_06



神经网络 bp参数设置 bp神经网络算法详解_神经网络 bp参数设置_07



输入层—>隐含层)

计算Lost对w1,w2,w3,w4的偏导数,这一层参数同时影响两个输出结果

神经网络 bp参数设置 bp神经网络算法详解_数据分析_08

神经网络 bp参数设置 bp神经网络算法详解_神经网络_09



神经网络 bp参数设置 bp神经网络算法详解_数据分析_10



调整权值:

梯度下降方法不断优化参数:θ为下降速率

神经网络 bp参数设置 bp神经网络算法详解_数据分析_11

第一次传播后lost=0.051896,Y1_=0.81064649903,Y2=0.81064649903


迭代10次后,lost=0.0409254668911,迭代100次后lost=0.00504936047507,

迭代1000次后lost=0.000467508796192,输出结果为Y1_=0.55002398472,Y2=0.969421952039,基本拟合

附上Python的计算过程代码,帮助理解计算过程


import numpy as np
def sigmoid(inX):
    return 1.0/(1+np.exp(-inX))
x1 = 0.1
x2 = 0.88
y1 = 0.55
y2 = 1
w1,w2,w3,w4,w5,w6,w7,w8 =1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
for i inrange(1):
#-------------开始计算
    h_in_1 = w1*x1+w2*x2
    h_in_2 = w3*x1+w4*x2
#h_in为传递到隐层的值
    print('h_in_1 = %f,h_in_2 = %f'%(h_in_1,h_in_2))

    h_out_1 = sigmoid(h_in_1)
    h_out_2 = sigmoid(h_in_2)
#h_out从隐层传出的值
    print('h_out_1 = %f,h_out_2 = %f'%(h_out_1,h_out_2))

    o_in_1 = w5*h_out_1+w6*h_out_2
    o_in_2 = w7*h_out_1+w8*h_out_2
#o_in为传入输出层的值
    print('o_in_1 = %f,o_in_2 = %f'%(o_in_1,o_in_2))

    o_out_1 = sigmoid(o_in_1)
    o_out_2 = sigmoid(o_in_2)
    #y1,y2为输出的结果
    print('o_out_1 = %f,o_out_2 = %f'%(o_out_1,o_out_2))

    lost1 = (y1-o_out_1)**2/2
    lost2 = (y2-o_out_2)**2/2
    lost = lost1+lost2
    print('lost = %f'%(lost))

    diff_w5 = -(y1-o_out_1)*o_out_1*(1-o_out_1)*h_out_1
    diff_w6 = -(y1-o_out_1)*o_out_1*(1-o_out_1)*h_out_2
    diff_w7 = -(y2-o_out_2)*o_out_2*(1-o_out_2)*h_out_2
    diff_w8 = -(y2-o_out_2)*o_out_2*(1-o_out_2)*h_out_1

    diff_w1 = (-(y1-o_out_1)*o_out_1*(1-o_out_1)*w5-(y2-o_out_2)*o_out_2*(1-o_out_2)*w7)*(1-h_out_1)*h_out_1*x1
    diff_w2 = (-(y1-o_out_1)*o_out_1*(1-o_out_1)*w5-(y2-o_out_2)*o_out_2*(1-o_out_2)*w7)*(1-h_out_1)*h_out_1*x2
    diff_w3 = (-(y2-o_out_2)*o_out_2*(1-o_out_2)*w6-(y1-o_out_1)*o_out_1*(1-o_out_1)*w8)*(1-h_out_2)*h_out_2*x1
    diff_w4 = (-(y2-o_out_2)*o_out_2*(1-o_out_2)*w6-(y1-o_out_1)*o_out_1*(1-o_out_1)*w8)*(1-h_out_2)*h_out_2*x2
    print('diff_w5 = %f, diff_w6 = %f, diff_w7 = %f, diff_w8= %f'%(diff_w5, diff_w6, diff_w7, diff_w8))
    print('diff_w1 = %f, diff_w2 = %f, diff_w3 =%f, diff_w4 = %f'%(diff_w1, diff_w2, diff_w3, diff_w4))
    #diff_w为w对代价函数的偏导数

    theta = 0.5
    update_w5 = w5-theta*diff_w5
    update_w6 = w6-theta*diff_w6
    update_w7 = w7-theta*diff_w7
    update_w8 = w8-theta*diff_w8
    update_w1 = w1-theta*diff_w1
    update_w2 = w2-theta*diff_w2
    update_w3 = w3-theta*diff_w3
    update_w4 = w4-theta*diff_w4

    print('update_w5 = %f, update_w6 = %f, update_w7 =%f, update_w8 = %f'%(update_w5,update_w6,update_w7,update_w8))
    print('update_w1 = %f, update_w2 = %f, update_w3 =%f, update_w4 = %f'%(update_w1,update_w2,update_w3,update_w4))
    w5 = update_w5
    w6 = update_w6
    w7 = update_w7
    w8 = update_w8
    w1 = update_w1
    w2 = update_w2
    w3 = update_w3
    w4 = update_w4
print(lost)
print(o_out_1, o_out_2)


运行结果:h_in_1 = 0.980000,h_in_2 = 0.980000 

 h_out_1 = 0.727108,h_out_2 = 0.727108 

 o_in_1 = 1.454216,o_in_2 = 1.454216 

 o_out_1 = 0.810646,o_out_2 = 0.810646 

 lost = 0.051896 

 diff_w5 = 0.029091,diff_w6 = 0.029091,diff_w7 = -0.021134,diff_w8 = -0.021134 

 diff_w1 = 0.000217,diff_w2 = 0.001911,diff_w3 = 0.000217,diff_w4 = 0.001911 

 update_w5 = 0.985455,update_w6 = 0.985455,update_w7 = 1.010567,update_w8 = 1.010567 

 update_w1 = 0.999891,update_w2 = 0.999045,update_w3 = 0.999891,update_w4 = 0.999045 

 0.0518956728931 

 0.81064649903 0.81064649903