package com.oblivm.backend.gc;

import java.io.IOException;

import com.oblivm.backend.flexsc.CompEnv;
import com.oblivm.backend.flexsc.Flag;
import com.oblivm.backend.flexsc.Mode;
import com.oblivm.backend.flexsc.Party;
import com.oblivm.backend.network.Network;
import com.oblivm.backend.ot.FakeOTSender;
import com.oblivm.backend.ot.OTExtSender;
import com.oblivm.backend.ot.OTPreprocessSender;
import com.oblivm.backend.ot.OTSender;

public abstract class GCGenComp extends GCCompEnv {

	static public GCSignal R = null;

	static {
		R = GCSignal.freshLabel(CompEnv.rnd);
		R.setLSB();
	}

	OTSender snd;
	protected long gid = 0;

	public GCGenComp(Network channel, Mode mode) {
		super(channel, Party.Alice, mode);

		if (channel.sender == null && channel.receiver == null) {
			if (Flag.FakeOT)
				snd = new FakeOTSender(80, channel);
			else if (Flag.ProprocessOT)
				snd = new OTPreprocessSender(80, channel);
			else
				snd = new OTExtSender(80, channel);
		}
	}

	public static GCSignal[] genPairForLabel(Mode mode) {
		GCSignal[] label = new GCSignal[2];
		if (mode != Mode.OFFLINE || !Flag.offline)
			label[0] = GCSignal.freshLabel(rnd);
		if (mode == Mode.OFFLINE) {
			if (Flag.offline) {
				label[0] = new GCSignal(com.oblivm.backend.gc.offline.GCGen.fread.read(10));
			} else
				label[0].send(com.oblivm.backend.gc.offline.GCGen.fout);
		}
		label[1] = R.xor(label[0]);
		return label;
	}

	public static GCSignal[] genPair() {
		GCSignal[] label = new GCSignal[2];
		label[0] = GCSignal.freshLabel(rnd);
		label[1] = R.xor(label[0]);
		return label;
	}

	public GCSignal inputOfAlice(boolean in) {
		Flag.sw.startOT();
		GCSignal[] label = genPairForLabel(mode);
		Flag.sw.startOTIO();
		label[in ? 1 : 0].send(channel);
		flush();
		Flag.sw.stopOTIO();
		Flag.sw.stopOT();
		return label[0];
	}

	public GCSignal inputOfBob(boolean in) {
		Flag.sw.startOT();
		GCSignal[] label = genPairForLabel(mode);
		try {
			snd.send(label);
		} catch (IOException e) {
			e.printStackTrace();
		}
		Flag.sw.stopOT();
		return label[0];
	}

	public GCSignal[] inputOfAlice(boolean[] x) {
		Flag.sw.startOT();
		GCSignal[][] pairs = new GCSignal[x.length][2];
		GCSignal[] result = new GCSignal[x.length];
		for (int i = 0; i < x.length; ++i) {
			pairs[i] = genPairForLabel(mode);
			result[i] = pairs[i][0];
		}
		Flag.sw.startOTIO();
		for (int i = 0; i < x.length; ++i)
			pairs[i][x[i] ? 1 : 0].send(channel);
		flush();
		Flag.sw.stopOTIO();
		Flag.sw.stopOT();
		return result;
	}

	public GCSignal[] inputOfBob(boolean[] x) {
		Flag.sw.startOT();
		GCSignal[][] pair = new GCSignal[x.length][2];
		for (int i = 0; i < x.length; ++i)
			pair[i] = genPairForLabel(mode);
		try {
			snd.send(pair);
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		GCSignal[] result = new GCSignal[x.length];
		for (int i = 0; i < x.length; ++i)
			result[i] = pair[i][0];
		Flag.sw.stopOT();
		return result;
	}

	protected boolean gatesRemain = false;

	public boolean outputToAlice(GCSignal out) {

		if (gatesRemain) {
			gatesRemain = false;
			flush();
		}

		if (out.isPublic())
			return out.v;

		GCSignal lb = GCSignal.receive(channel);

		if (lb.equals(out))
			return false;
		else if (lb.equals(R.xor(out)))
			return true;

		try {
			throw new Exception("bad label at final output.");
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		return false;
	}

	public boolean outputToBob(GCSignal out) {
		if (!out.isPublic())
			out.send(channel);
		return false;
	}

	public boolean[] outputToBob(GCSignal[] out) {
		boolean[] result = new boolean[out.length];

		for (int i = 0; i < result.length; ++i) {
			if (!out[i].isPublic())
				out[i].send(channel);
		}
		flush();

		for (int i = 0; i < result.length; ++i)
			result[i] = false;
		return result;
	}

	public boolean[] outputToAlice(GCSignal[] out) {
		boolean[] result = new boolean[out.length];
		for (int i = 0; i < result.length; ++i) {
			result[i] = outputToAlice(out[i]);
		}
		return result;
	}

	public GCSignal xor(GCSignal a, GCSignal b) {
		if (a.isPublic() && b.isPublic())
			return new GCSignal(a.v ^ b.v);
		else if (a.isPublic())
			return a.v ? not(b) : new GCSignal(b);
		else if (b.isPublic())
			return b.v ? not(a) : new GCSignal(a);
		else {
			return a.xor(b);
		}
	}

	public GCSignal not(GCSignal a) {
		if (a.isPublic())
			return new GCSignal(!a.v);
		else
			return R.xor(a);
	}
}