ソースを参照

PIRAccess revised to work for all levels

Boyang Wei 6 年 前
コミット
8103b9dfe0
1 ファイル変更168 行追加124 行削除
  1. 168 124
      src/pir/PIRAccess.java

+ 168 - 124
src/pir/PIRAccess.java

@@ -46,18 +46,9 @@ public class PIRAccess extends Protocol {
 
 		int pathTuples = pathTuples_CE.length;
 		int ttp = md.getTwoTauPow();
-
-		PrePIRCOT preksearch = new PrePIRCOT(con1, con2);
-		preksearch.runE(predata, pathTuples, timer);
-
-		byte[][] u = new byte[pathTuples][];
-		for (int i = 0; i < pathTuples; i++) {
-			u[i] = ArrayUtils.addAll(pathTuples_CE[i].getF(), pathTuples_CE[i].getN());
-		}
-		byte[] v = ArrayUtils.addAll(new byte[] { 1 }, N.CE);
-
-		PIRCOT ksearch = new PIRCOT(con1, con2);
-		OutPIRCOT j = ksearch.runE(predata, u, v, timer);
+		int treeIndex = tree_DE.getTreeIndex();
+		boolean isLastTree = treeIndex == md.getNumTrees() - 1;
+		boolean isFirstTree = treeIndex == 0;
 
 		byte[][] x_DE = new byte[pathTuples][];
 		byte[][] x_CE = new byte[pathTuples][];
@@ -65,18 +56,41 @@ public class PIRAccess extends Protocol {
 			x_DE[i] = pathTuples_DE[i].getA();
 			x_CE[i] = pathTuples_CE[i].getA();
 		}
+		OutPIRCOT j = new OutPIRCOT();
+		TwoOneXor dN21 = new TwoOneXor();
+		TwoThreeXorByte X = new TwoThreeXorByte();
 
-		ThreeShiftPIR threeshiftpir = new ThreeShiftPIR(con1, con2);
-		TwoThreeXorByte X = threeshiftpir.runE(predata, x_DE, x_CE, j, timer);
+		if (isFirstTree) {
+			X.DE = x_DE[0];
+			X.CE = x_CE[0];
+		} else {
+			PrePIRCOT preksearch = new PrePIRCOT(con1, con2);
+			preksearch.runE(predata, pathTuples, timer);
 
-		TwoOneXor dN21 = new TwoOneXor();
-		dN21.t_E = dN.CE ^ dN.DE;
-		dN21.s_CE = dN.CE;
-		dN21.s_DE = dN.DE;
+			byte[][] u = new byte[pathTuples][];
+			for (int i = 0; i < pathTuples; i++) {
+				u[i] = ArrayUtils.addAll(pathTuples_CE[i].getF(), pathTuples_CE[i].getN());
+			}
+			byte[] v = ArrayUtils.addAll(new byte[] { 1 }, N.CE);
 
-		ThreeShiftXorPIR threeshiftxorpir = new ThreeShiftXorPIR(con1, con2);
-		TwoThreeXorByte nextL = threeshiftxorpir.runE(predata, x_DE, x_CE, j, dN21, ttp, timer);
-		byte[] Lip1 = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
+			PIRCOT ksearch = new PIRCOT(con1, con2);
+			j = ksearch.runE(predata, u, v, timer);
+
+			ThreeShiftPIR threeshiftpir = new ThreeShiftPIR(con1, con2);
+			X = threeshiftpir.runE(predata, x_DE, x_CE, j, timer);
+
+			dN21.t_E = dN.CE ^ dN.DE;
+			dN21.s_CE = dN.CE;
+			dN21.s_DE = dN.DE;
+		}
+
+		TwoThreeXorByte nextL = null;
+		byte[] Lip1 = null;
+		if (!isLastTree) {
+			ThreeShiftXorPIR threeshiftxorpir = new ThreeShiftXorPIR(con1, con2);
+			nextL = threeshiftxorpir.runE(predata, x_DE, x_CE, j, dN21, ttp, timer);
+			Lip1 = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
+		}
 
 		OutPIRAccess out = new OutPIRAccess(null, pathTuples_CE, pathTuples_DE, j, X, nextL, Lip1);
 
@@ -95,20 +109,9 @@ public class PIRAccess extends Protocol {
 
 		int pathTuples = pathTuples_CD.length;
 		int ttp = md.getTwoTauPow();
-
-		PrePIRCOT preksearch = new PrePIRCOT(con1, con2);
-		preksearch.runD(predata, pathTuples, timer);
-
-		byte[][] u = new byte[pathTuples][];
-		for (int i = 0; i < pathTuples; i++) {
-			u[i] = ArrayUtils.addAll(pathTuples_DE[i].getF(), pathTuples_DE[i].getN());
-			Util.setXor(u[i], ArrayUtils.addAll(pathTuples_CD[i].getF(), pathTuples_CD[i].getN()));
-		}
-		byte[] v = ArrayUtils.addAll(new byte[] { 1 }, N.CE);
-		Util.setXor(v, ArrayUtils.addAll(new byte[] { 1 }, N.CD));
-
-		PIRCOT ksearch = new PIRCOT(con1, con2);
-		OutPIRCOT j = ksearch.runD(predata, u, v, timer);
+		int treeIndex = tree_DE.getTreeIndex();
+		boolean isLastTree = treeIndex == md.getNumTrees() - 1;
+		boolean isFirstTree = treeIndex == 0;
 
 		byte[][] x_DE = new byte[pathTuples][];
 		byte[][] x_CD = new byte[pathTuples][];
@@ -116,18 +119,43 @@ public class PIRAccess extends Protocol {
 			x_DE[i] = pathTuples_DE[i].getA();
 			x_CD[i] = pathTuples_CD[i].getA();
 		}
+		OutPIRCOT j = new OutPIRCOT();
+		TwoOneXor dN21 = new TwoOneXor();
+		TwoThreeXorByte X = new TwoThreeXorByte();
 
-		ThreeShiftPIR threeshiftpir = new ThreeShiftPIR(con1, con2);
-		TwoThreeXorByte X = threeshiftpir.runD(predata, x_DE, x_CD, j, timer);
+		if (isFirstTree) {
+			X.DE = x_DE[0];
+			X.CD = x_CD[0];
+		} else {
+			PrePIRCOT preksearch = new PrePIRCOT(con1, con2);
+			preksearch.runD(predata, pathTuples, timer);
 
-		TwoOneXor dN21 = new TwoOneXor();
-		dN21.t_D = dN.CD ^ dN.DE;
-		dN21.s_CD = dN.CD;
-		dN21.s_DE = dN.DE;
+			byte[][] u = new byte[pathTuples][];
+			for (int i = 0; i < pathTuples; i++) {
+				u[i] = ArrayUtils.addAll(pathTuples_DE[i].getF(), pathTuples_DE[i].getN());
+				Util.setXor(u[i], ArrayUtils.addAll(pathTuples_CD[i].getF(), pathTuples_CD[i].getN()));
+			}
+			byte[] v = ArrayUtils.addAll(new byte[] { 1 }, N.CE);
+			Util.setXor(v, ArrayUtils.addAll(new byte[] { 1 }, N.CD));
 
-		ThreeShiftXorPIR threeshiftxorpir = new ThreeShiftXorPIR(con1, con2);
-		TwoThreeXorByte nextL = threeshiftxorpir.runD(predata, x_DE, x_CD, j, dN21, ttp, timer);
-		byte[] Lip1 = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
+			PIRCOT ksearch = new PIRCOT(con1, con2);
+			j = ksearch.runD(predata, u, v, timer);
+
+			ThreeShiftPIR threeshiftpir = new ThreeShiftPIR(con1, con2);
+			X = threeshiftpir.runD(predata, x_DE, x_CD, j, timer);
+
+			dN21.t_D = dN.CD ^ dN.DE;
+			dN21.s_CD = dN.CD;
+			dN21.s_DE = dN.DE;
+		}
+
+		TwoThreeXorByte nextL = null;
+		byte[] Lip1 = null;
+		if (!isLastTree) {
+			ThreeShiftXorPIR threeshiftxorpir = new ThreeShiftXorPIR(con1, con2);
+			nextL = threeshiftxorpir.runD(predata, x_DE, x_CD, j, dN21, ttp, timer);
+			Lip1 = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
+		}
 
 		OutPIRAccess out = new OutPIRAccess(pathTuples_CD, null, pathTuples_DE, j, X, nextL, Lip1);
 
@@ -146,12 +174,9 @@ public class PIRAccess extends Protocol {
 
 		int pathTuples = pathTuples_CE.length;
 		int ttp = md.getTwoTauPow();
-
-		PrePIRCOT preksearch = new PrePIRCOT(con1, con2);
-		preksearch.runC(predata, timer);
-
-		PIRCOT ksearch = new PIRCOT(con1, con2);
-		OutPIRCOT j = ksearch.runC(predata, timer);
+		int treeIndex = tree_CE.getTreeIndex();
+		boolean isLastTree = treeIndex == md.getNumTrees() - 1;
+		boolean isFirstTree = treeIndex == 0;
 
 		byte[][] x_CE = new byte[pathTuples][];
 		byte[][] x_CD = new byte[pathTuples][];
@@ -159,18 +184,35 @@ public class PIRAccess extends Protocol {
 			x_CE[i] = pathTuples_CE[i].getA();
 			x_CD[i] = pathTuples_CD[i].getA();
 		}
+		OutPIRCOT j = new OutPIRCOT();
+		TwoOneXor dN21 = new TwoOneXor();
+		TwoThreeXorByte X = new TwoThreeXorByte();
 
-		ThreeShiftPIR threeshiftpir = new ThreeShiftPIR(con1, con2);
-		TwoThreeXorByte X = threeshiftpir.runC(predata, x_CD, x_CE, j, timer);
+		if (isFirstTree) {
+			X.CE = x_CE[0];
+			X.CD = x_CD[0];
+		} else {
+			PrePIRCOT preksearch = new PrePIRCOT(con1, con2);
+			preksearch.runC(predata, timer);
 
-		TwoOneXor dN21 = new TwoOneXor();
-		dN21.t_C = dN.CD ^ dN.CE;
-		dN21.s_CD = dN.CD;
-		dN21.s_CE = dN.CE;
+			PIRCOT ksearch = new PIRCOT(con1, con2);
+			j = ksearch.runC(predata, timer);
+
+			ThreeShiftPIR threeshiftpir = new ThreeShiftPIR(con1, con2);
+			X = threeshiftpir.runC(predata, x_CD, x_CE, j, timer);
 
-		ThreeShiftXorPIR threeshiftxorpir = new ThreeShiftXorPIR(con1, con2);
-		TwoThreeXorByte nextL = threeshiftxorpir.runC(predata, x_CD, x_CE, j, dN21, ttp, timer);
-		byte[] Lip1 = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
+			dN21.t_C = dN.CD ^ dN.CE;
+			dN21.s_CD = dN.CD;
+			dN21.s_CE = dN.CE;
+		}
+
+		TwoThreeXorByte nextL = null;
+		byte[] Lip1 = null;
+		if (!isLastTree) {
+			ThreeShiftXorPIR threeshiftxorpir = new ThreeShiftXorPIR(con1, con2);
+			nextL = threeshiftxorpir.runC(predata, x_CD, x_CE, j, dN21, ttp, timer);
+			Lip1 = Util.xor(Util.xor(nextL.DE, nextL.CE), nextL.CD);
+		}
 
 		OutPIRAccess out = new OutPIRAccess(pathTuples_CD, pathTuples_CE, null, j, X, nextL, Lip1);
 
@@ -188,74 +230,76 @@ public class PIRAccess extends Protocol {
 		Tree tree_DE = null;
 		Tree tree_CE = null;
 
-		int treeIndex = 2;
-
-		for (int j = 0; j < 100; j++) {
-			if (party == Party.Eddie) {
-				tree_DE = forest[0].getTree(treeIndex);
-				tree_CE = forest[1].getTree(treeIndex);
-			} else if (party == Party.Debbie) {
-				tree_DE = forest[0].getTree(treeIndex);
-				tree_CD = forest[1].getTree(treeIndex);
-			} else if (party == Party.Charlie) {
-				tree_CE = forest[0].getTree(treeIndex);
-				tree_CD = forest[1].getTree(treeIndex);
-			} else {
-				assert false;
-			}
-
-			int Llen = md.getLBytesOfTree(treeIndex);
-			int Nlen = md.getNBytesOfTree(treeIndex);
-
-			TwoThreeXorInt dN = new TwoThreeXorInt();
-
-			TwoThreeXorByte N = new TwoThreeXorByte();
-			N.CD = new byte[Nlen];
-			N.DE = new byte[Nlen];
-			N.CE = new byte[Nlen];
-			TwoThreeXorByte L = new TwoThreeXorByte();
-			L.CD = new byte[Llen];
-			L.DE = new byte[Llen];
-			L.CE = new byte[Llen];
-			byte[] Li = new byte[Llen];
-
-			if (party == Party.Eddie) {
-				OutPIRAccess out = this.runE(md, predata, tree_DE, tree_CE, Li, L, N, dN, timer);
-				out.j.t_D = con1.readInt();
-				out.j.t_C = con2.readInt();
-				out.X.CD = con1.read();
-				int pathTuples = out.pathTuples_CE.length;
-				int index = (out.j.t_D + out.j.s_CE) % pathTuples;
-				byte[] X = Util.xor(Util.xor(out.X.DE, out.X.CE), out.X.CD);
-
-				boolean fail = false;
-				if (index != 0) {
-					System.err.println(j + ": PIRAcc test failed on KSearch index");
-					fail = true;
-				}
-				if (new BigInteger(1, X).intValue() != 0) {
-					System.err.println(j + ": PIRAcc test failed on 3ShiftPIR X");
-					fail = true;
+		for (int test = 0; test < 100; test++) {
+
+			for (int treeIndex = 0; treeIndex < md.getNumTrees(); treeIndex++) {
+				if (party == Party.Eddie) {
+					tree_DE = forest[0].getTree(treeIndex);
+					tree_CE = forest[1].getTree(treeIndex);
+				} else if (party == Party.Debbie) {
+					tree_DE = forest[0].getTree(treeIndex);
+					tree_CD = forest[1].getTree(treeIndex);
+				} else if (party == Party.Charlie) {
+					tree_CE = forest[0].getTree(treeIndex);
+					tree_CD = forest[1].getTree(treeIndex);
+				} else {
+					throw new NoSuchPartyException(party + "");
 				}
-				if (new BigInteger(1, out.Lip1).intValue() != 0) {
-					System.err.println(j + ": PIRAcc test failed on 3ShiftXorPIR Lip1");
-					fail = true;
-				}
-				if (!fail)
-					System.out.println(j + ": PIRAcc test passed");
-
-			} else if (party == Party.Debbie) {
-				OutPIRAccess out = this.runD(md, predata, tree_DE, tree_CD, Li, L, N, dN, timer);
-				con1.write(out.j.t_D);
-				con1.write(out.X.CD);
-
-			} else if (party == Party.Charlie) {
-				OutPIRAccess out = this.runC(md, predata, tree_CD, tree_CE, Li, L, N, dN, timer);
-				con1.write(out.j.t_C);
 
-			} else {
-				throw new NoSuchPartyException(party + "");
+				int Llen = md.getLBytesOfTree(treeIndex);
+				int Nlen = md.getNBytesOfTree(treeIndex);
+
+				TwoThreeXorInt dN = new TwoThreeXorInt();
+
+				TwoThreeXorByte N = new TwoThreeXorByte();
+				N.CD = new byte[Nlen];
+				N.DE = new byte[Nlen];
+				N.CE = new byte[Nlen];
+				TwoThreeXorByte L = new TwoThreeXorByte();
+				L.CD = new byte[Llen];
+				L.DE = new byte[Llen];
+				L.CE = new byte[Llen];
+				byte[] Li = new byte[Llen];
+
+				if (party == Party.Eddie) {
+					OutPIRAccess out = this.runE(md, predata, tree_DE, tree_CE, Li, L, N, dN, timer);
+					out.j.t_D = con1.readInt();
+					out.j.t_C = con2.readInt();
+					out.X.CD = con1.read();
+					int pathTuples = out.pathTuples_CE.length;
+					int index = (out.j.t_D + out.j.s_CE) % pathTuples;
+					byte[] X = Util.xor(Util.xor(out.X.DE, out.X.CE), out.X.CD);
+
+					boolean fail = false;
+					if (index != 0) {
+						System.err.println(test + " " + treeIndex + ": PIRAcc test failed on KSearch index");
+						fail = true;
+					}
+					if (new BigInteger(1, X).intValue() != 0) {
+						System.err.println(test + " " + treeIndex + ": PIRAcc test failed on 3ShiftPIR X");
+						fail = true;
+					}
+					if (treeIndex < md.getNumTrees() - 1 && new BigInteger(1, out.Lip1).intValue() != 0) {
+						System.err.println(test + " " + treeIndex + ": PIRAcc test failed on 3ShiftXorPIR Lip1");
+						fail = true;
+					}
+					if (!fail)
+						System.out.println(test + " " + treeIndex + ": PIRAcc test passed");
+
+				} else if (party == Party.Debbie) {
+					OutPIRAccess out = this.runD(md, predata, tree_DE, tree_CD, Li, L, N, dN, timer);
+					con1.write(out.j.t_D);
+					con1.write(out.X.CD);
+
+				} else if (party == Party.Charlie) {
+					OutPIRAccess out = this.runC(md, predata, tree_CD, tree_CE, Li, L, N, dN, timer);
+					con1.write(out.j.t_C);
+
+				} else {
+					throw new NoSuchPartyException(party + "");
+				}
 			}
+
 		}
 	}