Browse Source

Start to assign Timer and Bandwidth to each Protocol; First tested on PIRCOT

Boyang Wei 6 years ago
parent
commit
caf5d65536

+ 76 - 83
src/communication/Communication.java

@@ -18,9 +18,9 @@ import com.oblivm.backend.gc.GCSignal;
 
 import crypto.SimpleAES;
 import oram.Bucket;
+import oram.Global;
 import oram.Tuple;
 import util.Bandwidth;
-import util.P;
 import util.StopWatch;
 import util.Util;
 
@@ -71,19 +71,12 @@ public class Communication {
 	protected int mState;
 	protected InetSocketAddress mAddress;
 
-	public Bandwidth[] bandwidth;
-	public boolean bandSwitch = true; // TODO: change this to static (or to
-										// Global)
-
+	// TODO: enable link encryption and remove manual AES Enc
 	private static SimpleAES aes = new SimpleAES();
-	public StopWatch comEnc = new StopWatch("CE_online_comp");
+	public StopWatch comEnc = new StopWatch("ComEnc_comp");
 
 	public Communication() {
 		mState = STATE_NONE;
-
-		bandwidth = new Bandwidth[P.size];
-		for (int i = 0; i < P.size; i++)
-			bandwidth[i] = new Bandwidth(P.names[i]);
 	}
 
 	public void setTcpNoDelay(boolean on) {
@@ -336,15 +329,15 @@ public class Communication {
 		r.write(out);
 	}
 
-	public void write(int pid, byte[] out) {
+	public void write(Bandwidth bandwidth, byte[] out) {
 		comEnc.start();
 		out = aes.encrypt(out);
 		comEnc.stop();
 
 		write(out);
 
-		if (bandSwitch)
-			bandwidth[pid].add(out.length);
+		if (Global.bandSwitch)
+			bandwidth.add(out.length);
 	}
 
 	/**
@@ -361,7 +354,7 @@ public class Communication {
 	 * public <T> void write(T out) {
 	 * write(SerializationUtils.serialize((Serializable) out)); }
 	 * 
-	 * public <T> void write(int pid, T out) { write(pid,
+	 * public <T> void write(Bandwidth bandwidth, T out) { write(pid,
 	 * SerializationUtils.serialize((Serializable) out)); }
 	 */
 
@@ -369,136 +362,136 @@ public class Communication {
 		write(b.toByteArray());
 	}
 
-	public void write(int pid, BigInteger b) {
-		write(pid, b.toByteArray());
+	public void write(Bandwidth bandwidth, BigInteger b) {
+		write(bandwidth, b.toByteArray());
 	}
 
 	public void write(int n) {
 		write(BigInteger.valueOf(n));
 	}
 
-	public void write(int pid, int n) {
-		write(pid, BigInteger.valueOf(n));
+	public void write(Bandwidth bandwidth, int n) {
+		write(bandwidth, BigInteger.valueOf(n));
 	}
 
 	public void write(long n) {
 		write(BigInteger.valueOf(n));
 	}
 
-	public void write(int pid, long n) {
-		write(pid, BigInteger.valueOf(n));
+	public void write(Bandwidth bandwidth, long n) {
+		write(bandwidth, BigInteger.valueOf(n));
 	}
 
 	public void write(byte[][] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, byte[][] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, byte[][] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(byte[][][] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, byte[][][] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, byte[][][] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(int[] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, int[] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, int[] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(int[][] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, int[][] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, int[][] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(Tuple t) {
 		write(ComUtil.serialize(t));
 	}
 
-	public void write(int pid, Tuple t) {
-		write(pid, ComUtil.serialize(t));
+	public void write(Bandwidth bandwidth, Tuple t) {
+		write(bandwidth, ComUtil.serialize(t));
 	}
 
 	public void write(Tuple[] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, Tuple[] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, Tuple[] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(Bucket b) {
 		write(b.getTuples());
 	}
 
-	public void write(int pid, Bucket b) {
-		write(pid, b.getTuples());
+	public void write(Bandwidth bandwidth, Bucket b) {
+		write(bandwidth, b.getTuples());
 	}
 
 	public void write(Bucket[] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, Bucket[] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, Bucket[] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(GCSignal key) {
 		write(key.bytes);
 	}
 
-	public void write(int pid, GCSignal key) {
-		write(pid, key.bytes);
+	public void write(Bandwidth bandwidth, GCSignal key) {
+		write(bandwidth, key.bytes);
 	}
 
 	public void write(GCSignal[] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, GCSignal[] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, GCSignal[] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(GCSignal[][] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, GCSignal[][] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, GCSignal[][] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(GCSignal[][][] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, GCSignal[][][] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, GCSignal[][][] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(GCSignal[][][][] arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, GCSignal[][][][] arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, GCSignal[][][][] arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public void write(ArrayList<byte[]> arr) {
 		write(ComUtil.serialize(arr));
 	}
 
-	public void write(int pid, ArrayList<byte[]> arr) {
-		write(pid, ComUtil.serialize(arr));
+	public void write(Bandwidth bandwidth, ArrayList<byte[]> arr) {
+		write(bandwidth, ComUtil.serialize(arr));
 	}
 
 	public static final Charset defaultCharset = Charset.forName("ASCII");
@@ -554,7 +547,7 @@ public class Communication {
 		return readMessage;
 	}
 
-	public byte[] read(int pid) {
+	public byte[] readAndDec() {
 		byte[] msg = read();
 		comEnc.start();
 		msg = aes.decrypt(msg);
@@ -598,136 +591,136 @@ public class Communication {
 		return new BigInteger(read());
 	}
 
-	public BigInteger readBigInteger(int pid) {
-		return new BigInteger(read(pid));
+	public BigInteger readBigIntegerAndDec() {
+		return new BigInteger(readAndDec());
 	}
 
 	public int readInt() {
 		return readBigInteger().intValue();
 	}
 
-	public int readInt(int pid) {
-		return readBigInteger(pid).intValue();
+	public int readIntAndDec() {
+		return readBigIntegerAndDec().intValue();
 	}
 
 	public long readLong() {
 		return readBigInteger().longValue();
 	}
 
-	public long readLong(int pid) {
-		return readBigInteger(pid).longValue();
+	public long readLongAndDec() {
+		return readBigIntegerAndDec().longValue();
 	}
 
 	public byte[][] readDoubleByteArray() {
 		return ComUtil.toDoubleByteArray(read());
 	}
 
-	public byte[][] readDoubleByteArray(int pid) {
-		return ComUtil.toDoubleByteArray(read(pid));
+	public byte[][] readDoubleByteArrayAndDec() {
+		return ComUtil.toDoubleByteArray(readAndDec());
 	}
 
 	public byte[][][] readTripleByteArray() {
 		return ComUtil.toTripleByteArray(read());
 	}
 
-	public byte[][][] readTripleByteArray(int pid) {
-		return ComUtil.toTripleByteArray(read(pid));
+	public byte[][][] readTripleByteArrayAndDec() {
+		return ComUtil.toTripleByteArray(readAndDec());
 	}
 
 	public int[] readIntArray() {
 		return ComUtil.toIntArray(read());
 	}
 
-	public int[] readIntArray(int pid) {
-		return ComUtil.toIntArray(read(pid));
+	public int[] readIntArrayAndDec() {
+		return ComUtil.toIntArray(readAndDec());
 	}
 
 	public int[][] readDoubleIntArray() {
 		return ComUtil.toDoubleIntArray(read());
 	}
 
-	public int[][] readDoubleIntArray(int pid) {
-		return ComUtil.toDoubleIntArray(read(pid));
+	public int[][] readDoubleIntArrayAndDec() {
+		return ComUtil.toDoubleIntArray(readAndDec());
 	}
 
 	public Tuple readTuple() {
 		return ComUtil.toTuple(read());
 	}
 
-	public Tuple readTuple(int pid) {
-		return ComUtil.toTuple(read(pid));
+	public Tuple readTupleAndDec() {
+		return ComUtil.toTuple(readAndDec());
 	}
 
 	public Tuple[] readTupleArray() {
 		return ComUtil.toTupleArray(read());
 	}
 
-	public Tuple[] readTupleArray(int pid) {
-		return ComUtil.toTupleArray(read(pid));
+	public Tuple[] readTupleArrayAndDec() {
+		return ComUtil.toTupleArray(readAndDec());
 	}
 
 	public Bucket readBucket() {
 		return new Bucket(readTupleArray());
 	}
 
-	public Bucket readBucket(int pid) {
-		return new Bucket(readTupleArray(pid));
+	public Bucket readBucketAndDec() {
+		return new Bucket(readTupleArrayAndDec());
 	}
 
 	public Bucket[] readBucketArray() {
 		return ComUtil.toBucketArray(read());
 	}
 
-	public Bucket[] readBucketArray(int pid) {
-		return ComUtil.toBucketArray(read(pid));
+	public Bucket[] readBucketArrayAndDec() {
+		return ComUtil.toBucketArray(readAndDec());
 	}
 
 	public GCSignal readGCSignal() {
 		return new GCSignal(read());
 	}
 
-	public GCSignal readGCSignal(int pid) {
-		return new GCSignal(read(pid));
+	public GCSignal readGCSignalAndDec() {
+		return new GCSignal(readAndDec());
 	}
 
 	public GCSignal[] readGCSignalArray() {
 		return ComUtil.toGCSignalArray(read());
 	}
 
-	public GCSignal[] readGCSignalArray(int pid) {
-		return ComUtil.toGCSignalArray(read(pid));
+	public GCSignal[] readGCSignalArrayAndDec() {
+		return ComUtil.toGCSignalArray(readAndDec());
 	}
 
 	public GCSignal[][] readDoubleGCSignalArray() {
 		return ComUtil.toDoubleGCSignalArray(read());
 	}
 
-	public GCSignal[][] readDoubleGCSignalArray(int pid) {
-		return ComUtil.toDoubleGCSignalArray(read(pid));
+	public GCSignal[][] readDoubleGCSignalArrayAndDec() {
+		return ComUtil.toDoubleGCSignalArray(readAndDec());
 	}
 
 	public GCSignal[][][] readTripleGCSignalArray() {
 		return ComUtil.toTripleGCSignalArray(read());
 	}
 
-	public GCSignal[][][] readTripleGCSignalArray(int pid) {
-		return ComUtil.toTripleGCSignalArray(read(pid));
+	public GCSignal[][][] readTripleGCSignalArrayAndDec() {
+		return ComUtil.toTripleGCSignalArray(readAndDec());
 	}
 
 	public GCSignal[][][][] readQuadGCSignalArray() {
 		return ComUtil.toQuadGCSignalArray(read());
 	}
 
-	public GCSignal[][][][] readQuadGCSignalArray(int pid) {
-		return ComUtil.toQuadGCSignalArray(read(pid));
+	public GCSignal[][][][] readQuadGCSignalArrayAndDec() {
+		return ComUtil.toQuadGCSignalArray(readAndDec());
 	}
 
 	public ArrayList<byte[]> readArrayList() {
 		return ComUtil.toArrayList(read());
 	}
 
-	public ArrayList<byte[]> readArrayList(int pid) {
-		return ComUtil.toArrayList(read(pid));
+	public ArrayList<byte[]> readArrayListAndDec() {
+		return ComUtil.toArrayList(readAndDec());
 	}
 
 	/**

+ 1 - 0
src/oram/Global.java

@@ -5,5 +5,6 @@ public class Global {
 	public static boolean cheat = true;
 	public static boolean pipeline = false;
 	public static boolean usePIR = true;
+	public static boolean bandSwitch = true;
 
 }

+ 35 - 116
src/pir/PIRCOT.java

@@ -9,24 +9,19 @@ import oram.Metadata;
 import pir.precomputation.PrePIRCOT;
 import protocols.Protocol;
 import protocols.struct.OutPIRCOT;
-import protocols.struct.OutSSCOT;
 import protocols.struct.Party;
 import protocols.struct.PreData;
 import util.M;
-import util.P;
-import util.Timer;
 import util.Util;
 
 public class PIRCOT extends Protocol {
 
-	private int pid = P.COT;
-
 	public PIRCOT(Communication con1, Communication con2) {
 		super(con1, con2);
 	}
 
-	public OutPIRCOT runE(PreData predata, byte[][] u, byte[] v, Timer timer) {
-		timer.start(pid, M.online_comp);
+	public OutPIRCOT runE(PreData predata, byte[][] u, byte[] v) {
+		timer.start(M.online_comp);
 
 		int l = u.length;
 		byte[][] a = new byte[l][];
@@ -37,13 +32,13 @@ public class PIRCOT extends Protocol {
 			a[j] = predata.sscot_F_k.compute(a[j]);
 		}
 
-		timer.start(pid, M.online_write);
-		con2.write(pid, a);
-		timer.stop(pid, M.online_write);
+		timer.start(M.online_write);
+		con2.write(online_band, a);
+		timer.stop(M.online_write);
 
-		timer.start(pid, M.online_read);
-		int delta = con2.readInt(pid);
-		timer.stop(pid, M.online_read);
+		timer.start(M.online_read);
+		int delta = con2.readIntAndDec();
+		timer.stop(M.online_read);
 
 		int t_E = (predata.sscot_s_DE + delta) % l;
 
@@ -52,12 +47,12 @@ public class PIRCOT extends Protocol {
 		out.s_DE = predata.sscot_s_DE;
 		out.s_CE = predata.sscot_s_CE;
 
-		timer.stop(pid, M.online_comp);
+		timer.stop(M.online_comp);
 		return out;
 	}
 
-	public OutPIRCOT runD(PreData predata, byte[][] u, byte[] v, Timer timer) {
-		timer.start(pid, M.online_comp);
+	public OutPIRCOT runD(PreData predata, byte[][] u, byte[] v) {
+		timer.start(M.online_comp);
 
 		int l = u.length;
 		byte[][] a = new byte[l][];
@@ -68,13 +63,13 @@ public class PIRCOT extends Protocol {
 			a[j] = predata.sscot_F_k.compute(a[j]);
 		}
 
-		timer.start(pid, M.online_write);
-		con2.write(pid, a);
-		timer.stop(pid, M.online_write);
+		timer.start(M.online_write);
+		con2.write(online_band, a);
+		timer.stop(M.online_write);
 
-		timer.start(pid, M.online_read);
-		int delta = con2.readInt(pid);
-		timer.stop(pid, M.online_read);
+		timer.start(M.online_read);
+		int delta = con2.readIntAndDec();
+		timer.stop(M.online_read);
 
 		int t_D = (predata.sscot_s_DE + delta) % l;
 
@@ -83,17 +78,17 @@ public class PIRCOT extends Protocol {
 		out.s_DE = predata.sscot_s_DE;
 		out.s_CD = predata.sscot_s_CD;
 
-		timer.stop(pid, M.online_comp);
+		timer.stop(M.online_comp);
 		return out;
 	}
 
-	public OutPIRCOT runC(PreData predata, Timer timer) {
-		timer.start(pid, M.online_comp);
+	public OutPIRCOT runC(PreData predata) {
+		timer.start(M.online_comp);
 
-		timer.start(pid, M.online_read);
-		byte[][] x = con1.readDoubleByteArray(pid);
-		byte[][] y = con2.readDoubleByteArray(pid);
-		timer.stop(pid, M.online_read);
+		timer.start(M.online_read);
+		byte[][] x = con1.readDoubleByteArrayAndDec();
+		byte[][] y = con2.readDoubleByteArrayAndDec();
+		timer.stop(M.online_read);
 
 		int l = x.length;
 		int count = 0;
@@ -112,25 +107,23 @@ public class PIRCOT extends Protocol {
 		int delta_D = (t_C - predata.sscot_s_CE + l) % l;
 		int delta_E = (t_C - predata.sscot_s_CD + l) % l;
 
-		timer.start(pid, M.online_write);
-		con2.write(pid, delta_D);
-		con1.write(pid, delta_E);
-		timer.stop(pid, M.online_write);
+		timer.start(M.online_write);
+		con2.write(online_band, delta_D);
+		con1.write(online_band, delta_E);
+		timer.stop(M.online_write);
 
 		OutPIRCOT out = new OutPIRCOT();
 		out.t_C = t_C;
 		out.s_CE = predata.sscot_s_CE;
 		out.s_CD = predata.sscot_s_CD;
 
-		timer.stop(pid, M.online_comp);
+		timer.stop(M.online_comp);
 		return out;
 	}
 
 	@Override
 	public void run(Party party, Metadata md, Forest[] forest) {
 
-		Timer timer = new Timer();
-
 		for (int j = 0; j < 100; j++) {
 			int n = 100;
 			int FN = 5;
@@ -148,16 +141,16 @@ public class PIRCOT extends Protocol {
 
 			if (party == Party.Eddie) {
 				con2.write(index);
-				presscot.runE(predata, n, timer);
-				output = runE(predata, a, v, timer);
+				presscot.runE(predata, n);
+				output = runE(predata, a, v);
 
 				con2.write(output.t_E);
 				con2.write(output.s_CE);
 				con2.write(output.s_DE);
 
 			} else if (party == Party.Debbie) {
-				presscot.runD(predata, n, timer);
-				output = runD(predata, b, new byte[FN], timer);
+				presscot.runD(predata, n);
+				output = runD(predata, b, new byte[FN]);
 
 				con2.write(output.t_D);
 				con2.write(output.s_DE);
@@ -165,8 +158,8 @@ public class PIRCOT extends Protocol {
 
 			} else if (party == Party.Charlie) {
 				index = con1.readInt();
-				presscot.runC(predata, timer);
-				output = runC(predata, timer);
+				presscot.runC(predata);
+				output = runC(predata);
 
 				int t_E = con1.readInt();
 				int s_CE = con1.readInt();
@@ -198,80 +191,6 @@ public class PIRCOT extends Protocol {
 		}
 	}
 
-	public void runE(PreData predata, byte[][] a, Timer timer) {
-		timer.start(pid, M.online_comp);
-
-		// step 1
-		int n = a.length;
-		byte[][] x = predata.sscot_r;
-		byte[][] v = new byte[n][];
-
-		for (int i = 0; i < n; i++) {
-			for (int j = 0; j < a[i].length; j++)
-				x[i][j] = (byte) (predata.sscot_r[i][j] ^ a[i][j]);
-
-			v[i] = predata.sscot_F_kprime.compute(x[i]);
-		}
-
-		timer.start(pid, M.online_write);
-		con2.write(pid, v);
-		timer.stop(pid, M.online_write);
-
-		timer.stop(pid, M.online_comp);
-	}
-
-	public void runD(PreData predata, byte[][] b, Timer timer) {
-		timer.start(pid, M.online_comp);
-
-		// step 2
-		int n = b.length;
-		byte[][] y = predata.sscot_r;
-		byte[][] w = new byte[n][];
-
-		for (int i = 0; i < n; i++) {
-			for (int j = 0; j < b[i].length; j++)
-				y[i][j] = (byte) (predata.sscot_r[i][j] ^ b[i][j]);
-
-			w[i] = predata.sscot_F_kprime.compute(y[i]);
-		}
-
-		timer.start(pid, M.online_write);
-		con2.write(pid, w);
-		timer.stop(pid, M.online_write);
-
-		timer.stop(pid, M.online_comp);
-	}
-
-	public OutSSCOT runC(Timer timer) {
-		timer.start(pid, M.online_comp);
-
-		// step 1
-		timer.start(pid, M.online_read);
-		byte[][] v = con1.readDoubleByteArray(pid);
-
-		// step 2
-		byte[][] w = con2.readDoubleByteArray(pid);
-		timer.stop(pid, M.online_read);
-
-		// step 3
-		int n = v.length;
-		OutSSCOT output = null;
-		int invariant = 0;
-
-		for (int i = 0; i < n; i++) {
-			if (Util.equal(v[i], w[i])) {
-				output = new OutSSCOT(i, null);
-				invariant++;
-			}
-		}
-
-		if (invariant != 1)
-			throw new SSCOTException("Invariant error: " + invariant);
-
-		timer.stop(pid, M.online_comp);
-		return output;
-	}
-
 	// for testing correctness
 	@Override
 	public void run(Party party, Metadata md, Forest forest) {

+ 4 - 4
src/pir/PIRRetrieve.java

@@ -357,7 +357,7 @@ public class PIRRetrieve extends Protocol {
 					ete.start();
 					OutPIRAccess out = this.runE(md, predata, tree_DE, tree_CE, Li, L, N, dN, timer);
 					ete.stop();
-					
+
 					out.j.t_D = con1.readInt();
 					out.j.t_C = con2.readInt();
 					out.X.CD = con1.read();
@@ -385,7 +385,7 @@ public class PIRRetrieve extends Protocol {
 					ete.start();
 					OutPIRAccess out = this.runD(md, predata, tree_DE, tree_CD, Li, L, N, dN, timer);
 					ete.stop();
-					
+
 					con1.write(out.j.t_D);
 					con1.write(out.X.CD);
 
@@ -393,7 +393,7 @@ public class PIRRetrieve extends Protocol {
 					ete.start();
 					OutPIRAccess out = this.runC(md, predata, tree_CD, tree_CE, Li, L, N, dN, timer);
 					ete.stop();
-					
+
 					con1.write(out.j.t_C);
 
 				} else {
@@ -412,7 +412,7 @@ public class PIRRetrieve extends Protocol {
 
 		// timer.divideBy(iterations - reset);
 		// timer.print();
-		
+
 		System.out.println(ete.toMS());
 
 		sanityCheck();

+ 27 - 33
src/pir/precomputation/PrePIRCOT.java

@@ -9,21 +9,17 @@ import protocols.Protocol;
 import protocols.struct.Party;
 import protocols.struct.PreData;
 import util.M;
-import util.P;
-import util.Timer;
 
 public class PrePIRCOT extends Protocol {
 
-	private int pid = P.COT;
-
 	public PrePIRCOT(Communication con1, Communication con2) {
 		super(con1, con2);
 	}
 
 	// TODO: change PRF output bits to max(32, N)
 
-	public void runE(PreData predata, int l, Timer timer) {
-		timer.start(pid, M.offline_comp);
+	public void runE(PreData predata, int l) {
+		timer.start(M.offline_comp);
 
 		predata.sscot_k = PRF.generateKey(Crypto.sr);
 		predata.sscot_r = new byte[l][];
@@ -34,49 +30,49 @@ public class PrePIRCOT extends Protocol {
 		predata.sscot_s_DE = Crypto.sr.nextInt(l);
 		predata.sscot_s_CE = Crypto.sr.nextInt(l);
 
-		timer.start(pid, M.offline_write);
-		con1.write(predata.sscot_k);
-		con1.write(predata.sscot_r);
-		con1.write(predata.sscot_s_DE);
-		con2.write(predata.sscot_s_CE);
-		timer.stop(pid, M.offline_write);
+		timer.start(M.offline_write);
+		con1.write(offline_band, predata.sscot_k);
+		con1.write(offline_band, predata.sscot_r);
+		con1.write(offline_band, predata.sscot_s_DE);
+		con2.write(offline_band, predata.sscot_s_CE);
+		timer.stop(M.offline_write);
 
 		predata.sscot_F_k = new PRF(Crypto.secParam);
 		predata.sscot_F_k.init(predata.sscot_k);
 
-		timer.stop(pid, M.offline_comp);
+		timer.stop(M.offline_comp);
 	}
 
-	public void runD(PreData predata, int l, Timer timer) {
-		timer.start(pid, M.offline_comp);
+	public void runD(PreData predata, int l) {
+		timer.start(M.offline_comp);
 
 		predata.sscot_s_CD = Crypto.sr.nextInt(l);
 
-		timer.start(pid, M.offline_write);
-		con2.write(predata.sscot_s_CD);
-		timer.stop(pid, M.offline_write);
+		timer.start(M.offline_write);
+		con2.write(offline_band, predata.sscot_s_CD);
+		timer.stop(M.offline_write);
 
-		timer.start(pid, M.offline_read);
-		predata.sscot_k = con1.read();
-		predata.sscot_r = con1.readDoubleByteArray();
-		predata.sscot_s_DE = con1.readInt();
-		timer.stop(pid, M.offline_read);
+		timer.start(M.offline_read);
+		predata.sscot_k = con1.readAndDec();
+		predata.sscot_r = con1.readDoubleByteArrayAndDec();
+		predata.sscot_s_DE = con1.readIntAndDec();
+		timer.stop(M.offline_read);
 
 		predata.sscot_F_k = new PRF(Crypto.secParam);
 		predata.sscot_F_k.init(predata.sscot_k);
 
-		timer.stop(pid, M.offline_comp);
+		timer.stop(M.offline_comp);
 	}
 
-	public void runC(PreData predata, Timer timer) {
-		timer.start(pid, M.offline_comp);
+	public void runC(PreData predata) {
+		timer.start(M.offline_comp);
 
-		timer.start(pid, M.offline_read);
-		predata.sscot_s_CE = con1.readInt();
-		predata.sscot_s_CD = con2.readInt();
-		timer.stop(pid, M.offline_read);
+		timer.start(M.offline_read);
+		predata.sscot_s_CE = con1.readIntAndDec();
+		predata.sscot_s_CD = con2.readIntAndDec();
+		timer.stop(M.offline_read);
 
-		timer.stop(pid, M.offline_comp);
+		timer.stop(M.offline_comp);
 	}
 
 	@Override
@@ -85,7 +81,5 @@ public class PrePIRCOT extends Protocol {
 
 	@Override
 	public void run(Party party, Metadata md, Forest[] forest) {
-		// TODO Auto-generated method stub
-
 	}
 }

+ 8 - 0
src/protocols/Protocol.java

@@ -7,10 +7,15 @@ import oram.Forest;
 import oram.Global;
 import oram.Metadata;
 import protocols.struct.Party;
+import util.Bandwidth;
+import util.Timer;
 
 public abstract class Protocol {
 	protected Communication con1;
 	protected Communication con2;
+	public Timer timer;
+	public Bandwidth online_band;
+	public Bandwidth offline_band;
 
 	/*
 	 * Connections are alphabetized so:
@@ -24,6 +29,9 @@ public abstract class Protocol {
 	public Protocol(Communication con1, Communication con2) {
 		this.con1 = con1;
 		this.con2 = con2;
+		timer = new Timer();
+		online_band = new Bandwidth("Online");
+		offline_band = new Bandwidth("Offline");
 	}
 
 	private static final boolean ENSURE_SANITY = true;

+ 2 - 2
src/util/Bandwidth.java

@@ -5,7 +5,7 @@ import exceptions.BandwidthException;
 public class Bandwidth {
 
 	public String task;
-	public int bandwidth;
+	public long bandwidth;
 
 	public Bandwidth(String t) {
 		task = t;
@@ -21,7 +21,7 @@ public class Bandwidth {
 		bandwidth = 0;
 	}
 
-	public void add(int n) {
+	public void add(long n) {
 		bandwidth += n;
 	}
 

+ 19 - 24
src/util/Timer.java

@@ -5,72 +5,68 @@ import java.util.Stack;
 import exceptions.TimerException;
 
 public class Timer {
-	StopWatch[][] watches;
+	StopWatch[] watches;
 	Stack<StopWatch> stack;
 
 	public Timer() {
-		watches = new StopWatch[P.size][M.size];
-		for (int i = 0; i < P.size; i++)
-			for (int j = 0; j < M.size; j++)
-				watches[i][j] = new StopWatch(P.names[i] + "_" + M.names[j]);
+		watches = new StopWatch[M.size];
+		for (int j = 0; j < M.size; j++)
+			watches[j] = new StopWatch(M.names[j]);
 		stack = new Stack<StopWatch>();
 	}
 
-	public Timer(StopWatch[][] sws) {
+	public Timer(StopWatch[] sws) {
 		watches = sws;
 		stack = new Stack<StopWatch>();
 	}
 
-	public void start(int p, int m) {
+	public void start(int m) {
 		if (!stack.empty()) {
-			if (stack.peek() == watches[p][m])
+			if (stack.peek() == watches[m])
 				throw new TimerException("Stopwatch already added to stack");
 			stack.peek().stop();
 		}
-		stack.push(watches[p][m]).start();
+		stack.push(watches[m]).start();
 
 	}
 
-	public void stop(int p, int m) {
+	public void stop(int m) {
 		if (stack.empty())
-			throw new TimerException("No stopwatch found");
+			throw new TimerException("No stopwatch running");
+		if (stack.peek() != watches[m])
+			throw new TimerException("Wrong Stopwatch to stop");
 		stack.pop().stop();
 		if (!stack.empty())
 			stack.peek().start();
-
 	}
 
 	public void reset() {
 		if (!stack.empty())
 			throw new TimerException("Stack not empty");
 		for (int i = 0; i < watches.length; i++)
-			for (int j = 0; j < watches[i].length; j++)
-				watches[i][j].reset();
+			watches[i].reset();
 	}
 
 	public void print() {
 		if (!stack.empty())
 			throw new TimerException("Stack not empty");
 		for (int i = 0; i < watches.length; i++)
-			for (int j = 0; j < watches[i].length; j++)
-				System.out.println(watches[i][j].toMS());
+			System.out.println(watches[i].toMS());
 	}
 
 	public void noPrePrint() {
 		if (!stack.empty())
 			throw new TimerException("Stack not empty");
 		for (int i = 0; i < watches.length; i++)
-			for (int j = 0; j < watches[i].length; j++)
-				System.out.println(watches[i][j].noPreToMS());
+			System.out.println(watches[i].noPreToMS());
 	}
 
 	public Timer divideBy(int n) {
 		if (!stack.empty())
 			throw new TimerException("Stack not empty");
-		StopWatch[][] sws = new StopWatch[P.size][M.size];
+		StopWatch[] sws = new StopWatch[M.size];
 		for (int i = 0; i < watches.length; i++)
-			for (int j = 0; j < watches[i].length; j++)
-				sws[i][j] = watches[i][j].divideBy(n);
+			sws[i] = watches[i].divideBy(n);
 		return new Timer(sws);
 	}
 
@@ -78,10 +74,9 @@ public class Timer {
 		if (!stack.empty() || !t.stack.empty())
 			throw new TimerException("Stack not empty");
 
-		StopWatch[][] sws = new StopWatch[P.size][M.size];
+		StopWatch[] sws = new StopWatch[M.size];
 		for (int i = 0; i < watches.length; i++)
-			for (int j = 0; j < watches[i].length; j++)
-				sws[i][j] = watches[i][j].add(t.watches[i][j]);
+			sws[i] = watches[i].add(t.watches[i]);
 		return new Timer(sws);
 	}
 }