Reshuffle.java 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package protocols;
  2. import java.math.BigInteger;
  3. import communication.Communication;
  4. import crypto.Crypto;
  5. import exceptions.AccessException;
  6. import exceptions.NoSuchPartyException;
  7. import oram.Forest;
  8. import oram.Global;
  9. import oram.Metadata;
  10. import oram.Tree;
  11. import oram.Tuple;
  12. import protocols.precomputation.PreAccess;
  13. import protocols.precomputation.PreReshuffle;
  14. import protocols.struct.OutAccess;
  15. import protocols.struct.Party;
  16. import protocols.struct.PreData;
  17. import util.M;
  18. import util.P;
  19. import util.Timer;
  20. import util.Util;
  21. public class Reshuffle extends Protocol {
  22. private int pid = P.RSF;
  23. public Reshuffle(Communication con1, Communication con2) {
  24. super(con1, con2);
  25. }
  26. public Tuple[] runE(PreData predata, Tuple[] path, boolean firstTree, Timer timer) {
  27. if (firstTree)
  28. return path;
  29. timer.start(pid, M.online_comp);
  30. // step 1
  31. timer.start(pid, M.online_read);
  32. Tuple[] z = con2.readTupleArray(pid);
  33. timer.stop(pid, M.online_read);
  34. // step 2
  35. Tuple[] b = new Tuple[z.length];
  36. for (int i = 0; i < b.length; i++)
  37. b[i] = path[i].xor(z[i]).xor(predata.reshuffle_r[i]);
  38. Tuple[] b_prime = Util.permute(b, predata.reshuffle_pi);
  39. timer.stop(pid, M.online_comp);
  40. return b_prime;
  41. }
  42. public void runD() {
  43. }
  44. public Tuple[] runC(PreData predata, Tuple[] path, boolean firstTree, Timer timer) {
  45. if (firstTree)
  46. return path;
  47. timer.start(pid, M.online_comp);
  48. // step 1
  49. Tuple[] z = new Tuple[path.length];
  50. for (int i = 0; i < z.length; i++)
  51. z[i] = path[i].xor(predata.reshuffle_p[i]);
  52. timer.start(pid, M.online_write);
  53. con1.write(pid, z);
  54. timer.stop(pid, M.online_write);
  55. timer.stop(pid, M.online_comp);
  56. return predata.reshuffle_a_prime;
  57. }
  58. // for testing correctness
  59. @Override
  60. public void run(Party party, Metadata md, Forest forest) {
  61. int records = 5;
  62. int repeat = 5;
  63. int tau = md.getTau();
  64. int numTrees = md.getNumTrees();
  65. long numInsert = md.getNumInsertRecords();
  66. int addrBits = md.getAddrBits();
  67. Timer timer = new Timer();
  68. sanityCheck();
  69. System.out.println();
  70. for (int i = 0; i < records; i++) {
  71. long N = Global.cheat ? 0 : Util.nextLong(numInsert, Crypto.sr);
  72. for (int j = 0; j < repeat; j++) {
  73. System.out.println("Test: " + i + " " + j);
  74. System.out.println("N=" + BigInteger.valueOf(N).toString(2));
  75. byte[] Li = new byte[0];
  76. for (int ti = 0; ti < numTrees; ti++) {
  77. long Ni_value = Util.getSubBits(N, addrBits, addrBits - md.getNBitsOfTree(ti));
  78. long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
  79. Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
  80. byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
  81. byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
  82. PreData predata = new PreData();
  83. PreAccess preaccess = new PreAccess(con1, con2);
  84. Access access = new Access(con1, con2);
  85. PreReshuffle prereshuffle = new PreReshuffle(con1, con2);
  86. if (party == Party.Eddie) {
  87. Tree OTi = forest.getTree(ti);
  88. int numTuples = (OTi.getD() - 1) * OTi.getW() + OTi.getStashSize();
  89. int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
  90. md.getABytesOfTree(ti) };
  91. preaccess.runE(predata, md.getTwoTauPow(), numTuples, tupleParam, timer);
  92. byte[] sE_Ni = Util.nextBytes(Ni.length, Crypto.sr);
  93. byte[] sD_Ni = Util.xor(Ni, sE_Ni);
  94. con1.write(sD_Ni);
  95. byte[] sE_Nip1_pr = Util.nextBytes(Nip1_pr.length, Crypto.sr);
  96. byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
  97. con1.write(sD_Nip1_pr);
  98. OutAccess outaccess = access.runE(predata, OTi, sE_Ni, sE_Nip1_pr, timer);
  99. if (ti == numTrees - 1)
  100. con2.write(N);
  101. prereshuffle.runE(predata, timer);
  102. Tuple[] E_P_prime = runE(predata, outaccess.E_P, ti == 0, timer);
  103. Tuple[] C_P = con2.readTupleArray();
  104. Tuple[] C_P_prime = con2.readTupleArray();
  105. Tuple[] oldPath = new Tuple[C_P.length];
  106. Tuple[] newPath = new Tuple[C_P.length];
  107. for (int k = 0; k < C_P.length; k++) {
  108. oldPath[k] = outaccess.E_P[k].xor(C_P[k]);
  109. newPath[k] = E_P_prime[k].xor(C_P_prime[k]);
  110. }
  111. oldPath = Util.permute(oldPath, predata.reshuffle_pi);
  112. boolean pass = true;
  113. for (int k = 0; k < newPath.length; k++) {
  114. if (!oldPath[k].equals(newPath[k])) {
  115. System.err.println("Reshuffle test failed");
  116. pass = false;
  117. break;
  118. }
  119. }
  120. if (pass)
  121. System.out.println("Reshuffle test passed");
  122. } else if (party == Party.Debbie) {
  123. Tree OTi = forest.getTree(ti);
  124. preaccess.runD(predata, timer);
  125. byte[] sD_Ni = con1.read();
  126. byte[] sD_Nip1_pr = con1.read();
  127. access.runD(predata, OTi, sD_Ni, sD_Nip1_pr, timer);
  128. int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
  129. md.getABytesOfTree(ti) };
  130. prereshuffle.runD(predata, tupleParam, timer);
  131. runD();
  132. } else if (party == Party.Charlie) {
  133. preaccess.runC(timer);
  134. System.out.println("L" + ti + "=" + new BigInteger(1, Li).toString(2));
  135. OutAccess outaccess = access.runC(md, ti, Li, timer);
  136. prereshuffle.runC(predata, timer);
  137. Tuple[] C_P_prime = runC(predata, outaccess.C_P, ti == 0, timer);
  138. Li = outaccess.C_Lip1;
  139. if (ti == numTrees - 1) {
  140. N = con1.readLong();
  141. long data = new BigInteger(1, outaccess.C_Ti.getA()).longValue();
  142. if (N == data) {
  143. System.out.println("Access passed");
  144. System.out.println();
  145. } else {
  146. throw new AccessException("Access failed");
  147. }
  148. }
  149. con1.write(outaccess.C_P);
  150. con1.write(C_P_prime);
  151. } else {
  152. throw new NoSuchPartyException(party + "");
  153. }
  154. }
  155. }
  156. }
  157. // timer.print();
  158. }
  159. }