PIRRetrieve.java 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. package protocols;
  2. import java.math.BigInteger;
  3. import java.util.Arrays;
  4. import communication.Communication;
  5. import crypto.Crypto;
  6. import exceptions.NoSuchPartyException;
  7. import oram.Forest;
  8. import oram.Global;
  9. import oram.Metadata;
  10. import oram.Tree;
  11. import oram.Tuple;
  12. import struct.OutFF;
  13. import struct.OutPIRAccess;
  14. import struct.OutULiT;
  15. import struct.Party;
  16. import struct.TwoThreeXorByte;
  17. import struct.TwoThreeXorInt;
  18. import util.Bandwidth;
  19. import util.M;
  20. import util.P;
  21. import util.StopWatch;
  22. import util.Util;
  23. // TODO: really FlipFlag on path, and update path in Eviction
  24. // TODO: fix simulation
  25. public class PIRRetrieve extends Protocol {
  26. Communication[] cons1;
  27. Communication[] cons2;
  28. int pid = P.RTV;
  29. public PIRRetrieve(Communication con1, Communication con2) {
  30. super(con1, con2);
  31. online_band = all.online_band[pid];
  32. offline_band = all.offline_band[pid];
  33. timer = all.timer[pid];
  34. }
  35. public void setCons(Communication[] a, Communication[] b) {
  36. cons1 = a;
  37. cons2 = b;
  38. }
  39. public OutPIRAccess runE(Metadata md, Tree tree_DE, Tree tree_CE, byte[] Li, TwoThreeXorByte L, TwoThreeXorByte N,
  40. TwoThreeXorInt dN) {
  41. timer.start(M.online_comp);
  42. int treeIndex = tree_DE.getTreeIndex();
  43. boolean isLastTree = treeIndex == md.getNumTrees() - 1;
  44. boolean isFirstTree = treeIndex == 0;
  45. PIRAccess piracc = new PIRAccess(con1, con2);
  46. OutPIRAccess outpiracc = piracc.runE(md, tree_DE, tree_CE, Li, L, N, dN);
  47. OutULiT T = new OutULiT();
  48. if (!isLastTree) {
  49. TwoThreeXorByte Lp = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex));
  50. TwoThreeXorByte Lpi = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex + 1));
  51. ULiT ulit = new ULiT(con1, con2, Crypto.sr_DE, Crypto.sr_CE);
  52. T = ulit.runE(outpiracc.X, N, dN, Lp, Lpi, outpiracc.nextL, md.getTwoTauPow());
  53. } else {
  54. T.DE = outpiracc.pathTuples_DE[0];
  55. T.CE = outpiracc.pathTuples_CE[0];
  56. }
  57. int pathTuples = outpiracc.pathTuples_CE.length;
  58. if (!isFirstTree) {
  59. byte[][] fb_DE = new byte[pathTuples][];
  60. byte[][] fb_CE = new byte[pathTuples][];
  61. for (int i = 0; i < pathTuples; i++) {
  62. fb_DE[i] = outpiracc.pathTuples_DE[i].getF();
  63. fb_CE[i] = outpiracc.pathTuples_CE[i].getF();
  64. }
  65. FlipFlag ff = new FlipFlag(con1, con2);
  66. OutFF outff = ff.runE(fb_DE, fb_CE, outpiracc.j.s_DE);
  67. for (int i = 0; i < pathTuples; i++) {
  68. // outpiracc.pathTuples_DE[i].setF(outff.fb_DE[i]);
  69. // outpiracc.pathTuples_CE[i].setF(outff.fb_CE[i]);
  70. }
  71. }
  72. int stashSize = tree_DE.getStashSize();
  73. int[] tupleParam = new int[] { treeIndex == 0 ? 0 : 1, md.getNBytesOfTree(treeIndex),
  74. md.getLBytesOfTree(treeIndex), md.getABytesOfTree(treeIndex) };
  75. Tuple[] path = new Tuple[pathTuples];
  76. for (int i = 0; i < pathTuples; i++) {
  77. path[i] = outpiracc.pathTuples_DE[i].xor(outpiracc.pathTuples_CE[i]);
  78. }
  79. Tuple[] R = Arrays.copyOfRange(path, 0, stashSize);
  80. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  81. R = updateroot.runE(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, R, T.DE.xor(T.CE));
  82. System.arraycopy(R, 0, path, 0, R.length);
  83. Eviction eviction = new Eviction(con1, con2);
  84. eviction.runE(isFirstTree, tupleParam, Li, path, tree_DE);
  85. // simulation of Reshare
  86. timer.start(M.online_write);
  87. con2.write(online_band, path);
  88. timer.stop(M.online_write);
  89. timer.start(M.online_read);
  90. con2.readTupleArrayAndDec();
  91. timer.stop(M.online_read);
  92. // second eviction sim
  93. for (int i = 0; i < pathTuples; i++) {
  94. path[i] = outpiracc.pathTuples_DE[i].xor(outpiracc.pathTuples_CE[i]);
  95. }
  96. R = Arrays.copyOfRange(path, 0, stashSize);
  97. R = updateroot.runE(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, R, T.DE.xor(T.CE));
  98. System.arraycopy(R, 0, path, 0, R.length);
  99. eviction.runE(isFirstTree, tupleParam, Li, path, tree_DE);
  100. // simulation of Reshare
  101. timer.start(M.online_write);
  102. con2.write(online_band, path);
  103. timer.stop(M.online_write);
  104. timer.start(M.online_read);
  105. con2.readTupleArrayAndDec();
  106. timer.stop(M.online_read);
  107. timer.stop(M.online_comp);
  108. return outpiracc;
  109. }
  110. public OutPIRAccess runD(Metadata md, Tree tree_DE, Tree tree_CD, byte[] Li, TwoThreeXorByte L, TwoThreeXorByte N,
  111. TwoThreeXorInt dN) {
  112. timer.start(M.online_comp);
  113. int treeIndex = tree_DE.getTreeIndex();
  114. boolean isLastTree = treeIndex == md.getNumTrees() - 1;
  115. boolean isFirstTree = treeIndex == 0;
  116. PIRAccess piracc = new PIRAccess(con1, con2);
  117. OutPIRAccess outpiracc = piracc.runD(md, tree_DE, tree_CD, Li, L, N, dN);
  118. OutULiT T = new OutULiT();
  119. if (!isLastTree) {
  120. TwoThreeXorByte Lp = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex));
  121. TwoThreeXorByte Lpi = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex + 1));
  122. ULiT ulit = new ULiT(con1, con2, Crypto.sr_DE, Crypto.sr_CD);
  123. T = ulit.runD(outpiracc.X, N, dN, Lp, Lpi, outpiracc.nextL, md.getTwoTauPow());
  124. } else {
  125. T.CD = outpiracc.pathTuples_CD[0];
  126. T.DE = outpiracc.pathTuples_DE[0];
  127. }
  128. int pathTuples = outpiracc.pathTuples_CD.length;
  129. if (!isFirstTree) {
  130. byte[][] fb_DE = new byte[pathTuples][];
  131. byte[][] fb_CD = new byte[pathTuples][];
  132. for (int i = 0; i < pathTuples; i++) {
  133. fb_DE[i] = outpiracc.pathTuples_DE[i].getF();
  134. fb_CD[i] = outpiracc.pathTuples_CD[i].getF();
  135. }
  136. FlipFlag ff = new FlipFlag(con1, con2);
  137. OutFF outff = ff.runD(fb_DE, fb_CD, outpiracc.j.s_DE);
  138. for (int i = 0; i < pathTuples; i++) {
  139. // outpiracc.pathTuples_DE[i].setF(outff.fb_DE[i]);
  140. // outpiracc.pathTuples_CD[i].setF(outff.fb_CD[i]);
  141. }
  142. }
  143. int stashSize = tree_DE.getStashSize();
  144. int[] tupleParam = new int[] { treeIndex == 0 ? 0 : 1, md.getNBytesOfTree(treeIndex),
  145. md.getLBytesOfTree(treeIndex), md.getABytesOfTree(treeIndex) };
  146. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  147. updateroot.runD(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, tree_DE.getW());
  148. Eviction eviction = new Eviction(con1, con2);
  149. eviction.runD(isFirstTree, tupleParam, Li, tree_DE);
  150. // second eviction sim
  151. updateroot.runD(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, tree_DE.getW());
  152. eviction.runD(isFirstTree, tupleParam, Li, tree_DE);
  153. timer.stop(M.online_comp);
  154. return outpiracc;
  155. }
  156. public OutPIRAccess runC(Metadata md, Tree tree_CD, Tree tree_CE, byte[] Li, TwoThreeXorByte L, TwoThreeXorByte N,
  157. TwoThreeXorInt dN) {
  158. timer.start(M.online_comp);
  159. int treeIndex = tree_CE.getTreeIndex();
  160. boolean isLastTree = treeIndex == md.getNumTrees() - 1;
  161. boolean isFirstTree = treeIndex == 0;
  162. PIRAccess piracc = new PIRAccess(con1, con2);
  163. OutPIRAccess outpiracc = piracc.runC(md, tree_CD, tree_CE, Li, L, N, dN);
  164. OutULiT T = new OutULiT();
  165. if (!isLastTree) {
  166. TwoThreeXorByte Lp = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex));
  167. TwoThreeXorByte Lpi = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex + 1));
  168. ULiT ulit = new ULiT(con1, con2, Crypto.sr_CE, Crypto.sr_CD);
  169. T = ulit.runC(outpiracc.X, N, dN, Lp, Lpi, outpiracc.nextL, md.getTwoTauPow());
  170. } else {
  171. T.CD = outpiracc.pathTuples_CD[0];
  172. T.CE = outpiracc.pathTuples_CE[0];
  173. }
  174. int pathTuples = outpiracc.pathTuples_CD.length;
  175. if (!isFirstTree) {
  176. byte[][] fb_CE = new byte[pathTuples][];
  177. byte[][] fb_CD = new byte[pathTuples][];
  178. for (int i = 0; i < pathTuples; i++) {
  179. fb_CE[i] = outpiracc.pathTuples_CE[i].getF();
  180. fb_CD[i] = outpiracc.pathTuples_CD[i].getF();
  181. }
  182. FlipFlag ff = new FlipFlag(con1, con2);
  183. OutFF outff = ff.runC(fb_CD, fb_CE, outpiracc.j.t_C);
  184. for (int i = 0; i < pathTuples; i++) {
  185. // outpiracc.pathTuples_CD[i].setF(outff.fb_CD[i]);
  186. // outpiracc.pathTuples_CE[i].setF(outff.fb_CE[i]);
  187. }
  188. }
  189. int stashSize = tree_CE.getStashSize();
  190. int[] tupleParam = new int[] { treeIndex == 0 ? 0 : 1, md.getNBytesOfTree(treeIndex),
  191. md.getLBytesOfTree(treeIndex), md.getABytesOfTree(treeIndex) };
  192. Tuple[] path = outpiracc.pathTuples_CD;
  193. Tuple[] R = Arrays.copyOfRange(path, 0, stashSize);
  194. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  195. R = updateroot.runC(isFirstTree, tupleParam, R, T.CD);
  196. System.arraycopy(R, 0, path, 0, R.length);
  197. Eviction eviction = new Eviction(con1, con2);
  198. eviction.runC(isFirstTree, tupleParam, path, tree_CD.getD(), stashSize, tree_CD.getW());
  199. // simulation of Reshare
  200. timer.start(M.online_write);
  201. con1.write(online_band, path);
  202. timer.stop(M.online_write);
  203. timer.start(M.online_read);
  204. con1.readTupleArrayAndDec();
  205. timer.stop(M.online_read);
  206. // second eviction sim
  207. R = Arrays.copyOfRange(path, 0, stashSize);
  208. R = updateroot.runC(isFirstTree, tupleParam, R, T.CD);
  209. System.arraycopy(R, 0, path, 0, R.length);
  210. eviction.runC(isFirstTree, tupleParam, path, tree_CD.getD(), stashSize, tree_CD.getW());
  211. // simulation of Reshare
  212. timer.start(M.online_write);
  213. con1.write(online_band, path);
  214. timer.stop(M.online_write);
  215. timer.start(M.online_read);
  216. con1.readTupleArrayAndDec();
  217. timer.stop(M.online_read);
  218. timer.stop(M.online_comp);
  219. return outpiracc;
  220. }
  221. @Override
  222. public void run(Party party, Metadata md, Forest[] forest) {
  223. StopWatch ete = new StopWatch("ETE");
  224. Tree tree_CD = null;
  225. Tree tree_DE = null;
  226. Tree tree_CE = null;
  227. int iterations = md.getIters()+2;
  228. int reset = 2;
  229. for (int test = 0; test < iterations; test++) {
  230. if (test == reset) {
  231. all.resetTime();
  232. ete.reset();
  233. }
  234. if (test == 1) {
  235. Global.bandSwitch = false;
  236. }
  237. for (int treeIndex = 0; treeIndex < md.getNumTrees(); treeIndex++) {
  238. if (party == Party.Eddie) {
  239. tree_DE = forest[0].getTree(treeIndex);
  240. tree_CE = forest[1].getTree(treeIndex);
  241. } else if (party == Party.Debbie) {
  242. tree_DE = forest[0].getTree(treeIndex);
  243. tree_CD = forest[1].getTree(treeIndex);
  244. } else if (party == Party.Charlie) {
  245. tree_CE = forest[0].getTree(treeIndex);
  246. tree_CD = forest[1].getTree(treeIndex);
  247. } else {
  248. throw new NoSuchPartyException(party + "");
  249. }
  250. int Llen = md.getLBytesOfTree(treeIndex);
  251. int Nlen = md.getNBytesOfTree(treeIndex);
  252. TwoThreeXorInt dN = new TwoThreeXorInt();
  253. TwoThreeXorByte N = new TwoThreeXorByte();
  254. N.CD = new byte[Nlen];
  255. N.DE = new byte[Nlen];
  256. N.CE = new byte[Nlen];
  257. TwoThreeXorByte L = new TwoThreeXorByte();
  258. L.CD = new byte[Llen];
  259. L.DE = new byte[Llen];
  260. L.CE = new byte[Llen];
  261. byte[] Li = new byte[Llen];
  262. if (party == Party.Eddie) {
  263. ete.start();
  264. OutPIRAccess out = this.runE(md, tree_DE, tree_CE, Li, L, N, dN);
  265. ete.stop();
  266. out.j.t_D = con1.readInt();
  267. out.j.t_C = con2.readInt();
  268. out.X.CD = con1.read();
  269. int pathTuples = out.pathTuples_CE.length;
  270. int index = (out.j.t_D + out.j.s_CE) % pathTuples;
  271. byte[] X = Util.xor(Util.xor(out.X.DE, out.X.CE), out.X.CD);
  272. boolean fail = false;
  273. if (index != 0) {
  274. System.err.println(test + " " + treeIndex + ": PIRAcc test failed on KSearch index");
  275. fail = true;
  276. }
  277. if (new BigInteger(1, X).intValue() != 0) {
  278. System.err.println(test + " " + treeIndex + ": PIRAcc test failed on 3ShiftPIR X");
  279. fail = true;
  280. }
  281. if (treeIndex < md.getNumTrees() - 1 && new BigInteger(1, out.Lip1).intValue() != 0) {
  282. System.err.println(test + " " + treeIndex + ": PIRAcc test failed on 3ShiftXorPIR Lip1");
  283. fail = true;
  284. }
  285. if (!fail)
  286. System.out.println(test + " " + treeIndex + ": PIRAcc test passed");
  287. } else if (party == Party.Debbie) {
  288. ete.start();
  289. OutPIRAccess out = this.runD(md, tree_DE, tree_CD, Li, L, N, dN);
  290. ete.stop();
  291. con1.write(out.j.t_D);
  292. con1.write(out.X.CD);
  293. } else if (party == Party.Charlie) {
  294. ete.start();
  295. OutPIRAccess out = this.runC(md, tree_CD, tree_CE, Li, L, N, dN);
  296. ete.stop();
  297. con1.write(out.j.t_C);
  298. } else {
  299. throw new NoSuchPartyException(party + "");
  300. }
  301. }
  302. }
  303. System.out.println();
  304. Bandwidth band_on = new Bandwidth("Total Online Band");
  305. Bandwidth band_off = new Bandwidth("Total Offline Band");
  306. for (int i = 0; i < P.size; i++) {
  307. System.out.println(all.offline_band[i].noPreToString());
  308. System.out.println(all.online_band[i].noPreToString());
  309. band_on.add(all.online_band[i]);
  310. band_off.add(all.offline_band[i]);
  311. }
  312. System.out.println(band_off.noPreToString());
  313. System.out.println(band_on.noPreToString());
  314. System.out.println();
  315. StopWatch time_on = new StopWatch("Total Online Time");
  316. StopWatch time_off = new StopWatch("Total Offline Time");
  317. for (int i = 0; i < P.size; i++) {
  318. for (int j = 0; j < 3; j++)
  319. time_on.add(all.timer[i].watches[j]);
  320. for (int j = 3; j < M.size; j++)
  321. time_off.add(all.timer[i].watches[j]);
  322. all.timer[i].noPrePrint();
  323. }
  324. System.out.println(time_off.noPreToMS());
  325. System.out.println(time_on.noPreToMS());
  326. System.out.println();
  327. System.out.println(ete.noPreToMS());
  328. System.out.println();
  329. sanityCheck();
  330. }
  331. @Override
  332. public void run(Party party, Metadata md, Forest forest) {
  333. }
  334. }