UpdateRoot.java 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. package protocols;
  2. import java.math.BigInteger;
  3. import org.apache.commons.lang3.ArrayUtils;
  4. import com.oblivm.backend.flexsc.CompEnv;
  5. import com.oblivm.backend.gc.GCSignal;
  6. import com.oblivm.backend.gc.regular.GCEva;
  7. import com.oblivm.backend.gc.regular.GCGen;
  8. import com.oblivm.backend.network.Network;
  9. import communication.Communication;
  10. import crypto.Crypto;
  11. import exceptions.NoSuchPartyException;
  12. import gc.GCUpdateRoot;
  13. import gc.GCUtil;
  14. import oram.Forest;
  15. import oram.Metadata;
  16. import oram.Tuple;
  17. import protocols.struct.Party;
  18. import util.M;
  19. import util.Util;
  20. public class UpdateRoot extends Protocol {
  21. public UpdateRoot(Communication con1, Communication con2) {
  22. super(con1, con2);
  23. }
  24. public Tuple[] runE(boolean firstTree, int sw, int lBits, int[] tupleParam, byte[] Li, Tuple[] R, Tuple Ti) {
  25. if (firstTree)
  26. return R;
  27. timer.start(M.offline_comp);
  28. int sLogW = (int) Math.ceil(Math.log(sw) / Math.log(2));
  29. GCSignal[][] j1KeyPairs = GCUtil.genKeyPairs(sLogW);
  30. GCSignal[][] LiKeyPairs = GCUtil.genKeyPairs(lBits);
  31. GCSignal[][] E_feKeyPairs = GCUtil.genKeyPairs(sw);
  32. GCSignal[][] C_feKeyPairs = GCUtil.genKeyPairs(sw);
  33. GCSignal[] j1ZeroKeys = GCUtil.getZeroKeys(j1KeyPairs);
  34. GCSignal[] LiZeroKeys = GCUtil.getZeroKeys(LiKeyPairs);
  35. GCSignal[] E_feZeroKeys = GCUtil.getZeroKeys(E_feKeyPairs);
  36. GCSignal[] C_feZeroKeys = GCUtil.getZeroKeys(C_feKeyPairs);
  37. GCSignal[][][] E_labelKeyPairs = new GCSignal[sw][][];
  38. GCSignal[][][] C_labelKeyPairs = new GCSignal[sw][][];
  39. GCSignal[][] E_labelZeroKeys = new GCSignal[sw][];
  40. GCSignal[][] C_labelZeroKeys = new GCSignal[sw][];
  41. for (int i = 0; i < sw; i++) {
  42. E_labelKeyPairs[i] = GCUtil.genKeyPairs(lBits);
  43. C_labelKeyPairs[i] = GCUtil.genKeyPairs(lBits);
  44. E_labelZeroKeys[i] = GCUtil.getZeroKeys(E_labelKeyPairs[i]);
  45. C_labelZeroKeys[i] = GCUtil.getZeroKeys(C_labelKeyPairs[i]);
  46. }
  47. Network channel = new Network(null, con1);
  48. CompEnv<GCSignal> gen = new GCGen(channel, timer, offline_band, M.offline_write);
  49. GCSignal[][] outZeroKeys = new GCUpdateRoot<GCSignal>(gen, lBits + 1, sw).rootFindDeepestAndEmpty(j1ZeroKeys,
  50. LiZeroKeys, E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys, C_labelZeroKeys);
  51. ((GCGen) gen).sendLastSetGTT();
  52. byte[][][] outKeyHashes = new byte[outZeroKeys.length][][];
  53. for (int i = 0; i < outZeroKeys.length; i++)
  54. outKeyHashes[i] = GCUtil.genOutKeyHashes(outZeroKeys[i]);
  55. timer.start(M.offline_write);
  56. con2.write(offline_band, C_feKeyPairs);
  57. con2.write(offline_band, C_labelKeyPairs);
  58. con1.write(offline_band, outKeyHashes);
  59. timer.stop(M.offline_write);
  60. timer.stop(M.offline_comp);
  61. //////////////////////////////////////////////////////////////////////////////////
  62. timer.start(M.online_comp);
  63. // step 1
  64. int j1 = Crypto.sr.nextInt(R.length);
  65. GCSignal[] j1InputKeys = GCUtil.revSelectKeys(j1KeyPairs, BigInteger.valueOf(j1).toByteArray());
  66. GCSignal[] LiInputKeys = GCUtil.revSelectKeys(LiKeyPairs, Li);
  67. GCSignal[] E_feInputKeys = GCUtil.selectFeKeys(E_feKeyPairs, R);
  68. GCSignal[][] E_labelInputKeys = GCUtil.selectLabelKeys(E_labelKeyPairs, R);
  69. timer.start(M.online_write);
  70. con1.write(online_band, j1InputKeys);
  71. con1.write(online_band, LiInputKeys);
  72. con1.write(online_band, E_feInputKeys);
  73. con1.write(online_band, E_labelInputKeys);
  74. timer.stop(M.online_write);
  75. // step 4
  76. R = ArrayUtils.addAll(R, new Tuple[] { Ti });
  77. SSXOT ssxot = new SSXOT(con1, con2);
  78. R = ssxot.runE(R, tupleParam);
  79. timer.stop(M.online_comp);
  80. return R;
  81. }
  82. public void runD(boolean firstTree, int sw, int lBits, int[] tupleParam, byte[] Li, int w) {
  83. if (firstTree)
  84. return;
  85. timer.start(M.offline_comp);
  86. int logSW = (int) Math.ceil(Math.log(sw) / Math.log(2));
  87. GCSignal[] j1ZeroKeys = GCUtil.genEmptyKeys(logSW);
  88. GCSignal[] LiZeroKeys = GCUtil.genEmptyKeys(lBits);
  89. GCSignal[] E_feZeroKeys = GCUtil.genEmptyKeys(sw);
  90. GCSignal[] C_feZeroKeys = GCUtil.genEmptyKeys(sw);
  91. GCSignal[][] E_labelZeroKeys = new GCSignal[sw][];
  92. GCSignal[][] C_labelZeroKeys = new GCSignal[sw][];
  93. for (int i = 0; i < sw; i++) {
  94. E_labelZeroKeys[i] = GCUtil.genEmptyKeys(lBits);
  95. C_labelZeroKeys[i] = GCUtil.genEmptyKeys(lBits);
  96. }
  97. Network channel = new Network(con1, null);
  98. CompEnv<GCSignal> eva = new GCEva(channel, timer, M.offline_read);
  99. GCUpdateRoot<GCSignal> gcur = new GCUpdateRoot<GCSignal>(eva, lBits + 1, sw);
  100. gcur.rootFindDeepestAndEmpty(j1ZeroKeys, LiZeroKeys, E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys,
  101. C_labelZeroKeys);
  102. ((GCEva) eva).receiveLastSetGTT();
  103. eva.setEvaluate();
  104. timer.start(M.offline_read);
  105. byte[][][] outKeyHashes = con1.readTripleByteArrayAndDec();
  106. timer.stop(M.offline_read);
  107. timer.stop(M.offline_comp);
  108. ///////////////////////////////////////////////////////////////////////////////
  109. timer.start(M.online_comp);
  110. // step 1
  111. timer.start(M.online_read);
  112. GCSignal[] j1InputKeys = con1.readGCSignalArrayAndDec();
  113. GCSignal[] LiInputKeys = con1.readGCSignalArrayAndDec();
  114. GCSignal[] E_feInputKeys = con1.readGCSignalArrayAndDec();
  115. GCSignal[][] E_labelInputKeys = con1.readDoubleGCSignalArrayAndDec();
  116. GCSignal[] C_feInputKeys = con2.readGCSignalArrayAndDec();
  117. GCSignal[][] C_labelInputKeys = con2.readDoubleGCSignalArrayAndDec();
  118. timer.stop(M.online_read);
  119. // step 2
  120. GCSignal[][] outKeys = gcur.rootFindDeepestAndEmpty(j1InputKeys, LiInputKeys, E_feInputKeys, C_feInputKeys,
  121. E_labelInputKeys, C_labelInputKeys);
  122. int j1 = GCUtil.evaOutKeys(outKeys[0], outKeyHashes[0]).intValue();
  123. int j2 = GCUtil.evaOutKeys(outKeys[1], outKeyHashes[1]).intValue();
  124. // step 3
  125. int r = Crypto.sr.nextInt(w);
  126. int[] I = new int[E_feInputKeys.length];
  127. for (int i = 0; i < I.length; i++)
  128. I[i] = i;
  129. I[j2] = I.length;
  130. int tmp = I[r];
  131. I[r] = I[j1];
  132. I[j1] = tmp;
  133. // step 4
  134. SSXOT ssxot = new SSXOT(con1, con2);
  135. ssxot.runD(sw + 1, sw, tupleParam, I);
  136. timer.stop(M.online_comp);
  137. }
  138. public Tuple[] runC(boolean firstTree, int[] tupleParam, Tuple[] R, Tuple Ti) {
  139. if (firstTree)
  140. return R;
  141. timer.start(M.offline_comp);
  142. timer.start(M.offline_read);
  143. GCSignal[][] C_feKeyPairs = con1.readDoubleGCSignalArrayAndDec();
  144. GCSignal[][][] C_labelKeyPairs = con1.readTripleGCSignalArrayAndDec();
  145. timer.stop(M.offline_read);
  146. timer.stop(M.offline_comp);
  147. ////////////////////////////////////////////////////////////////////////////
  148. timer.start(M.online_comp);
  149. // step 1
  150. GCSignal[] C_feInputKeys = GCUtil.selectFeKeys(C_feKeyPairs, R);
  151. GCSignal[][] C_labelInputKeys = GCUtil.selectLabelKeys(C_labelKeyPairs, R);
  152. timer.start(M.online_write);
  153. con2.write(online_band, C_feInputKeys);
  154. con2.write(online_band, C_labelInputKeys);
  155. timer.stop(M.online_write);
  156. // step 4
  157. R = ArrayUtils.addAll(R, new Tuple[] { Ti });
  158. SSXOT ssxot = new SSXOT(con1, con2);
  159. R = ssxot.runC(R, tupleParam);
  160. timer.stop(M.online_comp);
  161. return R;
  162. }
  163. @Override
  164. public void run(Party party, Metadata md, Forest[] forest) {
  165. for (int i = 0; i < 100; i++) {
  166. System.out.println("i=" + i);
  167. if (party == Party.Eddie) {
  168. int sw = Crypto.sr.nextInt(15) + 10;
  169. int lBits = Crypto.sr.nextInt(20) + 5;
  170. byte[] Li = Util.nextBytes((lBits + 7) / 8, Crypto.sr);
  171. Tuple[] R = new Tuple[sw];
  172. for (int j = 0; j < sw; j++)
  173. R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
  174. Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
  175. int[] tupleParam = new int[] { 1, 2, (lBits + 7) / 8, 3 };
  176. con1.write(sw);
  177. con1.write(lBits);
  178. con1.write(Li);
  179. con1.write(tupleParam);
  180. con2.write(sw);
  181. con2.write(lBits);
  182. con2.write(tupleParam);
  183. Tuple[] newR = runE(false, sw, lBits, tupleParam, Li, R, Ti);
  184. Tuple[] R_C = con2.readTupleArray();
  185. int cnt = 0;
  186. int[] index = new int[3];
  187. for (int j = 0; j < sw; j++) {
  188. newR[j].setXor(R_C[j]);
  189. if (!R[j].equals(newR[j])) {
  190. index[cnt] = j;
  191. cnt++;
  192. }
  193. }
  194. if (cnt == 1) {
  195. if (newR[index[0]].equals(Ti) && (R[index[0]].getF()[0] & 1) == 0)
  196. System.out.println("UpdateRoot test passed");
  197. else
  198. System.err.println("UpdateRoot test failed 1");
  199. } else if (cnt == 2) {
  200. int u = -1;
  201. for (int k = 0; k < cnt; k++) {
  202. if (newR[index[k]].equals(Ti)) {
  203. u = k;
  204. break;
  205. }
  206. }
  207. if (u == -1)
  208. System.err.println("UpdateRoot test failed 2");
  209. else {
  210. int a1 = index[u];
  211. int a2 = index[1 - u];
  212. if (!R[a1].equals(newR[a2]) || (R[u].getF()[0] & 1) == 1)
  213. System.err.println("UpdateRoot test failed 3");
  214. else
  215. System.out.println("UpdateRoot test passed");
  216. }
  217. } else if (cnt == 3) {
  218. int u = -1;
  219. for (int k = 0; k < cnt; k++) {
  220. if (newR[index[k]].equals(Ti)) {
  221. u = k;
  222. break;
  223. }
  224. }
  225. if (u == -1)
  226. System.err.println("UpdateRoot test failed 4");
  227. else {
  228. int a1, a2;
  229. if (u == 0) {
  230. a1 = 1;
  231. a2 = 2;
  232. } else if (u == 1) {
  233. a1 = 0;
  234. a2 = 2;
  235. } else {
  236. a1 = 0;
  237. a2 = 1;
  238. }
  239. u = index[u];
  240. a1 = index[a1];
  241. a2 = index[a2];
  242. if ((R[u].getF()[0] & 1) == 1)
  243. System.err.println("UpdateRoot test failed 5");
  244. else if (!R[a1].equals(newR[a2]))
  245. System.err.println("UpdateRoot test failed 6");
  246. else if (!R[a1].equals(newR[a2]) || !R[a2].equals(newR[a1]))
  247. System.err.println("UpdateRoot test failed 7");
  248. else
  249. System.out.println("UpdateRoot test passed");
  250. }
  251. } else {
  252. System.err.println("UpdateRoot test failed 8");
  253. }
  254. } else if (party == Party.Debbie) {
  255. int sw = con1.readInt();
  256. int lBits = con1.readInt();
  257. byte[] Li = con1.read();
  258. int[] tupleParam = con1.readIntArray();
  259. runD(false, sw, lBits, tupleParam, Li, md.getW());
  260. } else if (party == Party.Charlie) {
  261. int sw = con1.readInt();
  262. int lBits = con1.readInt();
  263. int[] tupleParam = con1.readIntArray();
  264. Tuple[] R = new Tuple[sw];
  265. for (int j = 0; j < sw; j++)
  266. R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
  267. Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
  268. R = runC(false, tupleParam, R, Ti);
  269. con1.write(R);
  270. } else {
  271. throw new NoSuchPartyException(party + "");
  272. }
  273. }
  274. }
  275. @Override
  276. public void run(Party party, Metadata md, Forest forest) {
  277. }
  278. }