Forest.java 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package oram;
  2. import java.math.BigInteger;
  3. import java.util.HashMap;
  4. import crypto.OramCrypto;
  5. import util.Util;
  6. public class Forest {
  7. private Tree[] trees;
  8. public Forest() {
  9. Metadata md = new Metadata();
  10. init(md);
  11. insertTuples(md);
  12. }
  13. public Forest(Metadata md) {
  14. init(md);
  15. insertTuples(md);
  16. }
  17. // init an empty forest
  18. private void init(Metadata md) {
  19. trees = new Tree[md.getNumTrees()];
  20. // for each tree
  21. for (int treeIndex = 0; treeIndex < md.getNumTrees(); treeIndex++) {
  22. // init the tree
  23. trees[treeIndex] = new Tree(treeIndex, md);
  24. // get bytes of tuple in this tree
  25. int fBytes = treeIndex == 0 ? 0 : 1;
  26. int nBytes = trees[treeIndex].getNBytes();
  27. int lBytes = trees[treeIndex].getLBytes();
  28. int aBytes = trees[treeIndex].getABytes();
  29. // for each level of the tree
  30. long numBuckets = 1;
  31. for (int i = 0; i < trees[treeIndex].getD(); i++) {
  32. // for each bucket
  33. for (int j = 0; j < numBuckets; j++) {
  34. // calculate bucket index
  35. long bucketIndex = j + numBuckets - 1;
  36. // get the bucket
  37. Bucket bucket = trees[treeIndex].getBucket(bucketIndex);
  38. // for each tuple within the bucket
  39. for (int k = 0; k < bucket.getNumTuples(); k++) {
  40. // create a empty tuple
  41. Tuple tuple = new Tuple(fBytes, nBytes, lBytes, aBytes);
  42. // add to the bucket
  43. bucket.setTuple(k, tuple);
  44. }
  45. }
  46. // numBuckets doubled for next level
  47. numBuckets *= 2;
  48. }
  49. }
  50. }
  51. // insert records into ORAM forest
  52. private void insertTuples(Metadata md) {
  53. int numTrees = trees.length;
  54. int tau = md.getTau();
  55. int w = md.getW();
  56. int lastNBits = md.getAddrBits() - tau * (numTrees - 2);
  57. // mapping between address N and leaf tuple
  58. @SuppressWarnings("unchecked")
  59. HashMap<Long, Long>[] addrToTuple = new HashMap[numTrees];
  60. for (int i = 1; i < numTrees; i++)
  61. addrToTuple[i] = new HashMap<Long, Long>();
  62. // keep track of each(current) inserted tuple's N and L for each tree
  63. long[] N = new long[numTrees];
  64. long[] L = new long[numTrees];
  65. // start inserting records with address addr
  66. for (long addr = 0; addr < md.getNumInsertRecords(); addr++) {
  67. // for each tree (from last to first)
  68. for (int i = numTrees - 1; i >= 0; i--) {
  69. long numBuckets = trees[i].getNumBucket();
  70. // index of bucket that contains the tuple we will insert/update
  71. long bucketIndex = 0;
  72. // index of tuple within the bucket
  73. int tupleIndex = 0;
  74. // set correct bucket/tuple index on trees after the first one
  75. if (i > 0) {
  76. // set N of the tuple
  77. if (i == numTrees - 1)
  78. N[i] = addr;
  79. else if (i == numTrees - 2)
  80. N[i] = N[i + 1] >> lastNBits;
  81. else
  82. N[i] = N[i + 1] >> tau;
  83. // find the corresponding leaf tuple index using N
  84. Long leafTupleIndex = addrToTuple[i].get(N[i]);
  85. // if N is a new address, then find an unused leaf tuple
  86. if (leafTupleIndex == null) {
  87. do {
  88. leafTupleIndex = Util.nextLong(OramCrypto.sr, (numBuckets / 2 + 1) * w);
  89. } while (addrToTuple[i].containsValue(leafTupleIndex));
  90. addrToTuple[i].put(N[i], leafTupleIndex);
  91. }
  92. // get leaf tuple label and set bucket/tuple index
  93. L[i] = leafTupleIndex / (long) w;
  94. bucketIndex = L[i] + numBuckets / 2;
  95. tupleIndex = (int) (leafTupleIndex % w);
  96. }
  97. // retrieve the tuple that needs to be updated
  98. Tuple targetTuple = trees[i].getBucket(bucketIndex).getTuple(tupleIndex);
  99. // for all trees except the last one,
  100. // update only one label bits in the A field of the target tuple
  101. if (i < numTrees - 1) {
  102. int indexN;
  103. if (i == numTrees - 2)
  104. indexN = (int) Util.getSubBits(N[i + 1], lastNBits, 0);
  105. else
  106. indexN = (int) Util.getSubBits(N[i + 1], tau, 0);
  107. int start = indexN * trees[i].getAlBytes();
  108. int end = start + trees[i].getAlBytes();
  109. targetTuple.setALabel(start, end, Util.rmSignBit(BigInteger.valueOf(L[i + 1]).toByteArray()));
  110. }
  111. // for the last tree, update the whole A field of the target
  112. // tuple
  113. else
  114. targetTuple.setA(Util.rmSignBit(BigInteger.valueOf(addr).toByteArray()));
  115. // for all trees except the first one,
  116. // also update F, N, L fields
  117. // no need to update F, N, L for the first tree
  118. if (i > 0) {
  119. targetTuple.setF(new byte[] { 1 });
  120. targetTuple.setN(Util.rmSignBit(BigInteger.valueOf(N[i]).toByteArray()));
  121. targetTuple.setL(Util.rmSignBit(BigInteger.valueOf(L[i]).toByteArray()));
  122. }
  123. }
  124. }
  125. }
  126. public void print() {
  127. System.out.println("===== ORAM Forest =====");
  128. System.out.println();
  129. for (int i = 0; i < trees.length; i++) {
  130. System.out.println("***** Tree " + i + " *****");
  131. for (int j = 0; j < trees[i].getNumBucket(); j++)
  132. System.out.println(trees[i].getBucket(j));
  133. System.out.println();
  134. }
  135. System.out.println("===== End of Forest =====");
  136. System.out.println();
  137. }
  138. }