Network:
package test2;
import java.util.Random;
public class Network {
private double input[]; // 输入层
private double hidden[]; // 隐藏层
private double output[]; // 输出层
private double target[]; // 期望输出向量
private double i_h_weight[][]; // 输入层-隐藏层权值
private double h_o_weight[][]; // 隐藏层-输出层权值
private double i_h_weightUpdate[][]; // 输入层权值更新
private double h_o_weightUpdate[][]; // 输出层权值更新
private double outputError[];// 输出层误差
private double hiddenError[];// 隐藏层误差
private double outputErrorSum;// 输出误差和
private double hiddenErrorSum;// 隐藏误差和
// private double i_threshold[]; // 输入层阈值
// private double o_threshold[]; // 输出层阈值
private double rate = 0.25;
private double momentum = 0.3;
private Random random;
/**
* 初始化
* @param inputSize
* @param hiddenSize
* @param outputSize
*/
public void init(int inputSize, int hiddenSize, int outputSize) {
input = new double[inputSize];
hidden = new double[hiddenSize];
output = new double[outputSize];
target = new double[outputSize];
i_h_weight = new double[inputSize][hiddenSize];
h_o_weight = new double[hiddenSize][outputSize];
i_h_weightUpdate = new double[inputSize][hiddenSize];
h_o_weightUpdate = new double[hiddenSize][outputSize];
outputError = new double[outputSize];
hiddenError = new double[hiddenSize];
rate = 0.2;
momentum = 0.3;
random = new Random();
randomWeights(i_h_weight);
randomWeights(h_o_weight);
}
/**
* 随机权值
* @param matrix
*/
private void randomWeights(double[][] matrix) {
for (int i = 0; i < matrix.length; i++)
for (int j = 0; j < matrix[i].length; j++) {
double real = random.nextDouble();
matrix[i][j] = real > 0.5 ? real : -real;
}
}
/**
* 训练
* @param trainData
* @param target
*/
public void train(double[] trainData, double[] target) {
loadInput(trainData);
loadTarget(target);
forward();
calculateError();
adjustWeight();
}
/**
* 测试
* @param inData
* @return
*/
public double[] test(double[] inData) {
if (inData.length != input.length) {
throw new IllegalArgumentException("长度不匹配.");
}
input = inData;
forward();
return getNetworkOutput();
}
/**
* 网络输出
* @return
*/
private double[] getNetworkOutput() {
int len = output.length;
double[] temp = new double[len];
for (int i = 0; i != len; i++)
temp[i] = output[i];
return temp;
}
/**
* 载入期望数据
* @param target
*/
private void loadTarget(double target[]) {
if (this.target.length != target.length) {
throw new IllegalArgumentException("长度不匹配.");
}
this.target = target;
}
/**
* 载入输入数据
* @param input
*/
private void loadInput(double input[]) {
if (this.input.length != input.length) {
throw new IllegalArgumentException("长度不匹配.");
}
this.input = input;
}
/**
* 前向传播
* @param layer0
* @param layer1
* @param weight
*/
private void forward(double[] layer0, double[] layer1, double[][] weight) {
for (int j = 0; j < layer1.length; j++) {
double sum = 0;
for (int i = 0; i < layer0.length; i++)
sum += weight[i][j] * layer0[i];
layer1[j] = sigmoid(sum);
}
}
/**
* 前向传播
*/
public void forward() {
forward(input, hidden, i_h_weight);
forward(hidden, output, h_o_weight);
}
/**
* 输出层误差
*/
private void outputError() {
double errSum = 0;
for (int i = 0; i < outputError.length; i++) {
double o = output[i];
outputError[i] = o * (1d - o) * (target[i] - o);// 误差函数
errSum += Math.abs(outputError[i]);
}
outputErrorSum = errSum;
}
/**
* 隐含层误差
*/
private void hiddenError() {
double errSum = 0;
for (int i = 0; i < hiddenError.length; i++) {
double o = hidden[i];
double sum = 0;
for (int j = 0; j < outputError.length; j++)
sum += h_o_weight[i][j] * outputError[j];
hiddenError[i] = o * (1d - o) * sum;
errSum += Math.abs(hiddenError[i]);
}
hiddenErrorSum = errSum;
}
/**
* 计算误差
*/
private void calculateError() {
outputError();
hiddenError();
}
/**
* 调整权值
* @param error
* @param layer
* @param weight
* @param prevWeight
*/
private void adjustWeight(double[] error, double[] layer, double[][] weight, double[][] prevWeight) {
// layer[0] = 1;
for (int i = 0; i < error.length; i++) {
for (int j = 0; j < layer.length; j++) {
double newVal = momentum * prevWeight[j][i] + rate * error[i] * layer[j];
weight[j][i] += newVal;
prevWeight[j][i] = newVal;
}
}
}
/**
* 调整权值
*/
private void adjustWeight() {
adjustWeight(hiddenError, input, i_h_weight, i_h_weightUpdate);// 15,15,(4,15),(4,15)
adjustWeight(outputError, hidden, h_o_weight, h_o_weightUpdate);
}
/**
* 激活函数,输出区间(0,1),关于(0,0.5)中心对称
*
* @param x
* @return
*/
public double sigmoid(double x) {
return 1 / (1 + Math.exp(-x));
}
/**
* 激活函数,输出区间(-1,1),关于(0,0)中心对称
*
* @param x
* @return
*/
public double tanh(double x) {
return (1 - Math.exp(-2 * x)) / (1 + Math.exp(-2 * x));
}
}
Mian:
package test2;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import test.BP;
public class Main {
public static void main(String[] args) throws IOException {
System.out.println("->读取样本数据");
ReadData rd = new ReadData();
List<double[]> data = rd.loadData("data/iris.txt", 0, 3, ",");
System.out.println("->读取完成");
System.out.println("->初始化神经网络");
int ipt = 4;
int opt = 3;
int hid = (int) (Math.sqrt(ipt + opt) + 10);
Network bp = new Network();
bp.init(ipt, hid, opt);
System.out.println("->初始化完成");
int maxLearn = 10000;
System.out.println("->最大学习次数:" + maxLearn);
System.out.println("->开始训练");
double start = System.currentTimeMillis();
for (int j = 0; j < maxLearn; j++) {
for (int i = 0; i < data.size(); i++) {
double[] target = new double[] { 0, 0, 0 };
if (i < 50)
target[0] = 1;
else if (i < 100)
target[1] = 1;
else if (i < 150)
target[2] = 1;
bp.train(data.get(i), target);
}
}
double end = System.currentTimeMillis();
System.out.println("->训练完成,用时:" + (end - start) + "ms");
System.out.println("-------------");
List<double[]> testData = rd.loadData("data/test.txt", 0, 3, ",");
int correct = 0;
int error = 0;
for (int i = 0; i < testData.size(); i++) {
double[] result = bp.test(testData.get(i));
// System.out.println("-------------");
// System.out.println("->网络输出:"+Arrays.toString(result));
// System.out.println("->分类结果:"+classify(result));
if (classify(result).equals(rd.getColumn("data/test.txt", 4, ",").get(i))) {
// System.out.println("->分类结果:√");
correct++;
} else {
// System.out.println("->分类结果:×");
error++;
}
}
System.out.println("->测试数据:" + (correct + error) + "条," + "正确 " + correct + "条");
System.out.println("->正确率:" + (float) correct / (correct + error));
}
private static String classify(double[] result) {
String[] category = { "Iris-setosa", "Iris-versicolor", "Iris-virginica" };
String resStr = "";
double max = -Integer.MIN_VALUE;
int idx = -1;
for (int i = 0; i != result.length; i++) {
if (result[i] > max) {
max = result[i];
idx = i;
}
}
switch (idx) {
case 0:
resStr = category[0];
break;
case 1:
resStr = category[1];
break;
case 2:
resStr = category[2];
default:
break;
}
return resStr;
}
}
结果: