Pārlūkot izejas kodu

UpdateRoot implemented

Boyoung- 9 gadi atpakaļ
vecāks
revīzija
82aef0cb84

+ 3 - 6
src/gc/GCLib.java

@@ -30,17 +30,14 @@ public class GCLib<T> extends IntegerLib<T> {
 		}
 	}
 
-	public T[][] rootFindDeepestAndEmpty(byte[] Li, T[] j1, T[] E_feBits, T[] C_feBits, T[][] E_tupleLabels,
+	public T[][] rootFindDeepestAndEmpty(T[] j1, T[] pathLabel, T[] E_feBits, T[] C_feBits, T[][] E_tupleLabels,
 			T[][] C_tupleLabels) {
 		int sLogW = (int) Math.ceil(Math.log(w) / Math.log(2));
-		T[] pathLabel = toSignals(new BigInteger(1, Li).longValue(), d - 1); // no
-																				// sign
-																				// bit
 		T[] feBits = xor(E_feBits, C_feBits);
 		T[][] tupleLabels = env.newTArray(w, 0);
 		for (int j = 0; j < w; j++)
 			tupleLabels[j] = xor(E_tupleLabels[j], C_tupleLabels[j]);
-		
+
 		T[] l = padSignal(ones(d - 1), d); // has sign bit
 		T[] j2 = zeros(sLogW); // no sign bit
 
@@ -56,7 +53,7 @@ public class GCLib<T> extends IntegerLib<T> {
 
 			j2 = mux(tupleIndex, j2, feBits[j]);
 		}
-		
+
 		T[][] output = env.newTArray(2, 0);
 		output[0] = j1;
 		output[1] = j2;

+ 16 - 19
src/gc/GCUtil.java

@@ -9,14 +9,14 @@ import crypto.Crypto;
 import oram.Tuple;
 
 public class GCUtil {
-	
+
 	public static GCSignal[] genEmptyKeys(int n) {
 		GCSignal[] keys = new GCSignal[n];
-		for (int i=0; i<n; i++)
+		for (int i = 0; i < n; i++)
 			keys[i] = new GCSignal(new byte[GCSignal.len]);
 		return keys;
 	}
-	
+
 	public static GCSignal[][] genKeyPairs(int n) {
 		GCSignal[][] pairs = new GCSignal[n][];
 		for (int i = 0; i < n; i++)
@@ -40,37 +40,34 @@ public class GCUtil {
 	}
 
 	/*
-	public static GCSignal[] selectKeys(GCSignal[][] pairs, byte[] input) {
-		BigInteger in = new BigInteger(1, input);
-		GCSignal[] out = new GCSignal[pairs.length];
-		for (int i = 0; i < pairs.length; i++)
-			out[i] = pairs[i][in.testBit(pairs.length - 1 - i) ? 1 : 0];
-		return out;
-	}
-	*/
-	
+	 * public static GCSignal[] selectKeys(GCSignal[][] pairs, byte[] input) {
+	 * BigInteger in = new BigInteger(1, input); GCSignal[] out = new
+	 * GCSignal[pairs.length]; for (int i = 0; i < pairs.length; i++) out[i] =
+	 * pairs[i][in.testBit(pairs.length - 1 - i) ? 1 : 0]; return out; }
+	 */
+
 	public static GCSignal[][] selectLabelKeys(GCSignal[][][] labelPairs, Tuple[] tuples) {
 		GCSignal[][] out = new GCSignal[tuples.length][];
-		for (int i=0; i<tuples.length; i++) 
+		for (int i = 0; i < tuples.length; i++)
 			out[i] = revSelectKeys(labelPairs[i], tuples[i].getL());
 		return out;
 	}
-	
+
 	public static GCSignal[] selectFeKeys(GCSignal[][] pairs, Tuple[] tuples) {
 		GCSignal[] out = new GCSignal[pairs.length];
 		for (int i = 0; i < pairs.length; i++)
-			out[i] = pairs[i][new BigInteger(tuples[i].getF()).testBit(0)?1:0];
+			out[i] = pairs[i][new BigInteger(tuples[i].getF()).testBit(0) ? 1 : 0];
 		return out;
 	}
-	
+
 	public static BigInteger[] genOutKeyHashes(GCSignal[] outZeroKeys) {
 		BigInteger[] hashes = new BigInteger[outZeroKeys.length];
-		for (int i=0; i<outZeroKeys.length; i++) {
+		for (int i = 0; i < outZeroKeys.length; i++) {
 			hashes[i] = new BigInteger(Crypto.sha1.digest(outZeroKeys[i].bytes));
 		}
 		return hashes;
 	}
-	
+
 	public static BigInteger evaOutKeys(GCSignal[] outKeys, BigInteger[] genHashes) {
 		BigInteger[] evaHashes = genOutKeyHashes(outKeys);
 		BigInteger output = BigInteger.ZERO;
@@ -80,7 +77,7 @@ public class GCUtil {
 					output = output.setBit(i);
 			} else if (genHashes[i].compareTo(evaHashes[i]) != 0) {
 				output = output.setBit(i);
-			} 
+			}
 		}
 		return output;
 	}

+ 5 - 3
src/protocols/Access.java

@@ -84,13 +84,13 @@ public class Access extends Protocol {
 		else
 			Ti = new Tuple(new byte[1], Ni, Li, y);
 
-		OutAccess outaccess = new OutAccess(null, null, null, null, Ti, pathTuples);
+		OutAccess outaccess = new OutAccess(Li, null, null, null, null, Ti, pathTuples);
 
 		timer.stop(P.ACC, M.online_comp);
 		return outaccess;
 	}
 
-	public void runD(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer timer) {
+	public byte[] runD(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer timer) {
 		timer.start(P.ACC, M.online_comp);
 
 		// step 0: get Li from C
@@ -134,6 +134,8 @@ public class Access extends Protocol {
 		}
 
 		timer.stop(P.ACC, M.online_comp);
+
+		return Li;
 	}
 
 	public OutAccess runC(Metadata md, int treeIndex, byte[] Li, Timer timer) {
@@ -192,7 +194,7 @@ public class Access extends Protocol {
 			Crypto.sr.nextBytes(pathTuples[j1].getA());
 		}
 
-		OutAccess outaccess = new OutAccess(Lip1, Ti, pathTuples, j2, null, null);
+		OutAccess outaccess = new OutAccess(Li, Lip1, Ti, pathTuples, j2, null, null);
 
 		timer.stop(P.ACC, M.online_comp);
 		return outaccess;

+ 3 - 1
src/protocols/OutAccess.java

@@ -3,6 +3,7 @@ package protocols;
 import oram.Tuple;
 
 public class OutAccess {
+	public byte[] Li;
 	public byte[] C_Lip1;
 	public Tuple E_Ti;
 	public Tuple C_Ti;
@@ -10,7 +11,8 @@ public class OutAccess {
 	public Tuple[] C_P;
 	public Integer C_j2;
 
-	public OutAccess(byte[] C_Lip1, Tuple C_Ti, Tuple[] C_P, Integer C_j2, Tuple E_Ti, Tuple[] E_P) {
+	public OutAccess(byte[] Li, byte[] C_Lip1, Tuple C_Ti, Tuple[] C_P, Integer C_j2, Tuple E_Ti, Tuple[] E_P) {
+		this.Li = Li;
 		this.C_Lip1 = C_Lip1;
 		this.E_Ti = E_Ti;
 		this.C_Ti = C_Ti;

+ 8 - 7
src/protocols/PreData.java

@@ -24,13 +24,13 @@ public class PreData {
 	public int[] access_sigma;
 	public Tuple[] access_p;
 
-	public Tuple[] ssxot_delta;
-	public int[] ssxot_E_pi;
-	public int[] ssxot_C_pi;
-	public int[] ssxot_E_pi_ivs;
-	public int[] ssxot_C_pi_ivs;
-	public Tuple[] ssxot_E_r;
-	public Tuple[] ssxot_C_r;
+	public Tuple[][] ssxot_delta = new Tuple[2][];
+	public int[][] ssxot_E_pi = new int[2][];
+	public int[][] ssxot_C_pi = new int[2][];
+	public int[][] ssxot_E_pi_ivs = new int[2][];
+	public int[][] ssxot_C_pi_ivs = new int[2][];
+	public Tuple[][] ssxot_E_r = new Tuple[2][];
+	public Tuple[][] ssxot_C_r = new Tuple[2][];
 
 	public byte[] ppt_Li;
 	public byte[] ppt_Lip1;
@@ -44,6 +44,7 @@ public class PreData {
 	public Tuple[] reshuffle_a_prime;
 
 	public GCSignal[][] ur_j1KeyPairs;
+	public GCSignal[][] ur_LiKeyPairs;
 	public GCSignal[][] ur_E_feKeyPairs;
 	public GCSignal[][] ur_C_feKeyPairs;
 	public GCSignal[][][] ur_E_labelKeyPairs;

+ 8 - 0
src/protocols/PreRetrieve.java

@@ -10,10 +10,13 @@ public class PreRetrieve extends Protocol {
 		super(con1, con2);
 	}
 
+	// TODO: not all protocols run on all trees (remove unnecessary precomp)
+
 	public void runE(PreData predata, Metadata md, int ti, Timer timer) {
 		PreAccess preaccess = new PreAccess(con1, con2);
 		PreReshuffle prereshuffle = new PreReshuffle(con1, con2);
 		PrePostProcessT prepostprocesst = new PrePostProcessT(con1, con2);
+		PreUpdateRoot preupdateroot = new PreUpdateRoot(con1, con2);
 
 		int numTuples = md.getStashSizeOfTree(ti) + md.getLBitsOfTree(ti) * md.getW();
 		int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
@@ -22,12 +25,14 @@ public class PreRetrieve extends Protocol {
 		preaccess.runE(predata, md.getTwoTauPow(), numTuples, tupleParam, timer);
 		prereshuffle.runE(predata, timer);
 		prepostprocesst.runE(predata, timer);
+		preupdateroot.runE(predata, md.getStashSizeOfTree(ti), md.getLBitsOfTree(ti), timer);
 	}
 
 	public void runD(PreData predata, Metadata md, int ti, PreData prev, Timer timer) {
 		PreAccess preaccess = new PreAccess(con1, con2);
 		PreReshuffle prereshuffle = new PreReshuffle(con1, con2);
 		PrePostProcessT prepostprocesst = new PrePostProcessT(con1, con2);
+		PreUpdateRoot preupdateroot = new PreUpdateRoot(con1, con2);
 
 		int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
 				md.getABytesOfTree(ti) };
@@ -35,16 +40,19 @@ public class PreRetrieve extends Protocol {
 		preaccess.runD(predata, timer);
 		prereshuffle.runD(predata, tupleParam, timer);
 		prepostprocesst.runD(predata, prev, md.getLBytesOfTree(ti), md.getAlBytesOfTree(ti), md.getTau(), timer);
+		preupdateroot.runD(predata, md.getStashSizeOfTree(ti), md.getLBitsOfTree(ti), tupleParam, timer);
 	}
 
 	public void runC(PreData predata, Metadata md, int ti, PreData prev, Timer timer) {
 		PreAccess preaccess = new PreAccess(con1, con2);
 		PreReshuffle prereshuffle = new PreReshuffle(con1, con2);
 		PrePostProcessT prepostprocesst = new PrePostProcessT(con1, con2);
+		PreUpdateRoot preupdateroot = new PreUpdateRoot(con1, con2);
 
 		preaccess.runC(timer);
 		prereshuffle.runC(predata, timer);
 		prepostprocesst.runC(predata, prev, md.getLBytesOfTree(ti), md.getAlBytesOfTree(ti), timer);
+		preupdateroot.runC(predata, timer);
 	}
 
 	@Override

+ 24 - 19
src/protocols/PreSSXOT.java

@@ -11,41 +11,46 @@ import oram.Tuple;
 import util.Util;
 
 public class PreSSXOT extends Protocol {
-	public PreSSXOT(Communication con1, Communication con2) {
+
+	private int id;
+
+	public PreSSXOT(Communication con1, Communication con2, int id) {
 		super(con1, con2);
+		this.id = id;
 	}
 
 	public void runE(PreData predata, Timer timer) {
 		timer.start(P.XOT, M.offline_read);
-		predata.ssxot_E_pi = con1.readObject();
-		predata.ssxot_E_r = con1.readObject();
+		predata.ssxot_E_pi[id] = con1.readObject();
+		predata.ssxot_E_r[id] = con1.readObject();
 		timer.stop(P.XOT, M.offline_read);
 	}
 
 	public void runD(PreData predata, int n, int k, int[] tupleParam, Timer timer) {
 		timer.start(P.XOT, M.offline_comp);
 
-		predata.ssxot_delta = new Tuple[k];
+		predata.ssxot_delta[id] = new Tuple[k];
 		for (int i = 0; i < k; i++)
-			predata.ssxot_delta[i] = new Tuple(tupleParam[0], tupleParam[1], tupleParam[2], tupleParam[3], Crypto.sr);
+			predata.ssxot_delta[id][i] = new Tuple(tupleParam[0], tupleParam[1], tupleParam[2], tupleParam[3],
+					Crypto.sr);
 
-		predata.ssxot_E_pi = Util.randomPermutation(n, Crypto.sr);
-		predata.ssxot_C_pi = Util.randomPermutation(n, Crypto.sr);
-		predata.ssxot_E_pi_ivs = Util.inversePermutation(predata.ssxot_E_pi);
-		predata.ssxot_C_pi_ivs = Util.inversePermutation(predata.ssxot_C_pi);
+		predata.ssxot_E_pi[id] = Util.randomPermutation(n, Crypto.sr);
+		predata.ssxot_C_pi[id] = Util.randomPermutation(n, Crypto.sr);
+		predata.ssxot_E_pi_ivs[id] = Util.inversePermutation(predata.ssxot_E_pi[id]);
+		predata.ssxot_C_pi_ivs[id] = Util.inversePermutation(predata.ssxot_C_pi[id]);
 
-		predata.ssxot_E_r = new Tuple[n];
-		predata.ssxot_C_r = new Tuple[n];
+		predata.ssxot_E_r[id] = new Tuple[n];
+		predata.ssxot_C_r[id] = new Tuple[n];
 		for (int i = 0; i < n; i++) {
-			predata.ssxot_E_r[i] = new Tuple(tupleParam[0], tupleParam[1], tupleParam[2], tupleParam[3], Crypto.sr);
-			predata.ssxot_C_r[i] = new Tuple(tupleParam[0], tupleParam[1], tupleParam[2], tupleParam[3], Crypto.sr);
+			predata.ssxot_E_r[id][i] = new Tuple(tupleParam[0], tupleParam[1], tupleParam[2], tupleParam[3], Crypto.sr);
+			predata.ssxot_C_r[id][i] = new Tuple(tupleParam[0], tupleParam[1], tupleParam[2], tupleParam[3], Crypto.sr);
 		}
 
 		timer.start(P.XOT, M.offline_write);
-		con1.write(predata.ssxot_E_pi);
-		con1.write(predata.ssxot_E_r);
-		con2.write(predata.ssxot_C_pi);
-		con2.write(predata.ssxot_C_r);
+		con1.write(predata.ssxot_E_pi[id]);
+		con1.write(predata.ssxot_E_r[id]);
+		con2.write(predata.ssxot_C_pi[id]);
+		con2.write(predata.ssxot_C_r[id]);
 		timer.stop(P.XOT, M.offline_write);
 
 		timer.stop(P.XOT, M.offline_comp);
@@ -53,8 +58,8 @@ public class PreSSXOT extends Protocol {
 
 	public void runC(PreData predata, Timer timer) {
 		timer.start(P.XOT, M.offline_read);
-		predata.ssxot_C_pi = con2.readObject();
-		predata.ssxot_C_r = con2.readObject();
+		predata.ssxot_C_pi[id] = con2.readObject();
+		predata.ssxot_C_r[id] = con2.readObject();
 		timer.stop(P.XOT, M.offline_read);
 	}
 

+ 17 - 5
src/protocols/PreUpdateRoot.java

@@ -20,12 +20,14 @@ public class PreUpdateRoot extends Protocol {
 		super(con1, con2);
 	}
 
-	public void runE(PreData predata, int sw, int lBits, byte[] Li, Timer timer) {
+	public void runE(PreData predata, int sw, int lBits, Timer timer) {
 		int sLogW = (int) Math.ceil(Math.log(sw) / Math.log(2));
 		predata.ur_j1KeyPairs = GCUtil.genKeyPairs(sLogW);
+		predata.ur_LiKeyPairs = GCUtil.genKeyPairs(lBits);
 		predata.ur_E_feKeyPairs = GCUtil.genKeyPairs(sw);
 		predata.ur_C_feKeyPairs = GCUtil.genKeyPairs(sw);
 		GCSignal[] j1ZeroKeys = GCUtil.getZeroKeys(predata.ur_j1KeyPairs);
+		GCSignal[] LiZeroKeys = GCUtil.getZeroKeys(predata.ur_LiKeyPairs);
 		GCSignal[] E_feZeroKeys = GCUtil.getZeroKeys(predata.ur_E_feKeyPairs);
 		GCSignal[] C_feZeroKeys = GCUtil.getZeroKeys(predata.ur_C_feKeyPairs);
 		predata.ur_E_labelKeyPairs = new GCSignal[sw][][];
@@ -41,8 +43,8 @@ public class PreUpdateRoot extends Protocol {
 
 		Network channel = new Network(null, con1);
 		CompEnv<GCSignal> gen = new GCGen(channel);
-		GCSignal[][] outZeroKeys = new GCLib<GCSignal>(gen, lBits + 1, sw).rootFindDeepestAndEmpty(Li, j1ZeroKeys,
-				E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys, C_labelZeroKeys);
+		GCSignal[][] outZeroKeys = new GCLib<GCSignal>(gen, lBits + 1, sw).rootFindDeepestAndEmpty(j1ZeroKeys,
+				LiZeroKeys, E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys, C_labelZeroKeys);
 
 		predata.ur_outKeyHashes = new BigInteger[outZeroKeys.length][];
 		for (int i = 0; i < outZeroKeys.length; i++)
@@ -51,11 +53,15 @@ public class PreUpdateRoot extends Protocol {
 		con2.write(predata.ur_C_feKeyPairs);
 		con2.write(predata.ur_C_labelKeyPairs);
 		con1.write(predata.ur_outKeyHashes);
+
+		PreSSXOT pressxot = new PreSSXOT(con1, con2, 0);
+		pressxot.runE(predata, timer);
 	}
 
-	public void runD(PreData predata, int sw, int lBits, byte[] Li, Timer timer) {
+	public void runD(PreData predata, int sw, int lBits, int[] tupleParam, Timer timer) {
 		int sLogW = (int) Math.ceil(Math.log(sw) / Math.log(2));
 		GCSignal[] j1ZeroKeys = GCUtil.genEmptyKeys(sLogW);
+		GCSignal[] LiZeroKeys = GCUtil.genEmptyKeys(lBits);
 		GCSignal[] E_feZeroKeys = GCUtil.genEmptyKeys(sw);
 		GCSignal[] C_feZeroKeys = GCUtil.genEmptyKeys(sw);
 		GCSignal[][] E_labelZeroKeys = new GCSignal[sw][];
@@ -68,16 +74,22 @@ public class PreUpdateRoot extends Protocol {
 		Network channel = new Network(con1, null);
 		CompEnv<GCSignal> eva = new GCEva(channel);
 		predata.ur_gc = new GCLib<GCSignal>(eva, lBits + 1, sw);
-		predata.ur_gc.rootFindDeepestAndEmpty(Li, j1ZeroKeys, E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys,
+		predata.ur_gc.rootFindDeepestAndEmpty(j1ZeroKeys, LiZeroKeys, E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys,
 				C_labelZeroKeys);
 		eva.setEvaluate();
 
 		predata.ur_outKeyHashes = con1.readObject();
+
+		PreSSXOT pressxot = new PreSSXOT(con1, con2, 0);
+		pressxot.runD(predata, sw + 1, sw, tupleParam, timer);
 	}
 
 	public void runC(PreData predata, Timer timer) {
 		predata.ur_C_feKeyPairs = con1.readObject();
 		predata.ur_C_labelKeyPairs = con1.readObject();
+
+		PreSSXOT pressxot = new PreSSXOT(con1, con2, 0);
+		pressxot.runC(predata, timer);
 	}
 
 	@Override

+ 10 - 1
src/protocols/Retrieve.java

@@ -1,6 +1,7 @@
 package protocols;
 
 import java.math.BigInteger;
+import java.util.Arrays;
 
 import communication.Communication;
 import crypto.Crypto;
@@ -24,10 +25,13 @@ public class Retrieve extends Protocol {
 		Access access = new Access(con1, con2);
 		Reshuffle reshuffle = new Reshuffle(con1, con2);
 		PostProcessT postprocesst = new PostProcessT(con1, con2);
+		UpdateRoot updateroot = new UpdateRoot(con1, con2);
 
 		OutAccess outaccess = access.runE(predata, OTi, Ni, Nip1_pr, timer);
 		Tuple[] path = reshuffle.runE(predata, outaccess.E_P, OTi.getTreeIndex() == 0, timer);
 		Tuple Ti = postprocesst.runE(predata, outaccess.E_Ti, OTi.getTreeIndex() == h - 1, timer);
+		Tuple[] root = Arrays.copyOfRange(path, 0, OTi.getStashSize());
+		root = updateroot.runE(predata, OTi.getTreeIndex() == 0, outaccess.Li, root, Ti, timer);
 
 		return outaccess;
 	}
@@ -36,20 +40,25 @@ public class Retrieve extends Protocol {
 		Access access = new Access(con1, con2);
 		Reshuffle reshuffle = new Reshuffle(con1, con2);
 		PostProcessT postprocesst = new PostProcessT(con1, con2);
+		UpdateRoot updateroot = new UpdateRoot(con1, con2);
 
-		access.runD(predata, OTi, Ni, Nip1_pr, timer);
+		byte[] Li = access.runD(predata, OTi, Ni, Nip1_pr, timer);
 		reshuffle.runD();
 		postprocesst.runD();
+		updateroot.runD(predata, OTi.getTreeIndex() == 0, Li, OTi.getW(), timer);
 	}
 
 	public OutAccess runC(PreData predata, Metadata md, int ti, byte[] Li, int h, Timer timer) {
 		Access access = new Access(con1, con2);
 		Reshuffle reshuffle = new Reshuffle(con1, con2);
 		PostProcessT postprocesst = new PostProcessT(con1, con2);
+		UpdateRoot updateroot = new UpdateRoot(con1, con2);
 
 		OutAccess outaccess = access.runC(md, ti, Li, timer);
 		Tuple[] path = reshuffle.runC(predata, outaccess.C_P, ti == 0, timer);
 		Tuple Ti = postprocesst.runC(predata, outaccess.C_Ti, Li, outaccess.C_Lip1, outaccess.C_j2, ti == h - 1, timer);
+		Tuple[] root = Arrays.copyOfRange(path, 0, md.getStashSizeOfTree(ti));
+		root = updateroot.runC(predata, ti == 0, root, Ti, timer);
 
 		return outaccess;
 	}

+ 15 - 7
src/protocols/SSXOT.java

@@ -14,8 +14,16 @@ import util.Util;
 
 public class SSXOT extends Protocol {
 
+	private int id;
+
 	public SSXOT(Communication con1, Communication con2) {
 		super(con1, con2);
+		this.id = 0;
+	}
+
+	public SSXOT(Communication con1, Communication con2, int id) {
+		super(con1, con2);
+		this.id = id;
 	}
 
 	public Tuple[] runE(PreData predata, Tuple[] m, Timer timer) {
@@ -24,7 +32,7 @@ public class SSXOT extends Protocol {
 		// step 1
 		Tuple[] a = new Tuple[m.length];
 		for (int i = 0; i < m.length; i++)
-			a[i] = m[predata.ssxot_E_pi[i]].xor(predata.ssxot_E_r[i]);
+			a[i] = m[predata.ssxot_E_pi[id][i]].xor(predata.ssxot_E_r[id][i]);
 
 		timer.start(P.XOT, M.online_write);
 		con2.write(a);
@@ -57,10 +65,10 @@ public class SSXOT extends Protocol {
 		Tuple[] E_p = new Tuple[k];
 		Tuple[] C_p = new Tuple[k];
 		for (int i = 0; i < k; i++) {
-			E_j[i] = predata.ssxot_E_pi_ivs[index[i]];
-			C_j[i] = predata.ssxot_C_pi_ivs[index[i]];
-			E_p[i] = predata.ssxot_E_r[E_j[i]].xor(predata.ssxot_delta[i]);
-			C_p[i] = predata.ssxot_C_r[C_j[i]].xor(predata.ssxot_delta[i]);
+			E_j[i] = predata.ssxot_E_pi_ivs[id][index[i]];
+			C_j[i] = predata.ssxot_C_pi_ivs[id][index[i]];
+			E_p[i] = predata.ssxot_E_r[id][E_j[i]].xor(predata.ssxot_delta[id][i]);
+			C_p[i] = predata.ssxot_C_r[id][C_j[i]].xor(predata.ssxot_delta[id][i]);
 		}
 
 		timer.start(P.XOT, M.online_write);
@@ -79,7 +87,7 @@ public class SSXOT extends Protocol {
 		// step 1
 		Tuple[] a = new Tuple[m.length];
 		for (int i = 0; i < m.length; i++)
-			a[i] = m[predata.ssxot_C_pi[i]].xor(predata.ssxot_C_r[i]);
+			a[i] = m[predata.ssxot_C_pi[id][i]].xor(predata.ssxot_C_r[id][i]);
 
 		timer.start(P.XOT, M.online_write);
 		con1.write(a);
@@ -120,7 +128,7 @@ public class SSXOT extends Protocol {
 			}
 
 			PreData predata = new PreData();
-			PreSSXOT pressxot = new PreSSXOT(con1, con2);
+			PreSSXOT pressxot = new PreSSXOT(con1, con2, 0);
 
 			if (party == Party.Eddie) {
 				pressxot.runE(predata, timer);

+ 81 - 36
src/protocols/UpdateRoot.java

@@ -2,6 +2,8 @@ package protocols;
 
 import java.math.BigInteger;
 
+import org.apache.commons.lang3.ArrayUtils;
+
 import com.oblivm.backend.gc.GCSignal;
 
 import communication.Communication;
@@ -19,40 +21,86 @@ public class UpdateRoot extends Protocol {
 		super(con1, con2);
 	}
 
-	public void runE(PreData predata, byte[] Li, Tuple[] R, Timer timer) {
+	public Tuple[] runE(PreData predata, boolean firstTree, byte[] Li, Tuple[] R, Tuple Ti, Timer timer) {
+		if (firstTree)
+			return R;
+
 		// step 1
 		int j1 = Crypto.sr.nextInt(R.length);
 		GCSignal[] j1InputKeys = GCUtil.revSelectKeys(predata.ur_j1KeyPairs, BigInteger.valueOf(j1).toByteArray());
+		GCSignal[] LiInputKeys = GCUtil.revSelectKeys(predata.ur_LiKeyPairs, Li);
 		GCSignal[] E_feInputKeys = GCUtil.selectFeKeys(predata.ur_E_feKeyPairs, R);
 		GCSignal[][] E_labelInputKeys = GCUtil.selectLabelKeys(predata.ur_E_labelKeyPairs, R);
 
 		con1.write(j1InputKeys);
+		con1.write(LiInputKeys);
 		con1.write(E_feInputKeys);
 		con1.write(E_labelInputKeys);
+
+		// step 4
+		R = ArrayUtils.addAll(R, new Tuple[] { Ti });
+		SSXOT ssxot = new SSXOT(con1, con2, 0);
+		R = ssxot.runE(predata, R, timer);
+
+		return R;
 	}
 
-	public void runD(PreData predata, byte[] Li, Timer timer) {
+	public void runD(PreData predata, boolean firstTree, byte[] Li, int w, Timer timer) {
+		if (firstTree)
+			return;
+
+		// step 1
 		GCSignal[] j1InputKeys = con1.readObject();
+		GCSignal[] LiInputKeys = con1.readObject();
 		GCSignal[] E_feInputKeys = con1.readObject();
 		GCSignal[][] E_labelInputKeys = con1.readObject();
 		GCSignal[] C_feInputKeys = con2.readObject();
 		GCSignal[][] C_labelInputKeys = con2.readObject();
 
-		GCSignal[][] outKeys = predata.ur_gc.rootFindDeepestAndEmpty(Li, j1InputKeys, E_feInputKeys, C_feInputKeys,
-				E_labelInputKeys, C_labelInputKeys);
+		// step 2
+		GCSignal[][] outKeys = predata.ur_gc.rootFindDeepestAndEmpty(j1InputKeys, LiInputKeys, E_feInputKeys,
+				C_feInputKeys, E_labelInputKeys, C_labelInputKeys);
 		int j1 = GCUtil.evaOutKeys(outKeys[0], predata.ur_outKeyHashes[0]).intValue();
 		int j2 = GCUtil.evaOutKeys(outKeys[1], predata.ur_outKeyHashes[1]).intValue();
 
-		System.out.println(j1 + " " + j2);
+		// System.out.println(j1 + " " + j2);
+
+		// step 3
+		int r = Crypto.sr.nextInt(w);
+		int[] I = new int[E_feInputKeys.length];
+		for (int i = 0; i < I.length; i++)
+			I[i] = i;
+		I[j2] = I.length;
+		int tmp = I[r];
+		I[r] = I[j1];
+		I[j1] = tmp;
+
+		// step 4
+		SSXOT ssxot = new SSXOT(con1, con2, 0);
+		ssxot.runD(predata, I, timer);
+
+		// for (int i=0; i<I.length; i++)
+		// System.out.print(I[i] + " ");
+		// System.out.println();
 	}
 
-	public void runC(PreData predata, Tuple[] R, Timer timer) {
+	public Tuple[] runC(PreData predata, boolean firstTree, Tuple[] R, Tuple Ti, Timer timer) {
+		if (firstTree)
+			return R;
+
 		// step 1
 		GCSignal[] C_feInputKeys = GCUtil.selectFeKeys(predata.ur_C_feKeyPairs, R);
 		GCSignal[][] C_labelInputKeys = GCUtil.selectLabelKeys(predata.ur_C_labelKeyPairs, R);
 
 		con2.write(C_feInputKeys);
 		con2.write(C_labelInputKeys);
+
+		// step 4
+		R = ArrayUtils.addAll(R, new Tuple[] { Ti });
+		SSXOT ssxot = new SSXOT(con1, con2, 0);
+		R = ssxot.runC(predata, R, timer);
+
+		return R;
 	}
 
 	// for testing correctness
@@ -61,38 +109,40 @@ public class UpdateRoot extends Protocol {
 		Timer timer = new Timer();
 
 		for (int i = 0; i < 10; i++) {
-			
+
 			System.out.println("i=" + i);
-			
+
 			PreData predata = new PreData();
 			PreUpdateRoot preupdateroot = new PreUpdateRoot(con1, con2);
-			
+
 			if (party == Party.Eddie) {
 				int sw = Crypto.sr.nextInt(25) + 10;
 				int lBits = Crypto.sr.nextInt(30) + 5;
-				byte[] Li = Util.nextBytes((lBits+7)/8, Crypto.sr);
+				byte[] Li = Util.nextBytes((lBits + 7) / 8, Crypto.sr);
 				Tuple[] R = new Tuple[sw];
-				for (int j=0; j<sw; j++)
-					R[j] = new Tuple(1, 2, (lBits+7)/8, 3, Crypto.sr);
-				
+				for (int j = 0; j < sw; j++)
+					R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
+				Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
+
 				System.out.println("sw,lBits: " + sw + " " + lBits);
-				
+
 				con1.write(sw);
 				con1.write(lBits);
 				con1.write(Li);
 				con2.write(sw);
 				con2.write(lBits);
-				
-				preupdateroot.runE(predata, sw, lBits, Li, timer);
-				runE(predata, Li, R, timer);
-				
+
+				preupdateroot.runE(predata, sw, lBits, timer);
+				runE(predata, false, Li, R, Ti, timer);
+
 				int emptyIndex = 0;
-				for (int j=0; j<sw; j++) {
+				for (int j = 0; j < sw; j++) {
 					if (new BigInteger(R[j].getF()).testBit(0)) {
-						String l = Util.addZeros(Util.getSubBits(new BigInteger(1, Util.xor(R[j].getL(), Li)), lBits, 0).toString(2), lBits);
+						String l = Util.addZeros(
+								Util.getSubBits(new BigInteger(1, Util.xor(R[j].getL(), Li)), lBits, 0).toString(2),
+								lBits);
 						System.out.println(j + ":\t" + l);
-					}
-					else {
+					} else {
 						emptyIndex = j;
 					}
 				}
@@ -102,26 +152,21 @@ public class UpdateRoot extends Protocol {
 				int sw = con1.readObject();
 				int lBits = con1.readObject();
 				byte[] Li = con1.read();
-				
-				preupdateroot.runD(predata, sw, lBits, Li, timer);
-				runD(predata, Li, timer);
+				int[] tupleParam = new int[] { 1, 2, (lBits + 7) / 8, 3 };
+
+				preupdateroot.runD(predata, sw, lBits, tupleParam, timer);
+				runD(predata, false, Li, md.getW(), timer);
 
 			} else if (party == Party.Charlie) {
 				int sw = con1.readObject();
 				int lBits = con1.readObject();
 				Tuple[] R = new Tuple[sw];
-				for (int j=0; j<sw; j++)
-					R[j] = new Tuple(1, 2, lBits, 3, null);
-				
+				for (int j = 0; j < sw; j++)
+					R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
+				Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
+
 				preupdateroot.runC(predata, timer);
-				runC(predata, R, timer);
-				
-				/*
-				if (output.t == index && Util.equal(output.m_t, m[index]))
-					System.out.println("SSCOT test passed");
-				else
-					System.err.println("SSCOT test failed");
-					*/
+				runC(predata, false, R, Ti, timer);
 
 			} else {
 				throw new NoSuchPartyException(party + "");

+ 1 - 1
src/util/Util.java

@@ -137,7 +137,7 @@ public class Util {
 			return out;
 		}
 	}
-	
+
 	public static String addZeros(String a, int n) {
 		String out = a;
 		for (int i = 0; i < n - a.length(); i++)

+ 50 - 0
test/protocols/TestUpdateRoot_C.java

@@ -0,0 +1,50 @@
+package protocols;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+
+public class TestUpdateRoot_C {
+
+	public static void main(String[] args) {
+		Runtime runTime = Runtime.getRuntime();
+		Process process = null;
+		String dir = System.getProperty("user.dir");
+		String binDir = dir + "\\bin";
+		String libs = dir + "\\lib\\*";
+		try {
+			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol update charlie");
+
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+		InputStream inputStream = process.getInputStream();
+		InputStreamReader isr = new InputStreamReader(inputStream);
+		InputStream errorStream = process.getErrorStream();
+		InputStreamReader esr = new InputStreamReader(errorStream);
+
+		System.out.println("STANDARD OUTPUT:");
+		int n1;
+		char[] c1 = new char[1024];
+		try {
+			while ((n1 = isr.read(c1)) > 0) {
+				System.out.print(new String(Arrays.copyOfRange(c1, 0, n1)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+
+		System.out.println("STANDARD ERROR:");
+		int n2;
+		char[] c2 = new char[1024];
+		try {
+			while ((n2 = esr.read(c2)) > 0) {
+				System.err.print(new String(Arrays.copyOfRange(c2, 0, n2)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+	}
+
+}

+ 50 - 0
test/protocols/TestUpdateRoot_D.java

@@ -0,0 +1,50 @@
+package protocols;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+
+public class TestUpdateRoot_D {
+
+	public static void main(String[] args) {
+		Runtime runTime = Runtime.getRuntime();
+		Process process = null;
+		String dir = System.getProperty("user.dir");
+		String binDir = dir + "\\bin";
+		String libs = dir + "\\lib\\*";
+		try {
+			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol update debbie");
+
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+		InputStream inputStream = process.getInputStream();
+		InputStreamReader isr = new InputStreamReader(inputStream);
+		InputStream errorStream = process.getErrorStream();
+		InputStreamReader esr = new InputStreamReader(errorStream);
+
+		System.out.println("STANDARD OUTPUT:");
+		int n1;
+		char[] c1 = new char[1024];
+		try {
+			while ((n1 = isr.read(c1)) > 0) {
+				System.out.print(new String(Arrays.copyOfRange(c1, 0, n1)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+
+		System.out.println("STANDARD ERROR:");
+		int n2;
+		char[] c2 = new char[1024];
+		try {
+			while ((n2 = esr.read(c2)) > 0) {
+				System.err.print(new String(Arrays.copyOfRange(c2, 0, n2)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+	}
+
+}

+ 50 - 0
test/protocols/TestUpdateRoot_E.java

@@ -0,0 +1,50 @@
+package protocols;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+
+public class TestUpdateRoot_E {
+
+	public static void main(String[] args) {
+		Runtime runTime = Runtime.getRuntime();
+		Process process = null;
+		String dir = System.getProperty("user.dir");
+		String binDir = dir + "\\bin";
+		String libs = dir + "\\lib\\*";
+		try {
+			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol update eddie");
+
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+		InputStream inputStream = process.getInputStream();
+		InputStreamReader isr = new InputStreamReader(inputStream);
+		InputStream errorStream = process.getErrorStream();
+		InputStreamReader esr = new InputStreamReader(errorStream);
+
+		System.out.println("STANDARD OUTPUT:");
+		int n1;
+		char[] c1 = new char[1024];
+		try {
+			while ((n1 = isr.read(c1)) > 0) {
+				System.out.print(new String(Arrays.copyOfRange(c1, 0, n1)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+
+		System.out.println("STANDARD ERROR:");
+		int n2;
+		char[] c2 = new char[1024];
+		try {
+			while ((n2 = esr.read(c2)) > 0) {
+				System.err.print(new String(Arrays.copyOfRange(c2, 0, n2)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+	}
+
+}