Commit 4acc0d7a authored by Chris Müller's avatar Chris Müller
Browse files

Add infrastructure for datasets

parent 6f58e846
5.1,3.5,1.4,0.2,setosa
4.9,3.0,1.4,0.2,setosa
4.7,3.2,1.3,0.2,setosa
4.6,3.1,1.5,0.2,setosa
5.0,3.6,1.4,0.2,setosa
5.4,3.9,1.7,0.4,setosa
4.6,3.4,1.4,0.3,setosa
5.0,3.4,1.5,0.2,setosa
4.4,2.9,1.4,0.2,setosa
4.9,3.1,1.5,0.1,setosa
5.4,3.7,1.5,0.2,setosa
4.8,3.4,1.6,0.2,setosa
4.8,3.0,1.4,0.1,setosa
4.3,3.0,1.1,0.1,setosa
5.8,4.0,1.2,0.2,setosa
5.7,4.4,1.5,0.4,setosa
5.4,3.9,1.3,0.4,setosa
5.1,3.5,1.4,0.3,setosa
5.7,3.8,1.7,0.3,setosa
5.1,3.8,1.5,0.3,setosa
5.4,3.4,1.7,0.2,setosa
5.1,3.7,1.5,0.4,setosa
4.6,3.6,1.0,0.2,setosa
5.1,3.3,1.7,0.5,setosa
4.8,3.4,1.9,0.2,setosa
5.0,3.0,1.6,0.2,setosa
5.0,3.4,1.6,0.4,setosa
5.2,3.5,1.5,0.2,setosa
5.2,3.4,1.4,0.2,setosa
4.7,3.2,1.6,0.2,setosa
4.8,3.1,1.6,0.2,setosa
5.4,3.4,1.5,0.4,setosa
5.2,4.1,1.5,0.1,setosa
5.5,4.2,1.4,0.2,setosa
4.9,3.1,1.5,0.1,setosa
5.0,3.2,1.2,0.2,setosa
5.5,3.5,1.3,0.2,setosa
4.9,3.1,1.5,0.1,setosa
4.4,3.0,1.3,0.2,setosa
5.1,3.4,1.5,0.2,setosa
5.0,3.5,1.3,0.3,setosa
4.5,2.3,1.3,0.3,setosa
4.4,3.2,1.3,0.2,setosa
5.0,3.5,1.6,0.6,setosa
5.1,3.8,1.9,0.4,setosa
4.8,3.0,1.4,0.3,setosa
5.1,3.8,1.6,0.2,setosa
4.6,3.2,1.4,0.2,setosa
5.3,3.7,1.5,0.2,setosa
5.0,3.3,1.4,0.2,setosa
7.0,3.2,4.7,1.4,versicolor
6.4,3.2,4.5,1.5,versicolor
6.9,3.1,4.9,1.5,versicolor
5.5,2.3,4.0,1.3,versicolor
6.5,2.8,4.6,1.5,versicolor
5.7,2.8,4.5,1.3,versicolor
6.3,3.3,4.7,1.6,versicolor
4.9,2.4,3.3,1.0,versicolor
6.6,2.9,4.6,1.3,versicolor
5.2,2.7,3.9,1.4,versicolor
5.0,2.0,3.5,1.0,versicolor
5.9,3.0,4.2,1.5,versicolor
6.0,2.2,4.0,1.0,versicolor
6.1,2.9,4.7,1.4,versicolor
5.6,2.9,3.6,1.3,versicolor
6.7,3.1,4.4,1.4,versicolor
5.6,3.0,4.5,1.5,versicolor
5.8,2.7,4.1,1.0,versicolor
6.2,2.2,4.5,1.5,versicolor
5.6,2.5,3.9,1.1,versicolor
5.9,3.2,4.8,1.8,versicolor
6.1,2.8,4.0,1.3,versicolor
6.3,2.5,4.9,1.5,versicolor
6.1,2.8,4.7,1.2,versicolor
6.4,2.9,4.3,1.3,versicolor
6.6,3.0,4.4,1.4,versicolor
6.8,2.8,4.8,1.4,versicolor
6.7,3.0,5.0,1.7,versicolor
6.0,2.9,4.5,1.5,versicolor
5.7,2.6,3.5,1.0,versicolor
5.5,2.4,3.8,1.1,versicolor
5.5,2.4,3.7,1.0,versicolor
5.8,2.7,3.9,1.2,versicolor
6.0,2.7,5.1,1.6,versicolor
5.4,3.0,4.5,1.5,versicolor
6.0,3.4,4.5,1.6,versicolor
6.7,3.1,4.7,1.5,versicolor
6.3,2.3,4.4,1.3,versicolor
5.6,3.0,4.1,1.3,versicolor
5.5,2.5,4.0,1.3,versicolor
5.5,2.6,4.4,1.2,versicolor
6.1,3.0,4.6,1.4,versicolor
5.8,2.6,4.0,1.2,versicolor
5.0,2.3,3.3,1.0,versicolor
5.6,2.7,4.2,1.3,versicolor
5.7,3.0,4.2,1.2,versicolor
5.7,2.9,4.2,1.3,versicolor
6.2,2.9,4.3,1.3,versicolor
5.1,2.5,3.0,1.1,versicolor
5.7,2.8,4.1,1.3,versicolor
6.3,3.3,6.0,2.5,virginica
5.8,2.7,5.1,1.9,virginica
7.1,3.0,5.9,2.1,virginica
6.3,2.9,5.6,1.8,virginica
6.5,3.0,5.8,2.2,virginica
7.6,3.0,6.6,2.1,virginica
4.9,2.5,4.5,1.7,virginica
7.3,2.9,6.3,1.8,virginica
6.7,2.5,5.8,1.8,virginica
7.2,3.6,6.1,2.5,virginica
6.5,3.2,5.1,2.0,virginica
6.4,2.7,5.3,1.9,virginica
6.8,3.0,5.5,2.1,virginica
5.7,2.5,5.0,2.0,virginica
5.8,2.8,5.1,2.4,virginica
6.4,3.2,5.3,2.3,virginica
6.5,3.0,5.5,1.8,virginica
7.7,3.8,6.7,2.2,virginica
7.7,2.6,6.9,2.3,virginica
6.0,2.2,5.0,1.5,virginica
6.9,3.2,5.7,2.3,virginica
5.6,2.8,4.9,2.0,virginica
7.7,2.8,6.7,2.0,virginica
6.3,2.7,4.9,1.8,virginica
6.7,3.3,5.7,2.1,virginica
7.2,3.2,6.0,1.8,virginica
6.2,2.8,4.8,1.8,virginica
6.1,3.0,4.9,1.8,virginica
6.4,2.8,5.6,2.1,virginica
7.2,3.0,5.8,1.6,virginica
7.4,2.8,6.1,1.9,virginica
7.9,3.8,6.4,2.0,virginica
6.4,2.8,5.6,2.2,virginica
6.3,2.8,5.1,1.5,virginica
6.1,2.6,5.6,1.4,virginica
7.7,3.0,6.1,2.3,virginica
6.3,3.4,5.6,2.4,virginica
6.4,3.1,5.5,1.8,virginica
6.0,3.0,4.8,1.8,virginica
6.9,3.1,5.4,2.1,virginica
6.7,3.1,5.6,2.4,virginica
6.9,3.1,5.1,2.3,virginica
5.8,2.7,5.1,1.9,virginica
6.8,3.2,5.9,2.3,virginica
6.7,3.3,5.7,2.5,virginica
6.7,3.0,5.2,2.3,virginica
6.3,2.5,5.0,1.9,virginica
6.5,3.0,5.2,2.0,virginica
6.2,3.4,5.4,2.3,virginica
5.9,3.0,5.1,1.8,virginica
package de.ruuns;
import org.ejml.data.DenseMatrix64F;
interface Classifier {
public DenseMatrix64F classify(DenseMatrix64F input);
public void train(Dataset D);
}
package de.ruuns;
import org.ejml.data.DenseMatrix64F;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.Map;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
public class Dataset {
public static final int SUPERVISED = 1;
public static final int UNSUPERVISED = 2;
public List<Instance> data;
public int classes;
public int features;
public Dataset(List<Instance> data) {
assert(data.size() > 0);
Set<Integer> classes = new TreeSet<Integer>();
for(Instance in: data)
classes.add(in.y);
this.features = data.get(0).x.getNumElements();
this.classes = classes.size();
this.data = data;
}
public int size() {
return data.size();
}
public static Dataset fromCSV(String file, int type) throws Exception {
BufferedReader csv = new BufferedReader(new FileReader(file));
List<Instance> data = new ArrayList<Instance>();
ArrayList<Map<String, Double>> groups = new ArrayList<Map<String, Double>>();
try {
String entry = csv.readLine();
// Initialize symbol class list
if(entry != null) {
String[] V = entry.trim().split(",|\\s");
for(int i = 0; i < V.length; i++) {
groups.add(new TreeMap<String, Double>());
}
}
csv.close();
csv = new BufferedReader(new FileReader(file));
while((entry = csv.readLine()) != null) {
String[] V = entry.trim().split(",|\\s");
DenseMatrix64F x = new DenseMatrix64F(V.length, 1);
for(int i = 0; i < V.length; ++i) {
String attribute = V[i].trim();
try {
x.set(i, Double.parseDouble(attribute));
} catch(NumberFormatException e) {
if(!groups.get(i).containsKey(attribute))
groups.get(i).put(attribute, (double) groups.get(i).size());
x.set(i, groups.get(i).get(attribute));
}
}
if(type == SUPERVISED) {
int target = (int) x.get(x.getNumElements() - 1);
x.reshape(x.getNumElements() - 1, 1);
data.add(new Instance(x, target));
} else
data.add(new Instance(x));
}
} finally {
csv.close();
}
return new Dataset(data);
}
}
package de.ruuns;
import java.util.List;
import org.ejml.data.DenseMatrix64F;
import de.ruuns.Classifier;
import de.ruuns.Loader;
public class Evaluation {
public static double error(Classifier C, Dataset D) {
DenseMatrix64F y = new DenseMatrix64F(D.classes, 1);
int errors = 0;
for(Instance p: D.data) {
DenseMatrix64F h = C.classify(p.x);
if(Instance.argmax(h) != p.y)
errors = errors + 1;
}
return errors / (double) D.size();
}
}
package de.ruuns;
import org.ejml.data.DenseMatrix64F;
public class Instance {
public static final int NO_CLASS = -1;
public DenseMatrix64F x;
public int y;
public Instance(DenseMatrix64F x, int y) {
this.x = x;
this.y = y;
}
public Instance(DenseMatrix64F x) {
this.x = x;
this.y = NO_CLASS;
}
public static void target(DenseMatrix64F v, int y) {
assert(y != NO_CLASS && y < v.getNumElements());
v.zero();
v.set(y, 1.0);
}
public static int argmax(DenseMatrix64F v) {
assert(v.getNumElements() > 0);
int index = 0;
double value = v.get(0);
for(int i = 1; i < v.getNumElements(); ++i) {
if(value < v.get(i)) {
value = v.get(i);
index = i;
}
}
return index;
}
}
package de.ruuns;
import java.util.List;
import org.ejml.data.DenseMatrix64F;
import de.ruuns.Dataset;
import de.ruuns.Evaluation;
import de.ruuns.Instance;
import de.ruuns.NeuralNetwork;
public class Main {
public static final int HIDDEN_NODES = 8;
public final static void main(String[] args) {
if(args.length < 1) {
System.err.println("No Dataset chosen");
System.exit(1);
}
try {
Dataset D = Dataset.fromCSV(args[0], Dataset.SUPERVISED);
Classifier mlp = new NeuralNetwork(0.05, 4, HIDDEN_NODES, D.classes);
System.out.printf("Accuracy %f\n", Evaluation.error(mlp, D));
} catch(Exception e) {
e.printStackTrace(System.err);
}
}
}
package de.ruuns;
import de.ruuns.Classifier;
import org.ejml.data.DenseMatrix64F;
import java.util.List;
public class NaiveBayes implements Classifier {
public NaiveBayes() {
}
public DenseMatrix64F classify(DenseMatrix64F input) {
return null;
}
public void train(Dataset dataset) {
}
private DenseMatrix64F mean = null;
private DenseMatrix64F variance = null;
}
package de.ruuns;
import de.ruuns.Classifier;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
import java.util.ArrayList;
import java.util.Random;
class Layer {
static Random RNG = new Random();
static Random RNG = new Random(System.currentTimeMillis());
public Layer(int I, int O, double var) {
W = new DenseMatrix64F(I + 1, O);
a = new DenseMatrix64F(O + 1);
this.W = new DenseMatrix64F(I + 1, O);
this.a = new DenseMatrix64F(O + 1);
this.a.reshape(1, O);
for(int i = 0; i < W.getNumElements(); ++i)
W.set(i, RNG.nextGaussian() * var);
for(int i = 0; i < W.getNumElements(); ++i) {
this.W.set(i, RNG.nextGaussian() * var);
}
};
public DenseMatrix64F W;
......@@ -22,23 +25,32 @@ class Layer {
public class NeuralNetwork {
public ArrayList<Layer> layers;
public class NeuralNetwork implements Classifier {
public ArrayList<Layer> layers = new ArrayList<Layer>();
public NeuralNetwork(double init_variance, int ... dims) {
assert(dims.length > 0);
int I = dims[0];
for(int i = 1; i < dims.length; ++i)
for(int i = 1; i < dims.length; ++i) {
layers.add(new Layer(I, dims[i], init_variance));
I = dims[i];
}
}
public DenseMatrix64F classify(DenseMatrix64F Input) {
return propagateForward(Input);
}
public void train(Dataset D) {
}
public DenseMatrix64F propagateForward(DenseMatrix64F X_input) {
DenseMatrix64F z = X_input.copy();
public DenseMatrix64F propagateForward(DenseMatrix64F x_input) {
assert(x_input.getNumCols() == 1);
DenseMatrix64F z = x_input.copy();
CommonOps.transpose(z);
for(Layer L: layers){
......@@ -47,24 +59,50 @@ public class NeuralNetwork {
z.set(z.getNumElements() - 1, 1.0);
// perform forward propagation
CommonOps.multTransA(z, L.W, L.a);
CommonOps.mult(z, L.W, L.a);
// remove bias value again
z.reshape(1, z.getNumElements() - 1, true);
// apply activation function g(a)
z = sigmoid(L.a);
z = logistics(L.a);
}
return z;
}
private DenseMatrix64F sigmoid(DenseMatrix64F X) {
for(int i = 0; i < X.getNumElements(); ++i) {
X.set(i, 1.0 / (1.0 + Math.exp(- X.get(i))));
public void optimizeGradientDescent(DenseMatrix64F X, DenseMatrix64F Y, double alpha) {
assert(Y.getNumCols() == 1);
}
private DenseMatrix64F logistics(DenseMatrix64F x) {
for(int i = 0; i < x.getNumElements(); ++i) {
double value = x.get(i);
if(value < -45.0)
x.set(i, 0.0);
else if(value > 45.0)
x.set(i, 1.0);
else
x.set(i, 1.0 / (1.0 + Math.exp(- x.get(i))));
}
return X;
return x;
}
private DenseMatrix64F softmax(DenseMatrix64F x) {
double max = CommonOps.elementMax(x);
for(int i = 0; i < x.getNumElements(); ++i) {
x.set(i, Math.exp(x.get(i) - max));
}
CommonOps.divide(CommonOps.elementSum(x), x);
return x;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment