【深度学习】BP算法分类iris数据集_初始化
Network:

package test2;

import java.util.Random;

public class Network {

    private double input[]; // 输入层
    private double hidden[]; // 隐藏层
    private double output[]; // 输出层
    private double target[]; // 期望输出向量
    private double i_h_weight[][]; // 输入层-隐藏层权值
    private double h_o_weight[][]; // 隐藏层-输出层权值
    private double i_h_weightUpdate[][]; // 输入层权值更新
    private double h_o_weightUpdate[][]; // 输出层权值更新
    private double outputError[];// 输出层误差
    private double hiddenError[];// 隐藏层误差
    private double outputErrorSum;// 输出误差和
    private double hiddenErrorSum;// 隐藏误差和
    // private double i_threshold[]; // 输入层阈值
    // private double o_threshold[]; // 输出层阈值
    private double rate = 0.25;
    private double momentum = 0.3;

    private Random random;
    /**
     * 初始化
     * @param inputSize
     * @param hiddenSize
     * @param outputSize
     */
    public void init(int inputSize, int hiddenSize, int outputSize) {
        input = new double[inputSize];
        hidden = new double[hiddenSize];
        output = new double[outputSize];
        target = new double[outputSize];

        i_h_weight = new double[inputSize][hiddenSize];
        h_o_weight = new double[hiddenSize][outputSize];
        i_h_weightUpdate = new double[inputSize][hiddenSize];
        h_o_weightUpdate = new double[hiddenSize][outputSize];

        outputError = new double[outputSize];
        hiddenError = new double[hiddenSize];

        rate = 0.2;
        momentum = 0.3;

        random = new Random();
        randomWeights(i_h_weight);
        randomWeights(h_o_weight);

    }
    /**
     * 随机权值
     * @param matrix
     */
    private void randomWeights(double[][] matrix) {
        for (int i = 0; i < matrix.length; i++)
            for (int j = 0; j < matrix[i].length; j++) {
                double real = random.nextDouble();
                matrix[i][j] = real > 0.5 ? real : -real;

            }
    }
    /**
     * 训练
     * @param trainData
     * @param target
     */
    public void train(double[] trainData, double[] target) {
        loadInput(trainData);
        loadTarget(target);
        forward();
        calculateError();
        adjustWeight();
    }
    /**
     * 测试
     * @param inData
     * @return
     */
    public double[] test(double[] inData) {
        if (inData.length != input.length) {
            throw new IllegalArgumentException("长度不匹配.");
        }
        input = inData;
        forward();
        return getNetworkOutput();
    }
    /**
     * 网络输出
     * @return
     */
    private double[] getNetworkOutput() {
        int len = output.length;
        double[] temp = new double[len];
        for (int i = 0; i != len; i++)
            temp[i] = output[i];
        return temp;
    }
    /**
     * 载入期望数据
     * @param target
     */
    private void loadTarget(double target[]) {
        if (this.target.length != target.length) {
            throw new IllegalArgumentException("长度不匹配.");
        }
        this.target = target;
    }
    /**
     * 载入输入数据
     * @param input
     */
    private void loadInput(double input[]) {
        if (this.input.length != input.length) {
            throw new IllegalArgumentException("长度不匹配.");
        }
        this.input = input;
    }
    /**
     * 前向传播
     * @param layer0
     * @param layer1
     * @param weight
     */
    private void forward(double[] layer0, double[] layer1, double[][] weight) {
        for (int j = 0; j < layer1.length; j++) {
            double sum = 0;
            for (int i = 0; i < layer0.length; i++)
                sum += weight[i][j] * layer0[i];
            layer1[j] = sigmoid(sum);
        }
    }
    /**
     * 前向传播
     */
    public void forward() {
        forward(input, hidden, i_h_weight);
        forward(hidden, output, h_o_weight);
    }
    /**
     * 输出层误差
     */
    private void outputError() {
        double errSum = 0;
        for (int i = 0; i < outputError.length; i++) {
            double o = output[i];
            outputError[i] = o * (1d - o) * (target[i] - o);// 误差函数
            errSum += Math.abs(outputError[i]);
        }
        outputErrorSum = errSum;
    }
    /**
     * 隐含层误差
     */
    private void hiddenError() {
        double errSum = 0;
        for (int i = 0; i < hiddenError.length; i++) {
            double o = hidden[i];
            double sum = 0;
            for (int j = 0; j < outputError.length; j++)
                sum += h_o_weight[i][j] * outputError[j];
            hiddenError[i] = o * (1d - o) * sum;
            errSum += Math.abs(hiddenError[i]);
        }
        hiddenErrorSum = errSum;
    }
    /**
     * 计算误差
     */
    private void calculateError() {
        outputError();
        hiddenError();
    }
    /**
     * 调整权值
     * @param error
     * @param layer
     * @param weight
     * @param prevWeight
     */
    private void adjustWeight(double[] error, double[] layer, double[][] weight, double[][] prevWeight) {
        // layer[0] = 1;
        for (int i = 0; i < error.length; i++) {
            for (int j = 0; j < layer.length; j++) {
                double newVal = momentum * prevWeight[j][i] + rate * error[i] * layer[j];
                weight[j][i] += newVal;
                prevWeight[j][i] = newVal;
            }
        }
    }
    /**
     * 调整权值
     */
    private void adjustWeight() {
        adjustWeight(hiddenError, input, i_h_weight, i_h_weightUpdate);// 15,15,(4,15),(4,15)
        adjustWeight(outputError, hidden, h_o_weight, h_o_weightUpdate);
    }

    /**
     * 激活函数,输出区间(0,1),关于(0,0.5)中心对称
     * 
     * @param x
     * @return
     */
    public double sigmoid(double x) {
        return 1 / (1 + Math.exp(-x));
    }

    /**
     * 激活函数,输出区间(-1,1),关于(0,0)中心对称
     * 
     * @param x
     * @return
     */
    public double tanh(double x) {
        return (1 - Math.exp(-2 * x)) / (1 + Math.exp(-2 * x));
    }

}

Mian:

package test2;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import test.BP;

public class Main {

    public static void main(String[] args) throws IOException {
        System.out.println("->读取样本数据");
        ReadData rd = new ReadData();
        List<double[]> data = rd.loadData("data/iris.txt", 0, 3, ",");
        System.out.println("->读取完成");
        System.out.println("->初始化神经网络");
        int ipt = 4;
        int opt = 3;
        int hid = (int) (Math.sqrt(ipt + opt) + 10);
        Network bp = new Network();
        bp.init(ipt, hid, opt);
        System.out.println("->初始化完成");
        int maxLearn = 10000;
        System.out.println("->最大学习次数:" + maxLearn);
        System.out.println("->开始训练");
        double start = System.currentTimeMillis();
        for (int j = 0; j < maxLearn; j++) {
            for (int i = 0; i < data.size(); i++) {
                double[] target = new double[] { 0, 0, 0 };
                if (i < 50)
                    target[0] = 1;
                else if (i < 100)
                    target[1] = 1;
                else if (i < 150)
                    target[2] = 1;
                bp.train(data.get(i), target);
            }
        }
        double end = System.currentTimeMillis();
        System.out.println("->训练完成,用时:" + (end - start) + "ms");

        System.out.println("-------------");
        List<double[]> testData = rd.loadData("data/test.txt", 0, 3, ",");
        int correct = 0;
        int error = 0;
        for (int i = 0; i < testData.size(); i++) {
            double[] result = bp.test(testData.get(i));
            // System.out.println("-------------");
            // System.out.println("->网络输出:"+Arrays.toString(result));
            // System.out.println("->分类结果:"+classify(result));
            if (classify(result).equals(rd.getColumn("data/test.txt", 4, ",").get(i))) {
                // System.out.println("->分类结果:√");
                correct++;
            } else {
                // System.out.println("->分类结果:×");
                error++;
            }
        }
        System.out.println("->测试数据:" + (correct + error) + "条," + "正确 " + correct + "条");
        System.out.println("->正确率:" + (float) correct / (correct + error));
    }

    private static String classify(double[] result) {
        String[] category = { "Iris-setosa", "Iris-versicolor", "Iris-virginica" };
        String resStr = "";
        double max = -Integer.MIN_VALUE;
        int idx = -1;
        for (int i = 0; i != result.length; i++) {
            if (result[i] > max) {
                max = result[i];
                idx = i;
            }
        }
        switch (idx) {
        case 0:
            resStr = category[0];
            break;
        case 1:
            resStr = category[1];
            break;
        case 2:
            resStr = category[2];
        default:
            break;
        }
        return resStr;
    }

}

结果:
【深度学习】BP算法分类iris数据集_java_02

网络上志同道合,我们一起学习网络安全,一起进步,QQ群:694839022