//package einzeln;

// Enthaelt Testdaten fuer ein neuronales Netz. Definition von Trainings- und Testbeispielen

import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.io.RandomAccessFile;
import java.lang.reflect.Method;
import java.util.Properties;

public class NetzDaten {
	static double[][] TRAININGSBEISPIELE = null;
	static double[][] TESTBEISPIELE = null;
	static int[] KLASSEN = null;
	static int[] KLASSEN2 = null;
	static int GROESSE_EINGABESCHICHT = -1;
	static int GROESSE_AUSGABESCHICHT = -1;
	public Properties p;

	public NetzDaten() throws Exception {
		p = new Properties();
		p.load(NetzDaten.class.getResourceAsStream("netzdaten.properties"));
		String funktion = p.getProperty("funktion");
		Method m[] = this.getClass().getDeclaredMethods();
		for (Method lauf : m) {
			if (lauf.getName().equals(funktion)) {
				lauf.invoke(this);
				break;
			}
		}
		if (TRAININGSBEISPIELE.length != (KLASSEN == null ? GROESSE_AUSGABESCHICHT
				: KLASSEN.length))
			throw new RuntimeException(
					"Anzahl der Testmuster nicht gleich Größe der Ausgabeschicht");
	}

	public static void main(String[] args) throws Exception {
		NetzDaten nd = new NetzDaten(); // Daten wählen

		Netz n = new Netz(GROESSE_EINGABESCHICHT, GROESSE_AUSGABESCHICHT,
				TRAININGSBEISPIELE, TESTBEISPIELE, KLASSEN);

		// n.mitKonvergenz = 3; // Konvergenz gegen best. Muster
		n.ANZAHL_RUNDEN = Integer.parseInt(nd.p.getProperty("ANZAHL_RUNDEN"));
		n.LERNRATE = Float.parseFloat(nd.p.getProperty("LERNRATE"));
		n.mitGradient = "true".equals(nd.p.getProperty("mitGradient"));
		if (nd.p.getProperty("anzahlThreads") != null) {
			n.anzThreads = Integer.parseInt(nd.p.getProperty("anzahlThreads"));
		}

		if ("backpropagation".equals(nd.p.getProperty("modell"))) {
			// Eine Zwischenschicht, also 3 Schichten insgesamt
			int groesseLernkapazitaet = Math.max(GROESSE_EINGABESCHICHT,
					GROESSE_AUSGABESCHICHT);
			n.initialisieren(GROESSE_EINGABESCHICHT, groesseLernkapazitaet,
					GROESSE_AUSGABESCHICHT);
			n.trainierenBackpropagation();
		} else {
			n.initialisieren(GROESSE_EINGABESCHICHT, GROESSE_AUSGABESCHICHT);
			n.trainierenHebb();
		}

		String fu = nd.p.getProperty("funktion");
		// Trainingsmuster und rechnen
		System.out.println("\nAbschluss Trainingsphase " + fu + " :");
		for (int i = 0; i < TRAININGSBEISPIELE.length; ++i) {
			int erg = n.rateMuster(TRAININGSBEISPIELE[i]);
			int soll = KLASSEN != null ? KLASSEN[i] : i;
			System.out.println("Trainingsmuster #" + i + " ist " + erg
					+ ", erwartet " + soll);
		}
		// Anderes Testmuster und rechnen
		System.out.println("\nDurchführen Testphase " + fu + " :");
		for (int i = 0; i < TESTBEISPIELE.length; ++i) {
			int erg = n.rateMuster(TESTBEISPIELE[i]);
			int soll = KLASSEN2 != null ? KLASSEN2[i] : i;
			System.out.print("Testmuster #" + i + " ist am besten " + erg);
			System.out
					.println(soll != erg ? (", Differenz " + erg + " und " + soll)
							: ", Treffer");
		}
		// Threads ggf. beenden
		n.beenden();

		// Das ganze Netz per ObjectOutputStream ausgeben
		if ("true".equals(nd.p.getProperty("dumpModell"))) {
			ObjectOutputStream oos = new ObjectOutputStream(
					new FileOutputStream("netz.bin"));
			oos.writeObject(n);
			oos.close();
			System.out.println("... netz.bin geschrieben.");
		}
	}

	// Erkennung von Schwarz-/Weissbildern
	private void bilderNehmen() {
		GROESSE_AUSGABESCHICHT = 4;
		GROESSE_EINGABESCHICHT = 100;
		TRAININGSBEISPIELE = new double[GROESSE_AUSGABESCHICHT][GROESSE_EINGABESCHICHT];
		TRAININGSBEISPIELE[0] = new double[] { 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
				0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
				0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
				1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
				1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
				1, 0, 0, 0, 0 }; // Kreuz
		TRAININGSBEISPIELE[1] = new double[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
				0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1,
				1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
				0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1,
				1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
				0, 0, 0, 0, 0 }; // Lachendes Gesicht

		TRAININGSBEISPIELE[2] = new double[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
				0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1,
				1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
				0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
				1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
				0, 0, 0, 0, 0 }; // Weinendes Gesicht

		TRAININGSBEISPIELE[3] = new double[] { 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
				0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
				1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0,
				0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
				1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1,
				1, 0, 0, 0, 0 }; // Vier

		TESTBEISPIELE = new double[1][GROESSE_EINGABESCHICHT];
		TESTBEISPIELE[0] = new double[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
				0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1,
				0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
				1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1,
				1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
				0, 0, 0, 0 }; // Anders lachendes Gesicht

	}

	private int hoehe = 5;
	private int breite = 5;
	private double a = 1.0, b = 0.0; // Parameter Trenngerade
	private final boolean oneCode = false; // false deutl. schlechtere
											// Konvergenz, aber weniger
											// overfitting
											// !

	// Trenngerade
	private double trennFunktion(double s) {
		return a * s + b;
	}

	// Zuweisung Index zu 1(über Gerade) oder 0(unter Gerade)
	private int trennwert(int index) {
		double z = hoehe - (index / breite) - 0.5;
		double s = index % breite + 0.5;
		double wert = trennFunktion(s);
		return (z >= wert) ? 1 : 0;
	}

	// Eingabefeld zu Index, wo 1 gesetzt wird. Index 0 beginnt visuell oben
	// links
	private double[] eingabefeldIndex(int index) {
		int anzahl = oneCode ? hoehe * breite
				: (int) (Math.log(hoehe * breite) / Math.log(2)) + 1;
		double feld[] = new double[anzahl];
		if (oneCode)
			feld[index] = 1;
		else {
			int i = 0;
			while (index > 0) {
				feld[i] = index % 2;
				index /= 2;
				i++;
			}
		}
		return feld;
	}

	// Erkennung von Punkten in einem rechteckigen Feld, welche durch eine
	// Funktion z.B. eine Gerade
	// getrennt werden
	// |0|1|2|3|4|
	// |5|6|7|8|9|
	// |10|...|14|
	// |15|...|19|
	// |20|...|24|
	private void punkteNehmen() {
		GROESSE_AUSGABESCHICHT = 2;
		// Hinweis: folgendes wirkt sich schlecht auf Konvergenz aus(binär
		// verpackt);
		GROESSE_EINGABESCHICHT = !oneCode ? ((int) (Math.log(hoehe * breite) / Math
				.log(2)) + 1) : hoehe * breite;

		// verringern
		TRAININGSBEISPIELE = new double[17][GROESSE_EINGABESCHICHT];
		KLASSEN = new int[17];

		// über und unter der Trenngerade
		int indizes[] = { 1, 2, 3, 5, 7, 10, 11, 15, 9, 13, 14, 17, 19, 21, 22,
				23, 24 };
		for (int i = 0; i < indizes.length; ++i) {
			TRAININGSBEISPIELE[i] = eingabefeldIndex(indizes[i]);
			KLASSEN[i] = trennwert(indizes[i]);
		}

		TESTBEISPIELE = new double[2][GROESSE_EINGABESCHICHT];
		KLASSEN2 = new int[2];
		TESTBEISPIELE[0] = eingabefeldIndex(18); // => deutlich drunter,
													// Klasse
													// 0
		KLASSEN2[0] = 0;
		TESTBEISPIELE[1] = eingabefeldIndex(6); // => deutlich drüber, Klasse 1
		KLASSEN2[1] = 1;
	}

	// Erkennung von Quadratzahlen wie 1,4,9 etc.
	private void beispieleQuadratNehmen() {
		// Erkennung von Quadratzahlen
		GROESSE_AUSGABESCHICHT = 6; // 15;
		GROESSE_EINGABESCHICHT = 4;
		TRAININGSBEISPIELE = new double[GROESSE_AUSGABESCHICHT][GROESSE_EINGABESCHICHT];

		// einfacher Test fuer 2 Quadrate
		// TRAININGSBEISPIELE[0] = new double[] { 1, 0, 0 }; // 1
		// TRAININGSBEISPIELE[1] = new double[] { 0, 0, 1 }; // 4

		// Klassifizierung
		KLASSEN = new int[GROESSE_AUSGABESCHICHT];
		TRAININGSBEISPIELE[0] = new double[] { 1, 0, 0, 0 };
		KLASSEN[0] = 1;// 1
		TRAININGSBEISPIELE[1] = new double[] { 0, 1, 0, 0 };
		KLASSEN[1] = 0;// 2
		TRAININGSBEISPIELE[2] = new double[] { 1, 1, 0, 0 };
		KLASSEN[2] = 0;// 3
		TRAININGSBEISPIELE[3] = new double[] { 0, 0, 1, 0 };
		KLASSEN[3] = 1;// 4
		TRAININGSBEISPIELE[4] = new double[] { 1, 0, 1, 0 };
		KLASSEN[4] = 0;// 5
		TRAININGSBEISPIELE[5] = new double[] { 0, 1, 1, 0 };
		KLASSEN[5] = 0;// 6

		// TRAININGSBEISPIELE[6] = new double[] { 1, 1, 1, 0 };
		// KLASSEN[6] = 0;// 7
		// TRAININGSBEISPIELE[7] = new double[] { 0, 0, 0, 1 };
		// KLASSEN[7] = 0;// 8
		// TRAININGSBEISPIELE[8] = new double[] { 1, 0, 0, 1 };
		// KLASSEN[8] = 1;// 9
		// TRAININGSBEISPIELE[9] = new double[] { 0, 1, 0, 1 };
		// KLASSEN[9] = 0;// 10
		// TRAININGSBEISPIELE[10] = new double[] { 1, 1, 0, 1 };
		// KLASSEN[10] = 0;// 11
		// TRAININGSBEISPIELE[11] = new double[] { 0, 0, 1, 1 };
		// KLASSEN[11] = 0;// 12
		// TRAININGSBEISPIELE[12] = new double[] { 1, 0, 1, 1 };
		// KLASSEN[12] = 0;// 13
		// TRAININGSBEISPIELE[13] = new double[] { 0, 1, 1, 1 };
		// KLASSEN[13] = 0;// 14
		// TRAININGSBEISPIELE[14] = new double[] { 1, 1, 1, 1 };
		// KLASSEN[14] = 0;// 15

		GROESSE_AUSGABESCHICHT = 2; // da Quadrat oder Nichtquadrat
		// KLASSEN = null; // doch nicht !

		// TRAININGSBEISPIELE[0] = new double[] { 1, 0, 0, 0, 0, 0 }; // 1
		// TRAININGSBEISPIELE[1] = new double[] { 0, 0, 1, 0, 0, 0 }; // 4
		// TRAININGSBEISPIELE[2] = new double[] { 1, 0, 0, 1, 0, 0 }; // 9
		// TRAININGSBEISPIELE[3] = new double[] { 0, 0, 0, 0, 1, 0 }; // 16
		// TRAININGSBEISPIELE[4] = new double[] { 0, 0, 1, 0, 0, 1 }; // 36
		// TRAININGSBEISPIELE[5] = new double[] { 1, 0, 0, 0, 1, 1 }; // 49

		TESTBEISPIELE = new double[1][GROESSE_EINGABESCHICHT];
		// TESTBEISPIELE[0] = new double[] { 1, 0, 0, 1, 1, 0 }; // 25
		TESTBEISPIELE[0] = new double[] { 0, 0, 1, 0 }; // 4
		KLASSEN2 = new int[1];
		KLASSEN2[0] = 1;
	}

	// Nimmt die Trainings- und Testdaten von mnist/Tariq Rashid
	private void mnistNehmen() throws Exception {
		// siehe
		// https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/
		// master/mnist_dataset/mnist_train_100.csv
		// siehe
		// https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/
		// master/mnist_dataset/mnist_test_10.csv
		// In Jeder Zeile steht zuerst die Zahl(0-9) zur Klassif., dann die
		// Farbwerte(0-255),
		// welche auf 0.0-1.0 normiert werden
		final String pfad = "/home/jurgen/workspace/Zahlentheorie/src/";
		GROESSE_AUSGABESCHICHT = 100;
		GROESSE_EINGABESCHICHT = 784; // 28 * 28
		TRAININGSBEISPIELE = new double[GROESSE_AUSGABESCHICHT][GROESSE_EINGABESCHICHT];
		KLASSEN = new int[GROESSE_AUSGABESCHICHT];
		File f = new File(pfad + "mnist_train_100.csv");
		RandomAccessFile raf = new RandomAccessFile(f, "r");
		String zeile;
		int nr = 0;
		while ((zeile = raf.readLine()) != null) {
			String werte[] = zeile.split(",");
			if (werte.length != GROESSE_EINGABESCHICHT + 1)
				throw new Exception("Zeilenformat passt nicht");
			for (int i = 1; i < GROESSE_EINGABESCHICHT + 1; ++i) {
				TRAININGSBEISPIELE[nr][i - 1] = 0.01 + Integer
						.parseInt(werte[i]) / 255.0;
			}
			KLASSEN[nr] = Integer.parseInt(werte[0]);
			++nr;
		}
		if (nr != GROESSE_AUSGABESCHICHT)
			throw new Exception("Nicht genügend Datensätze");
		raf.close();

		TESTBEISPIELE = new double[10][GROESSE_EINGABESCHICHT];
		KLASSEN2 = new int[10];
		f = new File(pfad + "mnist_test_10.csv");
		raf = new RandomAccessFile(f, "r");
		nr = 0;
		while ((zeile = raf.readLine()) != null) {
			String werte[] = zeile.split(",");
			if (werte.length != GROESSE_EINGABESCHICHT + 1)
				throw new Exception("Zeilenformat passt nicht");
			for (int i = 1; i < GROESSE_EINGABESCHICHT + 1; ++i) {
				TESTBEISPIELE[nr][i - 1] = 0.01 + Integer.parseInt(werte[i]) / 255.0;
			}
			KLASSEN2[nr] = Integer.parseInt(werte[0]);
			++nr;
		}
		if (nr != 10)
			throw new Exception("Nicht genügend Datensätze");
		raf.close();
	}
}
