Prechádzať zdrojové kódy

editting Access: able to reach step 3 to execute SSCOT

Boyoung- 8 rokov pred
rodič
commit
f8186310da

+ 41 - 0
src/communication/Communication.java

@@ -3,6 +3,7 @@ package communication;
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
+import java.io.Serializable;
 import java.io.StreamCorruptedException;
 import java.math.BigInteger;
 import java.net.InetSocketAddress;
@@ -14,6 +15,8 @@ import java.util.ArrayList;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
 
+import org.apache.commons.lang3.SerializationUtils;
+
 import util.Util;
 
 /**
@@ -337,6 +340,22 @@ public class Communication {
 			write(out[i]);
 	}
 
+	public void write(int[] out) {
+		write(out.length);
+		for (int i = 0; i < out.length; i++)
+			write(out[i]);
+	}
+
+	public <T> void write(T out) {
+		write(SerializationUtils.serialize((Serializable) out));
+	}
+
+	public <T> void write(T[] out) {
+		write(out.length);
+		for (int i = 0; i < out.length; i++)
+			write(out[i]);
+	}
+
 	public static final Charset defaultCharset = Charset.forName("ASCII");
 
 	// TODO: Rather than having millions of write/read methods can we take
@@ -429,6 +448,28 @@ public class Communication {
 		return data;
 	}
 
+	public int[] readIntArray() {
+		int len = readInt();
+		int[] data = new int[len];
+		for (int i = 0; i < len; i++)
+			data[i] = readInt();
+		return data;
+	}
+
+	public <T> T readObject() {
+		T object = SerializationUtils.deserialize(read());
+		return object;
+	}
+
+	public <T> T[] readObjectArray() {
+		int len = readInt();
+		@SuppressWarnings("unchecked")
+		T[] data = (T[]) new Object[len];
+		for (int i = 0; i < len; i++)
+			data[i] = readObject();
+		return data;
+	}
+
 	/**
 	 * This thread runs while listening for incoming connections. It behaves
 	 * like a server-side client. It runs until a connection is accepted (or

+ 1 - 1
src/oram/Forest.java

@@ -154,7 +154,7 @@ public class Forest implements Serializable {
 					// if N is a new address, then find an unused leaf tuple
 					if (leafTupleIndex == null) {
 						do {
-							leafTupleIndex = Util.nextLong(Crypto.sr, (numBuckets / 2 + 1) * w);
+							leafTupleIndex = Util.nextLong((numBuckets / 2 + 1) * w, Crypto.sr);
 						} while (addrToTuple[i].containsValue(leafTupleIndex));
 						addrToTuple[i].put(N[i], leafTupleIndex);
 					}

+ 4 - 0
src/oram/Tree.java

@@ -150,6 +150,10 @@ public class Tree implements Serializable {
 		return stashSize;
 	}
 
+	public int getFBytes() {
+		return treeIndex == 0 ? 0 : 1;
+	}
+
 	public int getNBits() {
 		return nBits;
 	}

+ 129 - 0
src/protocols/Access.java

@@ -0,0 +1,129 @@
+package protocols;
+
+import java.math.BigInteger;
+import java.util.Arrays;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import communication.Communication;
+import exceptions.NoSuchPartyException;
+import oram.Bucket;
+import oram.Forest;
+import oram.Metadata;
+import oram.Tree;
+import oram.Tuple;
+import util.Util;
+
+public class Access extends Protocol {
+
+	public Access(Communication con1, Communication con2) {
+		super(con1, con2);
+	}
+
+	public void runE(PreData predata, Tree OTi, byte[] Li, byte[] Nip1, byte[] Ni, byte[] Nip1_pr) {
+		// step 1
+		Bucket[] pathBuckets = OTi.getBucketsOnPath(new BigInteger(1, Li).longValue());
+		// Object[] objArray = Util.permute(pathBuckets, predata.access_sigma);
+		// pathBuckets = Arrays.copyOf(objArray, objArray.length,
+		// Bucket[].class);
+		for (int i = 0; i < pathBuckets.length; i++) {
+			pathBuckets[i].setXor(predata.access_p[i]);
+		}
+
+		// step 3
+		int numTuples = OTi.getStashSize() + (pathBuckets.length - 1) * OTi.getW();
+		byte[][] a = new byte[numTuples][];
+		byte[][] m = new byte[numTuples][];
+		int tupleCnt = 0;
+		for (int i = 0; i < pathBuckets.length; i++)
+			for (int j = 0; j < pathBuckets[i].getNumTuples(); j++) {
+				Tuple tuple = pathBuckets[i].getTuple(j);
+				a[tupleCnt] = ArrayUtils.addAll(tuple.getF(), tuple.getN());
+				m[tupleCnt] = tuple.getA();
+				tupleCnt++;
+			}
+		for (int i = 0; i < numTuples; i++) {
+			for (int j = 0; j < Ni.length; j++)
+				a[i][a[i].length - 1 - j] ^= Ni[Ni.length - 1 - j];
+		}
+
+		SSCOT sscot = new SSCOT(con1, con2);
+		sscot.runE(predata, m, a);
+	}
+
+	public void runD(PreData predata, Tree OTi, byte[] Li, byte[] Nip1, byte[] Ni, byte[] Nip1_pr) {
+		// step 1
+		Bucket[] pathBuckets = OTi.getBucketsOnPath(new BigInteger(1, Li).longValue());
+		// Object[] objArray = Util.permute(pathBuckets, predata.access_sigma);
+		// pathBuckets = Arrays.copyOf(objArray, objArray.length,
+		// Bucket[].class);
+		for (int i = 0; i < pathBuckets.length; i++) {
+			pathBuckets[i].setXor(predata.access_p[i]);
+		}
+
+		// step 2
+		con2.write(pathBuckets);
+		con2.write(Nip1);
+
+		// step 3
+		int numTuples = OTi.getStashSize() + (pathBuckets.length - 1) * OTi.getW();
+		byte[][] b = new byte[numTuples][];
+		int tupleCnt = 0;
+		for (int i = 0; i < pathBuckets.length; i++)
+			for (int j = 0; j < pathBuckets[i].getNumTuples(); j++) {
+				Tuple tuple = pathBuckets[i].getTuple(j);
+				b[tupleCnt] = ArrayUtils.addAll(tuple.getF(), tuple.getN());
+				tupleCnt++;
+			}
+		for (int i = 0; i < numTuples; i++) {
+			b[i][0] ^= 1;
+			for (int j = 0; j < Ni.length; j++)
+				b[i][b[i].length - 1 - j] ^= Ni[Ni.length - 1 - j];
+		}
+
+		SSCOT sscot = new SSCOT(con1, con2);
+		sscot.runD(predata, b);
+	}
+
+	public void runC() {
+		// step 2
+		Object[] objArray = con2.readObjectArray();
+		Bucket[] pathBuckets = Arrays.copyOf(objArray, objArray.length, Bucket[].class);
+		byte[] Nip1 = con2.read();
+
+		// step 3
+		SSCOT sscot = new SSCOT(con1, con2);
+		sscot.runC();
+	}
+
+	@Override
+	public void run(Party party, Metadata md, Forest forest) {
+		// for (int j = 0; j < 100; j++) {
+		PreData predata = new PreData();
+		PreAccess preaccess = new PreAccess(con1, con2);
+		int treeIndex = 2;
+		Tree tree = null;
+		int numBuckets = 0;
+		if (forest != null) {
+			tree = forest.getTree(treeIndex);
+			numBuckets = tree.getD();
+		}
+		byte[] Li = new byte[] { 0 };
+		if (party == Party.Eddie) {
+			preaccess.runE(predata, tree, numBuckets);
+			runE(predata, tree, Li, Li, Li, Li);
+
+		} else if (party == Party.Debbie) {
+			preaccess.runD(predata);
+			runD(predata, tree, Li, Li, Li, Li);
+
+		} else if (party == Party.Charlie) {
+			preaccess.runC();
+			runC();
+
+		} else {
+			throw new NoSuchPartyException(party + "");
+		}
+		// }
+	}
+}

+ 52 - 0
src/protocols/PreAccess.java

@@ -0,0 +1,52 @@
+package protocols;
+
+import java.util.Arrays;
+
+import communication.Communication;
+import crypto.Crypto;
+import oram.Bucket;
+import oram.Forest;
+import oram.Metadata;
+import oram.Tree;
+import util.Util;
+
+public class PreAccess extends Protocol {
+	public PreAccess(Communication con1, Communication con2) {
+		super(con1, con2);
+	}
+
+	public void runE(PreData predata, Tree OT, int numBuckets) {
+		int numTuples = OT.getStashSize() + (numBuckets - 1) * OT.getW();
+		PreSSCOT presscot = new PreSSCOT(con1, con2);
+		presscot.runE(predata, numTuples);
+
+		predata.access_sigma = Util.randomPermutation(numBuckets, Crypto.sr);
+
+		int[] tupleParam = new int[] { OT.getFBytes(), OT.getNBytes(), OT.getLBytes(), OT.getABytes() };
+		predata.access_p = new Bucket[numBuckets];
+		predata.access_p[0] = new Bucket(OT.getStashSize(), tupleParam, Crypto.sr);
+		for (int i = 1; i < numBuckets; i++)
+			predata.access_p[i] = new Bucket(OT.getW(), tupleParam, Crypto.sr);
+
+		con1.write(predata.access_sigma);
+		con1.write(predata.access_p);
+	}
+
+	public void runD(PreData predata) {
+		PreSSCOT presscot = new PreSSCOT(con1, con2);
+		presscot.runD(predata);
+
+		predata.access_sigma = con1.readIntArray();
+		Object[] objArray = con1.readObjectArray();
+		predata.access_p = Arrays.copyOf(objArray, objArray.length, Bucket[].class);
+	}
+
+	public void runC() {
+		PreSSCOT presscot = new PreSSCOT(con1, con2);
+		presscot.runC();
+	}
+
+	@Override
+	public void run(Party party, Metadata md, Forest forest) {
+	}
+}

+ 7 - 0
src/protocols/PreData.java

@@ -1,7 +1,14 @@
 package protocols;
 
+import oram.Bucket;
+
 public class PreData {
 	public byte[] sscot_k;
 	public byte[] sscot_kprime;
 	public byte[][] sscot_r;
+
+	public int[] access_sigma;
+	public int[] access_delta;
+	public int[] access_rho;
+	public Bucket[] access_p;
 }

+ 3 - 0
src/ui/CLI.java

@@ -15,6 +15,7 @@ import exceptions.NoSuchPartyException;
 import protocols.Party;
 import protocols.Protocol;
 import protocols.SSCOT;
+import protocols.Access;
 
 public class CLI {
 	public static final int DEFAULT_PORT = 8000;
@@ -67,6 +68,8 @@ public class CLI {
 
 		if (protocol.equals("sscot")) {
 			operation = SSCOT.class;
+		} else if (protocol.equals("access")) {
+			operation = Access.class;
 		} else {
 			System.out.println("Protocol " + protocol + " not supported");
 			System.exit(-1);

+ 29 - 1
src/util/Util.java

@@ -2,7 +2,10 @@ package util;
 
 import java.math.BigInteger;
 import java.nio.ByteBuffer;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
 import java.util.Random;
 
 import exceptions.LengthNotMatchException;
@@ -14,7 +17,13 @@ public class Util {
 		return new BigInteger(a).compareTo(new BigInteger(b)) == 0;
 	}
 
-	public static long nextLong(Random r, long range) {
+	public static byte[] nextBytes(int len, Random r) {
+		byte[] data = new byte[len];
+		r.nextBytes(data);
+		return data;
+	}
+
+	public static long nextLong(long range, Random r) {
 		long bits, val;
 		do {
 			bits = (r.nextLong() << 1) >>> 1;
@@ -88,6 +97,25 @@ public class Util {
 		return new BigInteger(b).intValue();
 	}
 
+	public static int[] randomPermutation(int len, Random rand) {
+		List<Integer> list = new ArrayList<Integer>(len);
+		for (int i = 0; i < len; i++)
+			list.add(i);
+		Collections.shuffle(list, rand);
+		int[] array = new int[len];
+		for (int i = 0; i < len; i++)
+			array[i] = list.get(i);
+		return array;
+	}
+
+	public static <T> T[] permute(T[] original, int[] p) {
+		@SuppressWarnings("unchecked")
+		T[] permuted = (T[]) new Object[original.length];
+		for (int i = 0; i < original.length; i++)
+			permuted[p[i]] = original[i];
+		return permuted;
+	}
+
 	public static void debug(String s) {
 		// only to make Communication.java compile
 	}

+ 29 - 20
test/misc/HelloWorld.java

@@ -1,32 +1,41 @@
 package misc;
 
 import java.math.BigInteger;
+import java.util.ArrayList;
+import java.util.List;
 
+import oram.Forest;
+import oram.Metadata;
 import util.Util;
 
 public class HelloWorld {
 
 	public static void main(String[] args) {
-		System.out.println("HelloWorld!");
-
-		byte[] tmp = new byte[3];
-		BigInteger bi = new BigInteger(1, tmp);
-		System.out.println(bi.toByteArray().length);
-
-		// System.out.println(tmp[3]);
-
-		// System.out.println(Arrays.copyOfRange(tmp, 2, 1).length);
-
-		byte[] a = new byte[] { 0 };
-		byte[] b = a.clone();
-		a[0] = 1;
-		System.out.println(a[0] + " " + b[0]);
-		// throw new ArrayIndexOutOfBoundsException("" + 11);
-
-		System.out.println((new long[3])[0]);
-
-		byte[] negInt = Util.intToBytes(-3);
-		System.out.println(new BigInteger(negInt).intValue());
+		/*
+		 * System.out.println("HelloWorld!");
+		 * 
+		 * byte[] tmp = new byte[3]; BigInteger bi = new BigInteger(1, tmp);
+		 * System.out.println(bi.toByteArray().length);
+		 * 
+		 * // System.out.println(tmp[3]);
+		 * 
+		 * // System.out.println(Arrays.copyOfRange(tmp, 2, 1).length);
+		 * 
+		 * byte[] a = new byte[] { 0 }; byte[] b = a.clone(); a[0] = 1;
+		 * System.out.println(a[0] + " " + b[0]); // throw new
+		 * ArrayIndexOutOfBoundsException("" + 11);
+		 * 
+		 * System.out.println((new long[3])[0]);
+		 * 
+		 * byte[] negInt = Util.intToBytes(-3); System.out.println(new
+		 * BigInteger(negInt).intValue());
+		 * 
+		 * byte aa = 1; aa ^= 1; System.out.println(aa);
+		 */
+
+		Metadata md = new Metadata();
+		Forest forest = Forest.readFromFile(md.getDefaultForestFileName());
+		forest.print();
 	}
 
 }

+ 1 - 1
test/oram/TestForest.java

@@ -20,7 +20,7 @@ public class TestForest {
 		// long numTests = numRecords;
 		for (long n = 0; n < numTests; n++) {
 			// address of record we want to test
-			long testAddr = Util.nextLong(Crypto.sr, numRecords);
+			long testAddr = Util.nextLong(numRecords, Crypto.sr);
 			// long testAddr = n;
 			long L = 0;
 			long outRecord = 0;

+ 3 - 3
test/ui/TestCLI3.java → test/protocols/TestAccess_C.java

@@ -1,11 +1,11 @@
-package ui;
+package protocols;
 
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.util.Arrays;
 
-public class TestCLI3 {
+public class TestAccess_C {
 
 	public static void main(String[] args) {
 		Runtime runTime = Runtime.getRuntime();
@@ -14,7 +14,7 @@ public class TestCLI3 {
 		String binDir = dir + "\\bin";
 		String libs = dir + "\\lib\\*";
 		try {
-			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol sscot charlie");
+			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol access charlie");
 
 		} catch (IOException e) {
 			e.printStackTrace();

+ 3 - 3
test/ui/TestCLI2.java → test/protocols/TestAccess_D.java

@@ -1,11 +1,11 @@
-package ui;
+package protocols;
 
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.util.Arrays;
 
-public class TestCLI2 {
+public class TestAccess_D {
 
 	public static void main(String[] args) {
 		Runtime runTime = Runtime.getRuntime();
@@ -14,7 +14,7 @@ public class TestCLI2 {
 		String binDir = dir + "\\bin";
 		String libs = dir + "\\lib\\*";
 		try {
-			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol sscot debbie");
+			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol access debbie");
 
 		} catch (IOException e) {
 			e.printStackTrace();

+ 3 - 3
test/ui/TestCLI.java → test/protocols/TestAccess_E.java

@@ -1,11 +1,11 @@
-package ui;
+package protocols;
 
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.util.Arrays;
 
-public class TestCLI {
+public class TestAccess_E {
 
 	public static void main(String[] args) {
 		Runtime runTime = Runtime.getRuntime();
@@ -14,7 +14,7 @@ public class TestCLI {
 		String binDir = dir + "\\bin";
 		String libs = dir + "\\lib\\*";
 		try {
-			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol sscot eddie");
+			process = runTime.exec("java -classpath " + binDir + ";" + libs + " ui.CLI -protocol access eddie");
 
 		} catch (IOException e) {
 			e.printStackTrace();