public class NeuralNetwork
extends java.lang.Object
Modifier and Type | Class and Description |
---|---|
static class |
NeuralNetwork.Builder
Builder for creating neural network instances.
|
Modifier and Type | Method and Description |
---|---|
static int |
argMax(Matrix output)
Get the position of the most probable in an output array.
|
double |
crossEntropyError(Matrix x,
Matrix y)
Calculate the cross entropy error of the neural network.
|
Matrix |
getWeights(int layer)
Get the weight matrix of a layer of the neural network.
|
void |
load(java.io.File file)
Loads a neural network from a file.
|
void |
load(java.io.InputStream is)
Loads a neural network from an InputStream.
|
void |
load(java.lang.String filename)
Loads a neural network from a file.
|
Matrix |
predict(Matrix input)
Give a prediction based on some input.
|
void |
save(java.io.File file)
Saves the neural network to a file (CSV).
|
void |
save(java.lang.String filename)
Saves the neural network to a file (CSV).
|
void |
setWeights(int layer,
Matrix weights)
Set the weight matrix of a layer of the neural network.
|
int |
size()
Get the size of the neural network.
|
double |
squaredError(Matrix x,
Matrix y)
Calculate the squared error of the neural network.
|
double |
train(Matrix input,
Matrix output,
double learningRate)
Train the neural network to predict an output given some input.
|
public double crossEntropyError(Matrix x, Matrix y)
x
- The input to the neural network.y
- The expected output.public double squaredError(Matrix x, Matrix y)
x
- The input to the neural network.y
- The expected output.public Matrix predict(Matrix input)
input
- The input to the neural network which is equal in size to the
number of input neurons.public static int argMax(Matrix output)
output
- The output of the neural network (using Softmax)public Matrix getWeights(int layer)
layer
- The layer number of the neural network.public void setWeights(int layer, Matrix weights)
layer
- The layer number of the neural network.weights
- The new weight matrix for the layer.public int size()
public double train(Matrix input, Matrix output, double learningRate)
input
- The input to the neural network.output
- The target output for the given input.learningRate
- The rate at which the neural network learns. This is normally
0.01.public void save(java.io.File file)
file
- The file in which to save the weights to.public void save(java.lang.String filename)
filename
- The filename in which to save the weights to.public void load(java.io.InputStream is)
is
- The InputStream to retrieve the weights from.public void load(java.io.File file)
file
- The file to retrieve the weights from.public void load(java.lang.String filename)
filename
- The name of the file to retrieve the weights from.