package ShamirSecretSharing;

import org.bouncycastle.util.Arrays;
import org.factcenter.qilin.primitives.concrete.ECGroup;
import org.factcenter.qilin.util.Pair;

import java.math.BigInteger;

/**
 * Created by Tzlil on 1/27/2016.
 */
public class Polynomial implements Comparable<Polynomial> {
    protected static final Polynomial ZERO = new Polynomial(new BigInteger[]{BigInteger.ZERO}); // neutral for add
    protected static final Polynomial ONE = new Polynomial(new BigInteger[]{BigInteger.ONE});   // neutral for mul
    private final int degree;
    private final BigInteger[] coefficients;

    /**
     * constructor
     * @param coefficients
     *        degree set as max index such that coefficients[degree] not equals zero
     */
    public Polynomial(BigInteger[] coefficients) {
        int d = coefficients.length - 1;
        while (d >  0 && coefficients[d].equals(BigInteger.ZERO)){
            d--;
        }
        this.degree = d;
        this.coefficients = coefficients;
    }


    @Override
    public int compareTo(Polynomial other) {
        if (this.degree != other.degree)
            return this.degree - other.degree;
        int compare;
        for (int i = degree; i >= degree ; i--){
            compare = this.coefficients[i].compareTo(other.coefficients[i]);
            if (compare != 0){
                return compare;
            }
        }
        return 0;
    }

    @Override
    public String toString() {
        return "Polynomial{" +
                "degree=" + degree +
                ", coefficients=" + java.util.Arrays.toString(coefficients) +
                '}';
    }


    /**
     * @param x
     * @return sum of coefficients[i] * (x ^ i)
     */
    public BigInteger image(BigInteger x){
        BigInteger result = BigInteger.ZERO;
        BigInteger power = BigInteger.ONE;
        for(int i = 0 ; i <= degree ; i++){
            result = result.add(coefficients[i].multiply(power));
            power = power.multiply(x);
        }
        return result;
    }

    /**
     * @param points
     * @return polynomial of minimal degree which goes through all points
     */
    public static Polynomial interpolation(Point[] points){
        LagrangePolynomial[] l = LagrangePolynomial.lagrangePolynomials(points);

        // product = product of l[i].divisor
        BigInteger product = BigInteger.ONE;
        for (int i = 0; i < l.length;i++){
            product = product.multiply(l[i].divisor);
        }

        // factor[i] = product divided by l[i].divisor = product of l[j].divisor s.t j!=i
        BigInteger[] factors = new BigInteger[l.length];
        for (int i = 0; i < l.length;i++){
            factors[i] = product.divide(l[i].divisor);
        }
        int degree = l[0].polynomial.degree;

        // coefficients[j] = (sum of l[i].image * factor[i] * l[i].coefficients[j] s.t i!=j) divide by product =
        //                 =  sum of l[i].image * l[i].coefficients[j] / l[i].divisor s.t i!=j
        BigInteger[] coefficients = new BigInteger[degree + 1];
        for (int j = 0; j < coefficients.length;j++){
            coefficients[j] = BigInteger.ZERO;
            for (int i = 0; i < l.length; i++){
                coefficients[j] = coefficients[j].add(l[i].image.multiply(factors[i]).multiply(l[i].polynomial.coefficients[j]));
            }
            coefficients[j] = coefficients[j].divide(product);
        }
        return new Polynomial(coefficients);
    }

    /**
     * @param other
     * @return new Polynomial of degree max(this degree,other degree) s.t for all x in Z
     *         new.image(x) = this.image(x) + other.image(x)
     */
    public Polynomial add(Polynomial other){
        Polynomial bigger,smaller;
        if(this.degree < other.degree){
            bigger = other;
            smaller = this;
        }else{
            bigger = this;
            smaller = other;
        }
        BigInteger[] coefficients = bigger.getCoefficients();

        for (int i = 0; i <= smaller.degree ; i++){
            coefficients[i] = smaller.coefficients[i].add(bigger.coefficients[i]);
        }
        return  new Polynomial(coefficients);
    }

    /**
     * @param constant
     * @return new Polynomial of degree this.degree s.t for all x in Z
     *         new.image(x) = constant * this.image(x)
     */
    public Polynomial mul(BigInteger constant){

        BigInteger[] coefficients = this.getCoefficients();

        for (int i = 0; i <= this.degree ; i++){
            coefficients[i] = constant.multiply(coefficients[i]);
        }
        return  new Polynomial(coefficients);
    }

    /**
     * @param other
     * @return new Polynomial of degree this degree + other degree + 1 s.t for all x in Z
     *         new.image(x) = this.image(x) * other.image(x)
     */
    public Polynomial mul(Polynomial other){

        BigInteger[] coefficients = new BigInteger[this.degree + other.degree + 1];
        java.util.Arrays.fill(coefficients,BigInteger.ZERO);

        for (int i = 0; i <= this.degree ; i++){
            for (int j = 0; j <= other.degree; j++){
                coefficients[i+j] = coefficients[i+j].add(this.coefficients[i].multiply(other.coefficients[j]));
            }
        }
        return  new Polynomial(coefficients);
    }


    /** getter
     * @return copy of coefficients
     */
    public BigInteger[] getCoefficients() {
        return Arrays.clone(coefficients);
    }

    /** getter
     * @return degree
     */
    public int getDegree() {
        return degree;
    }

    /**
     * inner class
     * container for (x,y) x from range and y from image of polynomial
     */
    public static class Point{
        public final BigInteger x;
        public final BigInteger y;

        /**
         * constructor
         * @param x
         * @param polynomial y = polynomial.image(x)
         */
        public Point(BigInteger x, Polynomial polynomial) {
            this.x = x;
            this.y = polynomial.image(x);
        }

        /**
         * copy constructor
         * @param point
         */
        public Point(Point point) {
            this.x = point.x;
            this.y = point.y;
        }
    }

}