PIRRetrieve.java 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. package pir;
  2. import java.math.BigInteger;
  3. import java.util.Arrays;
  4. import communication.Communication;
  5. import crypto.Crypto;
  6. import exceptions.NoSuchPartyException;
  7. import oram.Forest;
  8. import oram.Global;
  9. import oram.Metadata;
  10. import oram.Tree;
  11. import oram.Tuple;
  12. import protocols.Eviction;
  13. import protocols.Protocol;
  14. import protocols.UpdateRoot;
  15. import protocols.struct.OutFF;
  16. import protocols.struct.OutPIRAccess;
  17. import protocols.struct.OutULiT;
  18. import protocols.struct.Party;
  19. import protocols.struct.TwoThreeXorByte;
  20. import protocols.struct.TwoThreeXorInt;
  21. import util.M;
  22. import util.StopWatch;
  23. import util.Util;
  24. // TODO: really FlipFlag on path, and update path in Eviction
  25. // TODO: fix simulation
  26. public class PIRRetrieve extends Protocol {
  27. Communication[] cons1;
  28. Communication[] cons2;
  29. public PIRRetrieve(Communication con1, Communication con2) {
  30. super(con1, con2);
  31. }
  32. public void setCons(Communication[] a, Communication[] b) {
  33. cons1 = a;
  34. cons2 = b;
  35. }
  36. public OutPIRAccess runE(Metadata md, Tree tree_DE, Tree tree_CE, byte[] Li, TwoThreeXorByte L, TwoThreeXorByte N,
  37. TwoThreeXorInt dN) {
  38. timer.start(M.online_comp);
  39. int treeIndex = tree_DE.getTreeIndex();
  40. boolean isLastTree = treeIndex == md.getNumTrees() - 1;
  41. boolean isFirstTree = treeIndex == 0;
  42. PIRAccess piracc = new PIRAccess(con1, con2);
  43. OutPIRAccess outpiracc = piracc.runE(md, tree_DE, tree_CE, Li, L, N, dN);
  44. OutULiT T = new OutULiT();
  45. if (!isLastTree) {
  46. TwoThreeXorByte Lp = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex));
  47. TwoThreeXorByte Lpi = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex + 1));
  48. ULiT ulit = new ULiT(con1, con2, Crypto.sr_DE, Crypto.sr_CE);
  49. T = ulit.runE(outpiracc.X, N, dN, Lp, Lpi, outpiracc.nextL, md.getTwoTauPow());
  50. } else {
  51. T.DE = outpiracc.pathTuples_DE[0];
  52. T.CE = outpiracc.pathTuples_CE[0];
  53. }
  54. int pathTuples = outpiracc.pathTuples_CE.length;
  55. if (!isFirstTree) {
  56. byte[][] fb_DE = new byte[pathTuples][];
  57. byte[][] fb_CE = new byte[pathTuples][];
  58. for (int i = 0; i < pathTuples; i++) {
  59. fb_DE[i] = outpiracc.pathTuples_DE[i].getF();
  60. fb_CE[i] = outpiracc.pathTuples_CE[i].getF();
  61. }
  62. FlipFlag ff = new FlipFlag(con1, con2);
  63. OutFF outff = ff.runE(fb_DE, fb_CE, outpiracc.j.s_DE);
  64. for (int i = 0; i < pathTuples; i++) {
  65. // outpiracc.pathTuples_DE[i].setF(outff.fb_DE[i]);
  66. // outpiracc.pathTuples_CE[i].setF(outff.fb_CE[i]);
  67. }
  68. }
  69. int stashSize = tree_DE.getStashSize();
  70. int[] tupleParam = new int[] { treeIndex == 0 ? 0 : 1, md.getNBytesOfTree(treeIndex),
  71. md.getLBytesOfTree(treeIndex), md.getABytesOfTree(treeIndex) };
  72. Tuple[] path = new Tuple[pathTuples];
  73. for (int i = 0; i < pathTuples; i++) {
  74. path[i] = outpiracc.pathTuples_DE[i].xor(outpiracc.pathTuples_CE[i]);
  75. }
  76. Tuple[] R = Arrays.copyOfRange(path, 0, stashSize);
  77. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  78. R = updateroot.runE(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, R, T.DE.xor(T.CE));
  79. System.arraycopy(R, 0, path, 0, R.length);
  80. Eviction eviction = new Eviction(con1, con2);
  81. eviction.runE(isFirstTree, tupleParam, Li, path, tree_DE);
  82. // simulation of Reshare
  83. timer.start(M.online_write);
  84. con2.write(online_band, path);
  85. timer.stop(M.online_write);
  86. timer.start(M.online_read);
  87. con2.readTupleArrayAndDec();
  88. timer.stop(M.online_read);
  89. // second eviction sim
  90. for (int i = 0; i < pathTuples; i++) {
  91. path[i] = outpiracc.pathTuples_DE[i].xor(outpiracc.pathTuples_CE[i]);
  92. }
  93. R = Arrays.copyOfRange(path, 0, stashSize);
  94. R = updateroot.runE(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, R, T.DE.xor(T.CE));
  95. System.arraycopy(R, 0, path, 0, R.length);
  96. eviction.runE(isFirstTree, tupleParam, Li, path, tree_DE);
  97. // simulation of Reshare
  98. timer.start(M.online_write);
  99. con2.write(online_band, path);
  100. timer.stop(M.online_write);
  101. timer.start(M.online_read);
  102. con2.readTupleArrayAndDec();
  103. timer.stop(M.online_read);
  104. timer.stop(M.online_comp);
  105. return outpiracc;
  106. }
  107. public OutPIRAccess runD(Metadata md, Tree tree_DE, Tree tree_CD, byte[] Li, TwoThreeXorByte L, TwoThreeXorByte N,
  108. TwoThreeXorInt dN) {
  109. timer.start(M.online_comp);
  110. int treeIndex = tree_DE.getTreeIndex();
  111. boolean isLastTree = treeIndex == md.getNumTrees() - 1;
  112. boolean isFirstTree = treeIndex == 0;
  113. PIRAccess piracc = new PIRAccess(con1, con2);
  114. OutPIRAccess outpiracc = piracc.runD(md, tree_DE, tree_CD, Li, L, N, dN);
  115. OutULiT T = new OutULiT();
  116. if (!isLastTree) {
  117. TwoThreeXorByte Lp = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex));
  118. TwoThreeXorByte Lpi = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex + 1));
  119. ULiT ulit = new ULiT(con1, con2, Crypto.sr_DE, Crypto.sr_CD);
  120. T = ulit.runD(outpiracc.X, N, dN, Lp, Lpi, outpiracc.nextL, md.getTwoTauPow());
  121. } else {
  122. T.CD = outpiracc.pathTuples_CD[0];
  123. T.DE = outpiracc.pathTuples_DE[0];
  124. }
  125. int pathTuples = outpiracc.pathTuples_CD.length;
  126. if (!isFirstTree) {
  127. byte[][] fb_DE = new byte[pathTuples][];
  128. byte[][] fb_CD = new byte[pathTuples][];
  129. for (int i = 0; i < pathTuples; i++) {
  130. fb_DE[i] = outpiracc.pathTuples_DE[i].getF();
  131. fb_CD[i] = outpiracc.pathTuples_CD[i].getF();
  132. }
  133. FlipFlag ff = new FlipFlag(con1, con2);
  134. OutFF outff = ff.runD(fb_DE, fb_CD, outpiracc.j.s_DE);
  135. for (int i = 0; i < pathTuples; i++) {
  136. // outpiracc.pathTuples_DE[i].setF(outff.fb_DE[i]);
  137. // outpiracc.pathTuples_CD[i].setF(outff.fb_CD[i]);
  138. }
  139. }
  140. int stashSize = tree_DE.getStashSize();
  141. int[] tupleParam = new int[] { treeIndex == 0 ? 0 : 1, md.getNBytesOfTree(treeIndex),
  142. md.getLBytesOfTree(treeIndex), md.getABytesOfTree(treeIndex) };
  143. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  144. updateroot.runD(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, tree_DE.getW());
  145. Eviction eviction = new Eviction(con1, con2);
  146. eviction.runD(isFirstTree, tupleParam, Li, tree_DE);
  147. // second eviction sim
  148. updateroot.runD(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, tree_DE.getW());
  149. eviction.runD(isFirstTree, tupleParam, Li, tree_DE);
  150. timer.stop(M.online_comp);
  151. return outpiracc;
  152. }
  153. public OutPIRAccess runC(Metadata md, Tree tree_CD, Tree tree_CE, byte[] Li, TwoThreeXorByte L, TwoThreeXorByte N,
  154. TwoThreeXorInt dN) {
  155. timer.start(M.online_comp);
  156. int treeIndex = tree_CE.getTreeIndex();
  157. boolean isLastTree = treeIndex == md.getNumTrees() - 1;
  158. boolean isFirstTree = treeIndex == 0;
  159. PIRAccess piracc = new PIRAccess(con1, con2);
  160. OutPIRAccess outpiracc = piracc.runC(md, tree_CD, tree_CE, Li, L, N, dN);
  161. OutULiT T = new OutULiT();
  162. if (!isLastTree) {
  163. TwoThreeXorByte Lp = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex));
  164. TwoThreeXorByte Lpi = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex + 1));
  165. ULiT ulit = new ULiT(con1, con2, Crypto.sr_CE, Crypto.sr_CD);
  166. T = ulit.runC(outpiracc.X, N, dN, Lp, Lpi, outpiracc.nextL, md.getTwoTauPow());
  167. } else {
  168. T.CD = outpiracc.pathTuples_CD[0];
  169. T.CE = outpiracc.pathTuples_CE[0];
  170. }
  171. int pathTuples = outpiracc.pathTuples_CD.length;
  172. if (!isFirstTree) {
  173. byte[][] fb_CE = new byte[pathTuples][];
  174. byte[][] fb_CD = new byte[pathTuples][];
  175. for (int i = 0; i < pathTuples; i++) {
  176. fb_CE[i] = outpiracc.pathTuples_CE[i].getF();
  177. fb_CD[i] = outpiracc.pathTuples_CD[i].getF();
  178. }
  179. FlipFlag ff = new FlipFlag(con1, con2);
  180. OutFF outff = ff.runC(fb_CD, fb_CE, outpiracc.j.t_C);
  181. for (int i = 0; i < pathTuples; i++) {
  182. // outpiracc.pathTuples_CD[i].setF(outff.fb_CD[i]);
  183. // outpiracc.pathTuples_CE[i].setF(outff.fb_CE[i]);
  184. }
  185. }
  186. int stashSize = tree_CE.getStashSize();
  187. int[] tupleParam = new int[] { treeIndex == 0 ? 0 : 1, md.getNBytesOfTree(treeIndex),
  188. md.getLBytesOfTree(treeIndex), md.getABytesOfTree(treeIndex) };
  189. Tuple[] path = outpiracc.pathTuples_CD;
  190. Tuple[] R = Arrays.copyOfRange(path, 0, stashSize);
  191. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  192. R = updateroot.runC(isFirstTree, tupleParam, R, T.CD);
  193. System.arraycopy(R, 0, path, 0, R.length);
  194. Eviction eviction = new Eviction(con1, con2);
  195. eviction.runC(isFirstTree, tupleParam, path, tree_CD.getD(), stashSize, tree_CD.getW());
  196. // simulation of Reshare
  197. timer.start(M.online_write);
  198. con1.write(online_band, path);
  199. timer.stop(M.online_write);
  200. timer.start(M.online_read);
  201. con1.readTupleArrayAndDec();
  202. timer.stop(M.online_read);
  203. // second eviction sim
  204. R = Arrays.copyOfRange(path, 0, stashSize);
  205. R = updateroot.runC(isFirstTree, tupleParam, R, T.CD);
  206. System.arraycopy(R, 0, path, 0, R.length);
  207. eviction.runC(isFirstTree, tupleParam, path, tree_CD.getD(), stashSize, tree_CD.getW());
  208. // simulation of Reshare
  209. timer.start(M.online_write);
  210. con1.write(online_band, path);
  211. timer.stop(M.online_write);
  212. timer.start(M.online_read);
  213. con1.readTupleArrayAndDec();
  214. timer.stop(M.online_read);
  215. timer.stop(M.online_comp);
  216. return outpiracc;
  217. }
  218. @Override
  219. public void run(Party party, Metadata md, Forest[] forest) {
  220. StopWatch ete = new StopWatch("ETE");
  221. Tree tree_CD = null;
  222. Tree tree_DE = null;
  223. Tree tree_CE = null;
  224. int iterations = 100;
  225. int reset = 20;
  226. for (int test = 0; test < iterations; test++) {
  227. if (test == reset) {
  228. timer.reset();
  229. ete.reset();
  230. }
  231. if (test == 1) {
  232. Global.bandSwitch = false;
  233. }
  234. for (int treeIndex = 0; treeIndex < md.getNumTrees(); treeIndex++) {
  235. if (party == Party.Eddie) {
  236. tree_DE = forest[0].getTree(treeIndex);
  237. tree_CE = forest[1].getTree(treeIndex);
  238. } else if (party == Party.Debbie) {
  239. tree_DE = forest[0].getTree(treeIndex);
  240. tree_CD = forest[1].getTree(treeIndex);
  241. } else if (party == Party.Charlie) {
  242. tree_CE = forest[0].getTree(treeIndex);
  243. tree_CD = forest[1].getTree(treeIndex);
  244. } else {
  245. throw new NoSuchPartyException(party + "");
  246. }
  247. int Llen = md.getLBytesOfTree(treeIndex);
  248. int Nlen = md.getNBytesOfTree(treeIndex);
  249. TwoThreeXorInt dN = new TwoThreeXorInt();
  250. TwoThreeXorByte N = new TwoThreeXorByte();
  251. N.CD = new byte[Nlen];
  252. N.DE = new byte[Nlen];
  253. N.CE = new byte[Nlen];
  254. TwoThreeXorByte L = new TwoThreeXorByte();
  255. L.CD = new byte[Llen];
  256. L.DE = new byte[Llen];
  257. L.CE = new byte[Llen];
  258. byte[] Li = new byte[Llen];
  259. if (party == Party.Eddie) {
  260. ete.start();
  261. OutPIRAccess out = this.runE(md, tree_DE, tree_CE, Li, L, N, dN);
  262. ete.stop();
  263. out.j.t_D = con1.readInt();
  264. out.j.t_C = con2.readInt();
  265. out.X.CD = con1.read();
  266. int pathTuples = out.pathTuples_CE.length;
  267. int index = (out.j.t_D + out.j.s_CE) % pathTuples;
  268. byte[] X = Util.xor(Util.xor(out.X.DE, out.X.CE), out.X.CD);
  269. boolean fail = false;
  270. if (index != 0) {
  271. System.err.println(test + " " + treeIndex + ": PIRAcc test failed on KSearch index");
  272. fail = true;
  273. }
  274. if (new BigInteger(1, X).intValue() != 0) {
  275. System.err.println(test + " " + treeIndex + ": PIRAcc test failed on 3ShiftPIR X");
  276. fail = true;
  277. }
  278. if (treeIndex < md.getNumTrees() - 1 && new BigInteger(1, out.Lip1).intValue() != 0) {
  279. System.err.println(test + " " + treeIndex + ": PIRAcc test failed on 3ShiftXorPIR Lip1");
  280. fail = true;
  281. }
  282. if (!fail)
  283. System.out.println(test + " " + treeIndex + ": PIRAcc test passed");
  284. } else if (party == Party.Debbie) {
  285. ete.start();
  286. OutPIRAccess out = this.runD(md, tree_DE, tree_CD, Li, L, N, dN);
  287. ete.stop();
  288. con1.write(out.j.t_D);
  289. con1.write(out.X.CD);
  290. } else if (party == Party.Charlie) {
  291. ete.start();
  292. OutPIRAccess out = this.runC(md, tree_CD, tree_CE, Li, L, N, dN);
  293. ete.stop();
  294. con1.write(out.j.t_C);
  295. } else {
  296. throw new NoSuchPartyException(party + "");
  297. }
  298. }
  299. }
  300. // Bandwidth total = new Bandwidth("Total Online");
  301. // for (int i = 0; i < P.size; i++) {
  302. // for (int j = 0; j < cons1.length; j++)
  303. // total.add(cons1[j].bandwidth[i].add(cons2[j].bandwidth[i]).bandwidth);
  304. // }
  305. // System.out.println(total.toString());
  306. // timer.divideBy(iterations - reset);
  307. // timer.print();
  308. System.out.println(ete.toMS());
  309. sanityCheck();
  310. }
  311. // for testing correctness
  312. @Override
  313. public void run(Party party, Metadata md, Forest forest) {
  314. }
  315. }