PIRRetrieve.java 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. package protocols;
  2. import java.math.BigInteger;
  3. import java.util.Arrays;
  4. import communication.Communication;
  5. import crypto.Crypto;
  6. import exceptions.NoSuchPartyException;
  7. import oram.Forest;
  8. import oram.Global;
  9. import oram.Metadata;
  10. import oram.Tree;
  11. import oram.Tuple;
  12. import struct.OutFF;
  13. import struct.OutPIRAccess;
  14. import struct.OutULiT;
  15. import struct.Party;
  16. import struct.TwoThreeXorByte;
  17. import struct.TwoThreeXorInt;
  18. import util.Bandwidth;
  19. import util.M;
  20. import util.StopWatch;
  21. import util.Timer;
  22. import util.Util;
  23. // TODO: really FlipFlag on path, and update path in Eviction
  24. // TODO: fix simulation
  25. public class PIRRetrieve extends Protocol {
  26. Communication[] cons1;
  27. Communication[] cons2;
  28. public PIRRetrieve(Communication con1, Communication con2) {
  29. super(con1, con2);
  30. online_band = new Bandwidth();
  31. offline_band = new Bandwidth();
  32. timer = new Timer();
  33. }
  34. public void setCons(Communication[] a, Communication[] b) {
  35. cons1 = a;
  36. cons2 = b;
  37. }
  38. public OutPIRAccess runE(Metadata md, Tree tree_DE, Tree tree_CE, byte[] Li, TwoThreeXorByte L, TwoThreeXorByte N,
  39. TwoThreeXorInt dN) {
  40. timer.start(M.online_comp);
  41. int treeIndex = tree_DE.getTreeIndex();
  42. boolean isLastTree = treeIndex == md.getNumTrees() - 1;
  43. boolean isFirstTree = treeIndex == 0;
  44. PIRAccess piracc = new PIRAccess(con1, con2);
  45. OutPIRAccess outpiracc = piracc.runE(md, tree_DE, tree_CE, Li, L, N, dN);
  46. OutULiT T = new OutULiT();
  47. if (!isLastTree) {
  48. TwoThreeXorByte Lp = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex));
  49. TwoThreeXorByte Lpi = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex + 1));
  50. ULiT ulit = new ULiT(con1, con2, Crypto.sr_DE, Crypto.sr_CE);
  51. T = ulit.runE(outpiracc.X, N, dN, Lp, Lpi, outpiracc.nextL, md.getTwoTauPow());
  52. } else {
  53. T.DE = outpiracc.pathTuples_DE[0];
  54. T.CE = outpiracc.pathTuples_CE[0];
  55. }
  56. int pathTuples = outpiracc.pathTuples_CE.length;
  57. if (!isFirstTree) {
  58. byte[][] fb_DE = new byte[pathTuples][];
  59. byte[][] fb_CE = new byte[pathTuples][];
  60. for (int i = 0; i < pathTuples; i++) {
  61. fb_DE[i] = outpiracc.pathTuples_DE[i].getF();
  62. fb_CE[i] = outpiracc.pathTuples_CE[i].getF();
  63. }
  64. FlipFlag ff = new FlipFlag(con1, con2);
  65. OutFF outff = ff.runE(fb_DE, fb_CE, outpiracc.j.s_DE);
  66. for (int i = 0; i < pathTuples; i++) {
  67. // outpiracc.pathTuples_DE[i].setF(outff.fb_DE[i]);
  68. // outpiracc.pathTuples_CE[i].setF(outff.fb_CE[i]);
  69. }
  70. }
  71. int stashSize = tree_DE.getStashSize();
  72. int[] tupleParam = new int[] { treeIndex == 0 ? 0 : 1, md.getNBytesOfTree(treeIndex),
  73. md.getLBytesOfTree(treeIndex), md.getABytesOfTree(treeIndex) };
  74. Tuple[] path = new Tuple[pathTuples];
  75. for (int i = 0; i < pathTuples; i++) {
  76. path[i] = outpiracc.pathTuples_DE[i].xor(outpiracc.pathTuples_CE[i]);
  77. }
  78. Tuple[] R = Arrays.copyOfRange(path, 0, stashSize);
  79. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  80. R = updateroot.runE(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, R, T.DE.xor(T.CE));
  81. System.arraycopy(R, 0, path, 0, R.length);
  82. Eviction eviction = new Eviction(con1, con2);
  83. eviction.runE(isFirstTree, tupleParam, Li, path, tree_DE);
  84. // simulation of Reshare
  85. timer.start(M.online_write);
  86. con2.write(online_band, path);
  87. timer.stop(M.online_write);
  88. timer.start(M.online_read);
  89. con2.readTupleArrayAndDec();
  90. timer.stop(M.online_read);
  91. // second eviction sim
  92. for (int i = 0; i < pathTuples; i++) {
  93. path[i] = outpiracc.pathTuples_DE[i].xor(outpiracc.pathTuples_CE[i]);
  94. }
  95. R = Arrays.copyOfRange(path, 0, stashSize);
  96. R = updateroot.runE(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, R, T.DE.xor(T.CE));
  97. System.arraycopy(R, 0, path, 0, R.length);
  98. eviction.runE(isFirstTree, tupleParam, Li, path, tree_DE);
  99. // simulation of Reshare
  100. timer.start(M.online_write);
  101. con2.write(online_band, path);
  102. timer.stop(M.online_write);
  103. timer.start(M.online_read);
  104. con2.readTupleArrayAndDec();
  105. timer.stop(M.online_read);
  106. timer.stop(M.online_comp);
  107. return outpiracc;
  108. }
  109. public OutPIRAccess runD(Metadata md, Tree tree_DE, Tree tree_CD, byte[] Li, TwoThreeXorByte L, TwoThreeXorByte N,
  110. TwoThreeXorInt dN) {
  111. timer.start(M.online_comp);
  112. int treeIndex = tree_DE.getTreeIndex();
  113. boolean isLastTree = treeIndex == md.getNumTrees() - 1;
  114. boolean isFirstTree = treeIndex == 0;
  115. PIRAccess piracc = new PIRAccess(con1, con2);
  116. OutPIRAccess outpiracc = piracc.runD(md, tree_DE, tree_CD, Li, L, N, dN);
  117. OutULiT T = new OutULiT();
  118. if (!isLastTree) {
  119. TwoThreeXorByte Lp = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex));
  120. TwoThreeXorByte Lpi = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex + 1));
  121. ULiT ulit = new ULiT(con1, con2, Crypto.sr_DE, Crypto.sr_CD);
  122. T = ulit.runD(outpiracc.X, N, dN, Lp, Lpi, outpiracc.nextL, md.getTwoTauPow());
  123. } else {
  124. T.CD = outpiracc.pathTuples_CD[0];
  125. T.DE = outpiracc.pathTuples_DE[0];
  126. }
  127. int pathTuples = outpiracc.pathTuples_CD.length;
  128. if (!isFirstTree) {
  129. byte[][] fb_DE = new byte[pathTuples][];
  130. byte[][] fb_CD = new byte[pathTuples][];
  131. for (int i = 0; i < pathTuples; i++) {
  132. fb_DE[i] = outpiracc.pathTuples_DE[i].getF();
  133. fb_CD[i] = outpiracc.pathTuples_CD[i].getF();
  134. }
  135. FlipFlag ff = new FlipFlag(con1, con2);
  136. OutFF outff = ff.runD(fb_DE, fb_CD, outpiracc.j.s_DE);
  137. for (int i = 0; i < pathTuples; i++) {
  138. // outpiracc.pathTuples_DE[i].setF(outff.fb_DE[i]);
  139. // outpiracc.pathTuples_CD[i].setF(outff.fb_CD[i]);
  140. }
  141. }
  142. int stashSize = tree_DE.getStashSize();
  143. int[] tupleParam = new int[] { treeIndex == 0 ? 0 : 1, md.getNBytesOfTree(treeIndex),
  144. md.getLBytesOfTree(treeIndex), md.getABytesOfTree(treeIndex) };
  145. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  146. updateroot.runD(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, tree_DE.getW());
  147. Eviction eviction = new Eviction(con1, con2);
  148. eviction.runD(isFirstTree, tupleParam, Li, tree_DE);
  149. // second eviction sim
  150. updateroot.runD(isFirstTree, stashSize, md.getLBitsOfTree(treeIndex), tupleParam, Li, tree_DE.getW());
  151. eviction.runD(isFirstTree, tupleParam, Li, tree_DE);
  152. timer.stop(M.online_comp);
  153. return outpiracc;
  154. }
  155. public OutPIRAccess runC(Metadata md, Tree tree_CD, Tree tree_CE, byte[] Li, TwoThreeXorByte L, TwoThreeXorByte N,
  156. TwoThreeXorInt dN) {
  157. timer.start(M.online_comp);
  158. int treeIndex = tree_CE.getTreeIndex();
  159. boolean isLastTree = treeIndex == md.getNumTrees() - 1;
  160. boolean isFirstTree = treeIndex == 0;
  161. PIRAccess piracc = new PIRAccess(con1, con2);
  162. OutPIRAccess outpiracc = piracc.runC(md, tree_CD, tree_CE, Li, L, N, dN);
  163. OutULiT T = new OutULiT();
  164. if (!isLastTree) {
  165. TwoThreeXorByte Lp = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex));
  166. TwoThreeXorByte Lpi = new TwoThreeXorByte(md.getLBytesOfTree(treeIndex + 1));
  167. ULiT ulit = new ULiT(con1, con2, Crypto.sr_CE, Crypto.sr_CD);
  168. T = ulit.runC(outpiracc.X, N, dN, Lp, Lpi, outpiracc.nextL, md.getTwoTauPow());
  169. } else {
  170. T.CD = outpiracc.pathTuples_CD[0];
  171. T.CE = outpiracc.pathTuples_CE[0];
  172. }
  173. int pathTuples = outpiracc.pathTuples_CD.length;
  174. if (!isFirstTree) {
  175. byte[][] fb_CE = new byte[pathTuples][];
  176. byte[][] fb_CD = new byte[pathTuples][];
  177. for (int i = 0; i < pathTuples; i++) {
  178. fb_CE[i] = outpiracc.pathTuples_CE[i].getF();
  179. fb_CD[i] = outpiracc.pathTuples_CD[i].getF();
  180. }
  181. FlipFlag ff = new FlipFlag(con1, con2);
  182. OutFF outff = ff.runC(fb_CD, fb_CE, outpiracc.j.t_C);
  183. for (int i = 0; i < pathTuples; i++) {
  184. // outpiracc.pathTuples_CD[i].setF(outff.fb_CD[i]);
  185. // outpiracc.pathTuples_CE[i].setF(outff.fb_CE[i]);
  186. }
  187. }
  188. int stashSize = tree_CE.getStashSize();
  189. int[] tupleParam = new int[] { treeIndex == 0 ? 0 : 1, md.getNBytesOfTree(treeIndex),
  190. md.getLBytesOfTree(treeIndex), md.getABytesOfTree(treeIndex) };
  191. Tuple[] path = outpiracc.pathTuples_CD;
  192. Tuple[] R = Arrays.copyOfRange(path, 0, stashSize);
  193. UpdateRoot updateroot = new UpdateRoot(con1, con2);
  194. R = updateroot.runC(isFirstTree, tupleParam, R, T.CD);
  195. System.arraycopy(R, 0, path, 0, R.length);
  196. Eviction eviction = new Eviction(con1, con2);
  197. eviction.runC(isFirstTree, tupleParam, path, tree_CD.getD(), stashSize, tree_CD.getW());
  198. // simulation of Reshare
  199. timer.start(M.online_write);
  200. con1.write(online_band, path);
  201. timer.stop(M.online_write);
  202. timer.start(M.online_read);
  203. con1.readTupleArrayAndDec();
  204. timer.stop(M.online_read);
  205. // second eviction sim
  206. R = Arrays.copyOfRange(path, 0, stashSize);
  207. R = updateroot.runC(isFirstTree, tupleParam, R, T.CD);
  208. System.arraycopy(R, 0, path, 0, R.length);
  209. eviction.runC(isFirstTree, tupleParam, path, tree_CD.getD(), stashSize, tree_CD.getW());
  210. // simulation of Reshare
  211. timer.start(M.online_write);
  212. con1.write(online_band, path);
  213. timer.stop(M.online_write);
  214. timer.start(M.online_read);
  215. con1.readTupleArrayAndDec();
  216. timer.stop(M.online_read);
  217. timer.stop(M.online_comp);
  218. return outpiracc;
  219. }
  220. @Override
  221. public void run(Party party, Metadata md, Forest[] forest) {
  222. StopWatch ete = new StopWatch("ETE");
  223. Tree tree_CD = null;
  224. Tree tree_DE = null;
  225. Tree tree_CE = null;
  226. int iterations = 100;
  227. int reset = 20;
  228. for (int test = 0; test < iterations; test++) {
  229. if (test == reset) {
  230. timer.reset();
  231. ete.reset();
  232. }
  233. if (test == 1) {
  234. Global.bandSwitch = false;
  235. }
  236. for (int treeIndex = 0; treeIndex < md.getNumTrees(); treeIndex++) {
  237. if (party == Party.Eddie) {
  238. tree_DE = forest[0].getTree(treeIndex);
  239. tree_CE = forest[1].getTree(treeIndex);
  240. } else if (party == Party.Debbie) {
  241. tree_DE = forest[0].getTree(treeIndex);
  242. tree_CD = forest[1].getTree(treeIndex);
  243. } else if (party == Party.Charlie) {
  244. tree_CE = forest[0].getTree(treeIndex);
  245. tree_CD = forest[1].getTree(treeIndex);
  246. } else {
  247. throw new NoSuchPartyException(party + "");
  248. }
  249. int Llen = md.getLBytesOfTree(treeIndex);
  250. int Nlen = md.getNBytesOfTree(treeIndex);
  251. TwoThreeXorInt dN = new TwoThreeXorInt();
  252. TwoThreeXorByte N = new TwoThreeXorByte();
  253. N.CD = new byte[Nlen];
  254. N.DE = new byte[Nlen];
  255. N.CE = new byte[Nlen];
  256. TwoThreeXorByte L = new TwoThreeXorByte();
  257. L.CD = new byte[Llen];
  258. L.DE = new byte[Llen];
  259. L.CE = new byte[Llen];
  260. byte[] Li = new byte[Llen];
  261. if (party == Party.Eddie) {
  262. ete.start();
  263. OutPIRAccess out = this.runE(md, tree_DE, tree_CE, Li, L, N, dN);
  264. ete.stop();
  265. out.j.t_D = con1.readInt();
  266. out.j.t_C = con2.readInt();
  267. out.X.CD = con1.read();
  268. int pathTuples = out.pathTuples_CE.length;
  269. int index = (out.j.t_D + out.j.s_CE) % pathTuples;
  270. byte[] X = Util.xor(Util.xor(out.X.DE, out.X.CE), out.X.CD);
  271. boolean fail = false;
  272. if (index != 0) {
  273. System.err.println(test + " " + treeIndex + ": PIRAcc test failed on KSearch index");
  274. fail = true;
  275. }
  276. if (new BigInteger(1, X).intValue() != 0) {
  277. System.err.println(test + " " + treeIndex + ": PIRAcc test failed on 3ShiftPIR X");
  278. fail = true;
  279. }
  280. if (treeIndex < md.getNumTrees() - 1 && new BigInteger(1, out.Lip1).intValue() != 0) {
  281. System.err.println(test + " " + treeIndex + ": PIRAcc test failed on 3ShiftXorPIR Lip1");
  282. fail = true;
  283. }
  284. if (!fail)
  285. System.out.println(test + " " + treeIndex + ": PIRAcc test passed");
  286. } else if (party == Party.Debbie) {
  287. ete.start();
  288. OutPIRAccess out = this.runD(md, tree_DE, tree_CD, Li, L, N, dN);
  289. ete.stop();
  290. con1.write(out.j.t_D);
  291. con1.write(out.X.CD);
  292. } else if (party == Party.Charlie) {
  293. ete.start();
  294. OutPIRAccess out = this.runC(md, tree_CD, tree_CE, Li, L, N, dN);
  295. ete.stop();
  296. con1.write(out.j.t_C);
  297. } else {
  298. throw new NoSuchPartyException(party + "");
  299. }
  300. }
  301. }
  302. // Bandwidth total = new Bandwidth("Total Online");
  303. // for (int i = 0; i < P.size; i++) {
  304. // for (int j = 0; j < cons1.length; j++)
  305. // total.add(cons1[j].bandwidth[i].add(cons2[j].bandwidth[i]).bandwidth);
  306. // }
  307. // System.out.println(total.toString());
  308. // timer.divideBy(iterations - reset);
  309. // timer.print();
  310. System.out.println(ete.toMS());
  311. sanityCheck();
  312. }
  313. // for testing correctness
  314. @Override
  315. public void run(Party party, Metadata md, Forest forest) {
  316. }
  317. }