Forest.java 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. package oram;
  2. import java.io.FileInputStream;
  3. import java.io.FileOutputStream;
  4. import java.io.IOException;
  5. import java.io.ObjectInputStream;
  6. import java.io.ObjectOutputStream;
  7. import java.io.Serializable;
  8. import java.math.BigInteger;
  9. import java.util.HashMap;
  10. import java.util.Random;
  11. import crypto.OramCrypto;
  12. import exceptions.LengthNotMatchException;
  13. import util.Util;
  14. public class Forest implements Serializable {
  15. /**
  16. *
  17. */
  18. private static final long serialVersionUID = 1L;
  19. private static String folderName = "data/";
  20. private String defaultFileName;
  21. private int tau;
  22. private int addrBits;
  23. private long numInsertRecords;
  24. private long numBytes;
  25. private Tree[] trees;
  26. // build empty forest and insert records according to config file
  27. public Forest() {
  28. Metadata md = new Metadata();
  29. init(md);
  30. initTrees(md, null);
  31. insertRecords(md);
  32. }
  33. public Forest(Metadata md) {
  34. init(md);
  35. initTrees(md, null);
  36. insertRecords(md);
  37. }
  38. // build empty/random content forest
  39. public Forest(Random rand) {
  40. Metadata md = new Metadata();
  41. init(md);
  42. initTrees(md, rand);
  43. }
  44. // only used in xor operation
  45. // does not shallow/deep copy trees
  46. private Forest(Forest f) {
  47. defaultFileName = f.getDefaultFileName();
  48. tau = f.getTau();
  49. addrBits = f.getAddrBits();
  50. numInsertRecords = f.getNumInsertRecords();
  51. numBytes = f.getNumBytes();
  52. trees = new Tree[f.getNumTrees()];
  53. }
  54. private void init(Metadata md) {
  55. defaultFileName = md.getDefaultForestFileName();
  56. tau = md.getTau();
  57. addrBits = md.getAddrBits();
  58. numInsertRecords = md.getNumInsertRecords();
  59. numBytes = md.getForestBytes();
  60. }
  61. public String getDefaultFileName() {
  62. return defaultFileName;
  63. }
  64. public int getTau() {
  65. return tau;
  66. }
  67. public int getAddrBits() {
  68. return addrBits;
  69. }
  70. public long getNumInsertRecords() {
  71. return numInsertRecords;
  72. }
  73. public long getNumBytes() {
  74. return numBytes;
  75. }
  76. public int getNumTrees() {
  77. return trees.length;
  78. }
  79. public Tree getTree(int i) {
  80. return trees[i];
  81. }
  82. public void setTree(int i, Tree tree) {
  83. if (!trees[i].sameLength(tree))
  84. throw new LengthNotMatchException(trees[i].getNumBytes() + " != " + tree.getNumBytes());
  85. trees[i] = tree;
  86. }
  87. // init trees
  88. private void initTrees(Metadata md, Random rand) {
  89. trees = new Tree[md.getNumTrees()];
  90. for (int i = 0; i < trees.length; i++)
  91. trees[i] = new Tree(i, md, rand);
  92. }
  93. // insert records into ORAM forest
  94. private void insertRecords(Metadata md) {
  95. int numTrees = trees.length;
  96. int tau = md.getTau();
  97. int w = md.getW();
  98. int lastNBits = md.getAddrBits() - tau * (numTrees - 2);
  99. // mapping between address N and leaf tuple
  100. @SuppressWarnings("unchecked")
  101. HashMap<Long, Long>[] addrToTuple = new HashMap[numTrees];
  102. for (int i = 1; i < numTrees; i++)
  103. addrToTuple[i] = new HashMap<Long, Long>();
  104. // keep track of each(current) inserted tuple's N and L for each tree
  105. long[] N = new long[numTrees];
  106. long[] L = new long[numTrees];
  107. // start inserting records with address addr
  108. for (long addr = 0; addr < md.getNumInsertRecords(); addr++) {
  109. // for each tree (from last to first)
  110. for (int i = numTrees - 1; i >= 0; i--) {
  111. long numBuckets = trees[i].getNumBuckets();
  112. // index of bucket that contains the tuple we will insert/update
  113. long bucketIndex = 0;
  114. // index of tuple within the bucket
  115. int tupleIndex = 0;
  116. // set correct bucket/tuple index on trees after the first one
  117. if (i > 0) {
  118. // set N of the tuple
  119. if (i == numTrees - 1)
  120. N[i] = addr;
  121. else if (i == numTrees - 2)
  122. N[i] = N[i + 1] >> lastNBits;
  123. else
  124. N[i] = N[i + 1] >> tau;
  125. // find the corresponding leaf tuple index using N
  126. Long leafTupleIndex = addrToTuple[i].get(N[i]);
  127. // if N is a new address, then find an unused leaf tuple
  128. if (leafTupleIndex == null) {
  129. do {
  130. leafTupleIndex = Util.nextLong(OramCrypto.sr, (numBuckets / 2 + 1) * w);
  131. } while (addrToTuple[i].containsValue(leafTupleIndex));
  132. addrToTuple[i].put(N[i], leafTupleIndex);
  133. }
  134. // get leaf tuple label and set bucket/tuple index
  135. L[i] = leafTupleIndex / (long) w;
  136. bucketIndex = L[i] + numBuckets / 2;
  137. tupleIndex = (int) (leafTupleIndex % w);
  138. }
  139. // retrieve the tuple that needs to be updated
  140. Tuple targetTuple = trees[i].getBucket(bucketIndex).getTuple(tupleIndex);
  141. // for all trees except the last one,
  142. // update only one label bits in the A field of the target tuple
  143. if (i < numTrees - 1) {
  144. int indexN;
  145. if (i == numTrees - 2)
  146. indexN = (int) Util.getSubBits(N[i + 1], lastNBits, 0);
  147. else
  148. indexN = (int) Util.getSubBits(N[i + 1], tau, 0);
  149. int start = indexN * trees[i].getAlBytes();
  150. int end = start + trees[i].getAlBytes();
  151. targetTuple.setSubA(start, end, Util.rmSignBit(BigInteger.valueOf(L[i + 1]).toByteArray()));
  152. }
  153. // for the last tree, update the whole A field of the target
  154. // tuple
  155. else
  156. targetTuple.setA(Util.rmSignBit(BigInteger.valueOf(addr).toByteArray()));
  157. // for all trees except the first one,
  158. // also update F, N, L fields
  159. // no need to update F, N, L for the first tree
  160. if (i > 0) {
  161. targetTuple.setF(new byte[] { 1 });
  162. targetTuple.setN(Util.rmSignBit(BigInteger.valueOf(N[i]).toByteArray()));
  163. targetTuple.setL(Util.rmSignBit(BigInteger.valueOf(L[i]).toByteArray()));
  164. }
  165. }
  166. }
  167. }
  168. public Forest xor(Forest f) {
  169. if (!this.sameLength(f))
  170. throw new LengthNotMatchException(numBytes + " != " + f.getNumBytes());
  171. Forest newForest = new Forest(f);
  172. for (int i = 0; i < trees.length; i++)
  173. newForest.trees[i] = trees[i].xor(f.getTree(i));
  174. return newForest;
  175. }
  176. public void setXor(Forest f) {
  177. if (!this.sameLength(f))
  178. throw new LengthNotMatchException(numBytes + " != " + f.getNumBytes());
  179. for (int i = 0; i < trees.length; i++)
  180. trees[i].setXor(f.getTree(i));
  181. }
  182. public boolean sameLength(Forest f) {
  183. return numBytes == f.getNumBytes();
  184. }
  185. public void print() {
  186. System.out.println("===== ORAM Forest =====");
  187. System.out.println();
  188. for (int i = 0; i < trees.length; i++) {
  189. System.out.println("***** Tree " + i + " *****");
  190. for (int j = 0; j < trees[i].getNumBuckets(); j++)
  191. System.out.println(trees[i].getBucket(j));
  192. System.out.println();
  193. }
  194. System.out.println("===== End of Forest =====");
  195. System.out.println();
  196. }
  197. public void writeToFile() {
  198. writeToFile(defaultFileName);
  199. }
  200. public void writeToFile(String filename) {
  201. FileOutputStream fos = null;
  202. ObjectOutputStream oos = null;
  203. try {
  204. fos = new FileOutputStream(folderName + filename);
  205. oos = new ObjectOutputStream(fos);
  206. oos.writeObject(this);
  207. } catch (IOException e) {
  208. e.printStackTrace();
  209. } finally {
  210. if (oos != null)
  211. try {
  212. oos.close();
  213. } catch (IOException e) {
  214. e.printStackTrace();
  215. }
  216. }
  217. }
  218. public static Forest readFromFile(String filename) {
  219. FileInputStream fis = null;
  220. ObjectInputStream ois = null;
  221. Forest forest = null;
  222. try {
  223. fis = new FileInputStream(folderName + filename);
  224. ois = new ObjectInputStream(fis);
  225. forest = (Forest) ois.readObject();
  226. } catch (IOException | ClassNotFoundException e) {
  227. e.printStackTrace();
  228. } finally {
  229. if (ois != null)
  230. try {
  231. ois.close();
  232. } catch (IOException e) {
  233. e.printStackTrace();
  234. }
  235. }
  236. return forest;
  237. }
  238. }