// ECM_Basis - Basisklasse zur Arbeit mit elliptischen Kurven
// y^2 = x^3 + a*x + b
// Diskriminante 4*a^3 + 27*b^2 != 0, sonst singulär
// Der Anteil der Ordnungen N im Intervall (n+1-sqrt(n),n+1+sqrt(n)) aus Primfaktoren < y
// besteht, ist u^-u mit u = log(n) / log(y)

import java.math.*;
import java.util.*;

public class ECM_Basis
{
  protected BigInteger a, b;  // über Konstruktor setzen: Parameter der Kurve
  public static BigInteger n;  // von aussen: zu faktorisierende Zahl
  public static int rest8;  // f. Speziallfall
  protected final static BigInteger ZWEI = BigInteger.valueOf(2);
  protected final static BigInteger DREI = BigInteger.valueOf(3);
  protected final static BigInteger ZAHL27 = BigInteger.valueOf(27);
  protected final static int GRENZE_MAX = 100000;  // alle Zahlen bis zu dieser Größe
  protected static boolean GGTAUSGABE = true;  // weil Zielinfo
  protected static double grenz_faktor = 1.0;  // initial
  protected static double faktor_schritt = 0.1;

  protected static boolean DEBUG = false;
  protected static boolean DEBUGZWERT = false;  // Ausgabe Zwischenwerte
  protected static boolean DEBUGPLUS = false;
  protected BigInteger ergGGT = null;  // zu setzen des letzten ggT


  // Konstrukturen
  public ECM_Basis (BigInteger para_a, BigInteger para_b) throws Logik_Fehler
  {
    // Singuläre Kurven z.B. x^3 - 3x + 2 mit (2|2) sind zulässig und haben relativ zu anderen
    // Parametern oft kleinere Ordnungen. Auch b == 0 haben kleinere Ordnungen als b!= 0, jedoch
    // kann man sich n i c h t auf diese allein beschränken, weil abhängig von den Teilern von n+1.
    a = para_a; b = para_b;
  }

  public ECM_Basis (int para_a, int para_b) throws Logik_Fehler
  {
    this(BigInteger.valueOf(para_a), BigInteger.valueOf(para_b));
  }


  // Funktionen zum Addieren, Multiplizieren und Gruppenordnung ...


  // addieren zweier Punkte - konnte ein Wert ermittelt werden, wird dieser zurückgegeben.
  // Bei einem ggT wird dieser gespeichert und eine ArithmeticException geworfen
  public final Paar_ECM addieren (Paar_ECM p1, Paar_ECM p2) throws Logik_Fehler, ArithmeticException
  {
    // Sonderfälle im Zusammenhang mit neutralem Element
    if (p1.istNeutral())
      return p2;
    else if (p2.istNeutral())
      return p1;

    // Wegen der darauffolgenden Vergleiche ist es wichtig, dass die BigInts immer positiv oder 0 sind
    if (DEBUG)
    {
      if (p1.x.signum() < 0 || p1.y.signum() < 0 || p2.x.signum() < 0 || p2.y.signum() < 0)
      {
        System.out.println ("Negative Eingabeparameter bei addieren"); System.exit(1);
      }
      if (p1.x.compareTo(n) >= 0 || p1.y.compareTo(n) >= 0 || p2.x.compareTo(n) >= 0 || p2.y.compareTo(n) >= 0)
      {
        System.out.println ("Eingabeparameter >= n addieren"); System.exit(1);
      }
    }

    // Lambda-bestimmen (sollte ungleich 1 sein)
    BigInteger zaehler = null, nenner = null; boolean gleich = false;
    if (p1.x.equals(p2.x))
    {
      // gleich oder additiv invers; y müssen bei invers gerade und ungerade sein, wenn n ungerade
      if (p1.y.equals(p2.y))
      {
        // Sonderfall: y ist 0, laut Prüfungen oben nicht neutral !
        if (p1.y.equals(BigInteger.ZERO))
        {
          // auch (-x, 0) ist eine Lösung; aber schlecht weitermachen ! x ist Wurzel von -a !
          throw new Logik_Fehler("Punkt " + p1 + " nicht mal 2");
        }
        // lambda = (3*x1^2 + a) / (2 * y1)
        BigInteger temp = p1.x.modPow(ZWEI, n);
        zaehler = temp.shiftLeft(1).add(temp).add(a);  // remainder hier nicht; 2.9% CPU
        nenner = p1.y /* .shiftLeft(1) s.u. */; gleich = true;
      }
      else if ((p1.y.testBit(0) != p2.y.testBit(0)) && p1.y.equals(n.subtract(p2.y)))  // 1.2% CPU
      {
        return Paar_ECM.NEUTRAL;
      }
      else
      {
        // x gleich, y verschieden: es ist eine Wurzel ziehbar !
        for (short runde = 0; runde < 2; ++runde)
        {
          BigInteger vers = ((runde == 0) ? p1.y.subtract(p2.y) : p1.y.add(p2.y));
          BigInteger test = vers.gcd(n);
          if (! test.equals(BigInteger.ONE)) {
            System.out.println ("Teiler " + test); ergGGT = test;
            throw new ArithmeticException ("ggT2 ist " + ergGGT);
          }
        }
        throw new Logik_Fehler ("kein ggT gefunden");
      }
    }
    else
    {
      // y gleich, x verschieden nicht gesondert betrachtet(lambda=0)
      // lambda = (y2 - y1) / (x2 - x1)
      zaehler = p2.y.subtract(p1.y);
      nenner = p2.x.subtract(p1.x);
    }
    // Jetzt lambda ermitteln: bei nenner-ggT > 1 hier Fehler abfangen
    BigInteger inv = null; ergGGT = null;
    try
    {
      // Für Performanz: gemeinsame 2er raus
      int weg = zaehler.getLowestSetBit(); int weg2 = nenner.getLowestSetBit(); if (gleich) ++weg2;
      if (weg2 < weg) weg = weg2;
      zaehler = zaehler.shiftRight(weg);
      if (gleich && weg == 0) nenner = nenner.shiftLeft(1); /* doch dazu */
      else if (gleich) --weg;
      nenner = nenner.shiftRight(weg);

      inv = nenner.modInverse(n);  // 34.2% CPU
    }
    catch (ArithmeticException ex)
    {
      ergGGT = nenner.gcd(n); if (GGTAUSGABE) System.out.println("ggT3 ist " + ergGGT);
      throw ex;
    }
    BigInteger lambda = zaehler.multiply(inv).remainder(n);  // 8.8 % CPU
    if (lambda.signum() < 0) lambda = lambda.add(n);

    Paar_ECM rueck = new Paar_ECM(/*egal*/0,0);
    // x3 = lm^2 - x1 - x2
    rueck.x = lambda.modPow(ZWEI, n);  // 13.05% CPU
    if (rueck.x.signum() < 0) rueck.x = n.add(rueck.x);
    rueck.x = rueck.x.subtract(p1.x);  // 1.66 % CPU
    if (rueck.x.signum() < 0) rueck.x = n.add(rueck.x);
    rueck.x = rueck.x.subtract(p2.x);  // remainder ist teuer, vermeiden
    // y3 = -y1 + lm * (x1 - x3)
    rueck.y = lambda.multiply(p1.x.subtract(rueck.x)).remainder(n);  // 10.1% CPU
    if (rueck.y.signum() < 0) rueck.y = n.add(rueck.y);
    rueck.y = rueck.y.subtract(p1.y);  // remainder teuer; 1.6 % CPU

    // Arbeite mit positiven Werten !
    if (rueck.x.signum() < 0) rueck.x = n.add(rueck.x);
    if (rueck.y.signum() < 0) rueck.y = n.add(rueck.y);

    // Test:
    if (DEBUGPLUS)
    {
      teste_punkt(rueck);
    }

    return rueck;
  }


  // Bestimmt die Ordnung  b == 0: |ord - (n+1)/2| < sqrt(n)
  // Für n gilt: n = a^2 + b^2, dann
  // n prim: Ordnung = n + 1 +- 2 * a
  public final BigInteger holeOrdnung (Paar_ECM p) throws Logik_Fehler
  {
    if (! n.isProbablePrime(3)) { System.out.println ("Ordnung nicht best. da keine Primzahl"); return null; } // keine Primzahl

    teste_punkt(p); boolean DEBUGJAC=false; Paar_ECM erg = p;
    if (DEBUGZWERT)
    {
      System.out.print(p + "\t"); // debug
      System.out.println (punkt_drei_loesung(p));  // wahr --> drei Lösungen
    }

    int quadrat = 0, nichtQuadrat = 0; long exponent;
    for (exponent = 2; exponent < Long.MAX_VALUE; ++exponent)
    {
      Paar_ECM weiter = null;
      try
      {
        weiter = addieren(erg, p);  // erstmal einen Schritt, im catch dann ggf. zweiten
        if (DEBUGZWERT) System.out.print(weiter + "\t"); // debug
      }
      catch (ArithmeticException ex)
      {
         System.out.println ("Ordnung nicht bestimmbar"); return null;
      }
      if (weiter.istNeutral()) { /* System.out.println ("Abbruch mit neutralem Element");*/ break; }
      if (DEBUGZWERT) System.out.println (punkt_drei_loesung(weiter));  // wahr --> drei Lösungen
      if (DEBUGJAC)
      {
        int jac = PrimGenerator.jacobiSymbol(weiter.y, n);
        if (jac == 1) ++quadrat; else if (jac == -1) ++nichtQuadrat;
        System.out.println (weiter + " und " + jac);
      }

      // Mittelelement erkannt ?
      if (weiter.x.equals(erg.x)) // 1)
      {
        System.out.println("Fall 1");
        exponent = 2 * exponent - 1; break;
      }
      else if (weiter.y.equals(BigInteger.ZERO)) // 2)
      {
        System.out.println("Fall 2");
        exponent = 2 * exponent; break;
      }
      // folgendes geht nicht immer, wahrsch. bei b == 0
      //else if (weiter.y.modPow(ZWEI,n).equals (weiter.x.modPow(DREI,n).shiftLeft(1).remainder(n) ))  // 3) y^2 = 2*x^3, obwohl x nicht wiederholt
      //{
      //  System.out.println("Fall 3");
      //  exponent = 2 * exponent; break;
      //}
      // 4 Fall:
      erg = weiter;
    }

    return BigInteger.valueOf(exponent);
  }

  // für y^2 = F(x) und Primzahl p gilt: E(Fp) =) p + 1 + Summe(x=0; x < p) Jac(F(x) / p), also unabh. von spez. Punkt
  public final BigInteger holeOrdnung2 () throws Logik_Fehler
  {
    long summe = n.longValue() + 1;
    for (BigInteger lauf = BigInteger.ZERO; lauf.compareTo(n) < 0; lauf = lauf.add(BigInteger.ONE))
    {
      BigInteger f = lauf.multiply(lauf).multiply(lauf).remainder(n);
      f = f.add(lauf.multiply(a)).add(b).remainder(n);
      int wert = PrimGenerator.jacobiSymbol(f, n);
      summe += wert;
    }
    return BigInteger.valueOf(summe);
  }


  // Sucht eine Ordnung per giant-step-baby-step
  // p*(g * m + b) = 0, m ist Wurzel von einer Schranke, die polynomial sein sollte
  public final BigInteger sucheBabyStepGiantStep (Paar_ECM p, BigInteger maxExponent) throws Logik_Fehler
  {
    teste_punkt(p); Paar_ECM erg = p;

    // bestimmte Giant-Step
    Map<Paar_ECM, BigInteger> map_gst = new HashMap<Paar_ECM, BigInteger>();
    Paar_ECM gstb = multSkalar(maxExponent, p); Paar_ECM aktuell = gstb;
    if (gstb == null) { System.out.println("ggt bei GS1"); return null; } // ggT
    map_gst.put(gstb, BigInteger.ONE);
    for (BigInteger lauf = BigInteger.valueOf(2); lauf.compareTo(maxExponent) < 0; lauf  = lauf.add(BigInteger.ONE))
    {
      aktuell = addieren(aktuell, gstb); if (aktuell == null) { System.out.println("ggt bei GS2"); return null; } // ggT
      map_gst.put(aktuell, lauf);
    }

    // bestimme Baby-Step
    Paar_ECM bstb = new Paar_ECM(p.x, p.y); bstb.y = n.subtract(bstb.y); aktuell = bstb;
    for (BigInteger lauf = BigInteger.ONE; lauf.compareTo(maxExponent) < 0; lauf = lauf.add(BigInteger.ONE))
    {
      if (map_gst.containsKey(aktuell)) return map_gst.get(aktuell).multiply(maxExponent).add(lauf);  // Treffer
      aktuell = addieren(aktuell, bstb); if (aktuell == null) { System.out.println("ggt bei BS"); return null; } // ggT
    }

    return null;
  }

  // Multiplikation mit (pos.) Skalar. Ergebnis ist null (falls nicht komplett berechnet werden konnte)
  // oder Wert
  public final Paar_ECM multSkalar (BigInteger faktor, Paar_ECM p) throws Logik_Fehler
  {
    return multSkalar(faktor, p, false);
  }

  // nur ungerade: Quadrierungen am Ende nicht mehr machen
  public final Paar_ECM multSkalar (BigInteger faktor, Paar_ECM p, boolean nurUngerade) throws Logik_Fehler
  {
    if (faktor.signum() < 0) throw new Logik_Fehler ("Skalar ist nicht >= 0");
    if (p.istNeutral()) return p;
    teste_punkt(p);

    Paar_ECM rueck = Paar_ECM.NEUTRAL;
    final int durchl = faktor.bitLength(); int weg = faktor.getLowestSetBit();
    for (int lauf = 0; lauf < (durchl-weg); ++lauf)
    {
      if (lauf > 0)  // Quadrieren
      {
        try
        {
          p = addieren (p, p); // = verdoppeln; das neutrale Element kann hier nicht entstehen !
        }
        catch (ArithmeticException ex)
        {
          return null;
        }
      }

      if (faktor.testBit(lauf+weg))
      {
        try
        {
          rueck = addieren (rueck, p);  // das neutrale Element kann entstehen
        }
        catch (ArithmeticException ex)
        {
          // System.out.println ("Addieren nicht möglich");
          return null;
        } // Ende try-catch
        if (rueck.istNeutral())
        {
          BigInteger tmp = BigInteger.ZERO;
          for (int i = 0; i <= lauf; ++i) if (faktor.testBit(i+weg)) tmp = tmp.setBit(i);
          System.out.println ("Zwischenergebnis: Ordnung ist " + tmp);
        }
      }
    }

    if (!nurUngerade)
    {
      // Noch Quadrierungen machen. ein ggT sollte hier nicht mehr zu erzielen sein
      for (int i = 0; i < weg && !rueck.istNeutral(); ++i)
      {
        try
        {
          rueck = addieren (rueck, rueck); // = verdoppeln; das neutrale Element kann hier nicht entstehen !
        }
        catch (ArithmeticException ex)
        {
          return null;
        }
      }
    }

    return rueck;  // Wert ermittelbar
  }


  // Multiplikation mit (pos.) Skalar. Ergebnis ist null (falls nicht komplett berechnet werden konnte)
  // oder Wert
  public final Paar_ECM multSkalar2 (BigInteger faktor, Paar_ECM p) throws Logik_Fehler
  {
    return multSkalar2(faktor, p, false);
  }

  // nur ungerade: Quadrierungen am Ende nicht mehr machen
  public final Paar_ECM multSkalar2 (BigInteger faktor, Paar_ECM p, boolean nurUngerade) throws Logik_Fehler
  {
    if (faktor.signum() < 0) throw new Logik_Fehler ("Skalar ist nicht >= 0");
    if (p.istNeutral()) return p;
    teste_punkt(p);

    // x^9 = x^16 - x^7  ; x^15 = x^16 - x^1  => Muster 7 und 1 ermitteln

    Paar_ECM rueck = Paar_ECM.NEUTRAL;
    int durchl = faktor.bitLength(); int weg = faktor.getLowestSetBit(); BigInteger muster = faktor.shiftRight(weg); durchl -= weg;
    BigInteger drueber = BigInteger.ONE.shiftLeft(durchl); muster = drueber.subtract(muster);
    p.y = n.subtract(p.y);
    //System.out.println("Faktor " + faktor + ", drüber " + drueber + ", Muster " + muster);

    for (int lauf = 0; lauf < durchl; ++lauf)
    {
      if (muster.testBit(lauf))
      {
        try
        {
          rueck = addieren (rueck, p);  // das neutrale Element kann entstehen
        }
        catch (ArithmeticException ex)
        {
          // System.out.println ("Addieren nicht möglich");
          return null;
        } // Ende try-catch
        //if (rueck.istNeutral())
        //{
        //  BigInteger tmp = BigInteger.ZERO;
        //  for (int i = 0; i <= lauf; ++i) if (faktor.testBit(i+weg)) tmp = tmp.setBit(i);
        //  System.out.println ("Zwischenergebnis: Ordnung ist " + tmp);
        //}
      }

      try
      {
        p = addieren (p, p); // = verdoppeln; das neutrale Element kann hier nicht entstehen !
      }
      catch (ArithmeticException ex)
      {
        return null;
      }
    }
    p.y = n.subtract(p.y); // obiges wieder zurück
    // x^9 = x^16 - x^7  ; x^15 = x^16 - x^1  => auf 15 und 9 gehen
    try
    {
      rueck = addieren(rueck, p);
    }
    catch (ArithmeticException ex)
    {
      return null;
    }

    if (!nurUngerade)
    {
      // Noch Quadrierungen machen. ein ggT sollte hier nicht mehr zu erzielen sein
      for (int i = 0; i < weg && !rueck.istNeutral(); ++i)
      {
        try
        {
          rueck = addieren (rueck, rueck); // = verdoppeln; das neutrale Element kann hier nicht entstehen !
        }
        catch (ArithmeticException ex)
        {
          return null;
        }
      }
    }

    return rueck;  // Wert ermittelbar
  }


  // Testet, ob der Punkt auf der Kurve liegt, und gibt ggf. Fehlermeldung aus
  protected void teste_punkt (Paar_ECM p) throws Logik_Fehler
  {
    if (p.istNeutral()) return;  // neutrales Element liegt immer auf der Kurve

    BigInteger vergl = (p.x.modPow(DREI, n)).add(a.multiply(p.x)).add(b).remainder(n);
    if (vergl.signum() < 0) vergl = vergl.add(n);
    BigInteger poty = p.y.modPow(ZWEI, n);

    if (! vergl.equals(poty)) throw new Logik_Fehler(p + " liegt nicht auf der Kurve");
  }

  // Testet, ob es zu y des Punktes drei Lösungen f. x gibt
  protected boolean punkt_drei_loesung(Paar_ECM p) throws Logik_Fehler
  {
    teste_punkt(p);  // muss drauf sein
    // x^3 + a*x + (b - y^2) = 0 wie bei Horner abdividieren
    // 1 0 a (b-y^2) --> 1 x (x^2+a) ((x^2+a)*x + (b-y^2)), letztes soll 0 sein
    BigInteger rechts = b.subtract(p.y.modPow(ZWEI,n));
    //BigInteger k1 = BigInteger.ONE;
    BigInteger k2 = p.x;
    BigInteger k3 = p.x.modPow(ZWEI,n).add(a);
    //BigInteger k4 = k3.multiply(p.x).add(rechts).remainder(n);
    //if (!k4.equals(BigInteger.ZERO)) { System.out.println("Fehler bei abdividieren"); System.exit(1); }
    // jetzt hat man quadratische Gleichung mit k1, k2, k3
    BigInteger diskr = k2.modPow(ZWEI,n).subtract(k3.shiftLeft(2)).remainder(n);
    // System.out.println("Diskriminante Punkt " + diskr);
    return (PrimGenerator.jacobiSymbol(diskr, n) == +1);
  }

  // Berechnet das kgV bis zur Zahl ende. Wird für die Schranke B des Skalarfaktors gebraucht
  // start ist die BigInteger-Starzahl, also initial 1, start_grenzz und ende_grenzz
  // sind die Laufwerte, die dazukommen
  public final static BigInteger kgv_bis(BigInteger initialzahl, int start_grenzz, int ende_grenzz)
  {
    if ((initialzahl == null) || (start_grenzz < 1) || (ende_grenzz < start_grenzz))
    {
      System.out.println("kgv_bis: ungueltige Eingabeparameter"); System.exit(1);
    }
    // kgV aller ungerader Zahlen zwischen den Parametern
    BigInteger rueck = initialzahl; // keine Prüfung, dass konsistent!
    start_grenzz |= 1; ende_grenzz |= 1;
    BigInteger aktuell = BigInteger.valueOf(start_grenzz);
    for (;start_grenzz <= ende_grenzz; start_grenzz += 2)
    {
      BigInteger test = rueck.gcd(aktuell);
      rueck = rueck.multiply(aktuell);
      if (! test.equals(BigInteger.ONE))
      {
        rueck = rueck.divide(test);
      }
      aktuell = aktuell.add(ZWEI);
    }

    // noch die Potenzen der 2 dazu
    if (!rueck.testBit(0)) rueck = rueck.shiftRight(rueck.getLowestSetBit()); // ab zweitem Aufruf
    int zusatz = (int) (Math.log(ende_grenzz) / Math.log(2));
    rueck = rueck.shiftLeft(zusatz);

    // System.out.println ("Schranke hat Bitgröße " + rueck.bitLength());

    return rueck;
  }


  // schätzt ab, bis zu welcher Zahlengrenze y gegangen werden muss (in Lit. Schranke B genannt)
  // Formel wächst exponentiell: u = ln(n) / ln(y), p = u ^ -u  --> wähle festen wert
  // p=0.01  --> u == 3.6;  p=0.10  --> u == 2.51; p=0.25  --> u == 2
  // d.h. die Bitänge von y ist ein konstanter Bruchteil der Bitlänge von n.
  // ist modus nicht in [0;2], wird die Grenze polynomiell vom log berechnet
  public final static int grenze_berechnen(short modus)
  {
    if (n == null) return -1;  // noch nicht initialisiert

    double logn = (n.bitLength()- 0.5) * Math.log(2);  // nat. Logar.

    double u, y;
    switch (modus)
    {
      case 0: u = 3.6;   // 1% Wahrscheinlichkeit
      case 1: u = 2.51;  // 10% Wahrscheinlichkeit
      case 2: u = 2;     // 25% Wahrscheinlichkeit
        y = Math.exp(logn / u);  // y exponentiell mit n ansteigend
        break;
      case 3:
        // ist p der kleinste Teiler, ist B ca exp(Wurzel(0.5 * log p * log(log p)))
        double logn2 = (n.bitLength() / 2) * Math.log(2);  // nat. Logar.
        y = (logn2 > 1.0 ? Math.exp(Math.sqrt(0.5 * logn2 * Math.log(logn2))) : 3);
        break;
      default:
        y = (int) (grenz_faktor * logn * logn);    // polynomial ansteigend
    }

    if (y > (int) GRENZE_MAX) y = GRENZE_MAX;  // Bremse, falls exponentiell zunimmt

    System.out.println ("Zahlengrenze ist " + y);

    return (int) y;
  }



}


class Logik_Fehler extends Exception
{
  public Logik_Fehler(String wert) { super(wert); }
}
