代码实现:
1、第一种方法
第一种方法在zhangchaoyang的博客上面有C++的实现,只是上面针对的是标量的数据(输入和输出都是一维的)。而在Matlab中也提供了第一种方法的改进版(呵呵,个人觉得,大家可以在Matlab中运行open newrb查看下源代码)。
Matlab提供的一个函数是newrb()。它有个技能就是可以自动增加网络的隐层神经元数目直到均方差满足我们要求的精度或者神经元数数目达到最大(也就是我们提供的样本数目,当神经元数目和我们的样本数目一致时,rbf网络此时的均方误差为0)为止。它使用方法也能简单:
rbf = newrb(train_x, train_y);
output = rbf(test_x);
直接把训练样本给它就可以得到一个rbf网络了。然后我们把输入给它就可以得到网络的输出了。
2、第二种方法
第二种方法在zhangchaoyang的博客上面也有C++的实现,只是上面针对的还是标量的数据(输入和输出都是一维的)。但我是做图像的,网络需要接受高维的输入,而且在Matlab中,向量的运算要比for训练的运算要快很多。所以我就自己写了个可以接受向量输入和向量输出的通过BP算法监督训练的版本。BP算法可以参考这里:BackpropagationAlgorithm ,主要是计算每层每个节点的残差就可以了。另外,我的代码是可以通过梯度检查的,但在某些训练集上面,代价函数值却会随着迭代次数上升,这就很奇怪了,然后降低了学习率还是一样。但在某些简单点的训练集上面还是可以工作的,虽然训练误差也挺大的(没有完全拟合训练样本)。所以大家如果发现代码里面有错误的部分,还望大家告知下。
主要代码见下面:
learnRBF.m
%// This is a RBF network trained by BP algorithm
%// Author : zouxy
%// Date : 2013-10-28
close all; clear; clc;
%%% ************************************************
%%% ************ step 0: load data ****************
display('step 0: load data...');
% train_x = [1 2 3 4 5 6 7 8]; % each sample arranged as a column of train_x
% train_y = 2 * train_x;
train_x = rand(5, 10);
train_y = 2 * train_x;
test_x = train_x;
test_y = train_y;
%% from matlab
% rbf = newrb(train_x, train_y);
% output = rbf(test_x);
%%% ************************************************
%%% ******** step 1: initialize parameters ********
display('step 1: initialize parameters...');
numSamples = size(train_x, 2);
rbf.inputSize = size(train_x, 1);
rbf.hiddenSize = numSamples; % num of Radial Basis function
rbf.outputSize = size(train_y, 1);
rbf.alpha = 0.1; % learning rate (should not be large!)
%% centre of RBF
for i = 1 : rbf.hiddenSize
% randomly pick up some samples to initialize centres of RBF
index = randi([1, numSamples]);
rbf.center(:, i) = train_x(:, index);
end
%% delta of RBF
rbf.delta = rand(1, rbf.hiddenSize);
%% weight of RBF
r = 1.0; % random number between [-r, r]
rbf.weight = rand(rbf.outputSize, rbf.hiddenSize) * 2 * r - r;
%%% ************************************************
%%% ************ step 2: start training ************
display('step 2: start training...');
maxIter = 400;
preCost = 0;
for i = 1 : maxIter
fprintf(1, 'Iteration %d ,', i);
rbf = trainRBF(rbf, train_x, train_y);
fprintf(1, 'the cost is %d \n', rbf.cost);
curCost = rbf.cost;
if abs(curCost - preCost) < 1e-8
disp('Reached iteration termination condition and Termination now!');
break;
end
preCost = curCost;
end
%%% ************************************************
%%% ************ step 3: start testing ************
display('step 3: start testing...');
Green = zeros(rbf.hiddenSize, 1);
for i = 1 : size(test_x, 2)
for j = 1 : rbf.hiddenSize
Green(j, 1) = green(test_x(:, i), rbf.center(:, j), rbf.delta(j));
end
output(:, i) = rbf.weight * Green;
end
disp(test_y);
disp(output);
trainRBF.m
function [rbf] = trainRBF(rbf, train_x, train_y)
%%% step 1: calculate gradient
numSamples = size(train_x, 2);
Green = zeros(rbf.hiddenSize, 1);
output = zeros(rbf.outputSize, 1);
delta_weight = zeros(rbf.outputSize, rbf.hiddenSize);
delta_center = zeros(rbf.inputSize, rbf.hiddenSize);
delta_delta = zeros(1, rbf.hiddenSize);
rbf.cost = 0;
for i = 1 : numSamples
%% Feed forward
for j = 1 : rbf.hiddenSize
Green(j, 1) = green(train_x(:, i), rbf.center(:, j), rbf.delta(j));
end
output = rbf.weight * Green;
%% Back propagation
delta3 = -(train_y(:, i) - output);
rbf.cost = rbf.cost + sum(delta3.^2);
delta_weight = delta_weight + delta3 * Green';
delta2 = rbf.weight' * delta3 .* Green;
for j = 1 : rbf.hiddenSize
delta_center(:, j) = delta_center(:, j) + delta2(j) .* (train_x(:, i) - rbf.center(:, j)) ./ rbf.delta(j)^2;
delta_delta(j) = delta_delta(j)+ delta2(j) * sum((train_x(:, i) - rbf.center(:, j)).^2) ./ rbf.delta(j)^3;
end
end
%%% step 2: update parameters
rbf.cost = 0.5 * rbf.cost ./ numSamples;
rbf.weight = rbf.weight - rbf.alpha .* delta_weight ./ numSamples;
rbf.center = rbf.center - rbf.alpha .* delta_center ./ numSamples;
rbf.delta = rbf.delta - rbf.alpha .* delta_delta ./ numSamples;
end
function greenValue = green(x, c, delta)
greenValue = exp(-1.0 * sum((x - c).^2) / (2 * delta^2));
end