Browse Source

add forest secret-shares(xor) generation

Boyoung- 8 years ago
parent
commit
2c4d82f061
7 changed files with 243 additions and 60 deletions
  1. 33 11
      src/oram/Bucket.java
  2. 51 0
      src/oram/Forest.java
  3. 95 44
      src/oram/Tree.java
  4. 33 3
      src/oram/Tuple.java
  5. 20 0
      src/util/Util.java
  6. 6 2
      test/oram/TestForest.java
  7. 5 0
      test/util/TestUtil.java

+ 33 - 11
src/oram/Bucket.java

@@ -5,6 +5,7 @@ import java.util.Random;
 import exceptions.LengthNotMatchException;
 
 public class Bucket {
+	private int numBytes;
 	private Tuple[] tuples;
 
 	public Bucket(int numTuples, int[] tupleParams, Random rand) {
@@ -13,44 +14,65 @@ public class Bucket {
 		tuples = new Tuple[numTuples];
 		for (int i = 0; i < numTuples; i++)
 			tuples[i] = new Tuple(tupleParams[0], tupleParams[1], tupleParams[2], tupleParams[3], rand);
+		numBytes = numTuples * tuples[0].getNumBytes();
 	}
 
 	public Bucket(Tuple[] tuples) {
+		if (tuples == null)
+			throw new NullPointerException();
 		this.tuples = tuples;
+		numBytes = tuples.length * tuples[0].getNumBytes();
 	}
 
 	// deep copy
 	public Bucket(Bucket b) {
+		numBytes = b.getNumBytes();
 		tuples = new Tuple[b.getNumTuples()];
 		for (int i = 0; i < tuples.length; i++)
 			tuples[i] = new Tuple(b.getTuple(i));
 	}
 
-	public int getNumTuples() {
-		return tuples.length;
+	public int getNumBytes() {
+		return numBytes;
 	}
 
-	public Tuple[] getTuples() {
-		return tuples;
+	public int getNumTuples() {
+		return tuples.length;
 	}
 
 	public Tuple getTuple(int i) {
 		return tuples[i];
 	}
 
-	public void setTuples(Tuple[] tuples) {
-		if (this.tuples.length != tuples.length)
-			throw new LengthNotMatchException(this.tuples.length + " != " + tuples.length);
-		this.tuples = tuples;
-	}
-
 	public void setTuple(int i, Tuple tuple) {
+		if (!tuples[i].sameLength(tuple))
+			throw new LengthNotMatchException(tuples[i].getNumBytes() + " != " + tuple.getNumBytes());
 		tuples[i] = tuple;
 	}
 
+	public Bucket xor(Bucket b) {
+		if (!this.sameLength(b))
+			throw new LengthNotMatchException(numBytes + " != " + b.getNumBytes());
+		Tuple[] newTuples = new Tuple[tuples.length];
+		for (int i = 0; i < tuples.length; i++)
+			newTuples[i] = tuples[i].xor(b.getTuple(i));
+		return new Bucket(newTuples);
+	}
+
+	public void setXor(Bucket b) {
+		if (!this.sameLength(b))
+			throw new LengthNotMatchException(numBytes + " != " + b.getNumBytes());
+		for (int i = 0; i < tuples.length; i++)
+			tuples[i].setXor(b.getTuple(i));
+	}
+
+	public boolean sameLength(Bucket b) {
+		return numBytes == b.getNumBytes();
+	}
+
 	public byte[] toByteArray() {
 		int tupleBytes = tuples[0].getNumBytes();
-		byte[] bucket = new byte[tupleBytes * tuples.length];
+		byte[] bucket = new byte[numBytes];
 		for (int i = 0; i < tuples.length; i++) {
 			byte[] tuple = tuples[i].toByteArray();
 			System.arraycopy(tuple, 0, bucket, i * tupleBytes, tupleBytes);

+ 51 - 0
src/oram/Forest.java

@@ -5,14 +5,17 @@ import java.util.HashMap;
 import java.util.Random;
 
 import crypto.OramCrypto;
+import exceptions.LengthNotMatchException;
 import util.Util;
 
 public class Forest {
+	private long numBytes;
 	private Tree[] trees;
 
 	// build empty forest and insert records according to config file
 	public Forest() {
 		Metadata md = new Metadata();
+		numBytes = md.getForestBytes();
 		initTrees(md, null);
 		insertRecords(md);
 	}
@@ -20,9 +23,37 @@ public class Forest {
 	// build empty/random content forest
 	public Forest(Random rand) {
 		Metadata md = new Metadata();
+		numBytes = md.getForestBytes();
 		initTrees(md, rand);
 	}
 
+	// only used in xor operation
+	private Forest(Tree[] trees) {
+		if (trees == null)
+			throw new NullPointerException();
+		this.trees = trees;
+		for (int i = 0; i < trees.length; i++)
+			numBytes += trees[i].getNumBytes();
+	}
+
+	public long getNumBytes() {
+		return numBytes;
+	}
+
+	public int getNumTrees() {
+		return trees.length;
+	}
+
+	public Tree getTree(int i) {
+		return trees[i];
+	}
+
+	public void setTree(int i, Tree tree) {
+		if (!trees[i].sameLength(tree))
+			throw new LengthNotMatchException(trees[i].getNumBytes() + " != " + tree.getNumBytes());
+		trees[i] = tree;
+	}
+
 	// init trees
 	private void initTrees(Metadata md, Random rand) {
 		trees = new Tree[md.getNumTrees()];
@@ -109,6 +140,26 @@ public class Forest {
 		}
 	}
 
+	public Forest xor(Forest f) {
+		if (!this.sameLength(f))
+			throw new LengthNotMatchException(numBytes + " != " + f.getNumBytes());
+		Tree[] newTrees = new Tree[trees.length];
+		for (int i = 0; i < trees.length; i++)
+			newTrees[i] = trees[i].xor(f.getTree(i));
+		return new Forest(newTrees);
+	}
+
+	public void setXor(Forest f) {
+		if (!this.sameLength(f))
+			throw new LengthNotMatchException(numBytes + " != " + f.getNumBytes());
+		for (int i = 0; i < trees.length; i++)
+			trees[i].setXor(f.getTree(i));
+	}
+
+	public boolean sameLength(Forest f) {
+		return numBytes == f.getNumBytes();
+	}
+
 	public void print() {
 		System.out.println("===== ORAM Forest =====");
 		System.out.println();

+ 95 - 44
src/oram/Tree.java

@@ -20,6 +20,7 @@ public class Tree {
 	private int aBytes;
 	private int tupleBytes;
 	private long numBuckets;
+	private long numBytes;
 	private int d;
 
 	private Array64<Bucket> buckets;
@@ -37,6 +38,7 @@ public class Tree {
 		aBytes = md.getABytesOfTree(treeIndex);
 		tupleBytes = md.getTupleBytesOfTree(treeIndex);
 		numBuckets = md.getNumBucketsOfTree(treeIndex);
+		numBytes = md.getTreeBytesOfTree(treeIndex);
 		d = lBits + 1;
 
 		int fBytes = treeIndex == 0 ? 0 : 1;
@@ -47,6 +49,95 @@ public class Tree {
 			buckets.set(i, new Bucket(w, tupleParams, rand));
 	}
 
+	// only used for xor operation
+	// does not deep copy buckets
+	private Tree(Tree t) {
+		treeIndex = t.getTreeIndex();
+		w = t.getW();
+		stashSize = t.getStashSize();
+		nBits = t.getNBits();
+		lBits = t.getLBits();
+		alBits = t.getAlBits();
+		nBytes = t.getNBytes();
+		lBytes = t.getLBytes();
+		alBytes = t.getAlBytes();
+		aBytes = t.getABytes();
+		tupleBytes = t.getTupleBytes();
+		numBuckets = t.getNumBuckets();
+		numBytes = t.getNumBytes();
+		d = t.getD();
+
+		buckets = new Array64<Bucket>(numBuckets);
+	}
+
+	// only used for xor operation
+	private Array64<Bucket> getBuckets() {
+		return buckets;
+	}
+
+	public Bucket getBucket(long i) {
+		return buckets.get(i);
+	}
+
+	public void setBucket(long i, Bucket bucket) {
+		if (!buckets.get(i).sameLength(bucket))
+			throw new LengthNotMatchException(buckets.get(i).getNumBytes() + " != " + bucket.getNumBytes());
+		buckets.set(i, bucket);
+	}
+
+	public Tree xor(Tree t) {
+		if (!this.sameLength(t))
+			throw new LengthNotMatchException(numBytes + " != " + t.getNumBytes());
+		Tree newTree = new Tree(t);
+		for (long i = 0; i < numBuckets; i++)
+			// cannot use newTree.setBucket() here
+			newTree.getBuckets().set(i, buckets.get(i).xor(t.getBucket(i)));
+		return newTree;
+	}
+
+	public void setXor(Tree t) {
+		if (!this.sameLength(t))
+			throw new LengthNotMatchException(numBytes + " != " + t.getNumBytes());
+		for (long i = 0; i < numBuckets; i++)
+			buckets.get(i).setXor(t.getBucket(i));
+	}
+
+	public boolean sameLength(Tree t) {
+		return numBytes == t.getNumBytes();
+	}
+
+	private long[] getBucketIndicesOnPath(long L) {
+		if (treeIndex == 0)
+			return new long[] { 0 };
+		if (L < 0 || L > numBuckets / 2)
+			throw new InvalidPathLabelException(BigInteger.valueOf(L).toString(2));
+		BigInteger biL = BigInteger.valueOf(L);
+		long[] indices = new long[d];
+		for (int i = 1; i < d; i++) {
+			if (biL.testBit(d - i - 1))
+				indices[i] = indices[i - 1] * 2 + 2;
+			else
+				indices[i] = indices[i - 1] * 2 + 1;
+		}
+		return indices;
+	}
+
+	public Bucket[] getBucketsOnPath(long L) {
+		long[] indices = getBucketIndicesOnPath(L);
+		Bucket[] buckets = new Bucket[indices.length];
+		for (int i = 0; i < indices.length; i++)
+			buckets[i] = getBucket(indices[i]);
+		return buckets;
+	}
+
+	public void setBucketsOnPath(long L, Bucket[] buckets) {
+		long[] indices = getBucketIndicesOnPath(L);
+		if (indices.length != buckets.length)
+			throw new LengthNotMatchException(indices.length + " != " + buckets.length);
+		for (int i = 0; i < indices.length; i++)
+			setBucket(indices[i], buckets[i]);
+	}
+
 	public int getTreeIndex() {
 		return treeIndex;
 	}
@@ -95,51 +186,11 @@ public class Tree {
 		return numBuckets;
 	}
 
-	public int getD() {
-		return d;
-	}
-
-	public Array64<Bucket> getBuckets() {
-		return buckets;
-	}
-
-	public Bucket getBucket(long bucketIndex) {
-		return buckets.get(bucketIndex);
+	public long getNumBytes() {
+		return numBytes;
 	}
 
-	public void setBucket(long bucketIndex, Bucket bucket) {
-		buckets.set(bucketIndex, bucket);
-	}
-
-	private long[] getBucketIndicesOnPath(long L) {
-		if (treeIndex == 0)
-			return new long[] { 0 };
-		if (L < 0 || L > numBuckets / 2)
-			throw new InvalidPathLabelException(BigInteger.valueOf(L).toString(2));
-		BigInteger biL = BigInteger.valueOf(L);
-		long[] indices = new long[d];
-		for (int i = 1; i < d; i++) {
-			if (biL.testBit(d - i - 1))
-				indices[i] = indices[i - 1] * 2 + 2;
-			else
-				indices[i] = indices[i - 1] * 2 + 1;
-		}
-		return indices;
-	}
-
-	public Bucket[] getBucketsOnPath(long L) {
-		long[] indices = getBucketIndicesOnPath(L);
-		Bucket[] buckets = new Bucket[indices.length];
-		for (int i = 0; i < indices.length; i++)
-			buckets[i] = getBucket(indices[i]);
-		return buckets;
-	}
-
-	public void setBucketsOnPath(long L, Bucket[] buckets) {
-		long[] indices = getBucketIndicesOnPath(L);
-		if (indices.length != buckets.length)
-			throw new LengthNotMatchException(indices.length + " != " + buckets.length);
-		for (int i = 0; i < indices.length; i++)
-			setBucket(indices[i], buckets[i]);
+	public int getD() {
+		return d;
 	}
 }

+ 33 - 3
src/oram/Tuple.java

@@ -5,14 +5,17 @@ import java.util.Arrays;
 import java.util.Random;
 
 import exceptions.LengthNotMatchException;
+import util.Util;
 
 public class Tuple {
+	private int numBytes;
 	private byte[] F;
 	private byte[] N;
 	private byte[] L;
 	private byte[] A;
 
 	public Tuple(int fs, int ns, int ls, int as, Random rand) {
+		numBytes = fs + ns + ls + as;
 		F = new byte[fs];
 		N = new byte[ns];
 		L = new byte[ls];
@@ -26,6 +29,9 @@ public class Tuple {
 	}
 
 	public Tuple(byte[] f, byte[] n, byte[] l, byte[] a) {
+		if (f == null || n == null || l == null || a == null)
+			throw new NullPointerException();
+		numBytes = f.length + n.length + l.length + a.length;
 		F = f;
 		N = n;
 		L = l;
@@ -34,6 +40,7 @@ public class Tuple {
 
 	// deep copy
 	public Tuple(Tuple t) {
+		numBytes = t.getNumBytes();
 		F = t.getF().clone();
 		N = t.getN().clone();
 		L = t.getL().clone();
@@ -41,7 +48,7 @@ public class Tuple {
 	}
 
 	public int getNumBytes() {
-		return F.length + N.length + L.length + A.length;
+		return numBytes;
 	}
 
 	public byte[] getF() {
@@ -120,11 +127,34 @@ public class Tuple {
 			for (int i = 0; i < len - label.length; i++)
 				A[start + i] = 0;
 		}
-		System.arraycopy(label, 0, A, start + len - label.length, label.length);
+		System.arraycopy(label, 0, A, end - label.length, label.length);
+	}
+
+	public Tuple xor(Tuple t) {
+		if (!this.sameLength(t))
+			throw new LengthNotMatchException(this.getNumBytes() + " != " + t.getNumBytes());
+		byte[] newF = Util.xor(F, t.getF());
+		byte[] newN = Util.xor(N, t.getN());
+		byte[] newL = Util.xor(L, t.getL());
+		byte[] newA = Util.xor(A, t.getA());
+		return new Tuple(newF, newN, newL, newA);
+	}
+
+	public void setXor(Tuple t) {
+		if (!this.sameLength(t))
+			throw new LengthNotMatchException(this.getNumBytes() + " != " + t.getNumBytes());
+		Util.setXor(F, t.getF());
+		Util.setXor(N, t.getN());
+		Util.setXor(L, t.getL());
+		Util.setXor(A, t.getA());
+	}
+
+	public boolean sameLength(Tuple t) {
+		return numBytes == t.getNumBytes();
 	}
 
 	public byte[] toByteArray() {
-		byte[] tuple = new byte[F.length + N.length + L.length + A.length];
+		byte[] tuple = new byte[numBytes];
 		int offset = 0;
 		System.arraycopy(F, 0, tuple, offset, F.length);
 		offset += F.length;

+ 20 - 0
src/util/Util.java

@@ -4,6 +4,8 @@ import java.math.BigInteger;
 import java.util.Arrays;
 import java.util.Random;
 
+import exceptions.LengthNotMatchException;
+
 public class Util {
 	public static long nextLong(Random r, long range) {
 		long bits, val;
@@ -50,4 +52,22 @@ public class Util {
 			return Arrays.copyOfRange(arr, 1, arr.length);
 		return arr;
 	}
+
+	// c = a ^ b
+	public static byte[] xor(byte[] a, byte[] b) {
+		if (a.length != b.length)
+			throw new LengthNotMatchException(a.length + " != " + b.length);
+		byte[] c = new byte[a.length];
+		for (int i = 0; i < a.length; i++)
+			c[i] = (byte) (a[i] ^ b[i]);
+		return c;
+	}
+
+	// a = a ^ b to save memory
+	public static void setXor(byte[] a, byte[] b) {
+		if (a.length != b.length)
+			throw new LengthNotMatchException(a.length + " != " + b.length);
+		for (int i = 0; i < a.length; i++)
+			a[i] = (byte) (a[i] ^ b[i]);
+	}
 }

+ 6 - 2
test/oram/TestForest.java

@@ -3,8 +3,12 @@ package oram;
 public class TestForest {
 
 	public static void main(String[] args) {
-		Forest forest = new Forest();
-		forest.print();
+		Forest forest1 = new Forest();
+		Forest forest2 = new Forest();
+		Forest forest3 = forest1.xor(forest2);
+		forest1.print();
+		forest2.print();
+		forest3.print();
 	}
 
 }

+ 5 - 0
test/util/TestUtil.java

@@ -10,6 +10,11 @@ public class TestUtil {
 		System.out.println(BigInteger.valueOf(subBits).toString(2));
 		long b = Util.setSubBits(a, subBits, 5, 2);
 		System.out.println(BigInteger.valueOf(b).toString(2));
+
+		byte[] aa = new byte[] { 0 };
+		byte[] bb = new byte[] { 1 };
+		Util.setXor(aa, bb);
+		System.out.println(aa[0]);
 	}
 
 }