BP神经网络传播计算

1. 简介

BP神经网络是一种常用的人工神经网络模型,常用于数据分类、函数拟合等任务。在本文中,我将向你介绍BP神经网络的传播计算过程,并提供相应的代码示例。

2. 流程

下面是BP神经网络传播计算的整体流程:

步骤 描述
1 初始化网络参数(权重和偏置)
2 输入样本,计算神经网络的输出
3 计算网络输出与期望输出之间的误差
4 根据误差调整网络参数
5 重复步骤2-4直到达到停止条件

3. 代码示例

3.1 初始化网络参数

在这一步中,我们需要为神经网络的每个连接(权重)和每个神经元(偏置)赋予一个初始值。以下是使用Python实现的示例代码:

import numpy as np

def initialize_parameters(layer_dims):
    parameters = {}
    L = len(layer_dims) - 1  # 网络的层数

    for l in range(1, L + 1):
        parameters['W' + str(l)] = np.random.randn(layer_dims[l], layer_dims[l-1]) * 0.01
        parameters['b' + str(l)] = np.zeros((layer_dims[l], 1))

    return parameters

3.2 计算神经网络的输出

在这一步中,我们需要将输入样本通过神经网络前向传播,并计算出网络的输出。以下是使用Python实现的示例代码:

def forward_propagation(X, parameters):
    A = X
    caches = []
    L = len(parameters) // 2  # 网络的层数

    # 前L-1层使用ReLU激活函数
    for l in range(1, L):
        W = parameters['W' + str(l)]
        b = parameters['b' + str(l)]
        Z = np.dot(W, A) + b
        A = np.maximum(0, Z)
        cache = (W, b, Z, A)
        caches.append(cache)

    # 最后一层使用sigmoid激活函数
    WL = parameters['W' + str(L)]
    bL = parameters['b' + str(L)]
    ZL = np.dot(WL, A) + bL
    AL = 1 / (1 + np.exp(-ZL))
    cacheL = (WL, bL, ZL, AL)
    caches.append(cacheL)

    return AL, caches

3.3 计算误差

在这一步中,我们需要计算神经网络输出与期望输出之间的误差。对于二分类任务,可以使用交叉熵作为损失函数。以下是使用Python实现的示例代码:

def compute_cost(AL, Y):
    m = Y.shape[1]
    cost = -np.sum(Y * np.log(AL) + (1 - Y) * np.log(1 - AL)) / m
    cost = np.squeeze(cost)  # 从数组中删除所有单维度的条目,以符合后面的期望形状

    return cost

3.4 调整网络参数

在这一步中,我们需要根据误差计算梯度,并使用梯度下降算法来更新网络参数。以下是使用Python实现的示例代码:

def backward_propagation(AL, Y, caches):
    grads = {}
    L = len(caches)  # 网络的层数
    m = AL.shape[1]

    dAL = - (np.divide(Y, AL) - np.divide(1 - Y, 1 - AL))

    # 反向传播最后一层(sigmoid激活函数)
    cacheL = caches[L - 1]
    WL, bL, ZL, AL = cacheL
    dZL = AL - Y
    dWL = np.dot(dZL, caches[L - 2][3].T) / m
    dbL = np.sum(dZL, axis=1, keepdims=True) / m
    dAL = np.dot(WL.T, dZL)

    grads['dW' + str(L)] = dWL
    grads['db' + str(L)] = dbL