//package einzeln;

// Ein neuronales Netz zum Erkennen von Bildern
// Autor: J. Gamenik, Dez. 2016
// Vorlage: Markus von Rimscha, "Algorithmen kompakt und verständlich", Springer Verlag, 3. Auflage 2008, Kapitel 5.3
// Mai 2018: Hebb'sche Regel bei fehlenden Zwischenschichten
// Oktober 2018: sigmoid, relative Gewichtsaufteilung

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.*;

// Ein Netz hat eine Eingangsschicht, ggf. noch Zwischenschichten und eine Ausgangsschicht. Letztere
// gibt einen Wert zwischen 0.0 und 1.0(=Treffer) zurück, ob das Eingangsmuster erkannt wurde

// Hebb'sche Regel für Änderung der Gewichte:
// g(i) := g(i) + L * e(i) * a, L = 0.1 die Lernrate; Das Gewicht wächst also um so mehr, je größer
// sowohl der Eingangs- als auch Ausgangswert ist.

public class Netz implements Serializable {
	double[][] TRAININGSBEISPIELE;
	double[][] TESTBEISPIELE;
	int[] KLASSEN;
	int GROESSE_EINGABESCHICHT;
	int GROESSE_AUSGABESCHICHT;

	static final double EPS = 0.001;
	final static boolean DEBUGZSCHICHT = false; // Zeigt Veränderung der
												// Gewichte an
	int ANZAHL_RUNDEN = 1000;
	int mitKonvergenz = -1; // <0 -> keine Konvergenz gg.
							// ein Muster
	static Aktivierung mitFunktion = Aktivierung.IDENT;
	boolean mitGradient = true;
	double LERNRATE = 0.3;
	int anzThreads = 1;

	enum Aktivierung {
		IDENT, SIGMOID, TANH
	};

	private ArrayList<Schicht> schichten;
	private Random rand = new Random(System.currentTimeMillis());
	public static DecimalFormat format = new DecimalFormat("#0.000");

	// Konstruktor, welcher Eingabedaten uebernimmt
	public Netz(int groesseEingabeschicht, int groesseAusgabeschicht,
			double[][] trainingsbeispiele, double[][] testbeispiele,
			int[] klassen) {
		TRAININGSBEISPIELE = trainingsbeispiele;
		TESTBEISPIELE = testbeispiele;
		KLASSEN = klassen;
		GROESSE_EINGABESCHICHT = groesseEingabeschicht;
		GROESSE_AUSGABESCHICHT = groesseAusgabeschicht;

		if (TRAININGSBEISPIELE.length != (KLASSEN == null ? GROESSE_AUSGABESCHICHT
				: KLASSEN.length))
			throw new RuntimeException(
					"Anzahl der Testmuster nicht gleich Größe der Ausgabeschicht");
	}

	// Legt Schichten an mit n Neuronen an
	public void initialisieren(int... n) {
		schichten = new ArrayList<Schicht>(n.length);
		for (int i = 0; i < n.length; ++i) {
			if (n[i] < 2)
				throw new RuntimeException(
						"Schicht aus einem Neuron gefaehrdet die Konvergenz");
			schichten.add(new Schicht(n[i], i == 0 ? 0 : n[i - 1]));
		}
	}

	// Setzt die Gewichte ausser der Eingangsschicht
	public void gewichteZufaelligBelegen() {
		Set<Integer> leereEingaenge = pruefeTestdatenSignale(); // Eingabeknoten
																// ohne Signale
		for (int s = 1; s < schichten.size(); ++s) {
			for (int n = 0; n < schichten.get(s).anzNeuronen; ++n) {
				Neuron neuron = schichten.get(s).neuronen.get(n);
				double faktor = 1.0;
				// nach dem Buch von Tariq Rashid zwischen -1/wurzel(anzahl
				// gewichte) und +1/wurzel(anzahl gewichte)
				// faktor = Math.sqrt(neuron.gewicht.length - (s ==
				// 1 ? leereEingaenge.size() : 0));
				for (int v = 0; v < neuron.gewicht.length; ++v) {

					neuron.gewicht[v] = (s == 1 && leereEingaenge.contains(v)) ? 0.0
							: rand.nextDouble() * faktor;
				}
			}
		}
	}

	// Ermittelt Eingänge, wo bei allen Trainingsdaten es nur Eingänge mit 0
	// gibt
	private Set<Integer> pruefeTestdatenSignale() {
		Set<Integer> leere = new HashSet<Integer>();
		for (int i = 0; i < TRAININGSBEISPIELE[0].length; ++i) {
			boolean gesetzt = false;
			for (int j = 0; j < TRAININGSBEISPIELE.length && !gesetzt; ++j) {
				gesetzt = (TRAININGSBEISPIELE[j][i] != 0.0);
			}
			if (!gesetzt)
				leere.add(i);
		}
		return leere;
	}

	public double[] gibTrainingsBeispiel(int n) {
		if (TRAININGSBEISPIELE[n].length != GROESSE_EINGABESCHICHT)
			throw new RuntimeException(
					"Trainingsbeispiel passt nicht zur Größe der Eingabeschicht");
		double[] beispiel = new double[GROESSE_EINGABESCHICHT];
		for (int j = 0; j < GROESSE_EINGABESCHICHT; ++j)
			beispiel[j] = TRAININGSBEISPIELE[n][j];

		return beispiel;
	}

	public void belegeEingangsSchichtMitBeispiel(double[] beispiel) {
		for (int n = 0; n < GROESSE_EINGABESCHICHT; ++n) {
			Neuron eingangsNeuron = this.schichten.get(0).neuronen.get(n);
			eingangsNeuron.eingang[0] = beispiel[n];
			eingangsNeuron.berechneAusgang();
		}
	}

	private List<RechnenAusgabe> threads = new ArrayList();
	private Vector<Daten> pufferE = new Vector<Daten>();
	private Vector<Daten> pufferA = new Vector<Daten>();

	// Dies ist nach jvisualm die Methode mit der längsten Laufzeit. Deshalb
	// parallelisieren
	public void berechneKomplettesNetz() {
		while (threads.size() < anzThreads) {
			RechnenAusgabe t1 = new RechnenAusgabe(pufferE, pufferA,
					Thread.currentThread());
			t1.start();
			threads.add(t1);
		}

		for (int s = 0; s < this.schichten.size(); ++s) {
			Schicht schicht = this.schichten.get(s);
			List<Runnable> aufgaben = new ArrayList<Runnable>();
			int pMax = (anzThreads < 1) ? 1
					: (anzThreads > schicht.anzNeuronen ? schicht.anzNeuronen
							: anzThreads);
			pufferA.clear();
			pufferE.clear();
			for (int i = 0; i < pMax; ++i) {
				int von = (i == 0) ? 0 : (schicht.anzNeuronen / pMax);
				int bis = (i == pMax - 1) ? schicht.anzNeuronen - 1
						: (schicht.anzNeuronen / pMax) - 1;
				Daten d1 = new Daten();
				d1.von = von;
				d1.bis = bis;
				d1.schicht = schicht;
				d1.vorgaenger = (s == 0) ? null : this.schichten.get(s - 1);
				pufferE.add(d1);
			}
			benachrichtigen();
			// wenn in Ausgabe gleiche Anzahl wie in Eingabe, erkennt man, dass
			// beide Threads
			// ihre Aufgabe erledigt haben ?
			while (pufferE.size() > pufferA.size()) {
				try {
					Thread.sleep(100);
				} catch (InterruptedException ex) {
				}
			}
		}
	}

	private void benachrichtigen() {
		for (int i = 0; i < threads.size(); ++i) {
			((Thread) threads.get(i)).interrupt();
		}
	}

	public void beenden() {
		for (int i = 0; i < threads.size(); ++i) {
			((RechnenAusgabe) threads.get(i)).fertig = true;
			((Thread) threads.get(i)).interrupt();
		}
		threads.clear();
		if (!pufferE.isEmpty())
			throw new RuntimeException("Es sind noch offene Aufgaben");
		pufferA.clear();
	}

	// Bekommt ein Neuron und den Index in der Ausgabeschicht dazu. Bestimmt
	// Ausgabewert
	private double bestimmeSollwert(Neuron n, int neuronPosition,
			int indexBeispiel) {
		if (KLASSEN != null
				&& (KLASSEN[indexBeispiel] < 0
						|| KLASSEN[indexBeispiel] >= GROESSE_AUSGABESCHICHT
						|| neuronPosition < 0 || neuronPosition >= GROESSE_AUSGABESCHICHT))
			throw new RuntimeException("Ungültige Abbildung");

		double sollwert = ((KLASSEN == null && neuronPosition == indexBeispiel) || (KLASSEN != null && KLASSEN[indexBeispiel] == neuronPosition)) ? 1.0
				: 0.0;
		sollwert = Neuron.funktion(sollwert);
		return sollwert;
	}

	// rechnet alle oder nur die angestrebten Testdaten durch und gibt mittleren
	// quadratischen Fehler
	// zurück
	public double berechneAbstand(int runde) {
		// Annahme: Gewichte wurden schon initialisiert
		double summe = 0.0;
		Schicht ausgabeschicht = this.schichten.get(schichten.size() - 1);
		for (int nr = 0; nr < TRAININGSBEISPIELE.length; ++nr) {
			if (mitKonvergenz >= 0 && runde > ANZAHL_RUNDEN / 4
					&& nr != mitKonvergenz
					&& mitKonvergenz < TRAININGSBEISPIELE.length)
				continue;
			double[] beispiel = gibTrainingsBeispiel(nr); // danach beispiel
															// gefüllt
			belegeEingangsSchichtMitBeispiel(beispiel);
			berechneKomplettesNetz();
			double teilsumme = 0.0;
			for (int n = 0; n < ausgabeschicht.anzNeuronen; ++n) {
				Neuron neuron = ausgabeschicht.neuronen.get(n);
				double ist = neuron.ausgang;
				// Charakterisierung nach Klassen oder Musternummern
				double soll = bestimmeSollwert(neuron, n, nr);
				teilsumme += (soll - ist) * (soll - ist);
			}
			summe += Math.sqrt((teilsumme / ausgabeschicht.anzNeuronen));
		}
		if (mitKonvergenz < 0 || mitKonvergenz >= TRAININGSBEISPIELE.length)
			summe /= TRAININGSBEISPIELE.length;
		return summe;
	}

	// �berwachtes Lernen mit Hebb'scher Regel
	// F�r Ausgabeknoten a(j) gilt: g(i) = g(i) + L*e(i)*a(j)
	// d.h �nderung der Gewichte linear abh�ngig von Lernrate,
	// Eingabe- und Ausgabeimpuls. Keine Initialgewichte gebraucht
	public void trainierenHebb() {
		if (this.schichten.size() != 2)
			throw new RuntimeException(
					"Hebb'sche Regel nur ohne Zwischenschicht m�glich");

		// Gewichte stehen auf 0;
		Schicht s = schichten.get(1); // Ausgabeschicht
		for (int runde = 0; runde < ANZAHL_RUNDEN; ++runde) {
			int nr = (int) (Math.random() * TRAININGSBEISPIELE.length);
			double[] beispiel = gibTrainingsBeispiel(nr); // danach beispiel
															// gefüllt
			belegeEingangsSchichtMitBeispiel(beispiel);
			berechneKomplettesNetz();

			// Bestimme jedes Gewicht neu mithilfe des SollAusgangswerts
			boolean aenderung = false;
			for (int i = 0; i < GROESSE_EINGABESCHICHT; ++i) {
				for (int j = 0; j < GROESSE_AUSGABESCHICHT; ++j) {
					Neuron ak = s.neuronen.get(j);
					double sollwert = bestimmeSollwert(ak, j, nr);
					double abweichung = ak.eingang[i] * (sollwert - ak.ausgang)
							* LERNRATE;
					if (Math.abs(abweichung) < EPS)
						continue;
					System.out.println("Abweichung " + abweichung);
					ak.gewicht[i] += abweichung;
					aenderung = true;
				}
			}
			System.out.println("Ende Runde " + runde + "\n");
			if (!aenderung)
				break;
		}
	}

	// �berwachtes Lernen mit Backpropagation, Initialgewichte
	// gebraucht, da sonst Backprop. immer 0, wegen gewichteteEing�nge
	// des Neurons immer 0 ausserhalb der Eingabeschicht!
	// Ausgabeschicht: delta = f'(summe_g) * (a_soll - a_ist)
	// Zwischenschicht: delta = f'(summe_g) * summe[nn=Nachfolger von N]
	// (delta[nn] * g[N->nn])
	public void trainierenBackpropagation() {
		gewichteZufaelligBelegen();
		double abstandAnfang = berechneAbstand(0);
		System.out.println("Anfangsabstand: " + abstandAnfang);

		for (int runde = 0; runde < ANZAHL_RUNDEN; ++runde) {
			int nr = (int) (Math.random() * TRAININGSBEISPIELE.length);
			if (mitKonvergenz >= 0 && runde > ANZAHL_RUNDEN / 4) {
				if (mitKonvergenz >= TRAININGSBEISPIELE.length)
					throw new RuntimeException(
							"Kein gueltiges Konvergenzmuster");
				nr = mitKonvergenz; // Konvergenzmuster
			}
			double[] beispiel = gibTrainingsBeispiel(nr); // danach beispiel
															// gefüllt
			belegeEingangsSchichtMitBeispiel(beispiel);
			berechneKomplettesNetz();

			zeigeStatusAusgabeschicht(nr);

			// Netz rückw�rts anpassen ohne Eingangsschicht...
			for (int s = this.schichten.size() - 1; s > 0; s--) {
				Schicht schicht = schichten.get(s);
				for (int n = 0; n < schicht.anzNeuronen; ++n) {
					Neuron neuron = schicht.neuronen.get(n);
					double delta = 0.0;
					if (s == schichten.size() - 1) {
						// Ausgabeschicht; 1 Neuron wahr für Trainingsmuster
						// oder pro Beispiel ein Neuron als Ausgabesignal
						double sollAusgangsWert = bestimmeSollwert(neuron, n,
								nr);
						delta = (sollAusgangsWert - neuron.ausgang);
					} else {
						// Durchlaufe alle Nachfolgeneuronen, mult. deren dBp
						// mit Gewicht von N dort
						for (int nn = 0; nn < schichten.get(s + 1).anzNeuronen; ++nn) {
							Neuron nachfolgerNeuron = schichten.get(s + 1).neuronen
									.get(nn);
							delta += nachfolgerNeuron.deltaBackpropagation
									* nachfolgerNeuron.gewichtAlt[n];
						}
						// ??? delta /= schichten.get(s + 1).anzNeuronen; //
						// f.Konv. ?
					}
					// Speichere dBp
					neuron.deltaBackpropagation = (mitGradient ? Neuron
							.ableitung(neuron.gewichteteEingaenge) : 1) * delta;
					if (Math.abs(neuron.deltaBackpropagation) < EPS)
						continue;
					if (DEBUGZSCHICHT) {
						System.out.println("An " + s + "/" + n + ": dbP ist "
								+ format.format(neuron.deltaBackpropagation));
					}

					double summe = 0.0;
					if (!mitGradient) {
						for (int vn = 0; vn < neuron.gewicht.length; ++vn) {
							summe += Math.abs(neuron.gewicht[vn]); // kein 0
						}
					}

					// Neues Eingangsgewicht; dBp ersetzt Ausgabewert
					for (int vn = 0; vn < neuron.gewicht.length; ++vn) {
						double faktorDbp = 1.0;
						if (!mitGradient) {
							// Gradientenverfahren bei Sigmoid sehr ungünstig
							// für Konvergenz, wenn gewichtete Summe betrags-
							// mäßig >> 1. Deshalb dbp gewichtet verteilen
							faktorDbp = neuron.gewicht[vn] / summe;
						}
						double diffGewicht = (LERNRATE * neuron.eingang[vn]
								* faktorDbp * neuron.deltaBackpropagation);
						neuron.gewichtAlt[vn] = neuron.gewicht[vn];
						neuron.gewicht[vn] += diffGewicht;
						if (DEBUGZSCHICHT) {
							final String titel = (s + "/" + n + "/" + vn + ": ");
							System.out.println("Gewicht an " + titel
									+ format.format(neuron.gewichtAlt[vn])
									+ " -> "
									+ format.format(neuron.gewicht[vn]));
						}
						if (Math.abs(neuron.gewicht[vn]) > 1e6)
							throw new RuntimeException("Gewichte divergieren");
					} // Ende vn
				} // Ende n
			} // Ende s

			if (runde % 16 == 15) {
				double metrik = berechneAbstand(runde);
				System.out.println("\tAbstand Runde " + runde + ": " + metrik);
				if (metrik < 0.01)
					break; // Konvergenz aller Trainingsdaten
				if (metrik > 4 * abstandAnfang) {
					throw new RuntimeException("Divergenz");
				}
			}
		}
	}

	// Zeigt fuer alle Ausgabeneuronen Soll und Ist als Tupel an, auf 10 Werte
	// begrenzt
	private void zeigeStatusAusgabeschicht(int nr) {
		Schicht ausgangsschicht = schichten.get(schichten.size() - 1);
		String soll = "";
		String ist = "";
		for (int n = 0; n < ausgangsschicht.anzNeuronen && n < 10; ++n) {
			Neuron neuron = ausgangsschicht.neuronen.get(n);
			double sollAusgangsWert = bestimmeSollwert(neuron, n, nr);
			soll += format.format(sollAusgangsWert) + " ";
			ist += format.format(neuron.ausgang) + " ";
		}
		System.out.println("Soll/Ist #" + nr + " = (" + soll + ") und (" + ist
				+ ")");
	}

	// belegt mit der Eingabe, rechnet das Netz und gibt den Index des
	// wahrscheinl. Muster
	// zurueck. KLASSEN werden nicht beachtet
	int rateMuster(double[] quadrat) {
		if (quadrat.length != GROESSE_EINGABESCHICHT)
			throw new RuntimeException(
					"Ungültige Eingabeparameter beim Ermitteln der Ausgabe");
		belegeEingangsSchichtMitBeispiel(quadrat);
		berechneKomplettesNetz();
		// Ausgabe der Ausgangsschicht, bestimme Maximum
		double wert = Double.NEGATIVE_INFINITY;
		int indMax = -1;
		Schicht sch = schichten.get(schichten.size() - 1);
		for (int i = 0; i < sch.anzNeuronen; ++i) {
			Neuron letztes = sch.neuronen.get(i);
			if (letztes.ausgang > wert) {
				wert = letztes.ausgang;
				indMax = i;
			}
		}
		System.out.println("Wahrsch. Muster " + indMax + " mit Ist=" + wert);

		return indMax;
	}

}

class Daten {
	public int von;
	public int bis;
	public Schicht schicht;
	public Schicht vorgaenger;
}

// rechnet die Ausgabe eine Folge von Knoten einer Schicht
class RechnenAusgabe extends Thread {
	// private final static Object mutex = new Object();
	private Vector<Daten> eingabe;
	private Vector<Daten> ausgabe;
	private Thread haupt;
	public boolean fertig = false;

	public RechnenAusgabe(Vector<Daten> e, Vector<Daten> a, Thread h) {
		eingabe = e;
		ausgabe = a;
		haupt = h;
	}

	@Override
	public void run() {
		while (true) {
			if (fertig)
				return;
			Daten aktuell = null;
			synchronized (eingabe) {
				if (!eingabe.isEmpty()) {
					aktuell = eingabe.firstElement();
					eingabe.remove(0);
				}
			}
			if (aktuell == null) {
				try {
					Thread.sleep(100);
				} catch (InterruptedException ex) {
				}
			} else {
				int von = aktuell.von;
				int bis = aktuell.bis;
				Schicht schicht = aktuell.schicht;
				Schicht vorgaengerSchicht = aktuell.vorgaenger;
				for (int n = von; n <= bis; ++n) {
					Neuron neuron = schicht.neuronen.get(n);
					if (vorgaengerSchicht != null) {
						for (int vn = 0; vn < vorgaengerSchicht.anzNeuronen; ++vn) {
							Neuron vorgNeuron = vorgaengerSchicht.neuronen
									.get(vn);
							neuron.eingang[vn] = vorgNeuron.ausgang;
						}
					}
					neuron.berechneAusgang();
				}
				ausgabe.add(aktuell);
				haupt.interrupt();
			}
		}
	}
}

// Eine Schicht enthält eine Liste von Neuronen und die Anzahl der Neuronen
// der Vorgängerschicht, welche dann als Eingabe dienen können.
class Schicht implements Serializable {
	public int anzNeuronen; // , anzNeuronenVorgSchicht;
	public List<Neuron> neuronen;

	public Schicht(int anzNeuronen, int anzNeuronenVorgSchicht) {
		this.anzNeuronen = anzNeuronen;
		neuronen = new ArrayList<Neuron>(anzNeuronen);
		for (int i = 0; i < anzNeuronen; ++i) {
			neuronen.add(new Neuron(anzNeuronenVorgSchicht));
		}
	}

}

// Ein Neuron enthält für jeden Eingang einen Wert und ein Gewicht
// es kann anhand dessen den einen Ausgangswerte berechnen (n:1 Eingang:Ausgang)
class Neuron implements Serializable {
	private int anzEingangsverbindungen;
	public double[] gewicht;
	public double[] gewichtAlt; // wg. Backpropagation
	public double[] eingang;
	public double gewichteteEingaenge, deltaBackpropagation;
	public double ausgang;

	public Neuron(int anzEingangsverbindungen) {
		this.anzEingangsverbindungen = anzEingangsverbindungen;
		if (anzEingangsverbindungen > 0) {
			gewicht = new double[anzEingangsverbindungen];
			gewichtAlt = new double[anzEingangsverbindungen];
			eingang = new double[anzEingangsverbindungen];
		} else {
			gewicht = null;
			gewichtAlt = null;
			eingang = new double[1];
		}
		deltaBackpropagation = 0.0;
	}

	public void berechneAusgang() {
		gewichteteEingaenge = 0.0;
		if (anzEingangsverbindungen > 0) {
			for (int i = 0; i < anzEingangsverbindungen; ++i) {
				gewichteteEingaenge += gewicht[i] * eingang[i];
			}
			gewichteteEingaenge /= anzEingangsverbindungen;
		} else
			gewichteteEingaenge = eingang[0];
		ausgang = funktion(gewichteteEingaenge);
	}

	static double funktion(double wert) {
		if (Netz.mitFunktion == Netz.Aktivierung.SIGMOID) {
			return 1.0 / (1.0 + Math.exp(-wert));
		} else if (Netz.mitFunktion == Netz.Aktivierung.TANH) {
			return (Math.exp(wert) - Math.exp(-wert))
					/ (Math.exp(wert) + Math.exp(-wert));
		} else { // identitaet
			return wert;
		}
	}

	static double ableitung(double wert) {
		if (Netz.mitFunktion == Netz.Aktivierung.SIGMOID) {
			return Math.exp(-wert)
					/ ((1.0 + Math.exp(-wert)) * (1.0 + Math.exp(-wert)));
		} else if (Netz.mitFunktion == Netz.Aktivierung.TANH) {
			return 1 - funktion(wert) * funktion(wert);
		} else {
			return 1;
		}
	}
}