Commit 0e1085fe authored by Chris Müller's avatar Chris Müller

Implement Back-propagation algorithm

parent 340333dc
package de.ruuns;
import java.util.ArrayList;
import org.ejml.data.DenseMatrix64F;
interface Classifier {
......@@ -7,5 +8,11 @@ interface Classifier {
public void train(Dataset D);
public ArrayList<double[]> weights();
public ArrayList<double[]> derivatives();
public int parameters();
public void debug();
}
......@@ -21,4 +21,5 @@ public class Evaluation {
return errors / (double) D.size();
}
}
......@@ -27,6 +27,9 @@ public class Main {
bayes.train(D);
for(int i = 0; i < 10000; ++i)
mlp.train(D);
bayes.debug();
System.out.printf("MLP Accuracy %7.3f %%\n", 100 * (1.0 - Evaluation.error(mlp, D)));
......
......@@ -109,5 +109,17 @@ public class NaiveBayes implements Classifier {
System.out.flush();
}
public ArrayList<double[]> weights() {
return null;
}
public ArrayList<double[]> derivatives() {
return null;
}
public int parameters() {
return 0;
}
private ArrayList<ClassParameter> params = new ArrayList<ClassParameter>();
}
......@@ -10,17 +10,31 @@ class Layer {
static Random RNG = new Random(System.currentTimeMillis());
public Layer(int I, int O, double var) {
this.I = I;
this.O = O;
this.W = new DenseMatrix64F(I + 1, O);
this.a = new DenseMatrix64F(O + 1);
this.a.reshape(1, O);
this.W_grad = new DenseMatrix64F(I + 1, O);
this.a_output = new DenseMatrix64F(1, O);
this.a_input = new DenseMatrix64F(1, I + 1);
this.error = new DenseMatrix64F(1, O + 1);
this.W_grad.zero();
for(int i = 0; i < W.getNumElements(); ++i) {
this.W.set(i, RNG.nextGaussian() * var);
}
};
public int I;
public int O;
public DenseMatrix64F W;
public DenseMatrix64F a;
public DenseMatrix64F W_grad;
public DenseMatrix64F a_input;
public DenseMatrix64F a_output;
public DenseMatrix64F error;
}
......@@ -45,35 +59,100 @@ public class NeuralNetwork implements Classifier {
}
public void train(Dataset D) {
double learning_rate = 0.05;
for(Layer L: layers) {
L.W_grad.zero();
}
for(Instance ins: D.data)
propagateBackward(ins);
for(Layer L: layers) {
CommonOps.scale(1.0 / D.data.size(), L.W_grad);
CommonOps.addEquals(L.W, learning_rate, L.W_grad);
}
}
public int parameters() {
int i = 0;
for(Layer L: layers) {
i += L.W.getNumElements();
}
return i;
}
public ArrayList<double[]> weights() {
ArrayList<double[]> w = new ArrayList<double[]>();
for(Layer L: layers) {
w.add(L.W.getData());
}
return w;
}
public ArrayList<double[]> derivatives() {
ArrayList<double[]> d = new ArrayList<double[]>();
for(Layer L: layers) {
d.add(L.W_grad.getData());
}
public DenseMatrix64F propagateForward(DenseMatrix64F x_input) {
assert(x_input.getNumCols() == 1);
return d;
}
DenseMatrix64F z = x_input.copy();
public DenseMatrix64F propagateForward(DenseMatrix64F x) {
assert(x.getNumCols() == 1);
DenseMatrix64F z = x.copy();
CommonOps.transpose(z);
for(Layer L: layers){
// Bias value to input vector [a1 a2 a3 ... 1.0]
z.reshape(1, z.getNumElements() + 1, true);
z.set(z.getNumElements() - 1, 1.0);
CommonOps.insert(z, L.a_input, 0, 0);
L.a_input.set(L.I - 1, 1.0);
// perform forward propagation
CommonOps.mult(z, L.W, L.a);
// remove bias value again
z.reshape(1, z.getNumElements() - 1, true);
CommonOps.mult(L.a_input, L.W, L.a_output);
// apply activation function g(a)
z = logistics(L.a);
z = logistics(L.a_output);
}
return z;
}
public void optimizeGradientDescent(DenseMatrix64F X, DenseMatrix64F Y, double alpha) {
assert(Y.getNumCols() == 1);
public void propagateBackward(Instance ins) {
DenseMatrix64F y = propagateForward(ins.x);
DenseMatrix64F v = new DenseMatrix64F(1, y.getNumElements());
Layer L_last = layers.get(layers.size() - 1);
L_last.error.reshape(1, L_last.O);
Instance.target(v, ins.y);
// Error(L_output) = (h(x) - t(x)) * g'(a_j)
CommonOps.add(v, -1.0, y, L_last.error);
CommonOps.elementMult(L_last.error, logistics_derivative(y.copy()));
for(int i = layers.size() - 1; i >= 0; --i) {
Layer L = layers.get(i);
// Grad(W) = Grad(W) + a(l-1)' * Error(L)
CommonOps.multAddTransA(L.a_input, L.error, L.W_grad);
if(i > 0) {
Layer L_prev = layers.get(i - 1);
// Error(L_k) = W * Error(L_k+1) * g'(a_k)
L_prev.error.reshape(1, L_prev.O + 1);
CommonOps.multTransB(L.error, L.W, L_prev.error);
CommonOps.elementMult(L_prev.error, logistics_derivative(L.a_input.copy()));
L_prev.error.reshape(1, L_prev.O);
}
}
}
......@@ -92,6 +171,15 @@ public class NeuralNetwork implements Classifier {
return x;
}
/** f(z)' = f(z) * (1.0 - f(z)) = a * (1.0 - a) */
private DenseMatrix64F logistics_derivative(DenseMatrix64F a) {
for(int i = 0; i < a.getNumElements(); ++i) {
a.set(i, a.get(i) * (1.0 - a.get(i)));
}
return a;
}
private DenseMatrix64F softmax(DenseMatrix64F x) {
double max = CommonOps.elementMax(x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment