PIRReshuffle.java 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package pir;
  2. import java.math.BigInteger;
  3. import communication.Communication;
  4. import crypto.Crypto;
  5. import exceptions.NoSuchPartyException;
  6. import oram.Forest;
  7. import oram.Global;
  8. import oram.Metadata;
  9. import oram.Tree;
  10. import pir.precomputation.PrePIRReshuffle;
  11. import protocols.Protocol;
  12. import protocols.precomputation.PreAccess;
  13. import protocols.struct.OutAccess;
  14. import protocols.struct.Party;
  15. import protocols.struct.PreData;
  16. import util.M;
  17. import util.P;
  18. import util.Timer;
  19. import util.Util;
  20. public class PIRReshuffle extends Protocol {
  21. private int pid = P.RSF;
  22. public PIRReshuffle(Communication con1, Communication con2) {
  23. super(con1, con2);
  24. }
  25. public byte[][] runE(PreData predata, byte[][] path, boolean firstTree, Timer timer) {
  26. if (firstTree)
  27. return path;
  28. timer.start(pid, M.online_comp);
  29. // step 1
  30. timer.start(pid, M.online_read);
  31. byte[][] z = con2.readDoubleByteArray(pid);
  32. timer.stop(pid, M.online_read);
  33. // step 2
  34. byte[][] b = new byte[z.length][];
  35. for (int i = 0; i < b.length; i++)
  36. b[i] = Util.xor(Util.xor(path[i], z[i]), predata.pir_reshuffle_r[i]);
  37. byte[][] b_prime = Util.permute(b, predata.reshuffle_pi);
  38. timer.stop(pid, M.online_comp);
  39. return b_prime;
  40. }
  41. public void runD() {
  42. }
  43. public byte[][] runC(PreData predata, byte[][] path, boolean firstTree, Timer timer) {
  44. if (firstTree)
  45. return path;
  46. timer.start(pid, M.online_comp);
  47. // step 1
  48. byte[][] z = new byte[path.length][];
  49. for (int i = 0; i < z.length; i++)
  50. z[i] = Util.xor(path[i], predata.pir_reshuffle_p[i]);
  51. timer.start(pid, M.online_write);
  52. con1.write(pid, z);
  53. timer.stop(pid, M.online_write);
  54. timer.stop(pid, M.online_comp);
  55. return predata.pir_reshuffle_a_prime;
  56. }
  57. // for testing correctness
  58. @Override
  59. public void run(Party party, Metadata md, Forest forest) {
  60. int records = 5;
  61. int repeat = 5;
  62. int tau = md.getTau();
  63. int numTrees = md.getNumTrees();
  64. long numInsert = md.getNumInsertRecords();
  65. int addrBits = md.getAddrBits();
  66. Timer timer = new Timer();
  67. sanityCheck();
  68. System.out.println();
  69. for (int i = 0; i < records; i++) {
  70. long N = Global.cheat ? 0 : Util.nextLong(numInsert, Crypto.sr);
  71. for (int j = 0; j < repeat; j++) {
  72. System.out.println("Test: " + i + " " + j);
  73. System.out.println("N=" + BigInteger.valueOf(N).toString(2));
  74. byte[] Li = new byte[0];
  75. for (int ti = 0; ti < numTrees; ti++) {
  76. long Ni_value = Util.getSubBits(N, addrBits, addrBits - md.getNBitsOfTree(ti));
  77. long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
  78. Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
  79. byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
  80. byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
  81. PreData predata = new PreData();
  82. PreAccess preaccess = new PreAccess(con1, con2);
  83. PIRAccess access = new PIRAccess(con1, con2);
  84. PrePIRReshuffle prereshuffle = new PrePIRReshuffle(con1, con2);
  85. if (party == Party.Eddie) {
  86. Tree OTi = forest.getTree(ti);
  87. int numTuples = (OTi.getD() - 1) * OTi.getW() + OTi.getStashSize();
  88. int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
  89. md.getABytesOfTree(ti) };
  90. preaccess.runE(predata, md.getTwoTauPow(), numTuples, tupleParam, timer);
  91. byte[] sE_Ni = Util.nextBytes(Ni.length, Crypto.sr);
  92. byte[] sD_Ni = Util.xor(Ni, sE_Ni);
  93. con1.write(sD_Ni);
  94. byte[] sE_Nip1_pr = Util.nextBytes(Nip1_pr.length, Crypto.sr);
  95. byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
  96. con1.write(sD_Nip1_pr);
  97. // TODO: fix commented line below
  98. OutAccess outaccess = null;
  99. // OutAccess outaccess = access.runE(predata, OTi, sE_Ni, sE_Nip1_pr, timer);
  100. if (ti == numTrees - 1)
  101. con2.write(N);
  102. prereshuffle.runE(predata, timer);
  103. byte[][] fbArray = new byte[outaccess.E_P.length][];
  104. for (int i1 = 0; i1 < fbArray.length; i1++)
  105. fbArray[i1] = outaccess.E_P[i1].getF().clone();
  106. byte[][] E_P_prime = runE(predata, fbArray, ti == 0, timer);
  107. byte[][] C_P = con2.readDoubleByteArray();
  108. byte[][] C_P_prime = con2.readDoubleByteArray();
  109. byte[][] oldPath = new byte[C_P.length][];
  110. byte[][] newPath = new byte[C_P.length][];
  111. for (int k = 0; k < C_P.length; k++) {
  112. oldPath[k] = Util.xor(outaccess.E_P[k].getF(), C_P[k]);
  113. newPath[k] = Util.xor(E_P_prime[k], C_P_prime[k]);
  114. }
  115. oldPath = Util.permute(oldPath, predata.reshuffle_pi);
  116. boolean pass = true;
  117. for (int k = 0; k < newPath.length; k++) {
  118. if (!Util.equal(oldPath[k], newPath[k])) {
  119. System.err.println("PIR Reshuffle test failed");
  120. pass = false;
  121. break;
  122. }
  123. }
  124. if (pass)
  125. System.out.println("PIR Reshuffle test passed");
  126. } else if (party == Party.Debbie) {
  127. Tree OTi = forest.getTree(ti);
  128. preaccess.runD(predata, timer);
  129. byte[] sD_Ni = con1.read();
  130. byte[] sD_Nip1_pr = con1.read();
  131. // TODO: fix commented line below
  132. // access.runD(predata, OTi, sD_Ni, sD_Nip1_pr, timer);
  133. int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
  134. md.getABytesOfTree(ti) };
  135. prereshuffle.runD(predata, tupleParam, timer);
  136. runD();
  137. } else if (party == Party.Charlie) {
  138. Tree OTi = forest.getTree(ti);
  139. preaccess.runC(timer);
  140. System.out.println("L" + ti + "=" + new BigInteger(1, Li).toString(2));
  141. // TODO: fix commented line below
  142. OutAccess outaccess = null;
  143. // OutAccess outaccess = access.runC(md, OTi, ti, Li, timer);
  144. prereshuffle.runC(predata, timer);
  145. byte[][] fbArray = new byte[outaccess.C_P.length][];
  146. for (int i1 = 0; i1 < fbArray.length; i1++)
  147. fbArray[i1] = outaccess.C_P[i1].getF().clone();
  148. byte[][] C_P_prime = runC(predata, fbArray, ti == 0, timer);
  149. Li = outaccess.C_Lip1;
  150. if (ti == numTrees - 1) {
  151. N = con1.readLong();
  152. // long data = new BigInteger(1,
  153. // outaccess.C_Ti.getA()).longValue();
  154. // if (N == data) {
  155. // System.out.println("Access passed");
  156. // System.out.println();
  157. // } else {
  158. // throw new AccessException("Access failed");
  159. // }
  160. }
  161. con1.write(fbArray);
  162. con1.write(C_P_prime);
  163. } else {
  164. throw new NoSuchPartyException(party + "");
  165. }
  166. }
  167. }
  168. }
  169. // timer.print();
  170. }
  171. @Override
  172. public void run(Party party, Metadata md, Forest[] forest) {
  173. // TODO Auto-generated method stub
  174. }
  175. }