/*

 * This class is for artificial neural network prediction.

 * You can load Visual Gene Developer's trained neural network file (*.vgn)

 * then use it to predict output values for any given input values.

 * Currently, this class does not include modules for training networks

 * In addition, only hyperbolic tangent function (transfer function) is supported

 *

 * THIS PROGRAM IS DISTRIBUTED "AS IS".

 * NO WARRANTY OF ANY KIND IS EXPRESSED OR IMPLIED.

 * YOU USE THE PROGRAM AT YOUR OWN RISK.

 * THE AUTHOR WILL NOT BE LIABLE FOR DATA LOSS, DAMAGES,

 * LOSS OF PROFITS OR ANY OTHER KIND OF LOSS WHILE USING

 * OR MISUSING THIS SOFTWARE.

 * ANYONE CAN USE AND MODIFY CODES WITHOUT CHARGE.

 */

 

import java.io.File;

import java.util.List;

import javax.swing.JOptionPane;

 

/**

 * @author SangKyu

 */

public class NeuralNet {

 

    public int maxLayerCount = 5;

    public int maxNodeCount = 200;

    public int hiddenLayerCount;

    public int inputCount;

    public int outputCount;

    public int[] nodeCountInHiddenlayers = new int[maxLayerCount + 1];

    public double[][] nodeValue = new double[maxLayerCount + 1][maxNodeCount];

    public double[][] nodeThreshold = new double[maxLayerCount + 1][maxNodeCount];

    public double[][] nodeTotAct = new double[maxLayerCount + 1][maxNodeCount];

    public double[][][] connectWeightFactor

            = new double[maxLayerCount + 1][maxNodeCount][maxNodeCount];

 

   

   

    public void test() {

        //User your own path

        String trainedNetworkFile = "D:\\....."

                + "\\Sample SinCos - Trained network.vgn";

        String retMsg = openTrainedNetworkFile(trainedNetworkFile);

        if (retMsg.equals("")) {

            double[] outputValues = predict(new double[]{0.37146, 0.88627, 0.41384});

            if (outputValues != null )

                    JOptionPane.showMessageDialog(null, "Output2= " + outputValues[1]);

            //Correct result: 0.6740054...

        } else {

            JOptionPane.showMessageDialog(null, retMsg);

        }

    }

 

    //Get transfer function

    //Hyperbolic tangent function is calculated

    public double computeTransferFunction(double inX) {

        //Hyperbolic tangent

        return (Math.exp(inX) - Math.exp(-inX)) / (Math.exp(inX) + Math.exp(-inX));

    }

 

    //Predict output values for given inputValues

    public double[] predict(double[] inputValues) {

        if (inputValues.length != inputCount) return null;

        

        //Assign input values

        System.arraycopy(inputValues, 0, nodeValue[0], 0, inputCount);

 

        propagateNetwork();

 

        //Get output values

        double[] outputValues = new double[outputCount];

        System.arraycopy(nodeValue[hiddenLayerCount + 2], 0,

                outputValues, 0, outputCount);

 

        return outputValues;

    }

 

    //Propagate network

    private void propagateNetwork() {

        for (int w_n = 0; w_n < nodeCountInHiddenlayers[1]; w_n++) {

            nodeValue[1][w_n]

                    = computeTransferFunction(nodeValue[0][w_n]

                            + nodeThreshold[1][w_n]);

        }

 

        double current_Sum;

        for (int cur_Layer = 2; cur_Layer <= hiddenLayerCount + 2; cur_Layer++) {

            for (int w_n = 0; w_n < nodeCountInHiddenlayers[cur_Layer]; w_n++) {

                current_Sum = 0;

                for (int w_n_1 = 0; w_n_1 < nodeCountInHiddenlayers[cur_Layer - 1]; w_n_1++) {

                    current_Sum = current_Sum

                            + connectWeightFactor[cur_Layer][w_n_1][w_n]

                            * nodeValue[cur_Layer - 1][w_n_1];

                }

                current_Sum = current_Sum + nodeThreshold[cur_Layer][w_n];

                nodeValue[cur_Layer][w_n]

                        = computeTransferFunction(current_Sum);

 

                nodeTotAct[cur_Layer][w_n] = current_Sum;

            }

        }

    }

 

    //Return "" if everything is OK

    //Return error message if something is wrong

    public String openTrainedNetworkFile(String fileName) {

        List<String> linesList

                = Utilities.getLinesFromFile(new File(fileName));

 

        if (linesList == null) {

            return "Empty file";

        }

 

        String[] strLines = linesList.toArray(new String[linesList.size()]);

 

        if (!strLines[1].equals("Name=Visual Gene Developer - Neural Network")) {

            return "Not valid trained neuralnet file";

        }

 

        for (int curLine = 0; curLine < strLines.length; curLine++) {

            if (strLines[curLine].startsWith("Total input=")) {

                inputCount = getIntegerFromString(strLines[curLine], "=");

            } else if (strLines[curLine].startsWith("Total output=")) {

                outputCount = getIntegerFromString(strLines[curLine], "=");

            } else if (strLines[curLine].startsWith("Total layer=")) {

                hiddenLayerCount = getIntegerFromString(strLines[curLine], "=") - 2;

            } else if (strLines[curLine].startsWith("Transfer function=")) {

                if (!strLines[curLine].equals("Transfer function=Hyperbolic tangent")) {

                    return "Transfer function is not hyperbolic tangent";

                }

            } else if (strLines[curLine].startsWith("layer=total node")) {

                for (int i = 0; i < hiddenLayerCount; i++) {

                    curLine++;

                    nodeCountInHiddenlayers[i + 2]

                            = getIntegerFromString(strLines[curLine], "=");

                }

 

            } else if (strLines[curLine].startsWith("layer-node=threshold value")) {

                curLine++;

                do {

                    int curLayer = getInteger(strLines[curLine].substring(0, 2));

                    int node = getInteger(strLines[curLine].substring(3, 5));

                    double threshold = getDoubleFromString(strLines[curLine], "=");

                    nodeThreshold[curLayer][node - 1] = threshold;

                    curLine++;

                    if (curLine >= strLines.length) {

                        break;

                    }

                } while (strLines[curLine].equals("") == false);

 

            } else if (strLines[curLine].startsWith(

                    "layer-node(layer n-1)-node(layer n)=weight factor")) {

                curLine++;

                do {

                    int curLayer = getInteger(strLines[curLine].substring(0, 2));

                    int node1 = getInteger(strLines[curLine].substring(3, 5));

                    int node2 = getInteger(strLines[curLine].substring(6, 8));

                    double weightFactor = getDoubleFromString(strLines[curLine], "=");

                    connectWeightFactor[curLayer][node1 - 1][node2 - 1] = weightFactor;

                    curLine++;

                    if (curLine >= strLines.length) {

                        break;

                    }

                } while (strLines[curLine].equals("") == false);

 

            }

        }

 

        nodeCountInHiddenlayers[0] = inputCount;

        nodeCountInHiddenlayers[1] = inputCount;

        nodeCountInHiddenlayers[hiddenLayerCount + 2] = outputCount;

 

        return "";

    }

 

    private int getIntegerFromString(String srcStr, String valueSeparator) {

        return getInteger(srcStr.substring(srcStr.indexOf(valueSeparator) + 1));

    }

 

    private double getDoubleFromString(String srcStr, String valueSeparator) {

        return getDouble(srcStr.substring(srcStr.indexOf(valueSeparator) + 1));

    }

 

    public Integer getInteger(String str) {

        str = str.trim();

        if (str == null) {

            return null;

        }

        Integer ret;

        try {

            ret = new Integer(str);

        } catch (NumberFormatException e) {

            return null;

        }

        return ret;

    }

 

    public Double getDouble(String str) {

        str = str.trim();

        if (str == null) {

            return null;

        }

        Double ret;

        try {

            ret = new Double(str);

        } catch (NumberFormatException e) {

            return null;

        }

        return ret;

    }

 

}