public class NeuralNetwork
extends java.lang.Object
| Modifier and Type | Class and Description | 
|---|---|
| static class  | NeuralNetwork.BuilderBuilder for creating neural network instances. | 
| Modifier and Type | Method and Description | 
|---|---|
| static int | argMax(Matrix output)Get the position of the most probable in an output array. | 
| double | crossEntropyError(Matrix x,
                 Matrix y)Calculate the cross entropy error of the neural network. | 
| Matrix | getWeights(int layer)Get the weight matrix of a layer of the neural network. | 
| void | load(java.io.File file)Loads a neural network from a file. | 
| void | load(java.io.InputStream is)Loads a neural network from an InputStream. | 
| void | load(java.lang.String filename)Loads a neural network from a file. | 
| Matrix | predict(Matrix input)Give a prediction based on some input. | 
| void | save(java.io.File file)Saves the neural network to a file (CSV). | 
| void | save(java.lang.String filename)Saves the neural network to a file (CSV). | 
| void | setWeights(int layer,
          Matrix weights)Set the weight matrix of a layer of the neural network. | 
| int | size()Get the size of the neural network. | 
| double | squaredError(Matrix x,
            Matrix y)Calculate the squared error of the neural network. | 
| double | train(Matrix input,
     Matrix output,
     double learningRate)Train the neural network to predict an output given some input. | 
public double crossEntropyError(Matrix x, Matrix y)
x - The input to the neural network.y - The expected output.public double squaredError(Matrix x, Matrix y)
x - The input to the neural network.y - The expected output.public Matrix predict(Matrix input)
input - The input to the neural network which is equal in size to the
            number of input neurons.public static int argMax(Matrix output)
output - The output of the neural network (using Softmax)public Matrix getWeights(int layer)
layer - The layer number of the neural network.public void setWeights(int layer,
                       Matrix weights)
layer - The layer number of the neural network.weights - The new weight matrix for the layer.public int size()
public double train(Matrix input, Matrix output, double learningRate)
input - The input to the neural network.output - The target output for the given input.learningRate - The rate at which the neural network learns. This is normally
            0.01.public void save(java.io.File file)
file - The file in which to save the weights to.public void save(java.lang.String filename)
filename - The filename in which to save the weights to.public void load(java.io.InputStream is)
is - The InputStream to retrieve the weights from.public void load(java.io.File file)
file - The file to retrieve the weights from.public void load(java.lang.String filename)
filename - The name of the file to retrieve the weights from.