package prover;

import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import meerkat.crypto.concrete.ECElGamalEncryption;
import meerkat.crypto.mixnet.Mix2ZeroKnowledgeProver;
import meerkat.crypto.mixnet.Mix2ZeroKnowledgeVerifier;
import meerkat.protobuf.ConcreteCrypto;
import meerkat.protobuf.ConcreteCrypto.ElGamalCiphertext;
import meerkat.protobuf.Crypto;
import meerkat.protobuf.Mixing;
import org.bouncycastle.math.ec.ECPoint;
import qilin.primitives.RandomOracle;
import qilin.primitives.concrete.ECGroup;

import java.math.BigInteger;
import java.util.List;
import java.util.Random;

public class Prover implements Mix2ZeroKnowledgeProver {

    ECGroup group;
    RandomOracle randomOracle;
    Random rand;
    ECElGamalEncryption ecElGamalEncryption;
    ECPoint g,h;

    public Prover(Random rand,ECElGamalEncryption encryptor,RandomOracle randomOracle) {

        this.rand = rand;
        this.ecElGamalEncryption = encryptor;
        this.randomOracle = randomOracle;
        this.group = ecElGamalEncryption.getGroup();
        this.g = group.getGenerator();
        this.h = ecElGamalEncryption.getElGamalPK().getPK();
    }



    public Mixing.ZeroKnowledgeProof prove(Crypto.RerandomizableEncryptedMessage in1,
                                           Crypto.RerandomizableEncryptedMessage in2,
                                           Crypto.RerandomizableEncryptedMessage out1,
                                           Crypto.RerandomizableEncryptedMessage out2,
                                           boolean sw,int i,int j, int layer,
                                           Crypto.EncryptionRandomness r1,
                                           Crypto.EncryptionRandomness r2) throws InvalidProtocolBufferException {

        Mixing.ZeroKnowledgeProof.OrProof first,second,third,fourth;
        ProofOrganizer organizer = new ProofOrganizer(in1,in2,out1,out2,r1,r2,sw);

        System.out.println("first");
        first = createOrProof(organizer.getOrProofInput(ProofOrganizer.OrProofOrder.first));
        System.out.println("second");
        second = createOrProof(organizer.getOrProofInput(ProofOrganizer.OrProofOrder.second));
        System.out.println("third");
        third = createOrProof(organizer.getOrProofInput(ProofOrganizer.OrProofOrder.third));
        System.out.println("fourth");
        fourth = createOrProof(organizer.getOrProofInput(ProofOrganizer.OrProofOrder.fourth));

        Mixing.ZeroKnowledgeProof.Location location = Mixing.ZeroKnowledgeProof.Location.newBuilder()
                .setI(i)
                .setJ(j)
                .setLayer(layer)
                .build();

        Mixing.ZeroKnowledgeProof result = Mixing.ZeroKnowledgeProof.newBuilder()
                .setFirst(first)
                .setSecond(second)
                .setThird(third)
                .setFourth(fourth)
                .setLocation(location)
                .build();

        return  result;
    }

    private Mixing.ZeroKnowledgeProof.OrProof createOrProof(ProofOrganizer.OrProofInput orProofInput)
            throws InvalidProtocolBufferException {

        ElGamalCiphertext e1ElGamal = ECElGamalEncryption.rerandomizableEncryptedMessage2ElGamalCiphertext(orProofInput.e1);
        ElGamalCiphertext e2ElGamal = ECElGamalEncryption.rerandomizableEncryptedMessage2ElGamalCiphertext(orProofInput.e2);
        ElGamalCiphertext e1TagElGamal = ECElGamalEncryption.rerandomizableEncryptedMessage2ElGamalCiphertext(orProofInput.e1New);
        ElGamalCiphertext e2TagElGamal = ECElGamalEncryption.rerandomizableEncryptedMessage2ElGamalCiphertext(orProofInput.e2New);

        return createOrProofElGamal(e1ElGamal,e2ElGamal,e1TagElGamal,e2TagElGamal,orProofInput.x,orProofInput.flag);
    }

    
    private ECPoint convert2ECPoint(ByteString bs){
        return group.decode(bs.toByteArray());
    }

    public BigInteger hash(Mixing.ZeroKnowledgeProof.OrProof.ForRandomOracle input) {
        byte[] arr = input.toByteArray();
        return new BigInteger(this.randomOracle.hash(arr,arr.length));
    }


    private Mixing.ZeroKnowledgeProof.OrProof createOrProofElGamal(ElGamalCiphertext e1,
                                                                   ElGamalCiphertext e2,
                                                                   ElGamalCiphertext e1New,
                                                                   ElGamalCiphertext e2New,
                                                                   Crypto.EncryptionRandomness x,
                                                                   ProofOrganizer.TrueCouple flag) {

        ECPoint g1 = g;
        ECPoint h1 = group.add(convert2ECPoint(e1New.getC1()),group.negate(convert2ECPoint(e1.getC1())));
        ECPoint g2 = h;
        ECPoint h2 = group.add(convert2ECPoint(e1New.getC2()),group.negate(convert2ECPoint(e1.getC2())));

        ECPoint g1Tag = g;
        ECPoint h1Tag = group.add(convert2ECPoint(e2New.getC1()),group.negate(convert2ECPoint(e2.getC1())));
        ECPoint g2Tag = h;
        ECPoint h2Tag = group.add(convert2ECPoint(e2New.getC2()),group.negate(convert2ECPoint(e2.getC2())));

        BigInteger r = new BigInteger(ecElGamalEncryption.generateRandomness(rand).getData().toByteArray()).mod(group.orderUpperBound());
        BigInteger c1,c2,z,zTag;
        ECPoint u,v,uTag,vTag;
        Mixing.ZeroKnowledgeProof.OrProof.ForRandomOracle forRandomOracle;
        switch (flag) {
            case left:
                c2 = new BigInteger(ecElGamalEncryption.generateRandomness(rand).getData().toByteArray()).mod(group.orderUpperBound());
                zTag = new BigInteger(ecElGamalEncryption.generateRandomness(rand).getData().toByteArray()).mod(group.orderUpperBound());
                //step 1
                u = group.multiply(g1, r);
                v = group.multiply(g2, r);
                uTag = group.add(group.multiply(g1Tag, zTag), group.negate(group.multiply(h1Tag, c2)));
                vTag = group.add(group.multiply(g2Tag, zTag), group.negate(group.multiply(h2Tag, c2)));
                //step 2
                // c1 = (hash(input + step1) + group size - c2)% group size
                forRandomOracle =
                        Mixing.ZeroKnowledgeProof.OrProof.ForRandomOracle.newBuilder()
                                .setG1(ByteString.copyFrom(group.encode(g1)))
                                .setH1(ByteString.copyFrom(group.encode(h1)))
                                .setG2(ByteString.copyFrom(group.encode(g2)))
                                .setH2(ByteString.copyFrom(group.encode(h2)))
                                .setG1Tag(ByteString.copyFrom(group.encode(g1Tag)))
                                .setH1Tag(ByteString.copyFrom(group.encode(h1Tag)))
                                .setG2Tag(ByteString.copyFrom(group.encode(g2Tag)))
                                .setH2Tag(ByteString.copyFrom(group.encode(h2Tag)))
                                .setU(ByteString.copyFrom(group.encode(u)))
                                .setV(ByteString.copyFrom(group.encode(v)))
                                .setUTag(ByteString.copyFrom(group.encode(uTag)))
                                .setVTag(ByteString.copyFrom(group.encode(vTag)))
                                .build();
                c1 = hash(forRandomOracle).add(group.orderUpperBound().subtract(c2)).mod(group.orderUpperBound());
                //step 3
                //z = (r +  c1 * x) % group size;
                z = r.add(c1.multiply(new BigInteger(x.getData().toByteArray()))).mod(group.orderUpperBound());
                break;
            case right:
                c1 = new BigInteger(ecElGamalEncryption.generateRandomness(rand).getData().toByteArray()).mod(group.orderUpperBound());
                z = new BigInteger(ecElGamalEncryption.generateRandomness(rand).getData().toByteArray()).mod(group.orderUpperBound());
                //step 1
                uTag = group.multiply(g1Tag, r);
                vTag = group.multiply(g2Tag, r);
                u = group.add(group.multiply(g1, z), group.negate(group.multiply(h1, c1)));
                v = group.add(group.multiply(g2, z), group.negate(group.multiply(h2, c1)));
                //step 2
                // c1 = (hash(input + step1) + group size - c1)% group size
                forRandomOracle =
                        Mixing.ZeroKnowledgeProof.OrProof.ForRandomOracle.newBuilder()
                                .setG1(ByteString.copyFrom(group.encode(g1)))
                                .setH1(ByteString.copyFrom(group.encode(h1)))
                                .setG2(ByteString.copyFrom(group.encode(g2)))
                                .setH2(ByteString.copyFrom(group.encode(h2)))
                                .setG1Tag(ByteString.copyFrom(group.encode(g1Tag)))
                                .setH1Tag(ByteString.copyFrom(group.encode(h1Tag)))
                                .setG2Tag(ByteString.copyFrom(group.encode(g2Tag)))
                                .setH2Tag(ByteString.copyFrom(group.encode(h2Tag)))
                                .setU(ByteString.copyFrom(group.encode(u)))
                                .setV(ByteString.copyFrom(group.encode(v)))
                                .setUTag(ByteString.copyFrom(group.encode(uTag)))
                                .setVTag(ByteString.copyFrom(group.encode(vTag)))
                                .build();
                c2 = hash(forRandomOracle).add(group.orderUpperBound().subtract(c1)).mod(group.orderUpperBound());
                //step 3
                //zTag = (r +  c2 * x) % group size;
                zTag = r.add(c2.multiply(new BigInteger(x.getData().toByteArray()))).mod(group.orderUpperBound());
                break;
            default:
                return null;
        }

        //debugging
        assert (group.multiply(g1, z).equals(group.add(u, group.multiply(h1,c1))));
        assert (group.multiply(g2, z).equals(group.add(v, group.multiply(h2,c1))));
        assert (group.multiply(g1Tag, zTag).equals(group.add(uTag, group.multiply(h1Tag,c2))));
        assert (group.multiply(g2Tag, zTag).equals(group.add(vTag, group.multiply(h2Tag,c2))));


        return Mixing.ZeroKnowledgeProof.OrProof.newBuilder()
                .setG1(ByteString.copyFrom(group.encode(g1)))
                .setH1(ByteString.copyFrom(group.encode(h1)))
                .setG2(ByteString.copyFrom(group.encode(g2)))
                .setH2(ByteString.copyFrom(group.encode(h2)))
                .setG1Tag(ByteString.copyFrom(group.encode(g1Tag)))
                .setH1Tag(ByteString.copyFrom(group.encode(h1Tag)))
                .setG2Tag(ByteString.copyFrom(group.encode(g2Tag)))
                .setH2Tag(ByteString.copyFrom(group.encode(h2Tag)))
                .setU(ByteString.copyFrom(group.encode(u)))
                .setV(ByteString.copyFrom(group.encode(v)))
                .setUTag(ByteString.copyFrom(group.encode(uTag)))
                .setVTag(ByteString.copyFrom(group.encode(vTag)))
                .setC1(ByteString.copyFrom(c1.toByteArray()))
                .setC2(ByteString.copyFrom(c2.toByteArray()))
                .setZ(ByteString.copyFrom(z.toByteArray()))
                .setZTag(ByteString.copyFrom(zTag.toByteArray()))
                .build();
    }


}