Commit 340333dc authored by Chris Müller's avatar Chris Müller
Browse files

Add NaiveBayes Classifier.

parent 4acc0d7a
......@@ -6,4 +6,6 @@ interface Classifier {
public DenseMatrix64F classify(DenseMatrix64F input);
public void train(Dataset D);
public void debug();
}
package de.ruuns;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
public class Instance {
public static final int NO_CLASS = -1;
......@@ -25,6 +26,11 @@ public class Instance {
v.set(y, 1.0);
}
public static void normalize(DenseMatrix64F v) {
double sum = CommonOps.elementSum(v);
CommonOps.divide(sum, v);
}
public static int argmax(DenseMatrix64F v) {
assert(v.getNumElements() > 0);
......
......@@ -7,6 +7,7 @@ import de.ruuns.Dataset;
import de.ruuns.Evaluation;
import de.ruuns.Instance;
import de.ruuns.NeuralNetwork;
import de.ruuns.NaiveBayes;
public class Main {
......@@ -22,8 +23,14 @@ public class Main {
Dataset D = Dataset.fromCSV(args[0], Dataset.SUPERVISED);
Classifier mlp = new NeuralNetwork(0.05, 4, HIDDEN_NODES, D.classes);
Classifier bayes = new NaiveBayes();
System.out.printf("Accuracy %f\n", Evaluation.error(mlp, D));
bayes.train(D);
bayes.debug();
System.out.printf("MLP Accuracy %7.3f %%\n", 100 * (1.0 - Evaluation.error(mlp, D)));
System.out.printf("NB Accuracy %7.3f %%\n", 100 * (1.0 - Evaluation.error(bayes, D)));
} catch(Exception e) {
......
package de.ruuns;
import de.ruuns.Classifier;
import org.ejml.data.DenseMatrix64F;
import de.ruuns.Instance;
import java.util.List;
import java.util.ArrayList;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
public class NaiveBayes implements Classifier {
public NaiveBayes() {
class ClassParameter {
public ClassParameter(int features) {
mean = new DenseMatrix64F(features, 1);
variance = new DenseMatrix64F(features, 1);
samples = 0;
mean.zero();
variance.zero();
}
public DenseMatrix64F mean;
public DenseMatrix64F variance;
public int samples;
public double probability;
}
public NaiveBayes() {}
public DenseMatrix64F classify(DenseMatrix64F input) {
return null;
assert params.size() > 0;
DenseMatrix64F output = new DenseMatrix64F(params.size(), 1);
for(int c = 0; c < params.size(); ++c) {
double prop = 1.0;
for(int i = 0; i < input.getNumElements(); ++i) {
double mean = params.get(c).mean.get(i);
double var = params.get(c).variance.get(i);
double x = input.get(i);
double factor = 1.0 / Math.sqrt(2.0 * Math.PI * var);
double e = Math.exp(- (0.5 * (mean - x) * (mean - x)) / var);
prop *= (factor * e);
}
output.set(c, params.get(c).probability * prop);
}
Instance.normalize(output);
return output;
}
public void train(Dataset dataset) {
public void train(Dataset D) {
params.clear();
for(int i = 0; i < D.classes; ++i) {
params.add(new ClassParameter(D.features));
}
// calculate sum of all features
for(Instance in: D.data) {
CommonOps.addEquals(params.get(in.y).mean, in.x);
params.get(in.y).samples += 1;
}
// calculate mean for all classes
for(ClassParameter p: params) {
CommonOps.divide((double) p.samples, p.mean);
}
// calculate variance sum for all features
for(Instance in: D.data) {
DenseMatrix64F abs = new DenseMatrix64F(D.features, 1);
abs.zero();
CommonOps.sub(in.x, params.get(in.y).mean, abs);
CommonOps.elementMult(abs, abs);
CommonOps.addEquals(params.get(in.y).variance, abs);
}
// calculate variance for all classes
for(ClassParameter p: params) {
CommonOps.divide((double) p.samples, p.variance);
// calculate P(C)
p.probability = p.samples / (double) D.size();
}
}
public void debug() {
int class_id = 0;
for(ClassParameter p: params) {
DenseMatrix64F mean = p.mean.copy();
DenseMatrix64F var = p.variance.copy();
CommonOps.transpose(mean);
CommonOps.transpose(var);
System.out.printf("Naive Bayes Parameters for Class: %d (%d samples)\n", class_id, p.samples);
System.out.print(mean);
System.out.print(var);
System.out.printf("P(C: %d) = %f\n", class_id, p.probability);
System.out.println();
class_id += 1;
}
System.out.flush();
}
private DenseMatrix64F mean = null;
private DenseMatrix64F variance = null;
private ArrayList<ClassParameter> params = new ArrayList<ClassParameter>();
}
......@@ -104,6 +104,9 @@ public class NeuralNetwork implements Classifier {
return x;
}
public void debug() {
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment