PIRAccess.java 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. package pir;
  2. import java.math.BigInteger;
  3. import java.util.Arrays;
  4. import org.apache.commons.lang3.ArrayUtils;
  5. import communication.Communication;
  6. import crypto.Crypto;
  7. import exceptions.AccessException;
  8. import exceptions.NoSuchPartyException;
  9. import oram.Bucket;
  10. import oram.Forest;
  11. import oram.Global;
  12. import oram.Metadata;
  13. import oram.Tree;
  14. import oram.Tuple;
  15. import protocols.Protocol;
  16. import protocols.precomputation.PreAccess;
  17. import protocols.struct.OutAccess;
  18. import protocols.struct.OutSSCOT;
  19. import protocols.struct.OutSSIOT;
  20. import protocols.struct.Party;
  21. import protocols.struct.PreData;
  22. import util.M;
  23. import util.P;
  24. import util.Timer;
  25. import util.Util;
  26. public class PIRAccess extends Protocol {
  27. private int pid = P.ACC;
  28. public PIRAccess(Communication con1, Communication con2) {
  29. super(con1, con2);
  30. }
  31. public OutAccess runE(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer timer) {
  32. timer.start(pid, M.online_comp);
  33. // step 0: get Li from C
  34. byte[] Li = new byte[0];
  35. // timer.start(pid, M.online_read);
  36. if (OTi.getTreeIndex() > 0)
  37. Li = con2.read();
  38. // timer.stop(pid, M.online_read);
  39. // step 1
  40. Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  41. Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
  42. // for (int i = 0; i < pathTuples.length; i++)
  43. // pathTuples[i].setXor(predata.access_p[i]);
  44. pathTuples = Util.permute(pathTuples, predata.access_sigma);
  45. // step 3
  46. // byte[] y = null;
  47. // if (OTi.getTreeIndex() == 0)
  48. // y = pathTuples[0].getA();
  49. // else if (OTi.getTreeIndex() < OTi.getH() - 1)
  50. // y = Util.nextBytes(OTi.getABytes(), Crypto.sr);
  51. // else
  52. // y = new byte[OTi.getABytes()];
  53. if (OTi.getTreeIndex() > 0) {
  54. byte[][] a = new byte[pathTuples.length][];
  55. // byte[][] m = new byte[pathTuples.length][];
  56. for (int i = 0; i < pathTuples.length; i++) {
  57. // m[i] = Util.xor(pathTuples[i].getA(), y);
  58. a[i] = ArrayUtils.addAll(pathTuples[i].getF(), pathTuples[i].getN());
  59. for (int j = 0; j < Ni.length; j++)
  60. a[i][a[i].length - 1 - j] ^= Ni[Ni.length - 1 - j];
  61. }
  62. PIRCOT sscot = new PIRCOT(con1, con2);
  63. sscot.runE(predata, a, timer);
  64. // int j1 = con2.readInt();
  65. // y = pathTuples[j1].getA();
  66. }
  67. // con2.write(y);
  68. // step 4
  69. if (OTi.getTreeIndex() < OTi.getH() - 1) {
  70. // int ySegBytes = y.length / OTi.getTwoTauPow();
  71. // byte[][] y_array = new byte[OTi.getTwoTauPow()][];
  72. // for (int i = 0; i < OTi.getTwoTauPow(); i++)
  73. // y_array[i] = Arrays.copyOfRange(y, i * ySegBytes, (i + 1) *
  74. // ySegBytes);
  75. PIRIOT ssiot = new PIRIOT(con1, con2);
  76. ssiot.runE(predata, OTi.getTwoTauPow(), Nip1_pr, timer);
  77. }
  78. // PIR
  79. int[] j = new int[] { 0, 0 };
  80. timer.start(pid, M.online_write);
  81. con1.write(pid, j);
  82. con2.write(pid, j);
  83. timer.stop(pid, M.online_write);
  84. timer.start(pid, M.online_read);
  85. con1.readIntArray(pid);
  86. j = con2.readIntArray(pid);
  87. timer.stop(pid, M.online_read);
  88. byte[] y = null;
  89. if (OTi.getTreeIndex() == 0)
  90. y = pathTuples[0].getA();
  91. else
  92. y = pathTuples[j[0]].getA();
  93. timer.start(pid, M.online_write);
  94. con1.write(pid, y);
  95. con2.write(pid, y);
  96. timer.stop(pid, M.online_write);
  97. timer.start(pid, M.online_read);
  98. con1.read(pid);
  99. con2.read(pid);
  100. timer.stop(pid, M.online_read);
  101. // step 5
  102. Tuple Ti = null;
  103. if (OTi.getTreeIndex() == 0)
  104. Ti = pathTuples[0];
  105. else
  106. Ti = new Tuple(new byte[1], Ni, Li, y);
  107. OutAccess outaccess = new OutAccess(Li, null, null, null, null, Ti, pathTuples);
  108. timer.stop(pid, M.online_comp);
  109. return outaccess;
  110. }
  111. public byte[] runD(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer timer) {
  112. timer.start(pid, M.online_comp);
  113. // step 0: get Li from C
  114. byte[] Li = new byte[0];
  115. // timer.start(pid, M.online_read);
  116. if (OTi.getTreeIndex() > 0)
  117. Li = con2.read();
  118. // timer.stop(pid, M.online_read);
  119. // step 1
  120. Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  121. Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
  122. // for (int i = 0; i < pathTuples.length; i++)
  123. // pathTuples[i].setXor(predata.access_p[i]);
  124. pathTuples = Util.permute(pathTuples, predata.access_sigma);
  125. // step 2
  126. timer.start(pid, M.online_write);
  127. // con2.write(pid, pathTuples);
  128. con2.write(pid, Ni);
  129. timer.stop(pid, M.online_write);
  130. // step 3
  131. if (OTi.getTreeIndex() > 0) {
  132. byte[][] b = new byte[pathTuples.length][];
  133. for (int i = 0; i < pathTuples.length; i++) {
  134. b[i] = ArrayUtils.addAll(pathTuples[i].getF(), pathTuples[i].getN());
  135. b[i][0] ^= 1;
  136. for (int j = 0; j < Ni.length; j++)
  137. b[i][b[i].length - 1 - j] ^= Ni[Ni.length - 1 - j];
  138. }
  139. PIRCOT sscot = new PIRCOT(con1, con2);
  140. sscot.runD(predata, b, timer);
  141. }
  142. // step 4
  143. if (OTi.getTreeIndex() < OTi.getH() - 1) {
  144. PIRIOT ssiot = new PIRIOT(con1, con2);
  145. ssiot.runD(predata, Nip1_pr, timer);
  146. }
  147. // PIR
  148. int[] j = new int[] { 0, 0 };
  149. timer.start(pid, M.online_write);
  150. con1.write(pid, j);
  151. con2.write(pid, j);
  152. timer.stop(pid, M.online_write);
  153. timer.start(pid, M.online_read);
  154. con1.readIntArray(pid);
  155. j = con2.readIntArray(pid);
  156. timer.stop(pid, M.online_read);
  157. byte[] A = null;
  158. if (OTi.getTreeIndex() == 0)
  159. A = pathTuples[0].getA();
  160. else
  161. A = pathTuples[j[0]].getA();
  162. timer.start(pid, M.online_write);
  163. con1.write(pid, A);
  164. con2.write(pid, A);
  165. timer.stop(pid, M.online_write);
  166. timer.start(pid, M.online_read);
  167. con1.read(pid);
  168. con2.read(pid);
  169. timer.stop(pid, M.online_read);
  170. timer.stop(pid, M.online_comp);
  171. return Li;
  172. }
  173. public OutAccess runC(Metadata md, Tree OTi, int treeIndex, byte[] Li, Timer timer) {
  174. timer.start(pid, M.online_comp);
  175. // step 0: send Li to E and D
  176. // timer.start(pid, M.online_write);
  177. if (treeIndex > 0) {
  178. con1.write(Li);
  179. con2.write(Li);
  180. }
  181. // timer.stop(pid, M.online_write);
  182. // step 1
  183. Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  184. Tuple[] originalTuples = Bucket.bucketsToTuples(pathBuckets);
  185. Tuple[] pathTuples = new Tuple[originalTuples.length];
  186. for (int i = 0; i < pathTuples.length; i++)
  187. pathTuples[i] = new Tuple(originalTuples[i]);
  188. // step 2
  189. timer.start(pid, M.online_read);
  190. // Tuple[] pathTuples = con2.readTupleArray(pid);
  191. byte[] Ni = con2.read(pid);
  192. timer.stop(pid, M.online_read);
  193. // for (int i = 0; i < pathTuples.length; i++)
  194. // System.out.println("AAAAA " + pathTuples[i]);
  195. // step 3
  196. int j1 = 0;
  197. byte[] z = null;
  198. if (treeIndex == 0) {
  199. z = pathTuples[0].getA().clone();
  200. } else {
  201. PIRCOT sscot = new PIRCOT(con1, con2);
  202. OutSSCOT je = sscot.runC(timer);
  203. j1 = je.t;
  204. byte[] d = pathTuples[j1].getA().clone();
  205. // z = Util.xor(je.m_t, d);
  206. z = d;
  207. // con1.write(j1);
  208. }
  209. // byte[] y = con1.read();
  210. // step 4
  211. int j2 = 0;
  212. byte[] Lip1 = null;
  213. if (treeIndex < md.getNumTrees() - 1) {
  214. PIRIOT ssiot = new PIRIOT(con1, con2);
  215. OutSSIOT jy = ssiot.runC(timer);
  216. // step 5
  217. j2 = jy.t;
  218. // int lSegBytes = md.getABytesOfTree(treeIndex) /
  219. // md.getTwoTauPow();
  220. // byte[] z_j2 = Arrays.copyOfRange(z, j2 * lSegBytes, (j2 + 1) *
  221. // lSegBytes);
  222. // Lip1 = Util.xor(jy.m_t, z_j2);
  223. // Lip1 = Arrays.copyOfRange(Util.xor(z, y), j2 * lSegBytes, (j2 +
  224. // 1) * lSegBytes);
  225. }
  226. // PIR
  227. int[] j = new int[] { j1, j2 };
  228. timer.start(pid, M.online_write);
  229. con1.write(pid, j);
  230. con2.write(pid, j);
  231. timer.stop(pid, M.online_write);
  232. timer.start(pid, M.online_read);
  233. con1.readIntArray(pid);
  234. con2.readIntArray(pid);
  235. timer.stop(pid, M.online_read);
  236. int ABytes = md.getABytesOfTree(treeIndex);
  237. byte[] A = new byte[ABytes];
  238. timer.start(pid, M.online_write);
  239. con1.write(pid, A);
  240. con2.write(pid, A);
  241. timer.stop(pid, M.online_write);
  242. timer.start(pid, M.online_read);
  243. byte[] y = con1.read(pid);
  244. con2.read(pid);
  245. timer.stop(pid, M.online_read);
  246. if (treeIndex < md.getNumTrees() - 1) {
  247. int lSegBytes = ABytes / md.getTwoTauPow();
  248. Lip1 = Arrays.copyOfRange(Util.xor(z, y), j2 * lSegBytes, (j2 + 1) * lSegBytes);
  249. }
  250. Tuple Ti = null;
  251. if (treeIndex == 0) {
  252. Ti = pathTuples[0];
  253. } else {
  254. Ti = new Tuple(new byte[] { 1 }, Ni, new byte[md.getLBytesOfTree(treeIndex)], z);
  255. pathTuples[j1].getF()[0] = (byte) (1 - pathTuples[j1].getF()[0]);
  256. Crypto.sr.nextBytes(pathTuples[j1].getN());
  257. Crypto.sr.nextBytes(pathTuples[j1].getL());
  258. Crypto.sr.nextBytes(pathTuples[j1].getA());
  259. }
  260. OutAccess outaccess = new OutAccess(Li, Lip1, Ti, pathTuples, j2, null, null);
  261. timer.stop(pid, M.online_comp);
  262. return outaccess;
  263. }
  264. public OutAccess runE2(Tree OTi, Timer timer) {
  265. timer.start(pid, M.online_comp);
  266. // step 0: get Li from C
  267. byte[] Li = new byte[0];
  268. // timer.start(pid, M.online_read);
  269. if (OTi.getTreeIndex() > 0)
  270. Li = con2.read();
  271. // timer.stop(pid, M.online_read);
  272. // step 1
  273. Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  274. Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
  275. // step 5
  276. Tuple Ti = null;
  277. if (OTi.getTreeIndex() == 0)
  278. Ti = pathTuples[0];
  279. else {
  280. Ti = new Tuple(1, OTi.getNBytes(), OTi.getLBytes(), OTi.getABytes(), Crypto.sr);
  281. Ti.setF(new byte[1]);
  282. }
  283. OutAccess outaccess = new OutAccess(Li, null, null, null, null, Ti, pathTuples);
  284. timer.stop(pid, M.online_comp);
  285. return outaccess;
  286. }
  287. public byte[] runD2(Tree OTi, Timer timer) {
  288. timer.start(pid, M.online_comp);
  289. // step 0: get Li from C
  290. byte[] Li = new byte[0];
  291. // timer.start(pid, M.online_read);
  292. if (OTi.getTreeIndex() > 0)
  293. Li = con2.read();
  294. // timer.stop(pid, M.online_read);
  295. // step 1
  296. // Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  297. // Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
  298. // step 2
  299. // timer.start(pid, M.online_write);
  300. // con2.write(pid, pathTuples);
  301. // timer.stop(pid, M.online_write);
  302. timer.stop(pid, M.online_comp);
  303. return Li;
  304. }
  305. public OutAccess runC2(Metadata md, Tree OTi, int treeIndex, byte[] Li, Timer timer) {
  306. timer.start(pid, M.online_comp);
  307. // step 0: send Li to E and D
  308. // timer.start(pid, M.online_write);
  309. if (treeIndex > 0) {
  310. con1.write(Li);
  311. con2.write(Li);
  312. }
  313. // timer.stop(pid, M.online_write);
  314. // step 1
  315. Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  316. Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
  317. // step 2
  318. // timer.start(pid, M.online_read);
  319. // Tuple[] pathTuples = con2.readTupleArray(pid);
  320. // timer.stop(pid, M.online_read);
  321. // step 5
  322. Tuple Ti = null;
  323. if (treeIndex == 0) {
  324. Ti = pathTuples[0];
  325. } else {
  326. Ti = new Tuple(1, md.getNBytesOfTree(treeIndex), md.getLBytesOfTree(treeIndex),
  327. md.getABytesOfTree(treeIndex), Crypto.sr);
  328. Ti.setF(new byte[1]);
  329. }
  330. OutAccess outaccess = new OutAccess(Li, null, Ti, pathTuples, null, null, null);
  331. timer.stop(pid, M.online_comp);
  332. return outaccess;
  333. }
  334. // for testing correctness
  335. @Override
  336. public void run(Party party, Metadata md, Forest forest) {
  337. int records = 5;
  338. int repeat = 5;
  339. int tau = md.getTau();
  340. int numTrees = md.getNumTrees();
  341. long numInsert = md.getNumInsertRecords();
  342. int addrBits = md.getAddrBits();
  343. Timer timer = new Timer();
  344. sanityCheck();
  345. System.out.println();
  346. for (int i = 0; i < records; i++) {
  347. long N = Global.cheat ? 0 : Util.nextLong(numInsert, Crypto.sr);
  348. for (int j = 0; j < repeat; j++) {
  349. System.out.println("Test: " + i + " " + j);
  350. System.out.println("N=" + BigInteger.valueOf(N).toString(2));
  351. byte[] Li = new byte[0];
  352. for (int ti = 0; ti < numTrees; ti++) {
  353. long Ni_value = Util.getSubBits(N, addrBits, addrBits - md.getNBitsOfTree(ti));
  354. long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
  355. Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
  356. byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
  357. byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
  358. PreData predata = new PreData();
  359. PreAccess preaccess = new PreAccess(con1, con2);
  360. if (party == Party.Eddie) {
  361. Tree OTi = forest.getTree(ti);
  362. int numTuples = (OTi.getD() - 1) * OTi.getW() + OTi.getStashSize();
  363. int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
  364. md.getABytesOfTree(ti) };
  365. preaccess.runE(predata, md.getTwoTauPow(), numTuples, tupleParam, timer);
  366. byte[] sE_Ni = Util.nextBytes(Ni.length, Crypto.sr);
  367. byte[] sD_Ni = Util.xor(Ni, sE_Ni);
  368. con1.write(sD_Ni);
  369. byte[] sE_Nip1_pr = Util.nextBytes(Nip1_pr.length, Crypto.sr);
  370. byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
  371. con1.write(sD_Nip1_pr);
  372. OutAccess outaccess = runE(predata, OTi, sE_Ni, sE_Nip1_pr, timer);
  373. if (ti == numTrees - 1) {
  374. con2.write(N);
  375. con2.write(outaccess.E_Ti);
  376. }
  377. } else if (party == Party.Debbie) {
  378. Tree OTi = forest.getTree(ti);
  379. preaccess.runD(predata, timer);
  380. byte[] sD_Ni = con1.read();
  381. byte[] sD_Nip1_pr = con1.read();
  382. runD(predata, OTi, sD_Ni, sD_Nip1_pr, timer);
  383. } else if (party == Party.Charlie) {
  384. Tree OTi = forest.getTree(ti);
  385. preaccess.runC(timer);
  386. System.out.println("L" + ti + "=" + new BigInteger(1, Li).toString(2));
  387. OutAccess outaccess = runC(md, OTi, ti, Li, timer);
  388. Li = outaccess.C_Lip1;
  389. if (ti == numTrees - 1) {
  390. N = con1.readLong();
  391. Tuple E_Ti = con1.readTuple();
  392. long data = new BigInteger(1, Util.xor(outaccess.C_Ti.getA(), E_Ti.getA())).longValue();
  393. if (N == data) {
  394. System.out.println("PIR Access passed");
  395. System.out.println();
  396. } else {
  397. throw new AccessException("PIR Access failed: " + N + " != " + data);
  398. }
  399. }
  400. } else {
  401. throw new NoSuchPartyException(party + "");
  402. }
  403. }
  404. }
  405. }
  406. // timer.print();
  407. }
  408. }