Browse Source

added simple AES encryption on online communication; change scripts to run sequential program by default; fixed few bandwidth counting bugs

Boyoung- 7 years ago
parent
commit
8d159d1353

BIN
key/aes.key


+ 1 - 1
scripts/aws_charlie.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
-java -cp "${DIR}/../bin:${DIR}/../lib/*" ui.CLI -protocol rtv -eddie_ip 52.42.255.187 -debbie_ip 52.38.255.37 charlie -pipeline
+java -cp "${DIR}/../bin:${DIR}/../lib/*" ui.CLI -protocol rtv -eddie_ip 52.42.255.187 -debbie_ip 52.38.255.37 charlie

+ 1 - 1
scripts/aws_debbie.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
-java -cp "${DIR}/../bin:${DIR}/../lib/*" ui.CLI -protocol rtv -eddie_ip 52.42.255.187 debbie -pipeline
+java -cp "${DIR}/../bin:${DIR}/../lib/*" ui.CLI -protocol rtv -eddie_ip 52.42.255.187 debbie

+ 1 - 1
scripts/aws_eddie.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
 DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
-java -cp "${DIR}/../bin:${DIR}/../lib/*" ui.CLI -protocol rtv eddie -pipeline
+java -cp "${DIR}/../bin:${DIR}/../lib/*" ui.CLI -protocol rtv eddie

+ 89 - 2
src/communication/Communication.java

@@ -16,10 +16,12 @@ import java.util.concurrent.LinkedBlockingQueue;
 
 import com.oblivm.backend.gc.GCSignal;
 
+import crypto.SimpleAES;
 import oram.Bucket;
 import oram.Tuple;
 import util.Bandwidth;
 import util.P;
+import util.StopWatch;
 import util.Util;
 
 /**
@@ -70,7 +72,11 @@ public class Communication {
 	protected InetSocketAddress mAddress;
 
 	public Bandwidth[] bandwidth;
-	public boolean bandSwitch = true;
+	public boolean bandSwitch = true; // TODO: change this to static (or to
+										// Global)
+
+	private static SimpleAES aes = new SimpleAES();
+	public StopWatch comEnc = new StopWatch("CE_online_comp");
 
 	public Communication() {
 		mState = STATE_NONE;
@@ -332,7 +338,12 @@ public class Communication {
 	}
 
 	public void write(int pid, byte[] out) {
+		comEnc.start();
+		out = aes.encrypt(out);
+		comEnc.stop();
+
 		write(out);
+
 		if (bandSwitch)
 			bandwidth[pid].add(out.length);
 	}
@@ -440,7 +451,7 @@ public class Communication {
 	}
 
 	public void write(int pid, Bucket[] arr) {
-		write(ComUtil.serialize(arr));
+		write(pid, ComUtil.serialize(arr));
 	}
 
 	public void write(GCSignal key) {
@@ -544,6 +555,14 @@ public class Communication {
 		return readMessage;
 	}
 
+	public byte[] read(int pid) {
+		byte[] msg = read();
+		comEnc.start();
+		msg = aes.decrypt(msg);
+		comEnc.stop();
+		return msg;
+	}
+
 	/**
 	 * Read a specific number of bytes from the ConnectedThread in an
 	 * unsynchronized manner Note, this is a blocking call
@@ -580,70 +599,138 @@ public class Communication {
 		return new BigInteger(read());
 	}
 
+	public BigInteger readBigInteger(int pid) {
+		return new BigInteger(read(pid));
+	}
+
 	public int readInt() {
 		return readBigInteger().intValue();
 	}
 
+	public int readInt(int pid) {
+		return readBigInteger(pid).intValue();
+	}
+
 	public long readLong() {
 		return readBigInteger().longValue();
 	}
 
+	public long readLong(int pid) {
+		return readBigInteger(pid).longValue();
+	}
+
 	public byte[][] readDoubleByteArray() {
 		return ComUtil.toDoubleByteArray(read());
 	}
 
+	public byte[][] readDoubleByteArray(int pid) {
+		return ComUtil.toDoubleByteArray(read(pid));
+	}
+
 	public byte[][][] readTripleByteArray() {
 		return ComUtil.toTripleByteArray(read());
 	}
 
+	public byte[][][] readTripleByteArray(int pid) {
+		return ComUtil.toTripleByteArray(read(pid));
+	}
+
 	public int[] readIntArray() {
 		return ComUtil.toIntArray(read());
 	}
 
+	public int[] readIntArray(int pid) {
+		return ComUtil.toIntArray(read(pid));
+	}
+
 	public int[][] readDoubleIntArray() {
 		return ComUtil.toDoubleIntArray(read());
 	}
 
+	public int[][] readDoubleIntArray(int pid) {
+		return ComUtil.toDoubleIntArray(read(pid));
+	}
+
 	public Tuple readTuple() {
 		return ComUtil.toTuple(read());
 	}
 
+	public Tuple readTuple(int pid) {
+		return ComUtil.toTuple(read(pid));
+	}
+
 	public Tuple[] readTupleArray() {
 		return ComUtil.toTupleArray(read());
 	}
 
+	public Tuple[] readTupleArray(int pid) {
+		return ComUtil.toTupleArray(read(pid));
+	}
+
 	public Bucket readBucket() {
 		return new Bucket(readTupleArray());
 	}
 
+	public Bucket readBucket(int pid) {
+		return new Bucket(readTupleArray(pid));
+	}
+
 	public Bucket[] readBucketArray() {
 		return ComUtil.toBucketArray(read());
 	}
 
+	public Bucket[] readBucketArray(int pid) {
+		return ComUtil.toBucketArray(read(pid));
+	}
+
 	public GCSignal readGCSignal() {
 		return new GCSignal(read());
 	}
 
+	public GCSignal readGCSignal(int pid) {
+		return new GCSignal(read(pid));
+	}
+
 	public GCSignal[] readGCSignalArray() {
 		return ComUtil.toGCSignalArray(read());
 	}
 
+	public GCSignal[] readGCSignalArray(int pid) {
+		return ComUtil.toGCSignalArray(read(pid));
+	}
+
 	public GCSignal[][] readDoubleGCSignalArray() {
 		return ComUtil.toDoubleGCSignalArray(read());
 	}
 
+	public GCSignal[][] readDoubleGCSignalArray(int pid) {
+		return ComUtil.toDoubleGCSignalArray(read(pid));
+	}
+
 	public GCSignal[][][] readTripleGCSignalArray() {
 		return ComUtil.toTripleGCSignalArray(read());
 	}
 
+	public GCSignal[][][] readTripleGCSignalArray(int pid) {
+		return ComUtil.toTripleGCSignalArray(read(pid));
+	}
+
 	public GCSignal[][][][] readQuadGCSignalArray() {
 		return ComUtil.toQuadGCSignalArray(read());
 	}
 
+	public GCSignal[][][][] readQuadGCSignalArray(int pid) {
+		return ComUtil.toQuadGCSignalArray(read(pid));
+	}
+
 	public ArrayList<byte[]> readArrayList() {
 		return ComUtil.toArrayList(read());
 	}
 
+	public ArrayList<byte[]> readArrayList(int pid) {
+		return ComUtil.toArrayList(read(pid));
+	}
+
 	/**
 	 * This thread runs while listening for incoming connections. It behaves
 	 * like a server-side client. It runs until a connection is accepted (or

+ 117 - 0
src/crypto/SimpleAES.java

@@ -0,0 +1,117 @@
+package crypto;
+
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.math.BigInteger;
+import java.security.InvalidKeyException;
+import java.security.NoSuchAlgorithmException;
+import java.util.Random;
+
+import javax.crypto.BadPaddingException;
+import javax.crypto.Cipher;
+import javax.crypto.IllegalBlockSizeException;
+import javax.crypto.NoSuchPaddingException;
+import javax.crypto.spec.SecretKeySpec;
+
+public class SimpleAES {
+	private Cipher cipherEnc;
+	private Cipher cipherDec;
+
+	public SimpleAES() {
+		SecretKeySpec skey = readKey();
+		try {
+			cipherEnc = Cipher.getInstance("AES/ECB/PKCS5Padding");
+			cipherDec = Cipher.getInstance("AES/ECB/PKCS5Padding");
+			cipherEnc.init(Cipher.ENCRYPT_MODE, skey);
+			cipherDec.init(Cipher.DECRYPT_MODE, skey);
+		} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) {
+			e.printStackTrace();
+		}
+	}
+
+	public synchronized byte[] encrypt(byte[] in) {
+		byte[] out = null;
+		try {
+			out = cipherEnc.doFinal(in);
+		} catch (IllegalBlockSizeException | BadPaddingException e) {
+			e.printStackTrace();
+		}
+		return out;
+	}
+
+	public synchronized byte[] decrypt(byte[] in) {
+		byte[] out = null;
+		try {
+			out = cipherDec.doFinal(in);
+		} catch (IllegalBlockSizeException | BadPaddingException e) {
+			e.printStackTrace();
+		}
+		return out;
+	}
+
+	public static void generateKey(Random rand) {
+		byte[] key = new byte[16];
+		rand.nextBytes(key);
+		SecretKeySpec skey = new SecretKeySpec(key, "AES");
+		FileOutputStream fos = null;
+		ObjectOutputStream oos = null;
+		try {
+			fos = new FileOutputStream("key/aes.key");
+			oos = new ObjectOutputStream(fos);
+			oos.writeObject(skey);
+		} catch (IOException e) {
+			e.printStackTrace();
+		} finally {
+			if (oos != null) {
+				try {
+					oos.close();
+				} catch (IOException e) {
+					e.printStackTrace();
+				}
+			}
+		}
+	}
+
+	public SecretKeySpec readKey() {
+		FileInputStream fis = null;
+		ObjectInputStream ois = null;
+		SecretKeySpec skey = null;
+		try {
+			fis = new FileInputStream("key/aes.key");
+			ois = new ObjectInputStream(fis);
+			skey = (SecretKeySpec) ois.readObject();
+		} catch (IOException | ClassNotFoundException e) {
+			e.printStackTrace();
+		} finally {
+			if (ois != null) {
+				try {
+					ois.close();
+				} catch (IOException e) {
+					e.printStackTrace();
+				}
+			}
+		}
+		return skey;
+	}
+
+	// test
+	public static void main(String[] args) {
+		SimpleAES aes = new SimpleAES();
+		byte[] plain = new byte[10240 * 8];
+		Crypto.sr.nextBytes(plain);
+		byte[] enc = aes.encrypt(plain);
+		byte[] tmp = new byte[plain.length];
+		Crypto.sr.nextBytes(tmp);
+		aes.decrypt(aes.encrypt(tmp));
+		byte[] dec = aes.decrypt(enc);
+		long in = new BigInteger(plain).longValue();
+		long cipher = new BigInteger(enc).longValue();
+		long out = new BigInteger(dec).longValue();
+		System.out.println(in != cipher);
+		System.out.println(in == out);
+		System.out.println(plain.length == dec.length);
+	}
+}

+ 1 - 0
src/oram/Global.java

@@ -4,4 +4,5 @@ public class Global {
 
 	public static boolean cheat = true;
 	public static boolean pipeline = false;
+
 }

+ 11 - 11
src/protocols/Access.java

@@ -41,7 +41,7 @@ public class Access extends Protocol {
 		byte[] Li = new byte[0];
 		timer.start(pid, M.online_read);
 		if (OTi.getTreeIndex() > 0)
-			Li = con2.read();
+			Li = con2.read(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 1
@@ -105,7 +105,7 @@ public class Access extends Protocol {
 		byte[] Li = new byte[0];
 		timer.start(pid, M.online_read);
 		if (OTi.getTreeIndex() > 0)
-			Li = con2.read();
+			Li = con2.read(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 1
@@ -152,15 +152,15 @@ public class Access extends Protocol {
 		// step 0: send Li to E and D
 		timer.start(pid, M.online_write);
 		if (treeIndex > 0) {
-			con1.write(Li);
-			con2.write(Li);
+			con1.write(pid, Li);
+			con2.write(pid, Li);
 		}
 		timer.stop(pid, M.online_write);
 
 		// step 2
 		timer.start(pid, M.online_read);
-		Tuple[] pathTuples = con2.readTupleArray();
-		byte[] Ni = con2.read();
+		Tuple[] pathTuples = con2.readTupleArray(pid);
+		byte[] Ni = con2.read(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 3
@@ -215,7 +215,7 @@ public class Access extends Protocol {
 		byte[] Li = new byte[0];
 		timer.start(pid, M.online_read);
 		if (OTi.getTreeIndex() > 0)
-			Li = con2.read();
+			Li = con2.read(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 1
@@ -244,7 +244,7 @@ public class Access extends Protocol {
 		byte[] Li = new byte[0];
 		timer.start(pid, M.online_read);
 		if (OTi.getTreeIndex() > 0)
-			Li = con2.read();
+			Li = con2.read(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 1
@@ -266,14 +266,14 @@ public class Access extends Protocol {
 		// step 0: send Li to E and D
 		timer.start(pid, M.online_write);
 		if (treeIndex > 0) {
-			con1.write(Li);
-			con2.write(Li);
+			con1.write(pid, Li);
+			con2.write(pid, Li);
 		}
 		timer.stop(pid, M.online_write);
 
 		// step 2
 		timer.start(pid, M.online_read);
-		Tuple[] pathTuples = con2.readTupleArray();
+		Tuple[] pathTuples = con2.readTupleArray(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 5

+ 8 - 8
src/protocols/Eviction.java

@@ -124,7 +124,7 @@ public class Eviction extends Protocol {
 
 		if (firstTree) {
 			timer.start(pid, M.online_read);
-			Tuple[] originalPath = con2.readTupleArray();
+			Tuple[] originalPath = con2.readTupleArray(pid);
 			timer.stop(pid, M.online_read);
 
 			OTi.setBucketsOnPath(new BigInteger(1, Li).longValue(), new Bucket[] { new Bucket(originalPath) });
@@ -134,13 +134,13 @@ public class Eviction extends Protocol {
 		}
 
 		timer.start(pid, M.online_read);
-		GCSignal[] LiInputKeys = con1.readGCSignalArray();
-		GCSignal[][] E_feInputKeys = con1.readDoubleGCSignalArray();
-		GCSignal[][][] E_labelInputKeys = con1.readTripleGCSignalArray();
-		GCSignal[][] deltaInputKeys = con1.readDoubleGCSignalArray();
+		GCSignal[] LiInputKeys = con1.readGCSignalArray(pid);
+		GCSignal[][] E_feInputKeys = con1.readDoubleGCSignalArray(pid);
+		GCSignal[][][] E_labelInputKeys = con1.readTripleGCSignalArray(pid);
+		GCSignal[][] deltaInputKeys = con1.readDoubleGCSignalArray(pid);
 
-		GCSignal[][] C_feInputKeys = con2.readDoubleGCSignalArray();
-		GCSignal[][][] C_labelInputKeys = con2.readTripleGCSignalArray();
+		GCSignal[][] C_feInputKeys = con2.readDoubleGCSignalArray(pid);
+		GCSignal[][][] C_labelInputKeys = con2.readTripleGCSignalArray(pid);
 		timer.stop(pid, M.online_read);
 
 		int w = OTi.getW();
@@ -168,7 +168,7 @@ public class Eviction extends Protocol {
 		ssxot.runD(predata, evict, timer);
 
 		timer.start(pid, M.online_read);
-		Bucket[] pathBuckets = con2.readBucketArray();
+		Bucket[] pathBuckets = con2.readBucketArray(pid);
 		timer.stop(pid, M.online_read);
 
 		OTi.setBucketsOnPath(new BigInteger(1, Li).longValue(), pathBuckets);

+ 2 - 2
src/protocols/PermuteIndex.java

@@ -39,7 +39,7 @@ public class PermuteIndex extends Protocol {
 		timer.stop(pid, M.online_write);
 
 		timer.start(pid, M.online_read);
-		byte[][] g = con2.readDoubleByteArray();
+		byte[][] g = con2.readDoubleByteArray(pid);
 		timer.stop(pid, M.online_read);
 
 		ti = Util.xor(predata.pi_a, g);
@@ -60,7 +60,7 @@ public class PermuteIndex extends Protocol {
 		timer.start(pid, M.online_comp);
 
 		timer.start(pid, M.online_read);
-		byte[][] z = con2.readDoubleByteArray();
+		byte[][] z = con2.readDoubleByteArray(pid);
 		timer.stop(pid, M.online_read);
 
 		z = Util.xor(z, predata.pi_r);

+ 3 - 3
src/protocols/PermuteTarget.java

@@ -61,7 +61,7 @@ public class PermuteTarget extends Protocol {
 		timer.stop(pid, M.online_write);
 
 		timer.start(pid, M.online_read);
-		byte[][] g = con2.readDoubleByteArray();
+		byte[][] g = con2.readDoubleByteArray(pid);
 		timer.stop(pid, M.online_read);
 
 		target = Util.xor(predata.pt_a, g);
@@ -82,8 +82,8 @@ public class PermuteTarget extends Protocol {
 
 		// PermuteTargetII
 		timer.start(pid, M.online_read);
-		byte[][] z = con2.readDoubleByteArray();
-		int[] I = con2.readIntArray();
+		byte[][] z = con2.readDoubleByteArray(pid);
+		int[] I = con2.readIntArray(pid);
 		timer.stop(pid, M.online_read);
 
 		byte[][] mk = new byte[z.length][];

+ 1 - 1
src/protocols/PostProcessT.java

@@ -42,7 +42,7 @@ public class PostProcessT extends Protocol {
 
 		// step 1
 		timer.start(pid, M.online_read);
-		int delta = con2.readInt();
+		int delta = con2.readInt(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 3

+ 1 - 1
src/protocols/Reshuffle.java

@@ -37,7 +37,7 @@ public class Reshuffle extends Protocol {
 
 		// step 1
 		timer.start(pid, M.online_read);
-		Tuple[] z = con2.readTupleArray();
+		Tuple[] z = con2.readTupleArray(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 2

+ 7 - 0
src/protocols/Retrieve.java

@@ -206,6 +206,7 @@ public class Retrieve extends Protocol {
 				System.out.println("N=" + BigInteger.valueOf(N).toString(2));
 
 				System.out.print("Precomputation... ");
+
 				PreData[][] predata = new PreData[numTrees][2];
 				PreRetrieve preretrieve = new PreRetrieve(con1, con2);
 				for (int ti = 0; ti < numTrees; ti++) {
@@ -345,6 +346,12 @@ public class Retrieve extends Protocol {
 		sum.noPrePrint();
 		System.out.println();
 
+		StopWatch comEnc = new StopWatch("CE_online_comp");
+		for (int i = 0; i < cons1.length; i++)
+			comEnc = comEnc.add(cons1[i].comEnc.add(cons2[i].comEnc));
+		System.out.println(comEnc.noPreToMS());
+		System.out.println();
+
 		if (Global.pipeline)
 			ete_on.elapsedCPU = 0;
 		System.out.println(ete_on.noPreToMS());

+ 4 - 4
src/protocols/SSCOT.java

@@ -81,12 +81,12 @@ public class SSCOT extends Protocol {
 
 		// step 1
 		timer.start(pid, M.online_read);
-		byte[][] e = con1.readDoubleByteArray();
-		byte[][] v = con1.readDoubleByteArray();
+		byte[][] e = con1.readDoubleByteArray(pid);
+		byte[][] v = con1.readDoubleByteArray(pid);
 
 		// step 2
-		byte[][] p = con2.readDoubleByteArray();
-		byte[][] w = con2.readDoubleByteArray();
+		byte[][] p = con2.readDoubleByteArray(pid);
+		byte[][] w = con2.readDoubleByteArray(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 3

+ 4 - 4
src/protocols/SSIOT.java

@@ -76,12 +76,12 @@ public class SSIOT extends Protocol {
 
 		// step 1
 		timer.start(pid, M.online_read);
-		byte[][] e = con1.readDoubleByteArray();
-		byte[][] v = con1.readDoubleByteArray();
+		byte[][] e = con1.readDoubleByteArray(pid);
+		byte[][] v = con1.readDoubleByteArray(pid);
 
 		// step 2
-		byte[] p = con2.read();
-		byte[] w = con2.read();
+		byte[] p = con2.read(pid);
+		byte[] w = con2.read(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 3

+ 6 - 6
src/protocols/SSXOT.java

@@ -45,11 +45,11 @@ public class SSXOT extends Protocol {
 		timer.stop(pid, M.online_write);
 
 		timer.start(pid, M.online_read);
-		a = con2.readTupleArray();
+		a = con2.readTupleArray(pid);
 
 		// step 2
-		int[] j = con1.readIntArray();
-		Tuple[] p = con1.readTupleArray();
+		int[] j = con1.readIntArray(pid);
+		Tuple[] p = con1.readTupleArray(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 3
@@ -100,11 +100,11 @@ public class SSXOT extends Protocol {
 		timer.stop(pid, M.online_write);
 
 		timer.start(pid, M.online_read);
-		a = con1.readTupleArray();
+		a = con1.readTupleArray(pid);
 
 		// step 2
-		int[] j = con2.readIntArray();
-		Tuple[] p = con2.readTupleArray();
+		int[] j = con2.readIntArray(pid);
+		Tuple[] p = con2.readTupleArray(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 3

+ 6 - 6
src/protocols/UpdateRoot.java

@@ -66,12 +66,12 @@ public class UpdateRoot extends Protocol {
 
 		// step 1
 		timer.start(pid, M.online_read);
-		GCSignal[] j1InputKeys = con1.readGCSignalArray();
-		GCSignal[] LiInputKeys = con1.readGCSignalArray();
-		GCSignal[] E_feInputKeys = con1.readGCSignalArray();
-		GCSignal[][] E_labelInputKeys = con1.readDoubleGCSignalArray();
-		GCSignal[] C_feInputKeys = con2.readGCSignalArray();
-		GCSignal[][] C_labelInputKeys = con2.readDoubleGCSignalArray();
+		GCSignal[] j1InputKeys = con1.readGCSignalArray(pid);
+		GCSignal[] LiInputKeys = con1.readGCSignalArray(pid);
+		GCSignal[] E_feInputKeys = con1.readGCSignalArray(pid);
+		GCSignal[][] E_labelInputKeys = con1.readDoubleGCSignalArray(pid);
+		GCSignal[] C_feInputKeys = con2.readGCSignalArray(pid);
+		GCSignal[][] C_labelInputKeys = con2.readDoubleGCSignalArray(pid);
 		timer.stop(pid, M.online_read);
 
 		// step 2