Commit 0e1085fe authored by Chris Müller's avatar Chris Müller

Implement Back-propagation algorithm

parent 340333dc
package de.ruuns; package de.ruuns;
import java.util.ArrayList;
import org.ejml.data.DenseMatrix64F; import org.ejml.data.DenseMatrix64F;
interface Classifier { interface Classifier {
...@@ -7,5 +8,11 @@ interface Classifier { ...@@ -7,5 +8,11 @@ interface Classifier {
public void train(Dataset D); public void train(Dataset D);
public ArrayList<double[]> weights();
public ArrayList<double[]> derivatives();
public int parameters();
public void debug(); public void debug();
} }
...@@ -21,4 +21,5 @@ public class Evaluation { ...@@ -21,4 +21,5 @@ public class Evaluation {
return errors / (double) D.size(); return errors / (double) D.size();
} }
} }
...@@ -27,6 +27,9 @@ public class Main { ...@@ -27,6 +27,9 @@ public class Main {
bayes.train(D); bayes.train(D);
for(int i = 0; i < 10000; ++i)
mlp.train(D);
bayes.debug(); bayes.debug();
System.out.printf("MLP Accuracy %7.3f %%\n", 100 * (1.0 - Evaluation.error(mlp, D))); System.out.printf("MLP Accuracy %7.3f %%\n", 100 * (1.0 - Evaluation.error(mlp, D)));
......
...@@ -109,5 +109,17 @@ public class NaiveBayes implements Classifier { ...@@ -109,5 +109,17 @@ public class NaiveBayes implements Classifier {
System.out.flush(); System.out.flush();
} }
public ArrayList<double[]> weights() {
return null;
}
public ArrayList<double[]> derivatives() {
return null;
}
public int parameters() {
return 0;
}
private ArrayList<ClassParameter> params = new ArrayList<ClassParameter>(); private ArrayList<ClassParameter> params = new ArrayList<ClassParameter>();
} }
...@@ -10,17 +10,31 @@ class Layer { ...@@ -10,17 +10,31 @@ class Layer {
static Random RNG = new Random(System.currentTimeMillis()); static Random RNG = new Random(System.currentTimeMillis());
public Layer(int I, int O, double var) { public Layer(int I, int O, double var) {
this.I = I;
this.O = O;
this.W = new DenseMatrix64F(I + 1, O); this.W = new DenseMatrix64F(I + 1, O);
this.a = new DenseMatrix64F(O + 1); this.W_grad = new DenseMatrix64F(I + 1, O);
this.a.reshape(1, O);
this.a_output = new DenseMatrix64F(1, O);
this.a_input = new DenseMatrix64F(1, I + 1);
this.error = new DenseMatrix64F(1, O + 1);
this.W_grad.zero();
for(int i = 0; i < W.getNumElements(); ++i) { for(int i = 0; i < W.getNumElements(); ++i) {
this.W.set(i, RNG.nextGaussian() * var); this.W.set(i, RNG.nextGaussian() * var);
} }
}; };
public int I;
public int O;
public DenseMatrix64F W; public DenseMatrix64F W;
public DenseMatrix64F a; public DenseMatrix64F W_grad;
public DenseMatrix64F a_input;
public DenseMatrix64F a_output;
public DenseMatrix64F error;
} }
...@@ -45,35 +59,100 @@ public class NeuralNetwork implements Classifier { ...@@ -45,35 +59,100 @@ public class NeuralNetwork implements Classifier {
} }
public void train(Dataset D) { public void train(Dataset D) {
double learning_rate = 0.05;
for(Layer L: layers) {
L.W_grad.zero();
}
for(Instance ins: D.data)
propagateBackward(ins);
for(Layer L: layers) {
CommonOps.scale(1.0 / D.data.size(), L.W_grad);
CommonOps.addEquals(L.W, learning_rate, L.W_grad);
}
} }
public DenseMatrix64F propagateForward(DenseMatrix64F x_input) { public int parameters() {
assert(x_input.getNumCols() == 1); int i = 0;
for(Layer L: layers) {
i += L.W.getNumElements();
}
return i;
}
public ArrayList<double[]> weights() {
ArrayList<double[]> w = new ArrayList<double[]>();
for(Layer L: layers) {
w.add(L.W.getData());
}
DenseMatrix64F z = x_input.copy(); return w;
}
public ArrayList<double[]> derivatives() {
ArrayList<double[]> d = new ArrayList<double[]>();
for(Layer L: layers) {
d.add(L.W_grad.getData());
}
return d;
}
public DenseMatrix64F propagateForward(DenseMatrix64F x) {
assert(x.getNumCols() == 1);
DenseMatrix64F z = x.copy();
CommonOps.transpose(z); CommonOps.transpose(z);
for(Layer L: layers){ for(Layer L: layers){
// Bias value to input vector [a1 a2 a3 ... 1.0] // Bias value to input vector [a1 a2 a3 ... 1.0]
z.reshape(1, z.getNumElements() + 1, true); CommonOps.insert(z, L.a_input, 0, 0);
z.set(z.getNumElements() - 1, 1.0); L.a_input.set(L.I - 1, 1.0);
// perform forward propagation // perform forward propagation
CommonOps.mult(z, L.W, L.a); CommonOps.mult(L.a_input, L.W, L.a_output);
// remove bias value again
z.reshape(1, z.getNumElements() - 1, true);
// apply activation function g(a) // apply activation function g(a)
z = logistics(L.a); z = logistics(L.a_output);
} }
return z; return z;
} }
public void optimizeGradientDescent(DenseMatrix64F X, DenseMatrix64F Y, double alpha) { public void propagateBackward(Instance ins) {
assert(Y.getNumCols() == 1); DenseMatrix64F y = propagateForward(ins.x);
DenseMatrix64F v = new DenseMatrix64F(1, y.getNumElements());
Layer L_last = layers.get(layers.size() - 1);
L_last.error.reshape(1, L_last.O);
Instance.target(v, ins.y);
// Error(L_output) = (h(x) - t(x)) * g'(a_j)
CommonOps.add(v, -1.0, y, L_last.error);
CommonOps.elementMult(L_last.error, logistics_derivative(y.copy()));
for(int i = layers.size() - 1; i >= 0; --i) {
Layer L = layers.get(i);
// Grad(W) = Grad(W) + a(l-1)' * Error(L)
CommonOps.multAddTransA(L.a_input, L.error, L.W_grad);
if(i > 0) {
Layer L_prev = layers.get(i - 1);
// Error(L_k) = W * Error(L_k+1) * g'(a_k)
L_prev.error.reshape(1, L_prev.O + 1);
CommonOps.multTransB(L.error, L.W, L_prev.error);
CommonOps.elementMult(L_prev.error, logistics_derivative(L.a_input.copy()));
L_prev.error.reshape(1, L_prev.O);
}
}
} }
...@@ -92,6 +171,15 @@ public class NeuralNetwork implements Classifier { ...@@ -92,6 +171,15 @@ public class NeuralNetwork implements Classifier {
return x; return x;
} }
/** f(z)' = f(z) * (1.0 - f(z)) = a * (1.0 - a) */
private DenseMatrix64F logistics_derivative(DenseMatrix64F a) {
for(int i = 0; i < a.getNumElements(); ++i) {
a.set(i, a.get(i) * (1.0 - a.get(i)));
}
return a;
}
private DenseMatrix64F softmax(DenseMatrix64F x) { private DenseMatrix64F softmax(DenseMatrix64F x) {
double max = CommonOps.elementMax(x); double max = CommonOps.elementMax(x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment