Browse Source

trying to find out why access is slow

Boyoung- 8 years ago
parent
commit
68c600f052

+ 26 - 3
src/communication/Communication.java

@@ -17,6 +17,7 @@ import java.util.concurrent.LinkedBlockingQueue;
 
 import org.apache.commons.lang3.SerializationUtils;
 
+import oram.Tuple;
 import util.Util;
 
 /**
@@ -346,15 +347,27 @@ public class Communication {
 			write(out[i]);
 	}
 
-	public <T> void write(T out) {
+	private <T> void write(T out) {
 		write(SerializationUtils.serialize((Serializable) out));
 	}
-
+	
+	/*
 	public <T> void write(T[] out) {
 		write(out.length);
 		for (int i = 0; i < out.length; i++)
 			write(out[i]);
 	}
+	*/
+	
+	public void write(Tuple[] tuples) {
+		write(tuples.length);
+		for (int i=0; i<tuples.length; i++) {
+			write(tuples[i].getF());
+			write(tuples[i].getN());
+			write(tuples[i].getL());
+			write(tuples[i].getA());
+		}
+	}
 
 	public static final Charset defaultCharset = Charset.forName("ASCII");
 
@@ -456,11 +469,12 @@ public class Communication {
 		return data;
 	}
 
-	public <T> T readObject() {
+	private <T> T readObject() {
 		T object = SerializationUtils.deserialize(read());
 		return object;
 	}
 
+	/*
 	public <T> T[] readObjectArray() {
 		int len = readInt();
 		@SuppressWarnings("unchecked")
@@ -469,6 +483,15 @@ public class Communication {
 			data[i] = readObject();
 		return data;
 	}
+	*/
+	
+	public Tuple[] readTupleArray() {
+		int len = readInt();
+		Tuple[] tuples = new Tuple[len];
+		for (int i=0; i<len; i++)
+			tuples[i] = new Tuple(read(), read(), read(), read());
+		return tuples;
+	}
 
 	/**
 	 * This thread runs while listening for incoming connections. It behaves

+ 16 - 0
src/exceptions/AccessException.java

@@ -0,0 +1,16 @@
+package exceptions;
+
+public class AccessException extends RuntimeException {
+	/**
+	 * 
+	 */
+	private static final long serialVersionUID = 1L;
+
+	public AccessException() {
+		super();
+	}
+
+	public AccessException(String message) {
+		super(message);
+	}
+}

+ 16 - 0
src/exceptions/StopWatchException.java

@@ -0,0 +1,16 @@
+package exceptions;
+
+public class StopWatchException extends Exception {
+	/**
+	 * 
+	 */
+	private static final long serialVersionUID = 1L;
+
+	public StopWatchException() {
+		super();
+	}
+
+	public StopWatchException(String message) {
+		super(message);
+	}
+}

+ 93 - 40
src/protocols/Access.java

@@ -7,26 +7,45 @@ import org.apache.commons.lang3.ArrayUtils;
 
 import communication.Communication;
 import crypto.Crypto;
+import exceptions.AccessException;
 import exceptions.NoSuchPartyException;
 import oram.Bucket;
 import oram.Forest;
 import oram.Metadata;
 import oram.Tree;
 import oram.Tuple;
+import util.StopWatch;
 import util.Util;
 
 public class Access extends Protocol {
+	
+	private StopWatch step0;
+	private StopWatch step1;
+	private StopWatch step2;
+	private StopWatch step3;
+	private StopWatch step4;
+	private StopWatch step5;
 
 	public Access(Communication con1, Communication con2) {
 		super(con1, con2);
+		
+		step0 = new StopWatch();
+		step1 = new StopWatch();
+		step2 = new StopWatch();
+		step3 = new StopWatch();
+		step4 = new StopWatch();
+		step5 = new StopWatch();
 	}
 
 	public OutAccess runE(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr) {
+		step0.start();
 		// step 0: get Li from C
 		byte[] Li = new byte[0];
 		if (OTi.getTreeIndex() > 0)
 			Li = con2.read();
+		step0.stop();
 
+		step1.start();
 		// step 1
 		Bucket[] pathBuckets = OTi.getBucketsOnPath(new BigInteger(1, Li).longValue());
 		Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
@@ -34,7 +53,9 @@ public class Access extends Protocol {
 			pathTuples[i].setXor(predata.access_p[i]);
 		Object[] objArray = Util.permute(pathTuples, predata.access_sigma);
 		pathTuples = Arrays.copyOf(objArray, objArray.length, Tuple[].class);
+		step1.stop();
 
+		step3.start();
 		// step 3
 		byte[] y = null;
 		if (OTi.getTreeIndex() == 0)
@@ -57,7 +78,9 @@ public class Access extends Protocol {
 			SSCOT sscot = new SSCOT(con1, con2);
 			sscot.runE(predata, m, a);
 		}
+		step3.stop();
 
+		step4.start();
 		// step 4
 		if (OTi.getTreeIndex() < OTi.getH() - 1) {
 			int ySegBytes = y.length / OTi.getTwoTauPow();
@@ -68,24 +91,30 @@ public class Access extends Protocol {
 			SSIOT ssiot = new SSIOT(con1, con2);
 			ssiot.runE(predata, y_array, Nip1_pr);
 		}
+		step4.stop();
 
+		step5.start();
 		// step 5
 		Tuple Ti = null;
 		if (OTi.getTreeIndex() == 0)
 			Ti = pathTuples[0];
 		else
 			Ti = new Tuple(new byte[0], Ni, Li, y);
+		step5.stop();
 
 		OutAccess outaccess = new OutAccess(null, null, null, Ti, pathTuples);
 		return outaccess;
 	}
 
 	public void runD(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr) {
+		step0.start();
 		// step 0: get Li from C
 		byte[] Li = new byte[0];
 		if (OTi.getTreeIndex() > 0)
 			Li = con2.read();
+		step0.stop();
 
+		step1.start();
 		// step 1
 		Bucket[] pathBuckets = OTi.getBucketsOnPath(new BigInteger(1, Li).longValue());
 		Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
@@ -93,11 +122,15 @@ public class Access extends Protocol {
 			pathTuples[i].setXor(predata.access_p[i]);
 		Object[] objArray = Util.permute(pathTuples, predata.access_sigma);
 		pathTuples = Arrays.copyOf(objArray, objArray.length, Tuple[].class);
+		step1.stop();
 
+		step2.start();
 		// step 2
 		con2.write(pathTuples);
 		con2.write(Ni);
+		step2.stop();
 
+		step3.start();
 		// step 3
 		if (OTi.getTreeIndex() > 0) {
 			byte[][] b = new byte[pathTuples.length][];
@@ -111,26 +144,35 @@ public class Access extends Protocol {
 			SSCOT sscot = new SSCOT(con1, con2);
 			sscot.runD(predata, b);
 		}
+		step3.stop();
 
+		step4.start();
 		// step 4
 		if (OTi.getTreeIndex() < OTi.getH() - 1) {
 			SSIOT ssiot = new SSIOT(con1, con2);
 			ssiot.runD(predata, Nip1_pr);
 		}
+		step4.stop();
 	}
 
 	public OutAccess runC(Metadata md, int treeIndex, byte[] Li) {
+		step0.start();
 		// step 0: send Li to E and D
 		if (treeIndex > 0) {
 			con1.write(Li);
 			con2.write(Li);
 		}
+		step0.stop();
 
+		step2.start();
 		// step 2
-		Object[] objArray = con2.readObjectArray();
-		Tuple[] pathTuples = Arrays.copyOf(objArray, objArray.length, Tuple[].class);
+		//Object[] objArray = con2.readObjectArray();
+		//Tuple[] pathTuples = Arrays.copyOf(objArray, objArray.length, Tuple[].class);
+		Tuple[] pathTuples = con2.readTupleArray();
 		byte[] Ni = con2.read();
+		step2.stop();
 
+		step3.start();
 		// step 3
 		int j1 = 0;
 		byte[] z = null;
@@ -143,7 +185,9 @@ public class Access extends Protocol {
 			byte[] d = pathTuples[j1].getA();
 			z = Util.xor(je.m_t, d);
 		}
-
+		step3.stop();
+		
+		step4.start();
 		// step 4
 		int j2 = 0;
 		byte[] Lip1 = null;
@@ -157,7 +201,9 @@ public class Access extends Protocol {
 			byte[] z_j2 = Arrays.copyOfRange(z, j2 * lSegBytes, (j2 + 1) * lSegBytes);
 			Lip1 = Util.xor(jy.m_t, z_j2);
 		}
+		step4.stop();
 
+		step5.start();
 		Tuple Ti = null;
 		if (treeIndex == 0) {
 			Ti = pathTuples[0];
@@ -169,6 +215,7 @@ public class Access extends Protocol {
 			Crypto.sr.nextBytes(pathTuples[j1].getL());
 			Crypto.sr.nextBytes(pathTuples[j1].getA());
 		}
+		step5.stop();
 
 		OutAccess outaccess = new OutAccess(Lip1, Ti, pathTuples, null, null);
 		return outaccess;
@@ -176,54 +223,38 @@ public class Access extends Protocol {
 
 	@Override
 	public void run(Party party, Metadata md, Forest forest) {
-		/*
-		 * PreData predata = new PreData(); PreAccess preaccess = new
-		 * PreAccess(con1, con2); int treeIndex = 1; Tree tree = null; int
-		 * numTuples = 0; if (forest != null) { tree =
-		 * forest.getTree(treeIndex); numTuples = (tree.getD() - 1) *
-		 * tree.getW() + tree.getStashSize(); } byte[] Li = new BigInteger("11",
-		 * 2).toByteArray(); byte[] Ni = new byte[] { 0 }; byte[] Nip1_pr = new
-		 * byte[] { 0 }; if (party == Party.Eddie) { preaccess.runE(predata,
-		 * tree, numTuples); runE(predata, tree, Ni, Nip1_pr);
-		 * 
-		 * } else if (party == Party.Debbie) { preaccess.runD(predata);
-		 * runD(predata, tree, Ni, Nip1_pr);
-		 * 
-		 * } else if (party == Party.Charlie) { preaccess.runC(); runC(md,
-		 * treeIndex, Li);
-		 * 
-		 * } else { throw new NoSuchPartyException(party + ""); }
-		 */
-
-		int records = 10;
-		int repeart = 5;
+		int records = 5;
+		int repeat = 5;
 
 		int tau = md.getTau();
 		int numTrees = md.getNumTrees();
 		long numInsert = md.getNumInsertRecords();
 		int addrBits = md.getAddrBits();
 
+		StopWatch stopwatch = new StopWatch();
+
+		sanityCheck();
+
+		System.out.println();
+
 		for (int i = 0; i < records; i++) {
 			long N = Util.nextLong(numInsert, Crypto.sr);
-			// System.out.println("N=" + BigInteger.valueOf(N).toString(2));
-			for (int j = 0; j < repeart; j++) {
+
+			for (int j = 0; j < repeat; j++) {
+				System.out.println("Test: " + i + " " + j);
+				System.out.println("N=" + BigInteger.valueOf(N).toString(2));
+
 				byte[] Li = new byte[0];
-				for (int ti = 0; ti < numTrees; ti++) {
-					// System.out.println(i + " " + j + " " + ti);
 
+				for (int ti = 0; ti < numTrees; ti++) {
 					long Ni_value = Util.getSubBits(N, addrBits, addrBits - md.getNBitsOfTree(ti));
 					long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
 							Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
 					byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
 					byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
-					// System.out.println("Ni=" +
-					// BigInteger.valueOf(Ni_value).toString(2));
-					// System.out.println("Nip1_pr=" +
-					// BigInteger.valueOf(Nip1_pr_value).toString(2));
 
 					PreData predata = new PreData();
 					PreAccess preaccess = new PreAccess(con1, con2);
-					Access access = new Access(con1, con2);
 
 					if (party == Party.Eddie) {
 						Tree OTi = forest.getTree(ti);
@@ -238,10 +269,12 @@ public class Access extends Protocol {
 						byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
 						con1.write(sD_Nip1_pr);
 
-						access.runE(predata, OTi, sE_Ni, sE_Nip1_pr);
+						stopwatch.start();
+						runE(predata, OTi, sE_Ni, sE_Nip1_pr);
+						stopwatch.stop();
 
 						if (ti == numTrees - 1)
-							con2.write(N);
+							con2.write(BigInteger.valueOf(N).toByteArray());
 
 					} else if (party == Party.Debbie) {
 						Tree OTi = forest.getTree(ti);
@@ -251,19 +284,30 @@ public class Access extends Protocol {
 
 						byte[] sD_Nip1_pr = con1.read();
 
-						access.runD(predata, OTi, sD_Ni, sD_Nip1_pr);
+						stopwatch.start();
+						runD(predata, OTi, sD_Ni, sD_Nip1_pr);
+						stopwatch.stop();
 
 					} else if (party == Party.Charlie) {
 						preaccess.runC();
 
-						OutAccess outaccess = access.runC(md, ti, Li);
+						System.out.println("L" + ti + "=" + new BigInteger(1, Li).toString(2));
+
+						stopwatch.start();
+						OutAccess outaccess = runC(md, ti, Li);
+						stopwatch.stop();
+
 						Li = outaccess.C_Lip1;
 
 						if (ti == numTrees - 1) {
-							N = con1.readObject();
+							N = new BigInteger(con1.read()).longValue();
 							long data = new BigInteger(1, outaccess.C_Ti.getA()).longValue();
-							System.out.println(N);
-							System.out.println(data);
+							if (N == data) {
+								System.out.println("Access passed");
+								System.out.println();
+							} else {
+								throw new AccessException("Access failed");
+							}
 						}
 
 					} else {
@@ -272,5 +316,14 @@ public class Access extends Protocol {
 				}
 			}
 		}
+
+		System.out.println(stopwatch.toMS());
+		
+		System.out.println("step0\n" + step0.toMS());
+		System.out.println("step1\n" + step1.toMS());
+		System.out.println("step2\n" + step2.toMS());
+		System.out.println("step3\n" + step3.toMS());
+		System.out.println("step4\n" + step4.toMS());
+		System.out.println("step5\n" + step5.toMS());
 	}
 }

+ 3 - 2
src/protocols/PreAccess.java

@@ -45,8 +45,9 @@ public class PreAccess extends Protocol {
 
 		// Access
 		predata.access_sigma = con1.readIntArray();
-		Object[] objArray = con1.readObjectArray();
-		predata.access_p = Arrays.copyOf(objArray, objArray.length, Tuple[].class);
+		//Object[] objArray = con1.readObjectArray();
+		//predata.access_p = Arrays.copyOf(objArray, objArray.length, Tuple[].class);
+		predata.access_p = con1.readTupleArray();
 	}
 
 	public void runC() {

+ 77 - 0
src/util/StopWatch.java

@@ -0,0 +1,77 @@
+package util;
+
+import java.lang.management.ManagementFactory;
+import java.lang.management.ThreadMXBean;
+
+import exceptions.StopWatchException;
+
+public class StopWatch {
+
+	private long startWC;
+	private long startCPU;
+	public long elapsedWC;
+	public long elapsedCPU;
+	private boolean isOn;
+
+	public StopWatch() {
+		startWC = 0;
+		startCPU = 0;
+		elapsedWC = 0;
+		elapsedCPU = 0;
+		isOn = false;
+	}
+
+	public void reset() {
+		if (isOn) {
+			try {
+				throw new StopWatchException("StopWatch is still running");
+			} catch (StopWatchException e) {
+				e.printStackTrace();
+			}
+		}
+
+		startWC = 0;
+		startCPU = 0;
+		elapsedWC = 0;
+		elapsedCPU = 0;
+	}
+
+	public void start() {
+		if (isOn) {
+			try {
+				throw new StopWatchException("StopWatch is already running");
+			} catch (StopWatchException e) {
+				e.printStackTrace();
+			}
+		}
+
+		isOn = true;
+		startWC = System.nanoTime();
+		startCPU = getCPUTime();
+	}
+
+	public void stop() {
+		if (!isOn) {
+			try {
+				throw new StopWatchException("StopWatch is not running");
+			} catch (StopWatchException e) {
+				e.printStackTrace();
+			}
+		}
+
+		isOn = false;
+		elapsedCPU += getCPUTime() - startCPU;
+		elapsedWC += System.nanoTime() - startWC;
+	}
+
+	private long getCPUTime() {
+		ThreadMXBean bean = ManagementFactory.getThreadMXBean();
+		return bean.isCurrentThreadCpuTimeSupported() ? bean.getCurrentThreadCpuTime() : 0L;
+	}
+
+	public String toMS() {
+		String out = "WallClock(ms): " + elapsedWC / 1000000;
+		out += "\nCPUClock(ms): " + elapsedCPU / 1000000;
+		return out;
+	}
+}

+ 26 - 0
test/util/TestStopWatch.java

@@ -0,0 +1,26 @@
+package util;
+
+public class TestStopWatch {
+
+	public static void main(String[] args) {
+		StopWatch sw = new StopWatch();
+		sw.start();
+		try {
+			Thread.sleep(1000);
+		} catch (InterruptedException e) {
+			e.printStackTrace();
+		}
+		sw.stop();
+		System.out.println(sw.toMS());
+		
+		sw.start();
+		try {
+			Thread.sleep(1000);
+		} catch (InterruptedException e) {
+			e.printStackTrace();
+		}
+		sw.stop();
+		System.out.println(sw.toMS());
+	}
+
+}