Browse Source

add new PIRAccess impl; fix bug in PIRCOT

Boyang Wei 6 years ago
parent
commit
ee95f24b04

+ 281 - 337
src/pir/PIRAccess.java

@@ -1,27 +1,27 @@
 package pir;
 
 import java.math.BigInteger;
-import java.util.Arrays;
 
 import org.apache.commons.lang3.ArrayUtils;
 
 import communication.Communication;
 import crypto.Crypto;
-import exceptions.AccessException;
 import exceptions.NoSuchPartyException;
 import oram.Bucket;
 import oram.Forest;
-import oram.Global;
 import oram.Metadata;
 import oram.Tree;
 import oram.Tuple;
+import pir.precomputation.PrePIRCOT;
 import protocols.Protocol;
-import protocols.precomputation.PreAccess;
 import protocols.struct.OutAccess;
-import protocols.struct.OutSSCOT;
-import protocols.struct.OutSSIOT;
+import protocols.struct.OutPIRAccess;
+import protocols.struct.OutPIRCOT;
 import protocols.struct.Party;
 import protocols.struct.PreData;
+import protocols.struct.TwoOneXor;
+import protocols.struct.TwoThreeXorByte;
+import protocols.struct.TwoThreeXorInt;
 import util.M;
 import util.P;
 import util.Timer;
@@ -35,284 +35,231 @@ public class PIRAccess extends Protocol {
 		super(con1, con2);
 	}
 
-	public OutAccess runE(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer timer) {
+	public OutPIRAccess runE(Metadata md, PreData predata, Tree tree_DE, Tree tree_CE, byte[] Li, TwoThreeXorByte L,
+			TwoThreeXorByte N, TwoThreeXorInt dN, Timer timer) {
 		timer.start(pid, M.online_comp);
 
-		// step 0: get Li from C
-		byte[] Li = new byte[0];
-		timer.start(pid, M.online_read);
-		if (OTi.getTreeIndex() > 0)
-			Li = con2.read();
-		timer.stop(pid, M.online_read);
+		Bucket[] pathBuckets_DE = tree_DE.getBucketsOnPath(Li);
+		Tuple[] pathTuples_DE = Bucket.bucketsToTuples(pathBuckets_DE);
+		Bucket[] pathBuckets_CE = tree_CE.getBucketsOnPath(Li);
+		Tuple[] pathTuples_CE = Bucket.bucketsToTuples(pathBuckets_CE);
 
-		// step 1
-		Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
-		Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
-		// for (int i = 0; i < pathTuples.length; i++)
-		// pathTuples[i].setXor(predata.access_p[i]);
-		pathTuples = Util.permute(pathTuples, predata.access_sigma);
-
-		// step 3
-		// byte[] y = null;
-		// if (OTi.getTreeIndex() == 0)
-		// y = pathTuples[0].getA();
-		// else if (OTi.getTreeIndex() < OTi.getH() - 1)
-		// y = Util.nextBytes(OTi.getABytes(), Crypto.sr);
-		// else
-		// y = new byte[OTi.getABytes()];
-
-		if (OTi.getTreeIndex() > 0) {
-			byte[][] a = new byte[pathTuples.length][];
-			// byte[][] m = new byte[pathTuples.length][];
-			for (int i = 0; i < pathTuples.length; i++) {
-				// m[i] = Util.xor(pathTuples[i].getA(), y);
-				a[i] = ArrayUtils.addAll(pathTuples[i].getF(), pathTuples[i].getN());
-				for (int j = 0; j < Ni.length; j++)
-					a[i][a[i].length - 1 - j] ^= Ni[Ni.length - 1 - j];
-			}
+		int pathTuples = pathTuples_CE.length;
+		int ttp = md.getTwoTauPow();
 
-			PIRCOT sscot = new PIRCOT(con1, con2);
-			sscot.runE(predata, a, timer);
+		PrePIRCOT preksearch = new PrePIRCOT(con1, con2);
+		preksearch.runE(predata, pathTuples, timer);
 
-			// int j1 = con2.readInt();
-			// y = pathTuples[j1].getA();
+		byte[][] u = new byte[pathTuples][];
+		for (int i = 0; i < pathTuples; i++) {
+			u[i] = ArrayUtils.addAll(pathTuples_CE[i].getF(), pathTuples_CE[i].getN());
 		}
+		byte[] v = ArrayUtils.addAll(new byte[] { 1 }, N.CE);
 
-		// con2.write(y);
-
-		// step 4
-		if (OTi.getTreeIndex() < OTi.getH() - 1) {
-			// int ySegBytes = y.length / OTi.getTwoTauPow();
-			// byte[][] y_array = new byte[OTi.getTwoTauPow()][];
-			// for (int i = 0; i < OTi.getTwoTauPow(); i++)
-			// y_array[i] = Arrays.copyOfRange(y, i * ySegBytes, (i + 1) *
-			// ySegBytes);
+		PIRCOT ksearch = new PIRCOT(con1, con2);
+		OutPIRCOT j = ksearch.runE(predata, u, v, timer);
 
-			PIRIOT ssiot = new PIRIOT(con1, con2);
-			ssiot.runE(predata, OTi.getTwoTauPow(), Nip1_pr, timer);
+		byte[][] x_DE = new byte[pathTuples][];
+		byte[][] x_CE = new byte[pathTuples][];
+		for (int i = 0; i < pathTuples; i++) {
+			x_DE[i] = pathTuples_DE[i].getA();
+			x_CE[i] = pathTuples_CE[i].getA();
 		}
 
-		// PIR
-		int[] j = new int[] { 0, 0 };
-		timer.start(pid, M.online_write);
-		con1.write(pid, j);
-		con2.write(pid, j);
-		timer.stop(pid, M.online_write);
-		timer.start(pid, M.online_read);
-		con1.readIntArray(pid);
-		j = con2.readIntArray(pid);
-		timer.stop(pid, M.online_read);
-
-		byte[] y = null;
-		if (OTi.getTreeIndex() == 0)
-			y = pathTuples[0].getA();
-		else
-			y = pathTuples[j[0]].getA();
+		ThreeShiftPIR threeshiftpir = new ThreeShiftPIR(con1, con2);
+		TwoThreeXorByte X = threeshiftpir.runE(predata, x_DE, x_CE, j, timer);
 
-		timer.start(pid, M.online_write);
-		con1.write(pid, y);
-		con2.write(pid, y);
-		timer.stop(pid, M.online_write);
-		timer.start(pid, M.online_read);
-		con1.read(pid);
-		con2.read(pid);
-		timer.stop(pid, M.online_read);
+		TwoOneXor dN21 = new TwoOneXor();
+		dN21.t_E = dN.CE ^ dN.DE;
+		dN21.s_CE = dN.CE;
+		dN21.s_DE = dN.DE;
 
-		// step 5
-		Tuple Ti = null;
-		if (OTi.getTreeIndex() == 0)
-			Ti = pathTuples[0];
-		else
-			Ti = new Tuple(new byte[1], Ni, Li, y);
+		ThreeShiftXorPIR threeshiftxorpir = new ThreeShiftXorPIR(con1, con2);
+		TwoThreeXorByte nextL = threeshiftxorpir.runE(predata, x_DE, x_CE, j, dN21, ttp, timer);
+		byte[] Lip1 = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
 
-		OutAccess outaccess = new OutAccess(Li, null, null, null, null, Ti, pathTuples);
+		OutPIRAccess out = new OutPIRAccess(null, pathTuples_CE, pathTuples_DE, j, X, nextL, Lip1);
 
 		timer.stop(pid, M.online_comp);
-		return outaccess;
+		return out;
 	}
 
-	public byte[] runD(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer timer) {
+	public OutPIRAccess runD(Metadata md, PreData predata, Tree tree_DE, Tree tree_CD, byte[] Li, TwoThreeXorByte L,
+			TwoThreeXorByte N, TwoThreeXorInt dN, Timer timer) {
 		timer.start(pid, M.online_comp);
 
-		// step 0: get Li from C
-		byte[] Li = new byte[0];
-		timer.start(pid, M.online_read);
-		if (OTi.getTreeIndex() > 0)
-			Li = con2.read();
-		timer.stop(pid, M.online_read);
-
-		// step 1
-		Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
-		Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
-		// for (int i = 0; i < pathTuples.length; i++)
-		// pathTuples[i].setXor(predata.access_p[i]);
-		pathTuples = Util.permute(pathTuples, predata.access_sigma);
+		Bucket[] pathBuckets_DE = tree_DE.getBucketsOnPath(Li);
+		Tuple[] pathTuples_DE = Bucket.bucketsToTuples(pathBuckets_DE);
+		Bucket[] pathBuckets_CD = tree_CD.getBucketsOnPath(Li);
+		Tuple[] pathTuples_CD = Bucket.bucketsToTuples(pathBuckets_CD);
 
-		// step 2
-		timer.start(pid, M.online_write);
-		// con2.write(pid, pathTuples);
-		con2.write(pid, Ni);
-		timer.stop(pid, M.online_write);
+		int pathTuples = pathTuples_CD.length;
+		int ttp = md.getTwoTauPow();
 
-		// step 3
-		if (OTi.getTreeIndex() > 0) {
-			byte[][] b = new byte[pathTuples.length][];
-			for (int i = 0; i < pathTuples.length; i++) {
-				b[i] = ArrayUtils.addAll(pathTuples[i].getF(), pathTuples[i].getN());
-				b[i][0] ^= 1;
-				for (int j = 0; j < Ni.length; j++)
-					b[i][b[i].length - 1 - j] ^= Ni[Ni.length - 1 - j];
-			}
+		PrePIRCOT preksearch = new PrePIRCOT(con1, con2);
+		preksearch.runD(predata, pathTuples, timer);
 
-			PIRCOT sscot = new PIRCOT(con1, con2);
-			sscot.runD(predata, b, timer);
+		byte[][] u = new byte[pathTuples][];
+		for (int i = 0; i < pathTuples; i++) {
+			u[i] = ArrayUtils.addAll(pathTuples_DE[i].getF(), pathTuples_DE[i].getN());
+			Util.setXor(u[i], ArrayUtils.addAll(pathTuples_CD[i].getF(), pathTuples_CD[i].getN()));
 		}
+		byte[] v = ArrayUtils.addAll(new byte[] { 1 }, N.CE);
+		Util.setXor(v, ArrayUtils.addAll(new byte[] { 1 }, N.CD));
+
+		PIRCOT ksearch = new PIRCOT(con1, con2);
+		OutPIRCOT j = ksearch.runD(predata, u, v, timer);
 
-		// step 4
-		if (OTi.getTreeIndex() < OTi.getH() - 1) {
-			PIRIOT ssiot = new PIRIOT(con1, con2);
-			ssiot.runD(predata, Nip1_pr, timer);
+		byte[][] x_DE = new byte[pathTuples][];
+		byte[][] x_CD = new byte[pathTuples][];
+		for (int i = 0; i < pathTuples; i++) {
+			x_DE[i] = pathTuples_DE[i].getA();
+			x_CD[i] = pathTuples_CD[i].getA();
 		}
 
-		// PIR
-		int[] j = new int[] { 0, 0 };
-		timer.start(pid, M.online_write);
-		con1.write(pid, j);
-		con2.write(pid, j);
-		timer.stop(pid, M.online_write);
-		timer.start(pid, M.online_read);
-		con1.readIntArray(pid);
-		j = con2.readIntArray(pid);
-		timer.stop(pid, M.online_read);
+		ThreeShiftPIR threeshiftpir = new ThreeShiftPIR(con1, con2);
+		TwoThreeXorByte X = threeshiftpir.runD(predata, x_DE, x_CD, j, timer);
 
-		byte[] A = null;
-		if (OTi.getTreeIndex() == 0)
-			A = pathTuples[0].getA();
-		else
-			A = pathTuples[j[0]].getA();
-		timer.start(pid, M.online_write);
-		con1.write(pid, A);
-		con2.write(pid, A);
-		timer.stop(pid, M.online_write);
-		timer.start(pid, M.online_read);
-		con1.read(pid);
-		con2.read(pid);
-		timer.stop(pid, M.online_read);
+		TwoOneXor dN21 = new TwoOneXor();
+		dN21.t_D = dN.CD ^ dN.DE;
+		dN21.s_CD = dN.CD;
+		dN21.s_DE = dN.DE;
 
-		timer.stop(pid, M.online_comp);
+		ThreeShiftXorPIR threeshiftxorpir = new ThreeShiftXorPIR(con1, con2);
+		TwoThreeXorByte nextL = threeshiftxorpir.runD(predata, x_DE, x_CD, j, dN21, ttp, timer);
+		byte[] Lip1 = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
 
-		return Li;
+		OutPIRAccess out = new OutPIRAccess(pathTuples_CD, null, pathTuples_DE, j, X, nextL, Lip1);
+
+		timer.stop(pid, M.online_comp);
+		return out;
 	}
 
-	public OutAccess runC(Metadata md, Tree OTi, int treeIndex, byte[] Li, Timer timer) {
+	public OutPIRAccess runC(Metadata md, PreData predata, Tree tree_CD, Tree tree_CE, byte[] Li, TwoThreeXorByte L,
+			TwoThreeXorByte N, TwoThreeXorInt dN, Timer timer) {
 		timer.start(pid, M.online_comp);
 
-		// step 0: send Li to E and D
-		timer.start(pid, M.online_write);
-		if (treeIndex > 0) {
-			con1.write(Li);
-			con2.write(Li);
-		}
-		timer.stop(pid, M.online_write);
+		Bucket[] pathBuckets_CD = tree_CD.getBucketsOnPath(Li);
+		Tuple[] pathTuples_CD = Bucket.bucketsToTuples(pathBuckets_CD);
+		Bucket[] pathBuckets_CE = tree_CE.getBucketsOnPath(Li);
+		Tuple[] pathTuples_CE = Bucket.bucketsToTuples(pathBuckets_CE);
 
-		// step 1
-		Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
-		Tuple[] originalTuples = Bucket.bucketsToTuples(pathBuckets);
-		Tuple[] pathTuples = new Tuple[originalTuples.length];
-		for (int i = 0; i < pathTuples.length; i++)
-			pathTuples[i] = new Tuple(originalTuples[i]);
+		int pathTuples = pathTuples_CE.length;
+		int ttp = md.getTwoTauPow();
 
-		// step 2
-		timer.start(pid, M.online_read);
-		// Tuple[] pathTuples = con2.readTupleArray(pid);
-		byte[] Ni = con2.read(pid);
-		timer.stop(pid, M.online_read);
+		PrePIRCOT preksearch = new PrePIRCOT(con1, con2);
+		preksearch.runC(predata, timer);
 
-		// for (int i = 0; i < pathTuples.length; i++)
-		// System.out.println("AAAAA " + pathTuples[i]);
+		PIRCOT ksearch = new PIRCOT(con1, con2);
+		OutPIRCOT j = ksearch.runC(predata, timer);
 
-		// step 3
-		int j1 = 0;
-		byte[] z = null;
-		if (treeIndex == 0) {
-			z = pathTuples[0].getA().clone();
-		} else {
-			PIRCOT sscot = new PIRCOT(con1, con2);
-			OutSSCOT je = sscot.runC(timer);
-			j1 = je.t;
-			byte[] d = pathTuples[j1].getA().clone();
-			// z = Util.xor(je.m_t, d);
-			z = d;
-
-			// con1.write(j1);
+		byte[][] x_CE = new byte[pathTuples][];
+		byte[][] x_CD = new byte[pathTuples][];
+		for (int i = 0; i < pathTuples; i++) {
+			x_CE[i] = pathTuples_CE[i].getA();
+			x_CD[i] = pathTuples_CD[i].getA();
 		}
 
-		// byte[] y = con1.read();
-
-		// step 4
-		int j2 = 0;
-		byte[] Lip1 = null;
-		if (treeIndex < md.getNumTrees() - 1) {
-			PIRIOT ssiot = new PIRIOT(con1, con2);
-			OutSSIOT jy = ssiot.runC(timer);
-
-			// step 5
-			j2 = jy.t;
-			// int lSegBytes = md.getABytesOfTree(treeIndex) /
-			// md.getTwoTauPow();
-			// byte[] z_j2 = Arrays.copyOfRange(z, j2 * lSegBytes, (j2 + 1) *
-			// lSegBytes);
-			// Lip1 = Util.xor(jy.m_t, z_j2);
-
-			// Lip1 = Arrays.copyOfRange(Util.xor(z, y), j2 * lSegBytes, (j2 +
-			// 1) * lSegBytes);
-		}
+		ThreeShiftPIR threeshiftpir = new ThreeShiftPIR(con1, con2);
+		TwoThreeXorByte X = threeshiftpir.runC(predata, x_CD, x_CE, j, timer);
 
-		// PIR
-		int[] j = new int[] { j1, j2 };
-		timer.start(pid, M.online_write);
-		con1.write(pid, j);
-		con2.write(pid, j);
-		timer.stop(pid, M.online_write);
-		timer.start(pid, M.online_read);
-		con1.readIntArray(pid);
-		con2.readIntArray(pid);
-		timer.stop(pid, M.online_read);
+		TwoOneXor dN21 = new TwoOneXor();
+		dN21.t_C = dN.CD ^ dN.CE;
+		dN21.s_CD = dN.CD;
+		dN21.s_CE = dN.CE;
 
-		int ABytes = md.getABytesOfTree(treeIndex);
-		byte[] A = new byte[ABytes];
-		timer.start(pid, M.online_write);
-		con1.write(pid, A);
-		con2.write(pid, A);
-		timer.stop(pid, M.online_write);
-		timer.start(pid, M.online_read);
-		byte[] y = con1.read(pid);
-		con2.read(pid);
-		timer.stop(pid, M.online_read);
+		ThreeShiftXorPIR threeshiftxorpir = new ThreeShiftXorPIR(con1, con2);
+		TwoThreeXorByte nextL = threeshiftxorpir.runC(predata, x_CD, x_CE, j, dN21, ttp, timer);
+		byte[] Lip1 = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
 
-		if (treeIndex < md.getNumTrees() - 1) {
-			int lSegBytes = ABytes / md.getTwoTauPow();
-			Lip1 = Arrays.copyOfRange(Util.xor(z, y), j2 * lSegBytes, (j2 + 1) * lSegBytes);
-		}
+		OutPIRAccess out = new OutPIRAccess(pathTuples_CD, pathTuples_CE, null, j, X, nextL, Lip1);
 
-		Tuple Ti = null;
-		if (treeIndex == 0) {
-			Ti = pathTuples[0];
-		} else {
-			Ti = new Tuple(new byte[] { 1 }, Ni, new byte[md.getLBytesOfTree(treeIndex)], z);
+		timer.stop(pid, M.online_comp);
+		return out;
+	}
 
-			pathTuples[j1].getF()[0] = (byte) (1 - pathTuples[j1].getF()[0]);
-			Crypto.sr.nextBytes(pathTuples[j1].getN());
-			Crypto.sr.nextBytes(pathTuples[j1].getL());
-			Crypto.sr.nextBytes(pathTuples[j1].getA());
-		}
+	@Override
+	public void run(Party party, Metadata md, Forest[] forest) {
 
-		OutAccess outaccess = new OutAccess(Li, Lip1, Ti, pathTuples, j2, null, null);
+		Timer timer = new Timer();
+		PreData predata = new PreData();
+
+		Tree tree_CD = null;
+		Tree tree_DE = null;
+		Tree tree_CE = null;
+
+		int treeIndex = 2;
+
+		for (int j = 0; j < 100; j++) {
+			if (party == Party.Eddie) {
+				tree_DE = forest[0].getTree(treeIndex);
+				tree_CE = forest[1].getTree(treeIndex);
+			} else if (party == Party.Debbie) {
+				tree_DE = forest[0].getTree(treeIndex);
+				tree_CD = forest[1].getTree(treeIndex);
+			} else if (party == Party.Charlie) {
+				tree_CE = forest[0].getTree(treeIndex);
+				tree_CD = forest[1].getTree(treeIndex);
+			} else {
+				assert false;
+			}
 
-		timer.stop(pid, M.online_comp);
-		return outaccess;
+			int Llen = md.getLBytesOfTree(treeIndex);
+			int Nlen = md.getNBytesOfTree(treeIndex);
+
+			TwoThreeXorInt dN = new TwoThreeXorInt();
+
+			TwoThreeXorByte N = new TwoThreeXorByte();
+			N.CD = new byte[Nlen];
+			N.DE = new byte[Nlen];
+			N.CE = new byte[Nlen];
+			TwoThreeXorByte L = new TwoThreeXorByte();
+			L.CD = new byte[Llen];
+			L.DE = new byte[Llen];
+			L.CE = new byte[Llen];
+			byte[] Li = new byte[Llen];
+
+			if (party == Party.Eddie) {
+				OutPIRAccess out = this.runE(md, predata, tree_DE, tree_CE, Li, L, N, dN, timer);
+				out.j.t_D = con1.readInt();
+				out.j.t_C = con2.readInt();
+				out.X.CD = con1.read();
+				int pathTuples = out.pathTuples_CE.length;
+				int index = (out.j.t_D + out.j.s_CE) % pathTuples;
+				byte[] X = Util.xor(Util.xor(out.X.DE, out.X.CE), out.X.CD);
+
+				boolean fail = false;
+				if (index != 0) {
+					System.err.println(j + ": PIRAcc test failed on KSearch index");
+					fail = true;
+				}
+				if (new BigInteger(1, X).intValue() != 0) {
+					System.err.println(j + ": PIRAcc test failed on 3ShiftPIR X");
+					fail = true;
+				}
+				if (new BigInteger(1, out.Lip1).intValue() != 0) {
+					System.err.println(j + ": PIRAcc test failed on 3ShiftXorPIR Lip1");
+					fail = true;
+				}
+				if (!fail)
+					System.out.println(j + ": PIRAcc test passed");
+
+			} else if (party == Party.Debbie) {
+				OutPIRAccess out = this.runD(md, predata, tree_DE, tree_CD, Li, L, N, dN, timer);
+				con1.write(out.j.t_D);
+				con1.write(out.X.CD);
+
+			} else if (party == Party.Charlie) {
+				OutPIRAccess out = this.runC(md, predata, tree_CD, tree_CE, Li, L, N, dN, timer);
+				con1.write(out.j.t_C);
+
+			} else {
+				throw new NoSuchPartyException(party + "");
+			}
+		}
 	}
 
+	// on second path
 	public OutAccess runE2(Tree OTi, Timer timer) {
 		timer.start(pid, M.online_comp);
 
@@ -404,106 +351,103 @@ public class PIRAccess extends Protocol {
 	// for testing correctness
 	@Override
 	public void run(Party party, Metadata md, Forest forest) {
-		int records = 5;
-		int repeat = 5;
-
-		int tau = md.getTau();
-		int numTrees = md.getNumTrees();
-		long numInsert = md.getNumInsertRecords();
-		int addrBits = md.getAddrBits();
-
-		Timer timer = new Timer();
-
-		sanityCheck();
-
-		System.out.println();
-
-		for (int i = 0; i < records; i++) {
-			long N = Global.cheat ? 0 : Util.nextLong(numInsert, Crypto.sr);
-
-			for (int j = 0; j < repeat; j++) {
-				System.out.println("Test: " + i + " " + j);
-				System.out.println("N=" + BigInteger.valueOf(N).toString(2));
-
-				byte[] Li = new byte[0];
-
-				for (int ti = 0; ti < numTrees; ti++) {
-					long Ni_value = Util.getSubBits(N, addrBits, addrBits - md.getNBitsOfTree(ti));
-					long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
-							Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
-					byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
-					byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
-
-					PreData predata = new PreData();
-					PreAccess preaccess = new PreAccess(con1, con2);
-
-					if (party == Party.Eddie) {
-						Tree OTi = forest.getTree(ti);
-						int numTuples = (OTi.getD() - 1) * OTi.getW() + OTi.getStashSize();
-						int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
-								md.getABytesOfTree(ti) };
-						preaccess.runE(predata, md.getTwoTauPow(), numTuples, tupleParam, timer);
-
-						byte[] sE_Ni = Util.nextBytes(Ni.length, Crypto.sr);
-						byte[] sD_Ni = Util.xor(Ni, sE_Ni);
-						con1.write(sD_Ni);
-
-						byte[] sE_Nip1_pr = Util.nextBytes(Nip1_pr.length, Crypto.sr);
-						byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
-						con1.write(sD_Nip1_pr);
-
-						OutAccess outaccess = runE(predata, OTi, sE_Ni, sE_Nip1_pr, timer);
-
-						if (ti == numTrees - 1) {
-							con2.write(N);
-							con2.write(outaccess.E_Ti);
-						}
-
-					} else if (party == Party.Debbie) {
-						Tree OTi = forest.getTree(ti);
-						preaccess.runD(predata, timer);
-
-						byte[] sD_Ni = con1.read();
-
-						byte[] sD_Nip1_pr = con1.read();
-
-						runD(predata, OTi, sD_Ni, sD_Nip1_pr, timer);
-
-					} else if (party == Party.Charlie) {
-						Tree OTi = forest.getTree(ti);
-						preaccess.runC(timer);
-
-						System.out.println("L" + ti + "=" + new BigInteger(1, Li).toString(2));
-
-						OutAccess outaccess = runC(md, OTi, ti, Li, timer);
-
-						Li = outaccess.C_Lip1;
-
-						if (ti == numTrees - 1) {
-							N = con1.readLong();
-							Tuple E_Ti = con1.readTuple();
-							long data = new BigInteger(1, Util.xor(outaccess.C_Ti.getA(), E_Ti.getA())).longValue();
-							if (N == data) {
-								System.out.println("PIR Access passed");
-								System.out.println();
-							} else {
-								throw new AccessException("PIR Access failed: " + N + " != " + data);
-							}
-						}
-
-					} else {
-						throw new NoSuchPartyException(party + "");
-					}
-				}
-			}
-		}
-
-		// timer.print();
-	}
-
-	@Override
-	public void run(Party party, Metadata md, Forest[] forest) {
-		// TODO Auto-generated method stub
-
+		// int records = 5;
+		// int repeat = 5;
+		//
+		// int tau = md.getTau();
+		// int numTrees = md.getNumTrees();
+		// long numInsert = md.getNumInsertRecords();
+		// int addrBits = md.getAddrBits();
+		//
+		// Timer timer = new Timer();
+		//
+		// sanityCheck();
+		//
+		// System.out.println();
+		//
+		// for (int i = 0; i < records; i++) {
+		// long N = Global.cheat ? 0 : Util.nextLong(numInsert, Crypto.sr);
+		//
+		// for (int j = 0; j < repeat; j++) {
+		// System.out.println("Test: " + i + " " + j);
+		// System.out.println("N=" + BigInteger.valueOf(N).toString(2));
+		//
+		// byte[] Li = new byte[0];
+		//
+		// for (int ti = 0; ti < numTrees; ti++) {
+		// long Ni_value = Util.getSubBits(N, addrBits, addrBits -
+		// md.getNBitsOfTree(ti));
+		// long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
+		// Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
+		// byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
+		// byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
+		//
+		// PreData predata = new PreData();
+		// PreAccess preaccess = new PreAccess(con1, con2);
+		//
+		// if (party == Party.Eddie) {
+		// Tree OTi = forest.getTree(ti);
+		// int numTuples = (OTi.getD() - 1) * OTi.getW() + OTi.getStashSize();
+		// int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti),
+		// md.getLBytesOfTree(ti),
+		// md.getABytesOfTree(ti) };
+		// preaccess.runE(predata, md.getTwoTauPow(), numTuples, tupleParam, timer);
+		//
+		// byte[] sE_Ni = Util.nextBytes(Ni.length, Crypto.sr);
+		// byte[] sD_Ni = Util.xor(Ni, sE_Ni);
+		// con1.write(sD_Ni);
+		//
+		// byte[] sE_Nip1_pr = Util.nextBytes(Nip1_pr.length, Crypto.sr);
+		// byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
+		// con1.write(sD_Nip1_pr);
+		//
+		// OutAccess outaccess = runE(predata, OTi, sE_Ni, sE_Nip1_pr, timer);
+		//
+		// if (ti == numTrees - 1) {
+		// con2.write(N);
+		// con2.write(outaccess.E_Ti);
+		// }
+		//
+		// } else if (party == Party.Debbie) {
+		// Tree OTi = forest.getTree(ti);
+		// preaccess.runD(predata, timer);
+		//
+		// byte[] sD_Ni = con1.read();
+		//
+		// byte[] sD_Nip1_pr = con1.read();
+		//
+		// runD(predata, OTi, sD_Ni, sD_Nip1_pr, timer);
+		//
+		// } else if (party == Party.Charlie) {
+		// Tree OTi = forest.getTree(ti);
+		// preaccess.runC(timer);
+		//
+		// System.out.println("L" + ti + "=" + new BigInteger(1, Li).toString(2));
+		//
+		// OutAccess outaccess = runC(md, OTi, ti, Li, timer);
+		//
+		// Li = outaccess.C_Lip1;
+		//
+		// if (ti == numTrees - 1) {
+		// N = con1.readLong();
+		// Tuple E_Ti = con1.readTuple();
+		// long data = new BigInteger(1, Util.xor(outaccess.C_Ti.getA(),
+		// E_Ti.getA())).longValue();
+		// if (N == data) {
+		// System.out.println("PIR Access passed");
+		// System.out.println();
+		// } else {
+		// throw new AccessException("PIR Access failed: " + N + " != " + data);
+		// }
+		// }
+		//
+		// } else {
+		// throw new NoSuchPartyException(party + "");
+		// }
+		// }
+		// }
+		// }
+		//
+		// // timer.print();
 	}
 }

+ 1 - 1
src/pir/PIRCOT.java

@@ -62,7 +62,7 @@ public class PIRCOT extends Protocol {
 		int l = u.length;
 		byte[][] a = new byte[l][];
 		for (int j = 0; j < l; j++) {
-			a[j] = Util.xor(u[(j + l - predata.sscot_s_DE) % l], v);
+			a[j] = Util.xor(u[(j + predata.sscot_s_DE) % l], v);
 			a[j] = Util.padArray(a[j], predata.sscot_r[j].length);
 			Util.setXor(a[j], predata.sscot_r[j]);
 			a[j] = predata.sscot_F_k.compute(a[j]);

+ 50 - 15
src/pir/ThreeShiftPIR.java

@@ -9,6 +9,7 @@ import protocols.Protocol;
 import protocols.struct.OutPIRCOT;
 import protocols.struct.Party;
 import protocols.struct.PreData;
+import protocols.struct.TwoThreeXorByte;
 import util.M;
 import util.P;
 import util.Timer;
@@ -22,7 +23,7 @@ public class ThreeShiftPIR extends Protocol {
 		super(con1, con2);
 	}
 
-	public byte[] runE(PreData predata, byte[][] x_DE, byte[][] x_CE, OutPIRCOT i, Timer timer) {
+	public TwoThreeXorByte runE(PreData predata, byte[][] x_DE, byte[][] x_CE, OutPIRCOT i, Timer timer) {
 		timer.start(pid, M.online_comp);
 
 		ShiftPIR sftpir = new ShiftPIR(con1, con2);
@@ -33,11 +34,22 @@ public class ThreeShiftPIR extends Protocol {
 		sftpir.runP3(predata, i.t_E, timer);
 		Util.setXor(e1, e2);
 
+		TwoThreeXorByte X = new TwoThreeXorByte();
+		X.DE = e1;
+
+		timer.start(pid, M.online_write);
+		con1.write(pid, X.DE);
+		timer.stop(pid, M.online_write);
+
+		timer.start(pid, M.online_read);
+		X.CE = con2.read(pid);
+		timer.stop(pid, M.online_read);
+
 		timer.stop(pid, M.online_comp);
-		return e1;
+		return X;
 	}
 
-	public byte[] runD(PreData predata, byte[][] x_DE, byte[][] x_CD, OutPIRCOT i, Timer timer) {
+	public TwoThreeXorByte runD(PreData predata, byte[][] x_DE, byte[][] x_CD, OutPIRCOT i, Timer timer) {
 		timer.start(pid, M.online_comp);
 
 		ShiftPIR sftpir = new ShiftPIR(con1, con2);
@@ -48,11 +60,22 @@ public class ThreeShiftPIR extends Protocol {
 		byte[] d2 = sftpir.runP1(predata, x_CD, i.s_CD, timer);
 		Util.setXor(d1, d2);
 
+		TwoThreeXorByte X = new TwoThreeXorByte();
+		X.CD = d1;
+
+		timer.start(pid, M.online_write);
+		con2.write(pid, X.CD);
+		timer.stop(pid, M.online_write);
+
+		timer.start(pid, M.online_read);
+		X.DE = con1.read(pid);
+		timer.stop(pid, M.online_read);
+
 		timer.stop(pid, M.online_comp);
-		return d1;
+		return X;
 	}
 
-	public byte[] runC(PreData predata, byte[][] x_CD, byte[][] x_CE, OutPIRCOT i, Timer timer) {
+	public TwoThreeXorByte runC(PreData predata, byte[][] x_CD, byte[][] x_CE, OutPIRCOT i, Timer timer) {
 		timer.start(pid, M.online_comp);
 
 		ShiftPIR sftpir = new ShiftPIR(con1, con2);
@@ -63,8 +86,19 @@ public class ThreeShiftPIR extends Protocol {
 		byte[] c2 = sftpir.runP2(predata, x_CD, i.s_CD, timer);
 		Util.setXor(c1, c2);
 
+		TwoThreeXorByte X = new TwoThreeXorByte();
+		X.CE = c1;
+
+		timer.start(pid, M.online_write);
+		con1.write(pid, X.CE);
+		timer.stop(pid, M.online_write);
+
+		timer.start(pid, M.online_read);
+		X.CD = con2.read(pid);
+		timer.stop(pid, M.online_read);
+
 		timer.stop(pid, M.online_comp);
-		return c1;
+		return X;
 	}
 
 	@Override
@@ -93,6 +127,8 @@ public class ThreeShiftPIR extends Protocol {
 			ks.s_CE = (index - ks.t_D + l) % l;
 			ks.s_CD = (index - ks.t_E + l) % l;
 
+			TwoThreeXorByte X = new TwoThreeXorByte();
+
 			if (party == Party.Eddie) {
 				con1.write(x_CD);
 				con1.write(x_DE);
@@ -105,11 +141,11 @@ public class ThreeShiftPIR extends Protocol {
 				con2.write(ks.s_CE);
 				con2.write(ks.s_CD);
 
-				byte[] e = this.runE(predata, x_DE, x_CE, ks, timer);
-				byte[] d = con1.read();
-				byte[] c = con2.read();
-				Util.setXor(e, d);
-				Util.setXor(e, c);
+				X = this.runE(predata, x_DE, x_CE, ks, timer);
+				X.CD = con1.read();
+				byte[] e = X.CE;
+				Util.setXor(e, X.CD);
+				Util.setXor(e, X.DE);
 				byte[] x = x_DE[index];
 				Util.setXor(x, x_CE[index]);
 				Util.setXor(x, x_CD[index]);
@@ -126,8 +162,8 @@ public class ThreeShiftPIR extends Protocol {
 				ks.s_DE = con1.readInt();
 				ks.s_CD = con1.readInt();
 
-				byte[] d = this.runD(predata, x_DE, x_CD, ks, timer);
-				con1.write(d);
+				X = this.runD(predata, x_DE, x_CD, ks, timer);
+				con1.write(X.CD);
 
 			} else if (party == Party.Charlie) {
 				x_CD = con1.readDoubleByteArray();
@@ -136,8 +172,7 @@ public class ThreeShiftPIR extends Protocol {
 				ks.s_CE = con1.readInt();
 				ks.s_CD = con1.readInt();
 
-				byte[] c = this.runC(predata, x_CD, x_CE, ks, timer);
-				con1.write(c);
+				this.runC(predata, x_CD, x_CE, ks, timer);
 
 			} else {
 				throw new NoSuchPartyException(party + "");

+ 37 - 24
src/pir/ThreeShiftXorPIR.java

@@ -12,6 +12,7 @@ import protocols.struct.OutPIRCOT;
 import protocols.struct.Party;
 import protocols.struct.PreData;
 import protocols.struct.TwoOneXor;
+import protocols.struct.TwoThreeXorByte;
 import util.M;
 import util.P;
 import util.Timer;
@@ -25,15 +26,16 @@ public class ThreeShiftXorPIR extends Protocol {
 		super(con1, con2);
 	}
 
-	public byte[] runE(PreData predata, byte[][] x_DE, byte[][] x_CE, OutPIRCOT i, TwoOneXor j, int m, Timer timer) {
+	public TwoThreeXorByte runE(PreData predata, byte[][] x_DE, byte[][] x_CE, OutPIRCOT i, TwoOneXor dN, int ttp,
+			Timer timer) {
 		timer.start(pid, M.online_comp);
 
 		ShiftXorPIR sftpir = new ShiftXorPIR(con1, con2);
-		byte[] e1 = sftpir.runP1(predata, x_DE, i.s_DE, j.s_DE, m, timer);
+		byte[] e1 = sftpir.runP1(predata, x_DE, i.s_DE, dN.s_DE, ttp, timer);
 		sftpir = new ShiftXorPIR(con2, con1);
-		byte[] e2 = sftpir.runP2(predata, x_CE, i.s_CE, j.s_CE, m, timer);
+		byte[] e2 = sftpir.runP2(predata, x_CE, i.s_CE, dN.s_CE, ttp, timer);
 		sftpir = new ShiftXorPIR(con1, con2);
-		sftpir.runP3(predata, i.t_E, j.t_E, m, timer);
+		sftpir.runP3(predata, i.t_E, dN.t_E, ttp, timer);
 		Util.setXor(e1, e2);
 
 		timer.start(pid, M.online_write);
@@ -46,22 +48,25 @@ public class ThreeShiftXorPIR extends Protocol {
 		byte[] c = con2.read(pid);
 		timer.stop(pid, M.online_read);
 
-		Util.setXor(e1, d);
-		Util.setXor(e1, c);
+		TwoThreeXorByte nextL = new TwoThreeXorByte();
+		nextL.DE = e1;
+		nextL.CD = d;
+		nextL.CE = c;
 
 		timer.stop(pid, M.online_comp);
-		return e1;
+		return nextL;
 	}
 
-	public byte[] runD(PreData predata, byte[][] x_DE, byte[][] x_CD, OutPIRCOT i, TwoOneXor j, int m, Timer timer) {
+	public TwoThreeXorByte runD(PreData predata, byte[][] x_DE, byte[][] x_CD, OutPIRCOT i, TwoOneXor dN, int ttp,
+			Timer timer) {
 		timer.start(pid, M.online_comp);
 
 		ShiftXorPIR sftpir = new ShiftXorPIR(con1, con2);
-		byte[] d1 = sftpir.runP2(predata, x_DE, i.s_DE, j.s_DE, m, timer);
+		byte[] d1 = sftpir.runP2(predata, x_DE, i.s_DE, dN.s_DE, ttp, timer);
 		sftpir = new ShiftXorPIR(con2, con1);
-		sftpir.runP3(predata, i.t_D, j.t_D, m, timer);
+		sftpir.runP3(predata, i.t_D, dN.t_D, ttp, timer);
 		sftpir = new ShiftXorPIR(con2, con1);
-		byte[] d2 = sftpir.runP1(predata, x_CD, i.s_CD, j.s_CD, m, timer);
+		byte[] d2 = sftpir.runP1(predata, x_CD, i.s_CD, dN.s_CD, ttp, timer);
 		Util.setXor(d1, d2);
 
 		timer.start(pid, M.online_write);
@@ -74,22 +79,25 @@ public class ThreeShiftXorPIR extends Protocol {
 		byte[] c = con2.read(pid);
 		timer.stop(pid, M.online_read);
 
-		Util.setXor(d1, e);
-		Util.setXor(d1, c);
+		TwoThreeXorByte nextL = new TwoThreeXorByte();
+		nextL.DE = e;
+		nextL.CD = d1;
+		nextL.CE = c;
 
 		timer.stop(pid, M.online_comp);
-		return d1;
+		return nextL;
 	}
 
-	public byte[] runC(PreData predata, byte[][] x_CD, byte[][] x_CE, OutPIRCOT i, TwoOneXor j, int m, Timer timer) {
+	public TwoThreeXorByte runC(PreData predata, byte[][] x_CD, byte[][] x_CE, OutPIRCOT i, TwoOneXor dN, int ttp,
+			Timer timer) {
 		timer.start(pid, M.online_comp);
 
 		ShiftXorPIR sftpir = new ShiftXorPIR(con1, con2);
-		sftpir.runP3(predata, i.t_C, j.t_C, m, timer);
+		sftpir.runP3(predata, i.t_C, dN.t_C, ttp, timer);
 		sftpir = new ShiftXorPIR(con1, con2);
-		byte[] c1 = sftpir.runP1(predata, x_CE, i.s_CE, j.s_CE, m, timer);
+		byte[] c1 = sftpir.runP1(predata, x_CE, i.s_CE, dN.s_CE, ttp, timer);
 		sftpir = new ShiftXorPIR(con2, con1);
-		byte[] c2 = sftpir.runP2(predata, x_CD, i.s_CD, j.s_CD, m, timer);
+		byte[] c2 = sftpir.runP2(predata, x_CD, i.s_CD, dN.s_CD, ttp, timer);
 		Util.setXor(c1, c2);
 
 		timer.start(pid, M.online_write);
@@ -102,11 +110,13 @@ public class ThreeShiftXorPIR extends Protocol {
 		byte[] d = con2.read(pid);
 		timer.stop(pid, M.online_read);
 
-		Util.setXor(c1, e);
-		Util.setXor(c1, d);
+		TwoThreeXorByte nextL = new TwoThreeXorByte();
+		nextL.DE = e;
+		nextL.CD = d;
+		nextL.CE = c1;
 
 		timer.stop(pid, M.online_comp);
-		return c1;
+		return nextL;
 	}
 
 	@Override
@@ -163,7 +173,8 @@ public class ThreeShiftXorPIR extends Protocol {
 				con2.write(tox.s_CE);
 				con2.write(tox.s_CD);
 
-				byte[] e = this.runE(predata, x_DE, x_CE, ks, tox, m, timer);
+				TwoThreeXorByte nextL = this.runE(predata, x_DE, x_CE, ks, tox, m, timer);
+				byte[] e = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
 				byte[] d = con1.read();
 				byte[] c = con2.read();
 
@@ -187,7 +198,8 @@ public class ThreeShiftXorPIR extends Protocol {
 				tox.s_DE = con1.readInt();
 				tox.s_CD = con1.readInt();
 
-				byte[] d = this.runD(predata, x_DE, x_CD, ks, tox, m, timer);
+				TwoThreeXorByte nextL = this.runD(predata, x_DE, x_CD, ks, tox, m, timer);
+				byte[] d = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
 				con1.write(d);
 
 			} else if (party == Party.Charlie) {
@@ -200,7 +212,8 @@ public class ThreeShiftXorPIR extends Protocol {
 				tox.s_CE = con1.readInt();
 				tox.s_CD = con1.readInt();
 
-				byte[] c = this.runC(predata, x_CD, x_CE, ks, tox, m, timer);
+				TwoThreeXorByte nextL = this.runC(predata, x_CD, x_CE, ks, tox, m, timer);
+				byte[] c = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
 				con1.write(c);
 
 			} else {

+ 24 - 0
src/protocols/struct/OutPIRAccess.java

@@ -0,0 +1,24 @@
+package protocols.struct;
+
+import oram.Tuple;
+
+public class OutPIRAccess {
+	public Tuple[] pathTuples_CD;
+	public Tuple[] pathTuples_CE;
+	public Tuple[] pathTuples_DE;
+	public OutPIRCOT j;
+	public TwoThreeXorByte X;
+	public TwoThreeXorByte nextL;
+	public byte[] Lip1;
+
+	public OutPIRAccess(Tuple[] pathTuples_CD, Tuple[] pathTuples_CE, Tuple[] pathTuples_DE, OutPIRCOT j,
+			TwoThreeXorByte X, TwoThreeXorByte nextL, byte[] Lip1) {
+		this.pathTuples_CD = pathTuples_CD;
+		this.pathTuples_CE = pathTuples_CE;
+		this.pathTuples_DE = pathTuples_DE;
+		this.j = j;
+		this.X = X;
+		this.nextL = nextL;
+		this.Lip1 = Lip1;
+	}
+}