PostProcessT.java 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. package protocols;
  2. import java.math.BigInteger;
  3. import communication.Communication;
  4. import crypto.Crypto;
  5. import exceptions.AccessException;
  6. import exceptions.NoSuchPartyException;
  7. import oram.Forest;
  8. import oram.Metadata;
  9. import oram.Tree;
  10. import oram.Tuple;
  11. import protocols.precomputation.PreAccess;
  12. import protocols.precomputation.PrePostProcessT;
  13. import protocols.struct.OutAccess;
  14. import protocols.struct.Party;
  15. import protocols.struct.PreData;
  16. import util.M;
  17. import util.P;
  18. import util.Timer;
  19. import util.Util;
  20. public class PostProcessT extends Protocol {
  21. private int pid = P.PPT;
  22. public PostProcessT(Communication con1, Communication con2) {
  23. super(con1, con2);
  24. }
  25. public Tuple runE(PreData predata, Tuple Ti, boolean lastTree, Timer timer) {
  26. timer.start(pid, M.online_comp);
  27. if (lastTree) {
  28. Tuple out = new Tuple(Ti);
  29. Util.setXor(out.getL(), predata.ppt_Li);
  30. timer.stop(pid, M.online_comp);
  31. return out;
  32. }
  33. // step 1
  34. timer.start(pid, M.online_read);
  35. int delta = con2.readInt();
  36. timer.stop(pid, M.online_read);
  37. // step 3
  38. int twoTauPow = predata.ppt_s.length;
  39. byte[][] e = new byte[twoTauPow][];
  40. for (int i = 0; i < twoTauPow; i++)
  41. e[i] = predata.ppt_s[(i + delta) % twoTauPow];
  42. byte[] e_all = new byte[twoTauPow * e[0].length];
  43. for (int i = 0; i < twoTauPow; i++)
  44. System.arraycopy(e[i], 0, e_all, i * e[0].length, e[0].length);
  45. Tuple out = new Tuple(Ti);
  46. Util.setXor(out.getL(), predata.ppt_Li);
  47. Util.setXor(out.getA(), e_all);
  48. timer.stop(pid, M.online_comp);
  49. return out;
  50. }
  51. public void runD() {
  52. }
  53. public Tuple runC(PreData predata, Tuple Ti, byte[] Li, byte[] Lip1, int j2, boolean lastTree, Timer timer) {
  54. timer.start(pid, M.online_comp);
  55. if (lastTree) {
  56. Tuple out = new Tuple(Ti);
  57. Util.setXor(out.getL(), Util.xor(Li, predata.ppt_Li));
  58. timer.stop(pid, M.online_comp);
  59. return out;
  60. }
  61. // step 1
  62. int twoTauPow = predata.ppt_r.length;
  63. int delta = (predata.ppt_alpha - j2 + twoTauPow) % twoTauPow;
  64. timer.start(pid, M.online_write);
  65. con1.write(pid, delta);
  66. timer.stop(pid, M.online_write);
  67. // step 2
  68. byte[][] c = new byte[twoTauPow][];
  69. for (int i = 0; i < twoTauPow; i++)
  70. c[i] = predata.ppt_r[(i + delta) % twoTauPow];
  71. c[j2] = Util.xor(Util.xor(c[j2], Lip1), predata.ppt_Lip1);
  72. byte[] c_all = new byte[twoTauPow * Lip1.length];
  73. for (int i = 0; i < twoTauPow; i++)
  74. System.arraycopy(c[i], 0, c_all, i * Lip1.length, Lip1.length);
  75. Tuple out = new Tuple(Ti);
  76. Util.setXor(out.getL(), Util.xor(Li, predata.ppt_Li));
  77. Util.setXor(out.getA(), c_all);
  78. timer.stop(pid, M.online_comp);
  79. return out;
  80. }
  81. // for testing correctness
  82. @Override
  83. public void run(Party party, Metadata md, Forest forest) {
  84. int records = 5;
  85. int repeat = 5;
  86. int tau = md.getTau();
  87. int numTrees = md.getNumTrees();
  88. long numInsert = md.getNumInsertRecords();
  89. int addrBits = md.getAddrBits();
  90. Timer timer = new Timer();
  91. sanityCheck();
  92. System.out.println();
  93. for (int i = 0; i < records; i++) {
  94. long N = Metadata.cheat ? 0 : Util.nextLong(numInsert, Crypto.sr);
  95. for (int j = 0; j < repeat; j++) {
  96. System.out.println("Test: " + i + " " + j);
  97. System.out.println("N=" + BigInteger.valueOf(N).toString(2));
  98. byte[] Li = new byte[0];
  99. PreData prev = null;
  100. for (int ti = 0; ti < numTrees; ti++) {
  101. long Ni_value = Util.getSubBits(N, addrBits, addrBits - md.getNBitsOfTree(ti));
  102. long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
  103. Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
  104. byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
  105. byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
  106. PreData predata = new PreData();
  107. PreAccess preaccess = new PreAccess(con1, con2);
  108. Access access = new Access(con1, con2);
  109. PrePostProcessT prepostprocesst = new PrePostProcessT(con1, con2);
  110. if (party == Party.Eddie) {
  111. Tree OTi = forest.getTree(ti);
  112. int numTuples = (OTi.getD() - 1) * OTi.getW() + OTi.getStashSize();
  113. int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
  114. md.getABytesOfTree(ti) };
  115. preaccess.runE(predata, md.getTwoTauPow(), numTuples, tupleParam, timer);
  116. byte[] sE_Ni = Util.nextBytes(Ni.length, Crypto.sr);
  117. byte[] sD_Ni = Util.xor(Ni, sE_Ni);
  118. con1.write(sD_Ni);
  119. byte[] sE_Nip1_pr = Util.nextBytes(Nip1_pr.length, Crypto.sr);
  120. byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
  121. con1.write(sD_Nip1_pr);
  122. OutAccess outaccess = access.runE(predata, OTi, sE_Ni, sE_Nip1_pr, timer);
  123. if (ti == numTrees - 1)
  124. con2.write(N);
  125. prepostprocesst.runE(predata, timer);
  126. Tuple Ti_prime = runE(predata, outaccess.E_Ti, ti == numTrees - 1, timer);
  127. Ti_prime.setXor(con2.readTuple());
  128. byte[] Li_prime = Util.xor(predata.ppt_Li, con2.read());
  129. byte[] Lip1_prime = Util.xor(predata.ppt_Lip1, con2.read());
  130. int j2 = con2.readInt();
  131. Tuple Ti = outaccess.E_Ti.xor(con2.readTuple());
  132. if (!Util.equal(Ti.getF(), Ti_prime.getF()))
  133. System.err.println("PPT test failed");
  134. else if (!Util.equal(Ti.getN(), Ti_prime.getN()))
  135. System.err.println("PPT test failed");
  136. else if (!Util.equal(Li_prime, Ti_prime.getL()))
  137. System.err.println("PPT test failed");
  138. else if (!Util.equal(Lip1_prime,
  139. Ti_prime.getSubA(j2 * Lip1_prime.length, (j2 + 1) * Lip1_prime.length)))
  140. System.err.println("PPT test failed");
  141. else
  142. System.out.println("PPT test passed");
  143. } else if (party == Party.Debbie) {
  144. Tree OTi = forest.getTree(ti);
  145. preaccess.runD(predata, timer);
  146. byte[] sD_Ni = con1.read();
  147. byte[] sD_Nip1_pr = con1.read();
  148. access.runD(predata, OTi, sD_Ni, sD_Nip1_pr, timer);
  149. prepostprocesst.runD(predata, prev, md.getLBytesOfTree(ti), md.getAlBytesOfTree(ti), tau,
  150. timer);
  151. runD();
  152. } else if (party == Party.Charlie) {
  153. preaccess.runC(timer);
  154. System.out.println("L" + ti + "=" + new BigInteger(1, Li).toString(2));
  155. OutAccess outaccess = access.runC(md, ti, Li, timer);
  156. prepostprocesst.runC(predata, prev, md.getLBytesOfTree(ti), md.getAlBytesOfTree(ti), timer);
  157. Tuple Ti_prime = runC(predata, outaccess.C_Ti, Li, outaccess.C_Lip1, outaccess.C_j2,
  158. ti == numTrees - 1, timer);
  159. Li = outaccess.C_Lip1;
  160. if (ti == numTrees - 1) {
  161. N = con1.readLong();
  162. long data = new BigInteger(1, outaccess.C_Ti.getA()).longValue();
  163. if (N == data) {
  164. System.out.println("Access passed");
  165. System.out.println();
  166. } else {
  167. throw new AccessException("Access failed");
  168. }
  169. }
  170. con1.write(Ti_prime);
  171. con1.write(predata.ppt_Li);
  172. con1.write(predata.ppt_Lip1);
  173. con1.write(outaccess.C_j2);
  174. con1.write(outaccess.C_Ti);
  175. } else {
  176. throw new NoSuchPartyException(party + "");
  177. }
  178. prev = predata;
  179. }
  180. }
  181. }
  182. // timer.print();
  183. }
  184. }