Browse Source

add forest init

Boyoung- 9 years ago
parent
commit
23ee58b8b9
5 changed files with 138 additions and 181 deletions
  1. 91 176
      src/oram/Forest.java
  2. 1 1
      src/oram/Metadata.java
  3. 16 0
      src/oram/Tuple.java
  4. 20 4
      src/util/Util.java
  5. 10 0
      test/oram/TestForest.java

+ 91 - 176
src/oram/Forest.java

@@ -7,87 +7,25 @@ import crypto.OramCrypto;
 import util.Util;
 
 public class Forest {
-	//private String defaultFileName;
+	// private String defaultFileName;
 
 	private Tree[] trees;
 
+	public Forest() {
+		Metadata md = new Metadata();
+		init(md);
+		insertTuples(md);
+	}
+
 	public Forest(Metadata md) {
 		// init an empty forest
 		init(md);
-
-		int numTrees = trees.length;
-		@SuppressWarnings("unchecked")
-		HashMap<Long, Long>[] addrToTuple = new HashMap[numTrees];
-		for (int i=1; i<numTrees; i++)
-			addrToTuple[i] = new HashMap<Long, Long>();
-		
-		int tau = md.getTau();
-		int w = md.getW();
-		int lastNBits = md.getAddrBits() - tau * (numTrees-2);
-		
-		long[] N = new long[numTrees];
-		long[] L = new long[numTrees];
-		
-		for (long addr=0; addr<md.getNumInsertRecords(); addr++) {
-			for (int i=numTrees-1; i>=0; i--) {
-				long numBuckets = trees[i].getNumBucket();
-				long bucketIndex = 0;
-				int tupleIndex = 0;
-				if (i > 0) {
-					if (i == numTrees-1)
-						N[i] = addr;
-					else if (i == numTrees-2)
-						N[i] = N[i + 1] >> lastNBits;
-					else
-						N[i] = N[i + 1] >> tau;
-					Long tuple = addrToTuple[i].get(N[i]);
-					if (tuple == null) {
-						do {
-							tuple = Util.nextLong(OramCrypto.sr, (numBuckets/2+1)*w);
-						} while (addrToTuple[i].containsValue(tuple));
-						addrToTuple[i].put(N[i], tuple);
-					}
-					L[i] = tuple / (long) w;
-					bucketIndex = L[i] + numBuckets/2;
-					tupleIndex = (int) (tuple % w);
-				}
-				Tuple targetTuple = trees[i].getBucket(bucketIndex).getTuple(tupleIndex);
-
-				
-				if (i < numTrees-1) {
-					int indexN;
-					if (i == numTrees - 2)
-						indexN = (int) Util.getSubBits(N[i + 1], lastNBits, 0);
-					else
-						indexN = (int) Util.getSubBits(N[i + 1], tau, 0);
-					int start = (md.getTwoTauPow() - indexN - 1) * md.getLBitsOfTree(i + 1);
-					A = Util.setSubBits(new BigInteger(1, old.getA()), L[i + 1], start,
-							start + ForestMetadata.getLBits(i + 1));
-				}
-				else
-					// A = new BigInteger(ForestMetadata.getABits(i), rnd); //
-					// generate random record content
-					A = BigInteger.valueOf(address); // for testing: record
-														// content is the same
-														// as its N
-				
-				if (i == 0)
-					tuple = A;
-				else {
-					tuple = FB.shiftLeft(ForestMetadata.getTupleBits(i) - 1)
-							.or(N[i].shiftLeft(ForestMetadata.getLBits(i) + ForestMetadata.getABits(i)))
-							.or(L[i].shiftLeft(ForestMetadata.getABits(i))).or(A);
-				}
-
-				Tuple newTuple = new Tuple(i, Util.rmSignBit(tuple.toByteArray()));
-				bucket.setTuple(newTuple, tupleIndex);
-				Util.disp("Tree-" + i + " writing " + newTuple);
-				trees.get(i).setBucket(bucket, bucketIndex);
-			}
-		}
+		// insert records into ORAM forest
+		insertTuples(md);
 	}
-	
+
 	private void init(Metadata md) {
+		trees = new Tree[md.getNumTrees()];
 		// for each tree
 		for (int treeIndex = 0; treeIndex < md.getNumTrees(); treeIndex++) {
 			// init the tree
@@ -98,9 +36,8 @@ public class Forest {
 			int lBytes = trees[treeIndex].getLBytes();
 			int aBytes = trees[treeIndex].getABytes();
 			// for each level of the tree
+			long numBuckets = 1;
 			for (int i = 0; i < trees[treeIndex].getD(); i++) {
-				// get numBuckets on this level
-				long numBuckets = (long) Math.pow(2, i);
 				// for each bucket
 				for (int j = 0; j < numBuckets; j++) {
 					// calculate bucket index
@@ -115,126 +52,104 @@ public class Forest {
 						bucket.setTuple(k, tuple);
 					}
 				}
+				// numBuckets doubled for next level
+				numBuckets *= 2;
 			}
 		}
 	}
 
-	@SuppressWarnings("unchecked")
-	private void initForest(String filename1, String filename2) throws Exception {
-
-		// following used to hold new tuple content
-		BigInteger FB = null;
-		BigInteger[] N = new BigInteger[levels];
-		BigInteger[] L = new BigInteger[levels];
-		BigInteger A;
-		BigInteger tuple; // new tuple to be inserted
-		Bucket bucket; // bucket to be updated
-		long bucketIndex; // bucket index in the tree
-		int tupleIndex; // tuple index in the bucket
-
-		// this one is for loadpathcheat
-		BigInteger[] firstL = new BigInteger[levels];
-
-		HashMap<Long, Long>[] nToSlot = new HashMap[levels];
-		for (int i = 1; i < levels; i++)
-			nToSlot[i] = new HashMap<Long, Long>();
-
-		int shiftN = ForestMetadata.getLastNBits() % tau;
-		if (shiftN == 0)
-			shiftN = tau;
-
-		System.out.println("===== Forest Generation =====");
-		for (long address = 0L; address < numInsert; address++) {
-			System.out.println("record: " + address);
-			for (int i = h; i >= 0; i--) {
-				if (i == 0) {
-					// FB = BigInteger.ONE;
-					N[i] = BigInteger.ZERO;
-					// L[i] = BigInteger.ZERO;
-					bucketIndex = 0;
-					tupleIndex = 0;
+	private void insertTuples(Metadata md) {
+		int numTrees = trees.length;
+		int tau = md.getTau();
+		int w = md.getW();
+		int lastNBits = md.getAddrBits() - tau * (numTrees - 2);
+		// mapping between address N and leaf tuple
+		@SuppressWarnings("unchecked")
+		HashMap<Long, Long>[] addrToTuple = new HashMap[numTrees];
+		for (int i = 1; i < numTrees; i++)
+			addrToTuple[i] = new HashMap<Long, Long>();
+		// keep track of each(current) inserted tuple's N and L for each tree
+		long[] N = new long[numTrees];
+		long[] L = new long[numTrees];
 
-					if (address == 0)
-						firstL[i] = null;
-				} else {
-					FB = BigInteger.ONE;
-					if (i == h)
-						N[i] = BigInteger.valueOf(address);
-					else if (i == h - 1)
-						N[i] = N[i + 1].shiftRight(shiftN);
+		// start inserting records with address addr
+		for (long addr = 0; addr < md.getNumInsertRecords(); addr++) {
+			// for each tree (from last to first)
+			for (int i = numTrees - 1; i >= 0; i--) {
+				long numBuckets = trees[i].getNumBucket();
+				// index of bucket that contains the tuple we will insert/update
+				long bucketIndex = 0;
+				// index of tuple within the bucket
+				int tupleIndex = 0;
+				// set correct bucket/tuple index on trees after the first one
+				if (i > 0) {
+					// set N of the tuple
+					if (i == numTrees - 1)
+						N[i] = addr;
+					else if (i == numTrees - 2)
+						N[i] = N[i + 1] >> lastNBits;
 					else
-						N[i] = N[i + 1].shiftRight(tau);
-					// N[i] = BigInteger.valueOf(address >> ((h-i)*tau));
-					Long slot = nToSlot[i].get(N[i].longValue());
-					if (slot == null) {
+						N[i] = N[i + 1] >> tau;
+					// find the corresponding leaf tuple index using N
+					Long leafTupleIndex = addrToTuple[i].get(N[i]);
+					// if N is a new address, then find an unused leaf tuple
+					if (leafTupleIndex == null) {
 						do {
-							slot = Util.nextLong(ForestMetadata.getNumLeafTuples(i));
-						} while (nToSlot[i].containsValue(slot));
-						nToSlot[i].put(N[i].longValue(), slot);
+							leafTupleIndex = Util.nextLong(OramCrypto.sr, (numBuckets / 2 + 1) * w);
+						} while (addrToTuple[i].containsValue(leafTupleIndex));
+						addrToTuple[i].put(N[i], leafTupleIndex);
 					}
-					L[i] = BigInteger.valueOf(slot / (w * e));
-					bucketIndex = slot / w + ForestMetadata.getNumLeaves(i) - 1;
-					tupleIndex = (int) (slot % w);
-
-					if (address == 0)
-						firstL[i] = L[i];
+					// get leaf tuple label and set bucket/tuple index
+					L[i] = leafTupleIndex / (long) w;
+					bucketIndex = L[i] + numBuckets / 2;
+					tupleIndex = (int) (leafTupleIndex % w);
 				}
+				// retrieve the tuple that needs to be updated
+				Tuple targetTuple = trees[i].getBucket(bucketIndex).getTuple(tupleIndex);
 
-				bucket = trees.get(i).getBucket(bucketIndex);
-				bucket.setIndex(i);
-
-				if (i == h)
-					// A = new BigInteger(ForestMetadata.getABits(i), rnd); //
-					// generate random record content
-					A = BigInteger.valueOf(address); // for testing: record
-														// content is the same
-														// as its N
-				else {
-					BigInteger indexN = null;
-					if (i == h - 1)
-						indexN = Util.getSubBits(N[i + 1], 0, shiftN);
+				// for all trees except the last one,
+				// update only one label bits in the A field of the target tuple
+				if (i < numTrees - 1) {
+					int indexN;
+					if (i == numTrees - 2)
+						indexN = (int) Util.getSubBits(N[i + 1], lastNBits, 0);
 					else
-						indexN = Util.getSubBits(N[i + 1], 0, tau);
-					int start = (ForestMetadata.getTwoTauPow() - indexN.intValue() - 1)
-							* ForestMetadata.getLBits(i + 1);
-					Tuple old = bucket.getTuple(tupleIndex);
-					A = Util.setSubBits(new BigInteger(1, old.getA()), L[i + 1], start,
-							start + ForestMetadata.getLBits(i + 1));
+						indexN = (int) Util.getSubBits(N[i + 1], tau, 0);
+					int start = (md.getTwoTauPow() - indexN - 1) * md.getLBitsOfTree(i + 1);
+					int end = start + md.getLBitsOfTree(i + 1);
+					BigInteger newA = Util.setSubBits(new BigInteger(1, targetTuple.getA()),
+							BigInteger.valueOf(L[i + 1]), end, start);
+					targetTuple.setA(Util.rmSignBit(newA.toByteArray()));
 				}
+				// for the last tree, update the whole A field of the target
+				// tuple
+				else
+					targetTuple.setA(Util.rmSignBit(BigInteger.valueOf(addr).toByteArray()));
 
-				if (i == 0)
-					tuple = A;
-				else {
-					tuple = FB.shiftLeft(ForestMetadata.getTupleBits(i) - 1)
-							.or(N[i].shiftLeft(ForestMetadata.getLBits(i) + ForestMetadata.getABits(i)))
-							.or(L[i].shiftLeft(ForestMetadata.getABits(i))).or(A);
+				// for all trees except the first one,
+				// also update F, N, L fields
+				// no need to update F, N, L for the first tree
+				if (i > 0) {
+					targetTuple.setF(new byte[] { 1 });
+					targetTuple.setN(Util.rmSignBit(BigInteger.valueOf(N[i]).toByteArray()));
+					targetTuple.setL(Util.rmSignBit(BigInteger.valueOf(L[i]).toByteArray()));
 				}
-
-				Tuple newTuple = new Tuple(i, Util.rmSignBit(tuple.toByteArray()));
-				bucket.setTuple(newTuple, tupleIndex);
-				Util.disp("Tree-" + i + " writing " + newTuple);
-				trees.get(i).setBucket(bucket, bucketIndex);
 			}
-			System.out.println("--------------------");
 		}
+	}
 
-		Util.disp("");
+	public void print() {
+		System.out.println("===== ORAM Forest =====");
+		System.out.println();
 
-		if (noForest) {
-			noForestInitPaths(firstL);
-			return;
+		for (int i = 0; i < trees.length; i++) {
+			System.out.println("***** Tree " + i + " *****");
+			for (int j = 0; j < trees[i].getNumBucket(); j++)
+				System.out.println(trees[i].getBucket(j));
+			System.out.println();
 		}
 
-		// these two lines are real xors
-		// data2 = new ByteArray64(ForestMetadata.getForestBytes(), "random");
-		// data1.setXOR(data2);
-
-		// this line is for testing
-		data2 = new ByteArray64(ForestMetadata.getForestBytes(), "empty");
-
-		writeToFile(filename1, filename2);
-
-		if (loadPathCheat)
-			initPaths(firstL);
+		System.out.println("===== End of Forest =====");
+		System.out.println();
 	}
 }

+ 1 - 1
src/oram/Metadata.java

@@ -133,7 +133,7 @@ public class Metadata {
 		}
 	}
 
-	public void printInfo() {
+	public void print() {
 		System.out.println("===== ORAM Forest Metadata =====");
 		System.out.println();
 		System.out.println("tau:				" + tau);

+ 16 - 0
src/oram/Tuple.java

@@ -33,18 +33,34 @@ public class Tuple {
 		return F;
 	}
 
+	public void setF(byte[] f) {
+		F = f.clone();
+	}
+
 	public byte[] getN() {
 		return N;
 	}
 
+	public void setN(byte[] n) {
+		N = n.clone();
+	}
+
 	public byte[] getL() {
 		return L;
 	}
 
+	public void setL(byte[] l) {
+		L = l.clone();
+	}
+
 	public byte[] getA() {
 		return A;
 	}
 
+	public void setA(byte[] a) {
+		A = a.clone();
+	}
+
 	public byte[] toByteArray() {
 		byte[] tuple = new byte[numBytes];
 		int offset = 0;

+ 20 - 4
src/util/Util.java

@@ -1,5 +1,6 @@
 package util;
 
+import java.math.BigInteger;
 import java.util.Arrays;
 import java.util.Random;
 
@@ -13,7 +14,7 @@ public class Util {
 		return val;
 	}
 
-	public static long getSubBits(long l, long end, long start) {
+	public static long getSubBits(long l, int end, int start) {
 		if (start < 0)
 			throw new IllegalArgumentException(start + " < 0");
 		if (start > end)
@@ -22,13 +23,28 @@ public class Util {
 		return (l >>> start) & mask;
 	}
 
-	public static long setSubBits(long target, long input, long end, long start) {
-		long len = end - start;
-		input = getSubBits(input, len, 0);
+	public static BigInteger getSubBits(BigInteger bi, int end, int start) {
+		if (start < 0)
+			throw new IllegalArgumentException(start + " < 0");
+		if (start > end)
+			throw new IllegalArgumentException(start + " > " + end);
+		BigInteger mask = BigInteger.ONE.shiftLeft(end - start).subtract(BigInteger.ONE);
+		return bi.shiftRight(start).and(mask);
+	}
+
+	public static long setSubBits(long target, long input, int end, int start) {
+		input = getSubBits(input, end - start, 0);
 		long trash = getSubBits(target, end, start);
 		return ((trash ^ input) << start) ^ target;
 	}
 
+	public static BigInteger setSubBits(BigInteger target, BigInteger input, int end, int start) {
+		if (input.bitLength() > end - start)
+			input = getSubBits(input, end - start, 0);
+		BigInteger trash = getSubBits(target, end, start);
+		return trash.xor(input).shiftLeft(start).xor(target);
+	}
+
 	public static byte[] rmSignBit(byte[] arr) {
 		if (arr[0] == 0)
 			return Arrays.copyOfRange(arr, 1, arr.length);

+ 10 - 0
test/oram/TestForest.java

@@ -0,0 +1,10 @@
+package oram;
+
+public class TestForest {
+
+	public static void main(String[] args) {
+		Forest forest = new Forest();
+		forest.print();
+	}
+
+}