Forest.java 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package oram;
  2. import java.io.FileInputStream;
  3. import java.io.FileOutputStream;
  4. import java.io.IOException;
  5. import java.io.ObjectInputStream;
  6. import java.io.ObjectOutputStream;
  7. import java.io.Serializable;
  8. import java.math.BigInteger;
  9. import java.util.HashMap;
  10. import java.util.Random;
  11. import crypto.Crypto;
  12. import exceptions.LengthNotMatchException;
  13. import util.Util;
  14. public class Forest implements Serializable {
  15. /**
  16. *
  17. */
  18. private static final long serialVersionUID = 1L;
  19. private static String folderName = "data/";
  20. private String forestFileName;
  21. private int tau;
  22. private int addrBits;
  23. private long numInsertRecords;
  24. private long numBytes;
  25. private Tree[] trees;
  26. // build empty forest and insert records according to config file
  27. public Forest() {
  28. Metadata md = new Metadata();
  29. init(md);
  30. initTrees(md, null);
  31. insertRecords(md);
  32. }
  33. public Forest(Metadata md) {
  34. init(md);
  35. initTrees(md, null);
  36. insertRecords(md);
  37. }
  38. // build empty/random content forest
  39. public Forest(Random rand) {
  40. Metadata md = new Metadata();
  41. init(md);
  42. initTrees(md, rand);
  43. }
  44. public Forest(Metadata md, Random rand) {
  45. init(md);
  46. initTrees(md, rand);
  47. }
  48. // only used in xor operation
  49. // does not shallow/deep copy trees
  50. private Forest(Forest f) {
  51. forestFileName = f.getForestFileName();
  52. tau = f.getTau();
  53. addrBits = f.getAddrBits();
  54. numInsertRecords = f.getNumInsertRecords();
  55. numBytes = f.getNumBytes();
  56. trees = new Tree[f.getNumTrees()];
  57. }
  58. private void init(Metadata md) {
  59. forestFileName = md.getDefaultForestFileName();
  60. tau = md.getTau();
  61. addrBits = md.getAddrBits();
  62. numInsertRecords = md.getNumInsertRecords();
  63. numBytes = md.getForestBytes();
  64. }
  65. public String getForestFileName() {
  66. return forestFileName;
  67. }
  68. public int getTau() {
  69. return tau;
  70. }
  71. public int getAddrBits() {
  72. return addrBits;
  73. }
  74. public long getNumInsertRecords() {
  75. return numInsertRecords;
  76. }
  77. public long getNumBytes() {
  78. return numBytes;
  79. }
  80. public int getNumTrees() {
  81. return trees.length;
  82. }
  83. public Tree getTree(int i) {
  84. return trees[i];
  85. }
  86. public void setTree(int i, Tree tree) {
  87. if (!trees[i].sameLength(tree))
  88. throw new LengthNotMatchException(trees[i].getNumBytes() + " != " + tree.getNumBytes());
  89. trees[i] = tree;
  90. }
  91. // init trees
  92. private void initTrees(Metadata md, Random rand) {
  93. trees = new Tree[md.getNumTrees()];
  94. for (int i = 0; i < trees.length; i++)
  95. trees[i] = new Tree(i, md, rand);
  96. }
  97. // insert records into ORAM forest
  98. private void insertRecords(Metadata md) {
  99. int numTrees = trees.length;
  100. int tau = md.getTau();
  101. int w = md.getW();
  102. int lastNBits = md.getAddrBits() - tau * (numTrees - 2);
  103. // mapping between address N and leaf tuple
  104. @SuppressWarnings("unchecked")
  105. HashMap<Long, Long>[] addrToTuple = new HashMap[numTrees];
  106. for (int i = 1; i < numTrees; i++)
  107. addrToTuple[i] = new HashMap<Long, Long>();
  108. // keep track of each(current) inserted tuple's N and L for each tree
  109. long[] N = new long[numTrees];
  110. long[] L = new long[numTrees];
  111. // start inserting records with address addr
  112. for (long addr = 0; addr < md.getNumInsertRecords(); addr++) {
  113. // for each tree (from last to first)
  114. for (int i = numTrees - 1; i >= 0; i--) {
  115. long numBuckets = trees[i].getNumBuckets();
  116. // index of bucket that contains the tuple we will insert/update
  117. long bucketIndex = 0;
  118. // index of tuple within the bucket
  119. int tupleIndex = 0;
  120. // set correct bucket/tuple index on trees after the first one
  121. if (i > 0) {
  122. // set N of the tuple
  123. if (i == numTrees - 1)
  124. N[i] = addr;
  125. else if (i == numTrees - 2)
  126. N[i] = N[i + 1] >> lastNBits;
  127. else
  128. N[i] = N[i + 1] >> tau;
  129. // find the corresponding leaf tuple index using N
  130. Long leafTupleIndex = addrToTuple[i].get(N[i]);
  131. // if N is a new address, then find an unused leaf tuple
  132. if (leafTupleIndex == null) {
  133. do {
  134. leafTupleIndex = Util.nextLong(Crypto.sr, (numBuckets / 2 + 1) * w);
  135. } while (addrToTuple[i].containsValue(leafTupleIndex));
  136. addrToTuple[i].put(N[i], leafTupleIndex);
  137. }
  138. // get leaf tuple label and set bucket/tuple index
  139. L[i] = leafTupleIndex / (long) w;
  140. bucketIndex = L[i] + numBuckets / 2;
  141. tupleIndex = (int) (leafTupleIndex % w);
  142. }
  143. // retrieve the tuple that needs to be updated
  144. Tuple targetTuple = trees[i].getBucket(bucketIndex).getTuple(tupleIndex);
  145. // for all trees except the last one,
  146. // update only one label bits in the A field of the target tuple
  147. if (i < numTrees - 1) {
  148. int indexN;
  149. if (i == numTrees - 2)
  150. indexN = (int) Util.getSubBits(N[i + 1], lastNBits, 0);
  151. else
  152. indexN = (int) Util.getSubBits(N[i + 1], tau, 0);
  153. int start = indexN * trees[i].getAlBytes();
  154. int end = start + trees[i].getAlBytes();
  155. targetTuple.setSubA(start, end, Util.rmSignBit(BigInteger.valueOf(L[i + 1]).toByteArray()));
  156. }
  157. // for the last tree, update the whole A field of the target
  158. // tuple
  159. else
  160. targetTuple.setA(Util.rmSignBit(BigInteger.valueOf(addr).toByteArray()));
  161. // for all trees except the first one,
  162. // also update F, N, L fields
  163. // no need to update F, N, L for the first tree
  164. if (i > 0) {
  165. targetTuple.setF(new byte[] { 1 });
  166. targetTuple.setN(Util.rmSignBit(BigInteger.valueOf(N[i]).toByteArray()));
  167. targetTuple.setL(Util.rmSignBit(BigInteger.valueOf(L[i]).toByteArray()));
  168. }
  169. }
  170. }
  171. }
  172. public Forest xor(Forest f) {
  173. if (!this.sameLength(f))
  174. throw new LengthNotMatchException(numBytes + " != " + f.getNumBytes());
  175. Forest newForest = new Forest(f);
  176. for (int i = 0; i < trees.length; i++)
  177. newForest.trees[i] = trees[i].xor(f.getTree(i));
  178. return newForest;
  179. }
  180. public void setXor(Forest f) {
  181. if (!this.sameLength(f))
  182. throw new LengthNotMatchException(numBytes + " != " + f.getNumBytes());
  183. for (int i = 0; i < trees.length; i++)
  184. trees[i].setXor(f.getTree(i));
  185. }
  186. public boolean sameLength(Forest f) {
  187. return numBytes == f.getNumBytes();
  188. }
  189. public void print() {
  190. System.out.println("===== ORAM Forest =====");
  191. System.out.println();
  192. for (int i = 0; i < trees.length; i++) {
  193. System.out.println("***** Tree " + i + " *****");
  194. for (int j = 0; j < trees[i].getNumBuckets(); j++)
  195. System.out.println(trees[i].getBucket(j));
  196. System.out.println();
  197. }
  198. System.out.println("===== End of Forest =====");
  199. System.out.println();
  200. }
  201. public void writeToFile() {
  202. writeToFile(forestFileName);
  203. }
  204. public void writeToFile(String filename) {
  205. FileOutputStream fos = null;
  206. ObjectOutputStream oos = null;
  207. try {
  208. fos = new FileOutputStream(folderName + filename);
  209. oos = new ObjectOutputStream(fos);
  210. oos.writeObject(this);
  211. } catch (IOException e) {
  212. e.printStackTrace();
  213. } finally {
  214. if (oos != null)
  215. try {
  216. oos.close();
  217. } catch (IOException e) {
  218. e.printStackTrace();
  219. }
  220. }
  221. }
  222. public static Forest readFromFile(String filename) {
  223. FileInputStream fis = null;
  224. ObjectInputStream ois = null;
  225. Forest forest = null;
  226. try {
  227. fis = new FileInputStream(folderName + filename);
  228. ois = new ObjectInputStream(fis);
  229. forest = (Forest) ois.readObject();
  230. } catch (IOException | ClassNotFoundException e) {
  231. e.printStackTrace();
  232. } finally {
  233. if (ois != null)
  234. try {
  235. ois.close();
  236. } catch (IOException e) {
  237. e.printStackTrace();
  238. }
  239. }
  240. forest.forestFileName = filename;
  241. return forest;
  242. }
  243. }