Tree.java 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package oram;
  2. import java.io.Serializable;
  3. import java.math.BigInteger;
  4. import java.util.Random;
  5. import exceptions.InvalidPathLabelException;
  6. import exceptions.LengthNotMatchException;
  7. import util.Array64;
  8. public class Tree implements Serializable {
  9. /**
  10. *
  11. */
  12. private static final long serialVersionUID = 1L;
  13. private int treeIndex;
  14. private int w;
  15. private int stashSize;
  16. private int nBits;
  17. private int lBits;
  18. private int alBits;
  19. private int nBytes;
  20. private int lBytes;
  21. private int alBytes;
  22. private int aBytes;
  23. private int tupleBytes;
  24. private long numBuckets;
  25. private long numBytes;
  26. private int d;
  27. private Array64<Bucket> buckets;
  28. public Tree(int index, Metadata md, Random rand) {
  29. treeIndex = index;
  30. w = md.getW();
  31. stashSize = md.getStashSizeOfTree(treeIndex);
  32. nBits = md.getNBitsOfTree(treeIndex);
  33. lBits = md.getLBitsOfTree(treeIndex);
  34. alBits = md.getAlBitsOfTree(treeIndex);
  35. nBytes = md.getNBytesOfTree(treeIndex);
  36. lBytes = md.getLBytesOfTree(treeIndex);
  37. alBytes = md.getAlBytesOfTree(treeIndex);
  38. aBytes = md.getABytesOfTree(treeIndex);
  39. tupleBytes = md.getTupleBytesOfTree(treeIndex);
  40. numBuckets = md.getNumBucketsOfTree(treeIndex);
  41. numBytes = md.getTreeBytesOfTree(treeIndex);
  42. d = lBits + 1;
  43. int fBytes = treeIndex == 0 ? 0 : 1;
  44. int[] tupleParams = new int[] { fBytes, nBytes, lBytes, aBytes };
  45. buckets = new Array64<Bucket>(numBuckets);
  46. buckets.set(0, new Bucket(stashSize, tupleParams, rand));
  47. for (int i = 1; i < numBuckets; i++)
  48. buckets.set(i, new Bucket(w, tupleParams, rand));
  49. }
  50. // only used for xor operation
  51. // does not shallow/deep copy buckets
  52. private Tree(Tree t) {
  53. treeIndex = t.getTreeIndex();
  54. w = t.getW();
  55. stashSize = t.getStashSize();
  56. nBits = t.getNBits();
  57. lBits = t.getLBits();
  58. alBits = t.getAlBits();
  59. nBytes = t.getNBytes();
  60. lBytes = t.getLBytes();
  61. alBytes = t.getAlBytes();
  62. aBytes = t.getABytes();
  63. tupleBytes = t.getTupleBytes();
  64. numBuckets = t.getNumBuckets();
  65. numBytes = t.getNumBytes();
  66. d = t.getD();
  67. buckets = new Array64<Bucket>(numBuckets);
  68. }
  69. public Bucket getBucket(long i) {
  70. return buckets.get(i);
  71. }
  72. public void setBucket(long i, Bucket bucket) {
  73. if (!buckets.get(i).sameLength(bucket))
  74. throw new LengthNotMatchException(buckets.get(i).getNumBytes() + " != " + bucket.getNumBytes());
  75. buckets.set(i, bucket);
  76. }
  77. public Tree xor(Tree t) {
  78. if (!this.sameLength(t))
  79. throw new LengthNotMatchException(numBytes + " != " + t.getNumBytes());
  80. Tree newTree = new Tree(t);
  81. for (long i = 0; i < numBuckets; i++)
  82. // cannot use newTree.setBucket() here
  83. newTree.buckets.set(i, buckets.get(i).xor(t.getBucket(i)));
  84. return newTree;
  85. }
  86. public void setXor(Tree t) {
  87. if (!this.sameLength(t))
  88. throw new LengthNotMatchException(numBytes + " != " + t.getNumBytes());
  89. for (long i = 0; i < numBuckets; i++)
  90. buckets.get(i).setXor(t.getBucket(i));
  91. }
  92. public boolean sameLength(Tree t) {
  93. return numBytes == t.getNumBytes();
  94. }
  95. private long[] getBucketIndicesOnPath(long L) {
  96. if (treeIndex == 0)
  97. return new long[] { 0 };
  98. if (L < 0 || L > numBuckets / 2)
  99. throw new InvalidPathLabelException(BigInteger.valueOf(L).toString(2));
  100. BigInteger biL = BigInteger.valueOf(L);
  101. long[] indices = new long[d];
  102. for (int i = 1; i < d; i++) {
  103. if (biL.testBit(d - i - 1))
  104. indices[i] = indices[i - 1] * 2 + 2;
  105. else
  106. indices[i] = indices[i - 1] * 2 + 1;
  107. }
  108. return indices;
  109. }
  110. public Bucket[] getBucketsOnPath(long L) {
  111. long[] indices = getBucketIndicesOnPath(L);
  112. Bucket[] buckets = new Bucket[indices.length];
  113. for (int i = 0; i < indices.length; i++)
  114. buckets[i] = getBucket(indices[i]);
  115. return buckets;
  116. }
  117. public void setBucketsOnPath(long L, Bucket[] buckets) {
  118. long[] indices = getBucketIndicesOnPath(L);
  119. if (indices.length != buckets.length)
  120. throw new LengthNotMatchException(indices.length + " != " + buckets.length);
  121. for (int i = 0; i < indices.length; i++)
  122. setBucket(indices[i], buckets[i]);
  123. }
  124. public int getTreeIndex() {
  125. return treeIndex;
  126. }
  127. public int getW() {
  128. return w;
  129. }
  130. public int getStashSize() {
  131. return stashSize;
  132. }
  133. public int getFBytes() {
  134. return treeIndex == 0 ? 0 : 1;
  135. }
  136. public int getNBits() {
  137. return nBits;
  138. }
  139. public int getLBits() {
  140. return lBits;
  141. }
  142. public int getAlBits() {
  143. return alBits;
  144. }
  145. public int getNBytes() {
  146. return nBytes;
  147. }
  148. public int getLBytes() {
  149. return lBytes;
  150. }
  151. public int getAlBytes() {
  152. return alBytes;
  153. }
  154. public int getABytes() {
  155. return aBytes;
  156. }
  157. public int getTupleBytes() {
  158. return tupleBytes;
  159. }
  160. public long getNumBuckets() {
  161. return numBuckets;
  162. }
  163. public long getNumBytes() {
  164. return numBytes;
  165. }
  166. public int getD() {
  167. return d;
  168. }
  169. }