Browse Source

fix missing timing in ppt. add reshuffle

Boyoung- 8 years ago
parent
commit
05c500c611

+ 4 - 3
src/measure/P.java

@@ -1,13 +1,14 @@
 package measure;
 
 public class P {
-	public static final int size = 5;
+	public static final int size = 6;
 
 	public static final int ACC = 0;
 	public static final int COT = 1;
 	public static final int IOT = 2;
 	public static final int PPT = 3;
-	public static final int XOT = 4;
+	public static final int RSF = 4;
+	public static final int XOT = 5;
 
-	public static final String[] names = { "ACC", "COT", "IOT", "PPT", "XOT" };
+	public static final String[] names = { "ACC", "COT", "IOT", "PPT", "RSF", "XOT" };
 }

+ 6 - 0
src/oram/Tuple.java

@@ -157,6 +157,12 @@ public class Tuple implements Serializable {
 		return numBytes == t.getNumBytes();
 	}
 
+	public boolean equals(Tuple t) {
+		if (!sameLength(t))
+			return false;
+		return Util.equal(toByteArray(), t.toByteArray());
+	}
+
 	public byte[] toByteArray() {
 		byte[] tuple = new byte[numBytes];
 		int offset = 0;

+ 2 - 4
src/protocols/Access.java

@@ -41,8 +41,7 @@ public class Access extends Protocol {
 		Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
 		for (int i = 0; i < pathTuples.length; i++)
 			pathTuples[i].setXor(predata.access_p[i]);
-		Object[] objArray = Util.permute(pathTuples, predata.access_sigma);
-		pathTuples = Arrays.copyOf(objArray, objArray.length, Tuple[].class);
+		pathTuples = Util.permute(pathTuples, predata.access_sigma);
 
 		// step 3
 		byte[] y = null;
@@ -106,8 +105,7 @@ public class Access extends Protocol {
 		Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
 		for (int i = 0; i < pathTuples.length; i++)
 			pathTuples[i].setXor(predata.access_p[i]);
-		Object[] objArray = Util.permute(pathTuples, predata.access_sigma);
-		pathTuples = Arrays.copyOf(objArray, objArray.length, Tuple[].class);
+		pathTuples = Util.permute(pathTuples, predata.access_sigma);
 
 		// step 2
 		timer.start(P.ACC, M.online_write);

+ 4 - 0
src/protocols/PostProcessT.java

@@ -33,7 +33,9 @@ public class PostProcessT extends Protocol {
 		}
 
 		// step 1
+		timer.start(P.PPT, M.online_read);
 		int delta = con2.readObject();
+		timer.stop(P.PPT, M.online_read);
 
 		// step 3
 		int twoTauPow = predata.ppt_s.length;
@@ -70,7 +72,9 @@ public class PostProcessT extends Protocol {
 		int twoTauPow = predata.ppt_r.length;
 		int delta = (predata.ppt_alpha - j2 + twoTauPow) % twoTauPow;
 
+		timer.start(P.PPT, M.online_write);
 		con1.write(delta);
+		timer.stop(P.PPT, M.online_write);
 
 		// step 2
 		byte[][] c = new byte[twoTauPow][];

+ 5 - 0
src/protocols/PreData.java

@@ -32,4 +32,9 @@ public class PreData {
 	public int ppt_alpha;
 	public byte[][] ppt_r;
 	public byte[][] ppt_s;
+
+	public int[] reshuffle_pi;
+	public Tuple[] reshuffle_p;
+	public Tuple[] reshuffle_r;
+	public Tuple[] reshuffle_a_prime;
 }

+ 6 - 4
src/protocols/PrePostProcessT.java

@@ -15,14 +15,12 @@ public class PrePostProcessT extends Protocol {
 	}
 
 	public void runE(PreData predata, Timer timer) {
-		timer.start(P.PPT, M.offline_comp);
-
+		timer.start(P.PPT, M.offline_read);
 		predata.ppt_Li = con1.read();
 		predata.ppt_Lip1 = con1.read();
 
 		predata.ppt_s = con1.readObject();
-
-		timer.stop(P.PPT, M.offline_comp);
+		timer.stop(P.PPT, M.offline_read);
 	}
 
 	public void runD(PreData predata, PreData prev, int LiBytes, int Lip1Bytes, int tau, Timer timer) {
@@ -44,12 +42,14 @@ public class PrePostProcessT extends Protocol {
 		}
 		predata.ppt_s[predata.ppt_alpha] = Util.xor(predata.ppt_r[predata.ppt_alpha], predata.ppt_Lip1);
 
+		timer.start(P.PPT, M.offline_write);
 		con1.write(predata.ppt_Li);
 		con1.write(predata.ppt_Lip1);
 
 		con2.write(predata.ppt_alpha);
 		con2.write(predata.ppt_r);
 		con1.write(predata.ppt_s);
+		timer.stop(P.PPT, M.offline_write);
 
 		timer.stop(P.PPT, M.offline_comp);
 	}
@@ -63,8 +63,10 @@ public class PrePostProcessT extends Protocol {
 			predata.ppt_Li = Util.nextBytes(LiBytes, Crypto.sr);
 		predata.ppt_Lip1 = Util.nextBytes(Lip1Bytes, Crypto.sr);
 
+		timer.start(P.PPT, M.offline_read);
 		predata.ppt_alpha = con2.readObject();
 		predata.ppt_r = con2.readObject();
+		timer.stop(P.PPT, M.offline_read);
 
 		timer.stop(P.PPT, M.offline_comp);
 	}

+ 64 - 0
src/protocols/PreReshuffle.java

@@ -0,0 +1,64 @@
+package protocols;
+
+import communication.Communication;
+import crypto.Crypto;
+import measure.M;
+import measure.P;
+import measure.Timer;
+import oram.Forest;
+import oram.Metadata;
+import oram.Tuple;
+import util.Util;
+
+public class PreReshuffle extends Protocol {
+	public PreReshuffle(Communication con1, Communication con2) {
+		super(con1, con2);
+	}
+
+	public void runE(PreData predata, Timer timer) {
+		timer.start(P.RSF, M.offline_comp);
+
+		predata.reshuffle_pi = Util.inversePermutation(predata.access_sigma);
+
+		timer.start(P.RSF, M.offline_read);
+		predata.reshuffle_r = con1.readObject();
+		timer.stop(P.RSF, M.offline_read);
+
+		timer.stop(P.RSF, M.offline_comp);
+	}
+
+	public void runD(PreData predata, int[] tupleParam, Timer timer) {
+		timer.start(P.RSF, M.offline_comp);
+
+		predata.reshuffle_pi = Util.inversePermutation(predata.access_sigma);
+		int numTuples = predata.reshuffle_pi.length;
+		predata.reshuffle_p = new Tuple[numTuples];
+		predata.reshuffle_r = new Tuple[numTuples];
+		Tuple[] a = new Tuple[numTuples];
+		for (int i = 0; i < numTuples; i++) {
+			predata.reshuffle_p[i] = new Tuple(tupleParam[0], tupleParam[1], tupleParam[2], tupleParam[3], Crypto.sr);
+			predata.reshuffle_r[i] = new Tuple(tupleParam[0], tupleParam[1], tupleParam[2], tupleParam[3], Crypto.sr);
+			a[i] = predata.reshuffle_p[i].xor(predata.reshuffle_r[i]);
+		}
+		predata.reshuffle_a_prime = Util.permute(a, predata.reshuffle_pi);
+
+		timer.start(P.RSF, M.offline_write);
+		con2.write(predata.reshuffle_p);
+		con2.write(predata.reshuffle_a_prime);
+		con1.write(predata.reshuffle_r);
+		timer.stop(P.RSF, M.offline_write);
+
+		timer.stop(P.RSF, M.offline_comp);
+	}
+
+	public void runC(PreData predata, Timer timer) {
+		timer.start(P.RSF, M.offline_read);
+		predata.reshuffle_p = con2.readObject();
+		predata.reshuffle_a_prime = con2.readObject();
+		timer.stop(P.RSF, M.offline_read);
+	}
+
+	@Override
+	public void run(Party party, Metadata md, Forest forest) {
+	}
+}

+ 198 - 0
src/protocols/Reshuffle.java

@@ -0,0 +1,198 @@
+package protocols;
+
+import java.math.BigInteger;
+
+import communication.Communication;
+import crypto.Crypto;
+import exceptions.AccessException;
+import exceptions.NoSuchPartyException;
+import measure.M;
+import measure.P;
+import measure.Timer;
+import oram.Forest;
+import oram.Metadata;
+import oram.Tree;
+import oram.Tuple;
+import util.Util;
+
+public class Reshuffle extends Protocol {
+
+	public Reshuffle(Communication con1, Communication con2) {
+		super(con1, con2);
+	}
+
+	public Tuple[] runE(PreData predata, Tuple[] path, boolean firstTree, Timer timer) {
+		if (firstTree)
+			return path;
+
+		timer.start(P.RSF, M.online_comp);
+
+		// step 1
+		timer.start(P.RSF, M.online_read);
+		Tuple[] z = con2.readObject();
+		timer.stop(P.RSF, M.online_read);
+
+		// step 2
+		Tuple[] b = new Tuple[z.length];
+		for (int i = 0; i < b.length; i++)
+			b[i] = path[i].xor(z[i]).xor(predata.reshuffle_r[i]);
+		Tuple[] b_prime = Util.permute(b, predata.reshuffle_pi);
+
+		timer.stop(P.RSF, M.online_comp);
+		return b_prime;
+	}
+
+	public void runD() {
+	}
+
+	public Tuple[] runC(PreData predata, Tuple[] path, boolean firstTree, Timer timer) {
+		if (firstTree)
+			return path;
+
+		timer.start(P.RSF, M.online_comp);
+
+		// step 1
+		Tuple[] z = new Tuple[path.length];
+		for (int i = 0; i < z.length; i++)
+			z[i] = path[i].xor(predata.reshuffle_p[i]);
+
+		timer.start(P.RSF, M.online_write);
+		con1.write(z);
+		timer.stop(P.RSF, M.online_write);
+
+		timer.stop(P.RSF, M.online_comp);
+		return predata.reshuffle_a_prime;
+	}
+
+	// for testing correctness
+	@Override
+	public void run(Party party, Metadata md, Forest forest) {
+		int records = 5;
+		int repeat = 5;
+
+		int tau = md.getTau();
+		int numTrees = md.getNumTrees();
+		long numInsert = md.getNumInsertRecords();
+		int addrBits = md.getAddrBits();
+
+		Timer timer = new Timer();
+
+		sanityCheck();
+
+		System.out.println();
+
+		for (int i = 0; i < records; i++) {
+			long N = Util.nextLong(numInsert, Crypto.sr);
+
+			for (int j = 0; j < repeat; j++) {
+				System.out.println("Test: " + i + " " + j);
+				System.out.println("N=" + BigInteger.valueOf(N).toString(2));
+
+				byte[] Li = new byte[0];
+
+				for (int ti = 0; ti < numTrees; ti++) {
+					long Ni_value = Util.getSubBits(N, addrBits, addrBits - md.getNBitsOfTree(ti));
+					long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
+							Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
+					byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
+					byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
+
+					PreData predata = new PreData();
+					PreAccess preaccess = new PreAccess(con1, con2);
+					Access access = new Access(con1, con2);
+					PreReshuffle prereshuffle = new PreReshuffle(con1, con2);
+
+					if (party == Party.Eddie) {
+						Tree OTi = forest.getTree(ti);
+						int numTuples = (OTi.getD() - 1) * OTi.getW() + OTi.getStashSize();
+						preaccess.runE(predata, OTi, numTuples, timer);
+
+						byte[] sE_Ni = Util.nextBytes(Ni.length, Crypto.sr);
+						byte[] sD_Ni = Util.xor(Ni, sE_Ni);
+						con1.write(sD_Ni);
+
+						byte[] sE_Nip1_pr = Util.nextBytes(Nip1_pr.length, Crypto.sr);
+						byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
+						con1.write(sD_Nip1_pr);
+
+						OutAccess outaccess = access.runE(predata, OTi, sE_Ni, sE_Nip1_pr, timer);
+
+						if (ti == numTrees - 1)
+							con2.write(N);
+
+						prereshuffle.runE(predata, timer);
+						Tuple[] E_P_prime = runE(predata, outaccess.E_P, ti == 0, timer);
+
+						Tuple[] C_P = con2.readObject();
+						Tuple[] C_P_prime = con2.readObject();
+						Tuple[] oldPath = new Tuple[C_P.length];
+						Tuple[] newPath = new Tuple[C_P.length];
+
+						for (int k = 0; k < C_P.length; k++) {
+							oldPath[k] = outaccess.E_P[k].xor(C_P[k]);
+							newPath[k] = E_P_prime[k].xor(C_P_prime[k]);
+						}
+						oldPath = Util.permute(oldPath, predata.reshuffle_pi);
+
+						boolean pass = true;
+						for (int k = 0; k < newPath.length; k++) {
+							if (!oldPath[k].equals(newPath[k])) {
+								System.err.println("Reshuffle test failed");
+								pass = false;
+								break;
+							}
+						}
+						if (pass)
+							System.out.println("Reshuffle test passed");
+
+					} else if (party == Party.Debbie) {
+						Tree OTi = forest.getTree(ti);
+						preaccess.runD(predata, timer);
+
+						byte[] sD_Ni = con1.read();
+
+						byte[] sD_Nip1_pr = con1.read();
+
+						access.runD(predata, OTi, sD_Ni, sD_Nip1_pr, timer);
+
+						int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
+								md.getABytesOfTree(ti) };
+						prereshuffle.runD(predata, tupleParam, timer);
+						runD();
+
+					} else if (party == Party.Charlie) {
+						preaccess.runC(timer);
+
+						System.out.println("L" + ti + "=" + new BigInteger(1, Li).toString(2));
+
+						OutAccess outaccess = access.runC(md, ti, Li, timer);
+
+						prereshuffle.runC(predata, timer);
+						Tuple[] C_P_prime = runC(predata, outaccess.C_P, ti == 0, timer);
+
+						Li = outaccess.C_Lip1;
+
+						if (ti == numTrees - 1) {
+							N = con1.readObject();
+							long data = new BigInteger(1, outaccess.C_Ti.getA()).longValue();
+							if (N == data) {
+								System.out.println("Access passed");
+								System.out.println();
+							} else {
+								throw new AccessException("Access failed");
+							}
+						}
+
+						con1.write(outaccess.C_P);
+						con1.write(C_P_prime);
+
+					} else {
+						throw new NoSuchPartyException(party + "");
+					}
+				}
+			}
+		}
+
+		// timer.print();
+	}
+}

+ 3 - 0
src/ui/CLI.java

@@ -16,6 +16,7 @@ import protocols.Party;
 import protocols.Protocol;
 import protocols.SSCOT;
 import protocols.SSIOT;
+import protocols.Reshuffle;
 import protocols.PostProcessT;
 import protocols.SSXOT;
 import protocols.Access;
@@ -73,6 +74,8 @@ public class CLI {
 			operation = SSCOT.class;
 		} else if (protocol.equals("ssiot")) {
 			operation = SSIOT.class;
+		} else if (protocol.equals("reshuffle")) {
+			operation = Reshuffle.class;
 		} else if (protocol.equals("ppt")) {
 			operation = PostProcessT.class;
 		} else if (protocol.equals("ssxot")) {

+ 2 - 2
src/util/Util.java

@@ -117,12 +117,12 @@ public class Util {
 		return ip;
 	}
 
+	@SuppressWarnings("unchecked")
 	public static <T> T[] permute(T[] original, int[] p) {
-		@SuppressWarnings("unchecked")
 		T[] permuted = (T[]) new Object[original.length];
 		for (int i = 0; i < original.length; i++)
 			permuted[p[i]] = original[i];
-		return permuted;
+		return (T[]) Arrays.copyOf(permuted, permuted.length, original.getClass());
 	}
 
 	public static byte[] longToBytes(long l, int numBytes) {

+ 25 - 13
test/misc/MiscTests.java

@@ -2,11 +2,13 @@ package misc;
 
 import java.math.BigInteger;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
 import crypto.Crypto;
 import oram.Forest;
 import oram.Metadata;
+import oram.Tuple;
 import util.StopWatch;
 import util.Util;
 
@@ -40,21 +42,31 @@ public class MiscTests {
 		 * Forest.readFromFile(md.getDefaultForestFileName()); forest.print();
 		 */
 
-		StopWatch sw1 = new StopWatch();
-		StopWatch sw2 = new StopWatch();
-		byte[] arr1 = Util.nextBytes((int) Math.pow(2, 20), Crypto.sr);
-		byte[] arr2 = Util.nextBytes((int) Math.pow(2, 20), Crypto.sr);
-
-		sw1.start();
-		Util.xor(arr1, arr2);
-		sw1.stop();
+		/*
+		 * StopWatch sw1 = new StopWatch(); StopWatch sw2 = new StopWatch();
+		 * byte[] arr1 = Util.nextBytes((int) Math.pow(2, 20), Crypto.sr);
+		 * byte[] arr2 = Util.nextBytes((int) Math.pow(2, 20), Crypto.sr);
+		 * 
+		 * sw1.start(); Util.xor(arr1, arr2); sw1.stop();
+		 * 
+		 * sw2.start(); new BigInteger(1, arr1).xor(new BigInteger(1,
+		 * arr2)).toByteArray(); sw2.stop();
+		 * 
+		 * System.out.println(sw1.toMS()); System.out.println(sw2.toMS());
+		 */
 
-		sw2.start();
-		new BigInteger(1, arr1).xor(new BigInteger(1, arr2)).toByteArray();
-		sw2.stop();
+		int n = 20;
+		Integer[] oldArr = new Integer[n];
+		for (int i = 0; i < n; i++)
+			oldArr[i] = Crypto.sr.nextInt(50);
+		int[] pi = Util.randomPermutation(n, Crypto.sr);
+		int[] pi_ivs = Util.inversePermutation(pi);
+		Integer[] newArr = Util.permute(oldArr, pi);
+		newArr = Util.permute(newArr, pi_ivs);
 
-		System.out.println(sw1.toMS());
-		System.out.println(sw2.toMS());
+		for (int i = 0; i < n; i++) {
+			System.out.println(oldArr[i] + " " + newArr[i]);
+		}
 	}
 
 }

+ 50 - 0
test/protocols/TestReshuffle_C.java

@@ -0,0 +1,50 @@
+package protocols;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+
+public class TestReshuffle_C {
+
+	public static void main(String[] args) {
+		Runtime runTime = Runtime.getRuntime();
+		Process process = null;
+		String dir = System.getProperty("user.dir");
+		String binDir = dir + "\\bin";
+		String libs = dir + "\\lib\\*";
+		try {
+			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol reshuffle charlie");
+
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+		InputStream inputStream = process.getInputStream();
+		InputStreamReader isr = new InputStreamReader(inputStream);
+		InputStream errorStream = process.getErrorStream();
+		InputStreamReader esr = new InputStreamReader(errorStream);
+
+		System.out.println("STANDARD OUTPUT:");
+		int n1;
+		char[] c1 = new char[1024];
+		try {
+			while ((n1 = isr.read(c1)) > 0) {
+				System.out.print(new String(Arrays.copyOfRange(c1, 0, n1)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+
+		System.out.println("STANDARD ERROR:");
+		int n2;
+		char[] c2 = new char[1024];
+		try {
+			while ((n2 = esr.read(c2)) > 0) {
+				System.err.print(new String(Arrays.copyOfRange(c2, 0, n2)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+	}
+
+}

+ 50 - 0
test/protocols/TestReshuffle_D.java

@@ -0,0 +1,50 @@
+package protocols;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+
+public class TestReshuffle_D {
+
+	public static void main(String[] args) {
+		Runtime runTime = Runtime.getRuntime();
+		Process process = null;
+		String dir = System.getProperty("user.dir");
+		String binDir = dir + "\\bin";
+		String libs = dir + "\\lib\\*";
+		try {
+			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol reshuffle debbie");
+
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+		InputStream inputStream = process.getInputStream();
+		InputStreamReader isr = new InputStreamReader(inputStream);
+		InputStream errorStream = process.getErrorStream();
+		InputStreamReader esr = new InputStreamReader(errorStream);
+
+		System.out.println("STANDARD OUTPUT:");
+		int n1;
+		char[] c1 = new char[1024];
+		try {
+			while ((n1 = isr.read(c1)) > 0) {
+				System.out.print(new String(Arrays.copyOfRange(c1, 0, n1)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+
+		System.out.println("STANDARD ERROR:");
+		int n2;
+		char[] c2 = new char[1024];
+		try {
+			while ((n2 = esr.read(c2)) > 0) {
+				System.err.print(new String(Arrays.copyOfRange(c2, 0, n2)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+	}
+
+}

+ 50 - 0
test/protocols/TestReshuffle_E.java

@@ -0,0 +1,50 @@
+package protocols;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.Arrays;
+
+public class TestReshuffle_E {
+
+	public static void main(String[] args) {
+		Runtime runTime = Runtime.getRuntime();
+		Process process = null;
+		String dir = System.getProperty("user.dir");
+		String binDir = dir + "\\bin";
+		String libs = dir + "\\lib\\*";
+		try {
+			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol reshuffle eddie");
+
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+		InputStream inputStream = process.getInputStream();
+		InputStreamReader isr = new InputStreamReader(inputStream);
+		InputStream errorStream = process.getErrorStream();
+		InputStreamReader esr = new InputStreamReader(errorStream);
+
+		System.out.println("STANDARD OUTPUT:");
+		int n1;
+		char[] c1 = new char[1024];
+		try {
+			while ((n1 = isr.read(c1)) > 0) {
+				System.out.print(new String(Arrays.copyOfRange(c1, 0, n1)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+
+		System.out.println("STANDARD ERROR:");
+		int n2;
+		char[] c2 = new char[1024];
+		try {
+			while ((n2 = esr.read(c2)) > 0) {
+				System.err.print(new String(Arrays.copyOfRange(c2, 0, n2)));
+			}
+		} catch (IOException e) {
+			e.printStackTrace();
+		}
+	}
+
+}