Retrieve.java 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. package protocols;
  2. import java.math.BigInteger;
  3. import java.util.Arrays;
  4. import communication.Communication;
  5. import crypto.Crypto;
  6. import exceptions.AccessException;
  7. import exceptions.NoSuchPartyException;
  8. import oram.Forest;
  9. import oram.Global;
  10. import oram.Metadata;
  11. import oram.Tree;
  12. import oram.Tuple;
  13. import protocols.precomputation.PreRetrieve;
  14. import protocols.struct.OutAccess;
  15. import protocols.struct.OutRetrieve;
  16. import protocols.struct.Party;
  17. import protocols.struct.PreData;
  18. import util.Bandwidth;
  19. import util.P;
  20. import util.StopWatch;
  21. import util.Timer;
  22. import util.Util;
  23. public class Retrieve extends Protocol {
  24. Communication[] cons1;
  25. Communication[] cons2;
  26. public Retrieve(Communication con1, Communication con2) {
  27. super(con1, con2);
  28. }
  29. public void setCons(Communication[] a, Communication[] b) {
  30. cons1 = a;
  31. cons2 = b;
  32. }
  33. public void runE(PreData[] predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, int h, Timer timer) {
  34. // 1st eviction
  35. Access access = new Access(con1, con2);
  36. Reshuffle reshuffle = new Reshuffle(con1, con2);
  37. PostProcessT postprocesst = new PostProcessT(con1, con2);
  38. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  39. Eviction eviction = new Eviction(con1, con2);
  40. OutAccess outaccess = access.runE(predata[0], OTi, Ni, Nip1_pr, timer);
  41. Tuple[] path = reshuffle.runE(predata[0], outaccess.E_P, OTi.getTreeIndex() == 0, timer);
  42. Tuple Ti = postprocesst.runE(predata[0], outaccess.E_Ti, OTi.getTreeIndex() == h - 1, timer);
  43. Tuple[] root = Arrays.copyOfRange(path, 0, OTi.getStashSize());
  44. root = updateroot.runE(predata[0], OTi.getTreeIndex() == 0, outaccess.Li, root, Ti, timer);
  45. System.arraycopy(root, 0, path, 0, root.length);
  46. eviction.runE(predata[0], OTi.getTreeIndex() == 0, outaccess.Li,
  47. OTi.getTreeIndex() == 0 ? new Tuple[] { Ti } : path, OTi, timer);
  48. // 2nd eviction
  49. OutAccess outaccess2 = access.runE2(OTi, timer);
  50. Tuple[] path2 = outaccess2.E_P;
  51. Tuple Ti2 = outaccess2.E_Ti;
  52. Tuple[] root2 = Arrays.copyOfRange(path2, 0, OTi.getStashSize());
  53. root2 = updateroot.runE(predata[1], OTi.getTreeIndex() == 0, outaccess2.Li, root2, Ti2, timer);
  54. System.arraycopy(root2, 0, path2, 0, root2.length);
  55. eviction.runE(predata[1], OTi.getTreeIndex() == 0, outaccess2.Li,
  56. OTi.getTreeIndex() == 0 ? new Tuple[] { Ti2 } : path2, OTi, timer);
  57. }
  58. public void runD(PreData predata[], Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer timer) {
  59. // 1st eviction
  60. Access access = new Access(con1, con2);
  61. Reshuffle reshuffle = new Reshuffle(con1, con2);
  62. PostProcessT postprocesst = new PostProcessT(con1, con2);
  63. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  64. Eviction eviction = new Eviction(con1, con2);
  65. byte[] Li = access.runD(predata[0], OTi, Ni, Nip1_pr, timer);
  66. reshuffle.runD();
  67. postprocesst.runD();
  68. updateroot.runD(predata[0], OTi.getTreeIndex() == 0, Li, OTi.getW(), timer);
  69. eviction.runD(predata[0], OTi.getTreeIndex() == 0, Li, OTi, timer);
  70. // 2nd eviction
  71. byte[] Li2 = access.runD2(OTi, timer);
  72. updateroot.runD(predata[1], OTi.getTreeIndex() == 0, Li2, OTi.getW(), timer);
  73. eviction.runD(predata[1], OTi.getTreeIndex() == 0, Li2, OTi, timer);
  74. }
  75. public OutAccess runC(PreData[] predata, Metadata md, int ti, byte[] Li, Timer timer) {
  76. // 1st eviction
  77. Access access = new Access(con1, con2);
  78. Reshuffle reshuffle = new Reshuffle(con1, con2);
  79. PostProcessT postprocesst = new PostProcessT(con1, con2);
  80. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  81. Eviction eviction = new Eviction(con1, con2);
  82. OutAccess outaccess = access.runC(md, ti, Li, timer);
  83. Tuple[] path = reshuffle.runC(predata[0], outaccess.C_P, ti == 0, timer);
  84. Tuple Ti = postprocesst.runC(predata[0], outaccess.C_Ti, Li, outaccess.C_Lip1, outaccess.C_j2,
  85. ti == md.getNumTrees() - 1, timer);
  86. Tuple[] root = Arrays.copyOfRange(path, 0, md.getStashSizeOfTree(ti));
  87. root = updateroot.runC(predata[0], ti == 0, root, Ti, timer);
  88. System.arraycopy(root, 0, path, 0, root.length);
  89. eviction.runC(predata[0], ti == 0, ti == 0 ? new Tuple[] { Ti } : path, md.getLBitsOfTree(ti) + 1,
  90. md.getStashSizeOfTree(ti), md.getW(), timer);
  91. // 2nd eviction
  92. byte[] Li2 = Util.nextBytes(md.getLBytesOfTree(ti), Crypto.sr);
  93. OutAccess outaccess2 = access.runC2(md, ti, Li2, timer);
  94. Tuple[] path2 = outaccess2.C_P;
  95. Tuple Ti2 = outaccess2.C_Ti;
  96. Tuple[] root2 = Arrays.copyOfRange(path2, 0, md.getStashSizeOfTree(ti));
  97. root2 = updateroot.runC(predata[1], ti == 0, root2, Ti2, timer);
  98. System.arraycopy(root2, 0, path2, 0, root2.length);
  99. eviction.runC(predata[1], ti == 0, ti == 0 ? new Tuple[] { Ti2 } : path2, md.getLBitsOfTree(ti) + 1,
  100. md.getStashSizeOfTree(ti), md.getW(), timer);
  101. return outaccess;
  102. }
  103. public Pipeline pipelineE(PreData[] predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, int h, Timer[] timer) {
  104. Access access = new Access(con1, con2);
  105. OutAccess outaccess = access.runE(predata[0], OTi, Ni, Nip1_pr, timer[0]);
  106. int ti = OTi.getTreeIndex();
  107. Pipeline pipeline = new Pipeline(cons1[ti + 1], cons2[ti + 1], Party.Eddie, predata, OTi, h, timer[ti + 1],
  108. null, ti, outaccess.Li, outaccess);
  109. pipeline.start();
  110. return pipeline;
  111. }
  112. public Pipeline pipelineD(PreData predata[], Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer[] timer) {
  113. Access access = new Access(con1, con2);
  114. byte[] Li = access.runD(predata[0], OTi, Ni, Nip1_pr, timer[0]);
  115. int ti = OTi.getTreeIndex();
  116. Pipeline pipeline = new Pipeline(cons1[ti + 1], cons2[ti + 1], Party.Debbie, predata, OTi, 0, timer[ti + 1],
  117. null, ti, Li, null);
  118. pipeline.start();
  119. return pipeline;
  120. }
  121. public OutRetrieve pipelineC(PreData[] predata, Metadata md, int ti, byte[] Li, Timer[] timer) {
  122. Access access = new Access(con1, con2);
  123. OutAccess outaccess = access.runC(md, ti, Li, timer[0]);
  124. Pipeline pipeline = new Pipeline(cons1[ti + 1], cons2[ti + 1], Party.Charlie, predata, null, 0, timer[ti + 1],
  125. md, ti, Li, outaccess);
  126. pipeline.start();
  127. return new OutRetrieve(outaccess, pipeline);
  128. }
  129. // for testing correctness
  130. @Override
  131. public void run(Party party, Metadata md, Forest forest) {
  132. if (Global.pipeline)
  133. System.out.println("Pipeline Mode is On");
  134. if (Global.cheat)
  135. System.out.println("Cheat Mode is On");
  136. int records = 30;
  137. int reset = 5;
  138. int repeat = 10;
  139. int tau = md.getTau();
  140. int numTrees = md.getNumTrees();
  141. long numInsert = md.getNumInsertRecords();
  142. int addrBits = md.getAddrBits();
  143. int numTimer = Global.pipeline ? numTrees + 1 : 1;
  144. Timer[] timer = new Timer[numTimer];
  145. for (int i = 0; i < numTimer; i++)
  146. timer[i] = new Timer();
  147. StopWatch ete_off = new StopWatch("ETE_offline");
  148. StopWatch ete_on = new StopWatch("ETE_online");
  149. long[] gates = new long[2];
  150. Pipeline[] threads = new Pipeline[numTrees];
  151. sanityCheck();
  152. System.out.println();
  153. for (int i = 0; i < records; i++) {
  154. long N = Global.cheat ? 0 : Util.nextLong(numInsert, Crypto.sr);
  155. for (int j = 0; j < repeat; j++) {
  156. int cycleIndex = i * repeat + j;
  157. if (cycleIndex == reset * repeat) {
  158. for (int k = 0; k < timer.length; k++)
  159. timer[k].reset();
  160. ete_on.reset();
  161. ete_off.reset();
  162. }
  163. if (cycleIndex == 1) {
  164. for (int k = 0; k < cons1.length; k++) {
  165. cons1[k].bandSwitch = false;
  166. cons2[k].bandSwitch = false;
  167. }
  168. }
  169. System.out.println("Test: " + i + " " + j);
  170. System.out.println("N=" + BigInteger.valueOf(N).toString(2));
  171. System.out.print("Precomputation... ");
  172. PreData[][] predata = new PreData[numTrees][2];
  173. PreRetrieve preretrieve = new PreRetrieve(con1, con2);
  174. for (int ti = 0; ti < numTrees; ti++) {
  175. predata[ti][0] = new PreData();
  176. predata[ti][1] = new PreData();
  177. if (party == Party.Eddie) {
  178. ete_off.start();
  179. preretrieve.runE(predata[ti], md, ti, timer[0]);
  180. ete_off.stop();
  181. } else if (party == Party.Debbie) {
  182. ete_off.start();
  183. long[] cnt = preretrieve.runD(predata[ti], md, ti, ti == 0 ? null : predata[ti - 1][0],
  184. timer[0]);
  185. ete_off.stop();
  186. if (cycleIndex == 0) {
  187. gates[0] += cnt[0];
  188. gates[1] += cnt[1];
  189. }
  190. } else if (party == Party.Charlie) {
  191. ete_off.start();
  192. preretrieve.runC(predata[ti], md, ti, ti == 0 ? null : predata[ti - 1][0], timer[0]);
  193. ete_off.stop();
  194. } else {
  195. throw new NoSuchPartyException(party + "");
  196. }
  197. }
  198. sanityCheck();
  199. System.out.println("done!");
  200. byte[] Li = new byte[0];
  201. for (int ti = 0; ti < numTrees; ti++) {
  202. long Ni_value = Util.getSubBits(N, addrBits, addrBits - md.getNBitsOfTree(ti));
  203. long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
  204. Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
  205. byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
  206. byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
  207. if (party == Party.Eddie) {
  208. Tree OTi = forest.getTree(ti);
  209. byte[] sE_Ni = Util.nextBytes(Ni.length, Crypto.sr);
  210. byte[] sD_Ni = Util.xor(Ni, sE_Ni);
  211. con1.write(sD_Ni);
  212. byte[] sE_Nip1_pr = Util.nextBytes(Nip1_pr.length, Crypto.sr);
  213. byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
  214. con1.write(sD_Nip1_pr);
  215. if (!Global.pipeline) {
  216. ete_on.start();
  217. runE(predata[ti], OTi, sE_Ni, sE_Nip1_pr, numTrees, timer[0]);
  218. ete_on.stop();
  219. } else {
  220. if (ti == 0)
  221. ete_on.start();
  222. threads[ti] = pipelineE(predata[ti], OTi, sE_Ni, sE_Nip1_pr, numTrees, timer);
  223. }
  224. if (ti == numTrees - 1)
  225. con2.write(N);
  226. } else if (party == Party.Debbie) {
  227. Tree OTi = forest.getTree(ti);
  228. byte[] sD_Ni = con1.read();
  229. byte[] sD_Nip1_pr = con1.read();
  230. if (!Global.pipeline) {
  231. ete_on.start();
  232. runD(predata[ti], OTi, sD_Ni, sD_Nip1_pr, timer[0]);
  233. ete_on.stop();
  234. } else {
  235. if (ti == 0)
  236. ete_on.start();
  237. threads[ti] = pipelineD(predata[ti], OTi, sD_Ni, sD_Nip1_pr, timer);
  238. }
  239. } else if (party == Party.Charlie) {
  240. int lBits = md.getLBitsOfTree(ti);
  241. System.out.println("L" + ti + "="
  242. + Util.addZeros(Util.getSubBits(new BigInteger(1, Li), lBits, 0).toString(2), lBits));
  243. OutAccess outaccess = null;
  244. if (!Global.pipeline) {
  245. ete_on.start();
  246. outaccess = runC(predata[ti], md, ti, Li, timer[0]);
  247. ete_on.stop();
  248. } else {
  249. if (ti == 0)
  250. ete_on.start();
  251. OutRetrieve outretrieve = pipelineC(predata[ti], md, ti, Li, timer);
  252. outaccess = outretrieve.outaccess;
  253. threads[ti] = outretrieve.pipeline;
  254. }
  255. Li = outaccess.C_Lip1;
  256. if (ti == numTrees - 1) {
  257. N = con1.readLong();
  258. long data = new BigInteger(1, outaccess.C_Ti.getA()).longValue();
  259. if (N == data) {
  260. System.out.println("Access passed");
  261. System.out.println();
  262. } else {
  263. throw new AccessException("Access failed");
  264. }
  265. }
  266. } else {
  267. throw new NoSuchPartyException(party + "");
  268. }
  269. }
  270. if (Global.pipeline) {
  271. for (int ti = 0; ti < numTrees; ti++) {
  272. try {
  273. threads[ti].join();
  274. } catch (InterruptedException e) {
  275. e.printStackTrace();
  276. }
  277. }
  278. ete_on.stop();
  279. }
  280. }
  281. }
  282. System.out.println();
  283. Timer sum = new Timer();
  284. for (int i = 0; i < timer.length; i++)
  285. sum = sum.add(timer[i]);
  286. sum.noPrePrint();
  287. System.out.println();
  288. StopWatch comEnc = new StopWatch("CE_online_comp");
  289. for (int i = 0; i < cons1.length; i++)
  290. comEnc = comEnc.add(cons1[i].comEnc.add(cons2[i].comEnc));
  291. System.out.println(comEnc.noPreToMS());
  292. System.out.println();
  293. if (Global.pipeline)
  294. ete_on.elapsedCPU = 0;
  295. System.out.println(ete_on.noPreToMS());
  296. System.out.println(ete_off.noPreToMS());
  297. System.out.println();
  298. Bandwidth[] bandwidth = new Bandwidth[P.size];
  299. for (int i = 0; i < P.size; i++) {
  300. bandwidth[i] = new Bandwidth(P.names[i]);
  301. for (int j = 0; j < cons1.length; j++)
  302. bandwidth[i] = bandwidth[i].add(cons1[j].bandwidth[i].add(cons2[j].bandwidth[i]));
  303. System.out.println(bandwidth[i].noPreToString());
  304. }
  305. System.out.println();
  306. System.out.println(gates[0]);
  307. System.out.println(gates[1]);
  308. System.out.println();
  309. sanityCheck();
  310. }
  311. }