瀏覽代碼

porting circuit oram's example to our 3-party setting; implementing coram3p circuit algorithms

Boyoung- 9 年之前
父節點
當前提交
3f5c30e927

+ 8 - 0
ObliVMGC/com/oblivm/backend/circuits/arithmetic/IntegerLib.java

@@ -117,6 +117,14 @@ public class IntegerLib<T> extends CircuitLib<T> implements ArithmeticLib<T> {
 		return geq(y, x);
 	}
 
+	public T greater(T[] x, T[] y) {
+		return not(leq(x, y));
+	}
+
+	public T less(T[] x, T[] y) {
+		return not(geq(x, y));
+	}
+
 	public T[] multiply(T[] x, T[] y) {
 		return Arrays.copyOf(multiplyInternal(x, y), x.length);// res;
 	}

+ 1 - 1
ObliVMGC/com/oblivm/backend/example/HammingDistance.java

@@ -84,7 +84,7 @@ public class HammingDistance {
 
 		private T[] compute(CompEnv<T> gen, T[] inputA, T[] inputB) {
 			IntegerLib<T> il = new IntegerLib<T>(gen);
-			il.hammingDistance(inputA, inputB);
+			il.hammingDistance(inputB, inputB);
 			gen.setEvaluate();
 			return il.hammingDistance(inputA, inputB);
 		}

+ 8 - 6
ObliVMGC/com/oblivm/backend/gc/GCGenComp.java

@@ -27,12 +27,14 @@ public abstract class GCGenComp extends GCCompEnv {
 	public GCGenComp(Network channel, Mode mode) {
 		super(channel, Party.Alice, mode);
 
-		if (Flag.FakeOT)
-			snd = new FakeOTSender(80, channel);
-		else if (Flag.ProprocessOT)
-			snd = new OTPreprocessSender(80, channel);
-		else
-			snd = new OTExtSender(80, channel);
+		if (channel.sender == null && channel.receiver == null) {
+			if (Flag.FakeOT)
+				snd = new FakeOTSender(80, channel);
+			else if (Flag.ProprocessOT)
+				snd = new OTPreprocessSender(80, channel);
+			else
+				snd = new OTExtSender(80, channel);
+		}
 	}
 
 	public static GCSignal[] genPairForLabel(Mode mode) {

+ 17 - 5
ObliVMGC/com/oblivm/backend/gc/GCSignal.java

@@ -5,12 +5,18 @@ package com.oblivm.backend.gc;
 
 import java.io.IOException;
 import java.io.OutputStream;
+import java.io.Serializable;
 import java.security.SecureRandom;
 import java.util.Arrays;
 
 import com.oblivm.backend.network.Network;
 
-public class GCSignal {
+public class GCSignal implements Serializable {
+	/**
+	 * 
+	 */
+	private static final long serialVersionUID = 1L;
+
 	public static final int len = 10;
 	public byte[] bytes;
 	public boolean v;
@@ -71,7 +77,10 @@ public class GCSignal {
 
 	// 'send' and 'receive' are supposed to be used only for secret signals
 	public void send(Network channel) {
-		channel.writeByte(bytes, len);
+		if (channel.receiver == null)
+			channel.writeByte(bytes, len);
+		else
+			channel.receiver.write(bytes);
 	}
 
 	// 'send' and 'receive' are supposed to be used only for secret signals
@@ -91,9 +100,12 @@ public class GCSignal {
 	}
 
 	public static void receive(Network channel, GCSignal s) {
-		if (s.bytes == null)
-			s.bytes = new byte[len];
-		channel.readBytes(s.bytes);
+		if (channel.sender == null) {
+			if (s.bytes == null)
+				s.bytes = new byte[len];
+			channel.readBytes(s.bytes);
+		} else
+			s.bytes = channel.sender.read();
 	}
 
 	@Override

+ 10 - 0
ObliVMGC/com/oblivm/backend/network/Network.java

@@ -12,6 +12,8 @@ import com.oblivm.backend.flexsc.CompEnv;
 import com.oblivm.backend.flexsc.Mode;
 import com.oblivm.backend.gc.GCSignal;
 
+import communication.Communication;
+
 public class Network {
 	protected Socket sock;
 	protected ServerSocket serverSock;
@@ -23,6 +25,9 @@ public class Network {
 	boolean THREADEDIO = true;
 	static int NetworkThreadedQueueSize = 1024 * 256;
 
+	public Communication sender;
+	public Communication receiver;
+
 	public void setUpThread() {
 		if (THREADEDIO) {
 			queue = new CustomizedConcurrentQueue(NetworkThreadedQueueSize);
@@ -36,6 +41,11 @@ public class Network {
 
 	}
 
+	public Network(Communication s, Communication r) {
+		sender = s;
+		receiver = r;
+	}
+
 	public Network(InputStream is, OutputStream os, Socket sock) {
 		this.is = is;
 		this.os = os;

+ 55 - 0
src/gc/GarbledCircuitLib.java

@@ -0,0 +1,55 @@
+package gc;
+
+import com.oblivm.backend.circuits.arithmetic.IntegerLib;
+import com.oblivm.backend.flexsc.CompEnv;
+
+import crypto.Crypto;
+import util.Util;
+
+public class GarbledCircuitLib<T> extends IntegerLib<T> {
+
+	private int d;
+	private int w;
+	private int logD;
+	private int logW;
+
+	public GarbledCircuitLib(CompEnv<T> e, int d, int w) {
+		super(e);
+		this.d = d;
+		this.w = w;
+		logD = (int) Math.ceil(Math.log(d) / Math.log(2));
+		logW = (int) Math.ceil(Math.log(w) / Math.log(2));
+	}
+
+	private void zerosFollowedByOnes(T[] input) {
+		for (int i = input.length - 2; i >= 0; i--) {
+			input[i] = or(input[i], input[i + 1]);
+		}
+	}
+
+	public T[][] deepestAndEmptyTuples(long i, T[] pathLabel, T[] feBits, T[][] tupleLabels) {
+		T[] l = toSignals(1L << (d - 1 - i) - 1L, d - 1);
+		T[] j1 = toSignals(Util.nextLong(w, Crypto.sr), logW);
+		T[] j2 = toSignals(Util.nextLong(w, Crypto.sr), logW);
+		T[] et = zeros(1);
+		for (int j = 0; j < w; j++) {
+			T[] tupleIndex = toSignals(j, logW);
+			T[] lz = xor(pathLabel, tupleLabels[j]);
+			zerosFollowedByOnes(lz);
+			T firstIf = and(feBits[j], less(lz, l));
+			l = mux(l, lz, firstIf);
+			j1 = mux(j1, tupleIndex, firstIf);
+			et = mux(ones(1), et, feBits[j]);
+			j2 = mux(tupleIndex, j2, feBits[j]);
+		}
+		T[] l_p = numberOfOnes(not(l)); // TODO: set length to logD?
+
+		T[][] output = env.newTArray(4, 0);
+		output[0] = l_p;
+		output[1] = j1;
+		output[2] = j2;
+		output[3] = et;
+		return output;
+	}
+
+}

+ 137 - 0
src/protocols/GarbledCircuit.java

@@ -0,0 +1,137 @@
+package protocols;
+
+import java.math.BigInteger;
+
+import com.oblivm.backend.circuits.arithmetic.IntegerLib;
+import com.oblivm.backend.flexsc.CompEnv;
+import com.oblivm.backend.gc.GCGenComp;
+import com.oblivm.backend.gc.GCSignal;
+import com.oblivm.backend.gc.regular.GCEva;
+import com.oblivm.backend.gc.regular.GCGen;
+import com.oblivm.backend.network.Network;
+
+import communication.Communication;
+import crypto.Crypto;
+import exceptions.AccessException;
+import exceptions.NoSuchPartyException;
+import oram.Forest;
+import oram.Metadata;
+import oram.Tree;
+import util.Util;
+
+public class GarbledCircuit extends Protocol {
+
+	private int totalLength = 1000;
+
+	public GarbledCircuit(Communication con1, Communication con2) {
+		super(con1, con2);
+	}
+
+	private int countOnes(boolean[] in) {
+		int cnt = 0;
+		for (int i = 0; i < in.length; i++)
+			if (in[i])
+				cnt++;
+		return cnt;
+	}
+
+	private int booleansToInt(boolean[] arr) {
+		int n = 0;
+		for (int i = arr.length - 1; i >= 0; i--)
+			n = (n << 1) | (arr[i] ? 1 : 0);
+		return n;
+	}
+
+	public void runE() {
+		Network w = new Network(null, con1);
+		CompEnv<GCSignal> gen = new GCGen(w);
+
+		boolean[] input1 = new boolean[totalLength];
+		boolean[] input2 = new boolean[totalLength];
+		boolean[] input = new boolean[totalLength];
+		for (int i = 0; i < input1.length; ++i) {
+			input1[i] = CompEnv.rnd.nextBoolean();
+			input2[i] = CompEnv.rnd.nextBoolean();
+			input[i] = input1[i] ^ input2[i];
+		}
+
+		GCSignal[][] inputKeyPairs1 = new GCSignal[input1.length][];
+		GCSignal[] localInputKeys1 = new GCSignal[input1.length];
+		GCSignal[][] inputKeyPairs2 = new GCSignal[input1.length][];
+		GCSignal[] localInputKeys2 = new GCSignal[input1.length];
+		GCSignal[] inputE = new GCSignal[input1.length];
+		GCSignal[] inputD = new GCSignal[input1.length];
+		for (int i = 0; i < input1.length; i++) {
+			inputKeyPairs1[i] = GCGenComp.genPair();
+			localInputKeys1[i] = inputKeyPairs1[i][0];
+			inputKeyPairs2[i] = GCGenComp.genPair();
+			localInputKeys2[i] = inputKeyPairs2[i][0];
+
+			inputE[i] = input1[i] ? inputKeyPairs1[i][1] : inputKeyPairs1[i][0];
+			inputD[i] = input2[i] ? inputKeyPairs2[i][1] : inputKeyPairs2[i][0];
+		}
+
+		GCSignal[] outputE = new IntegerLib<GCSignal>(gen).hammingDistance(localInputKeys1, localInputKeys2);
+
+		con1.write(inputE);
+		con1.write(inputD);
+
+		GCSignal[] outputD = con1.readObject();
+
+		boolean[] output = new boolean[totalLength];
+		for (int i = 0; i < outputE.length; i++) {
+			if (outputE[i].isPublic())
+				output[i] = outputE[i].v;
+			else if (outputE[i].equals(outputD[i]))
+				output[i] = false;
+			else if (outputD[i].equals(GCGenComp.R.xor(outputE[i])))
+				output[i] = true;
+			else
+				System.err.println("ERROR on GC output!");
+		}
+
+		int inCnt = countOnes(input);
+		int outCnt = booleansToInt(output);
+		System.out.println((inCnt == outCnt) + " " + inCnt + " " + outCnt);
+	}
+
+	public void runD() {
+		Network w = new Network(con1, null);
+		CompEnv<GCSignal> gen = new GCEva(w);
+
+		GCSignal[] randomInput = new GCSignal[totalLength];
+		for (int i = 0; i < randomInput.length; i++)
+			randomInput[i] = GCSignal.freshLabel(Crypto.sr);
+		IntegerLib<GCSignal> il = new IntegerLib<GCSignal>(gen);
+		il.hammingDistance(randomInput, randomInput);
+
+		GCSignal[] inputE = con1.readObject();
+		GCSignal[] inputD = con1.readObject();
+
+		gen.setEvaluate();
+		GCSignal[] outputD = il.hammingDistance(inputE, inputD);
+
+		con1.write(outputD);
+	}
+
+	public void runC() {
+
+	}
+
+	// for testing correctness
+	@Override
+	public void run(Party party, Metadata md, Forest forest) {
+		if (party == Party.Eddie) {
+			runE();
+
+		} else if (party == Party.Debbie) {
+			runD();
+
+		} else if (party == Party.Charlie) {
+			runC();
+
+		} else {
+			throw new NoSuchPartyException(party + "");
+		}
+	}
+}

+ 3 - 0
src/ui/CLI.java

@@ -20,6 +20,7 @@ import protocols.Reshuffle;
 import protocols.PostProcessT;
 import protocols.SSXOT;
 import protocols.Access;
+import protocols.GarbledCircuit;
 
 public class CLI {
 	public static final int DEFAULT_PORT = 8000;
@@ -82,6 +83,8 @@ public class CLI {
 			operation = SSXOT.class;
 		} else if (protocol.equals("access")) {
 			operation = Access.class;
+		} else if (protocol.equals("gc")) {
+			operation = GarbledCircuit.class;
 		} else {
 			System.out.println("Protocol " + protocol + " not supported");
 			System.exit(-1);