UpdateRoot.java 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package protocols;
  2. import java.math.BigInteger;
  3. import org.apache.commons.lang3.ArrayUtils;
  4. import com.oblivm.backend.gc.GCSignal;
  5. import communication.Communication;
  6. import crypto.Crypto;
  7. import exceptions.NoSuchPartyException;
  8. import gc.GCUtil;
  9. import oram.Forest;
  10. import oram.Metadata;
  11. import oram.Tuple;
  12. import protocols.precomputation.PreUpdateRoot;
  13. import protocols.struct.Party;
  14. import protocols.struct.PreData;
  15. import util.M;
  16. import util.P;
  17. import util.Timer;
  18. import util.Util;
  19. public class UpdateRoot extends Protocol {
  20. private int pid = P.UR;
  21. public UpdateRoot(Communication con1, Communication con2) {
  22. super(con1, con2);
  23. }
  24. public Tuple[] runE(PreData predata, boolean firstTree, byte[] Li, Tuple[] R, Tuple Ti, Timer timer) {
  25. if (firstTree)
  26. return R;
  27. timer.start(pid, M.online_comp);
  28. // step 1
  29. int j1 = Crypto.sr.nextInt(R.length);
  30. GCSignal[] j1InputKeys = GCUtil.revSelectKeys(predata.ur_j1KeyPairs, BigInteger.valueOf(j1).toByteArray());
  31. GCSignal[] LiInputKeys = GCUtil.revSelectKeys(predata.ur_LiKeyPairs, Li);
  32. GCSignal[] E_feInputKeys = GCUtil.selectFeKeys(predata.ur_E_feKeyPairs, R);
  33. GCSignal[][] E_labelInputKeys = GCUtil.selectLabelKeys(predata.ur_E_labelKeyPairs, R);
  34. timer.start(pid, M.online_write);
  35. con1.write(pid, j1InputKeys);
  36. con1.write(pid, LiInputKeys);
  37. con1.write(pid, E_feInputKeys);
  38. con1.write(pid, E_labelInputKeys);
  39. timer.stop(pid, M.online_write);
  40. // step 4
  41. R = ArrayUtils.addAll(R, new Tuple[] { Ti });
  42. SSXOT ssxot = new SSXOT(con1, con2, 0);
  43. R = ssxot.runE(predata, R, timer);
  44. timer.stop(pid, M.online_comp);
  45. return R;
  46. }
  47. public void runD(PreData predata, boolean firstTree, byte[] Li, int w, Timer timer) {
  48. if (firstTree)
  49. return;
  50. timer.start(pid, M.online_comp);
  51. // step 1
  52. timer.start(pid, M.online_read);
  53. GCSignal[] j1InputKeys = con1.readGCSignalArray();
  54. GCSignal[] LiInputKeys = con1.readGCSignalArray();
  55. GCSignal[] E_feInputKeys = con1.readGCSignalArray();
  56. GCSignal[][] E_labelInputKeys = con1.readDoubleGCSignalArray();
  57. GCSignal[] C_feInputKeys = con2.readGCSignalArray();
  58. GCSignal[][] C_labelInputKeys = con2.readDoubleGCSignalArray();
  59. timer.stop(pid, M.online_read);
  60. // step 2
  61. GCSignal[][] outKeys = predata.ur_gcur.rootFindDeepestAndEmpty(j1InputKeys, LiInputKeys, E_feInputKeys,
  62. C_feInputKeys, E_labelInputKeys, C_labelInputKeys);
  63. int j1 = GCUtil.evaOutKeys(outKeys[0], predata.ur_outKeyHashes[0]).intValue();
  64. int j2 = GCUtil.evaOutKeys(outKeys[1], predata.ur_outKeyHashes[1]).intValue();
  65. // step 3
  66. int r = Crypto.sr.nextInt(w);
  67. int[] I = new int[E_feInputKeys.length];
  68. for (int i = 0; i < I.length; i++)
  69. I[i] = i;
  70. I[j2] = I.length;
  71. int tmp = I[r];
  72. I[r] = I[j1];
  73. I[j1] = tmp;
  74. // step 4
  75. SSXOT ssxot = new SSXOT(con1, con2, 0);
  76. ssxot.runD(predata, I, timer);
  77. timer.stop(pid, M.online_comp);
  78. }
  79. public Tuple[] runC(PreData predata, boolean firstTree, Tuple[] R, Tuple Ti, Timer timer) {
  80. if (firstTree)
  81. return R;
  82. timer.start(pid, M.online_comp);
  83. // step 1
  84. GCSignal[] C_feInputKeys = GCUtil.selectFeKeys(predata.ur_C_feKeyPairs, R);
  85. GCSignal[][] C_labelInputKeys = GCUtil.selectLabelKeys(predata.ur_C_labelKeyPairs, R);
  86. timer.start(pid, M.online_write);
  87. con2.write(pid, C_feInputKeys);
  88. con2.write(pid, C_labelInputKeys);
  89. timer.stop(pid, M.online_write);
  90. // step 4
  91. R = ArrayUtils.addAll(R, new Tuple[] { Ti });
  92. SSXOT ssxot = new SSXOT(con1, con2, 0);
  93. R = ssxot.runC(predata, R, timer);
  94. timer.stop(pid, M.online_comp);
  95. return R;
  96. }
  97. // for testing correctness
  98. @Override
  99. public void run(Party party, Metadata md, Forest forest) {
  100. Timer timer = new Timer();
  101. for (int i = 0; i < 100; i++) {
  102. System.out.println("i=" + i);
  103. PreData predata = new PreData();
  104. PreUpdateRoot preupdateroot = new PreUpdateRoot(con1, con2);
  105. if (party == Party.Eddie) {
  106. int sw = Crypto.sr.nextInt(15) + 10;
  107. int lBits = Crypto.sr.nextInt(20) + 5;
  108. byte[] Li = Util.nextBytes((lBits + 7) / 8, Crypto.sr);
  109. Tuple[] R = new Tuple[sw];
  110. for (int j = 0; j < sw; j++)
  111. R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
  112. Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
  113. con1.write(sw);
  114. con1.write(lBits);
  115. con1.write(Li);
  116. con2.write(sw);
  117. con2.write(lBits);
  118. preupdateroot.runE(predata, false, sw, lBits, timer);
  119. Tuple[] newR = runE(predata, false, Li, R, Ti, timer);
  120. Tuple[] R_C = con2.readTupleArray();
  121. int cnt = 0;
  122. int[] index = new int[3];
  123. for (int j = 0; j < sw; j++) {
  124. newR[j].setXor(R_C[j]);
  125. if (!R[j].equals(newR[j])) {
  126. index[cnt] = j;
  127. cnt++;
  128. }
  129. }
  130. if (cnt == 1) {
  131. if (newR[index[0]].equals(Ti) && (R[index[0]].getF()[0] & 1) == 0)
  132. System.out.println("UpdateRoot test passed");
  133. else
  134. System.err.println("UpdateRoot test failed 1");
  135. } else if (cnt == 2) {
  136. int u = -1;
  137. for (int k = 0; k < cnt; k++) {
  138. if (newR[index[k]].equals(Ti)) {
  139. u = k;
  140. break;
  141. }
  142. }
  143. if (u == -1)
  144. System.err.println("UpdateRoot test failed 2");
  145. else {
  146. int a1 = index[u];
  147. int a2 = index[1 - u];
  148. if (!R[a1].equals(newR[a2]) || (R[u].getF()[0] & 1) == 1)
  149. System.err.println("UpdateRoot test failed 3");
  150. else
  151. System.out.println("UpdateRoot test passed");
  152. }
  153. } else if (cnt == 3) {
  154. int u = -1;
  155. for (int k = 0; k < cnt; k++) {
  156. if (newR[index[k]].equals(Ti)) {
  157. u = k;
  158. break;
  159. }
  160. }
  161. if (u == -1)
  162. System.err.println("UpdateRoot test failed 4");
  163. else {
  164. int a1, a2;
  165. if (u == 0) {
  166. a1 = 1;
  167. a2 = 2;
  168. } else if (u == 1) {
  169. a1 = 0;
  170. a2 = 2;
  171. } else {
  172. a1 = 0;
  173. a2 = 1;
  174. }
  175. u = index[u];
  176. a1 = index[a1];
  177. a2 = index[a2];
  178. if ((R[u].getF()[0] & 1) == 1)
  179. System.err.println("UpdateRoot test failed 5");
  180. else if (!R[a1].equals(newR[a2]))
  181. System.err.println("UpdateRoot test failed 6");
  182. else if (!R[a1].equals(newR[a2]) || !R[a2].equals(newR[a1]))
  183. System.err.println("UpdateRoot test failed 7");
  184. else
  185. System.out.println("UpdateRoot test passed");
  186. }
  187. } else {
  188. System.err.println("UpdateRoot test failed 8");
  189. }
  190. System.out.println();
  191. } else if (party == Party.Debbie) {
  192. int sw = con1.readInt();
  193. int lBits = con1.readInt();
  194. byte[] Li = con1.read();
  195. int[] tupleParam = new int[] { 1, 2, (lBits + 7) / 8, 3 };
  196. preupdateroot.runD(predata, false, sw, lBits, tupleParam, timer);
  197. runD(predata, false, Li, md.getW(), timer);
  198. } else if (party == Party.Charlie) {
  199. int sw = con1.readInt();
  200. int lBits = con1.readInt();
  201. Tuple[] R = new Tuple[sw];
  202. for (int j = 0; j < sw; j++)
  203. R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
  204. Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
  205. preupdateroot.runC(predata, false, timer);
  206. R = runC(predata, false, R, Ti, timer);
  207. con1.write(R);
  208. } else {
  209. throw new NoSuchPartyException(party + "");
  210. }
  211. }
  212. // timer.print();
  213. }
  214. }