UpdateRoot.java 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. package protocols;
  2. import java.math.BigInteger;
  3. import org.apache.commons.lang3.ArrayUtils;
  4. import com.oblivm.backend.flexsc.CompEnv;
  5. import com.oblivm.backend.gc.GCSignal;
  6. import com.oblivm.backend.gc.regular.GCEva;
  7. import com.oblivm.backend.gc.regular.GCGen;
  8. import com.oblivm.backend.network.Network;
  9. import communication.Communication;
  10. import crypto.Crypto;
  11. import exceptions.NoSuchPartyException;
  12. import gc.GCUpdateRoot;
  13. import gc.GCUtil;
  14. import oram.Forest;
  15. import oram.Metadata;
  16. import oram.Tuple;
  17. import struct.Party;
  18. import subprotocols.SSXOT;
  19. import util.M;
  20. import util.Util;
  21. public class UpdateRoot extends Protocol {
  22. public UpdateRoot(Communication con1, Communication con2) {
  23. super(con1, con2);
  24. }
  25. public Tuple[] runE(boolean firstTree, int sw, int lBits, int[] tupleParam, byte[] Li, Tuple[] R, Tuple Ti) {
  26. if (firstTree)
  27. return R;
  28. timer.start(M.offline_comp);
  29. int sLogW = (int) Math.ceil(Math.log(sw) / Math.log(2));
  30. GCSignal[][] j1KeyPairs = GCUtil.genKeyPairs(sLogW);
  31. GCSignal[][] LiKeyPairs = GCUtil.genKeyPairs(lBits);
  32. GCSignal[][] E_feKeyPairs = GCUtil.genKeyPairs(sw);
  33. GCSignal[][] C_feKeyPairs = GCUtil.genKeyPairs(sw);
  34. GCSignal[] j1ZeroKeys = GCUtil.getZeroKeys(j1KeyPairs);
  35. GCSignal[] LiZeroKeys = GCUtil.getZeroKeys(LiKeyPairs);
  36. GCSignal[] E_feZeroKeys = GCUtil.getZeroKeys(E_feKeyPairs);
  37. GCSignal[] C_feZeroKeys = GCUtil.getZeroKeys(C_feKeyPairs);
  38. GCSignal[][][] E_labelKeyPairs = new GCSignal[sw][][];
  39. GCSignal[][][] C_labelKeyPairs = new GCSignal[sw][][];
  40. GCSignal[][] E_labelZeroKeys = new GCSignal[sw][];
  41. GCSignal[][] C_labelZeroKeys = new GCSignal[sw][];
  42. for (int i = 0; i < sw; i++) {
  43. E_labelKeyPairs[i] = GCUtil.genKeyPairs(lBits);
  44. C_labelKeyPairs[i] = GCUtil.genKeyPairs(lBits);
  45. E_labelZeroKeys[i] = GCUtil.getZeroKeys(E_labelKeyPairs[i]);
  46. C_labelZeroKeys[i] = GCUtil.getZeroKeys(C_labelKeyPairs[i]);
  47. }
  48. Network channel = new Network(null, con1);
  49. CompEnv<GCSignal> gen = new GCGen(channel, timer, offline_band, M.offline_write);
  50. GCSignal[][] outZeroKeys = new GCUpdateRoot<GCSignal>(gen, lBits + 1, sw).rootFindDeepestAndEmpty(j1ZeroKeys,
  51. LiZeroKeys, E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys, C_labelZeroKeys);
  52. ((GCGen) gen).sendLastSetGTT();
  53. byte[][][] outKeyHashes = new byte[outZeroKeys.length][][];
  54. for (int i = 0; i < outZeroKeys.length; i++)
  55. outKeyHashes[i] = GCUtil.genOutKeyHashes(outZeroKeys[i]);
  56. timer.start(M.offline_write);
  57. con2.write(offline_band, C_feKeyPairs);
  58. con2.write(offline_band, C_labelKeyPairs);
  59. con1.write(offline_band, outKeyHashes);
  60. timer.stop(M.offline_write);
  61. timer.stop(M.offline_comp);
  62. //////////////////////////////////////////////////////////////////////////////////
  63. timer.start(M.online_comp);
  64. // step 1
  65. int j1 = Crypto.sr.nextInt(R.length);
  66. GCSignal[] j1InputKeys = GCUtil.revSelectKeys(j1KeyPairs, BigInteger.valueOf(j1).toByteArray());
  67. GCSignal[] LiInputKeys = GCUtil.revSelectKeys(LiKeyPairs, Li);
  68. GCSignal[] E_feInputKeys = GCUtil.selectFeKeys(E_feKeyPairs, R);
  69. GCSignal[][] E_labelInputKeys = GCUtil.selectLabelKeys(E_labelKeyPairs, R);
  70. timer.start(M.online_write);
  71. con1.write(online_band, j1InputKeys);
  72. con1.write(online_band, LiInputKeys);
  73. con1.write(online_band, E_feInputKeys);
  74. con1.write(online_band, E_labelInputKeys);
  75. timer.stop(M.online_write);
  76. // step 4
  77. R = ArrayUtils.addAll(R, new Tuple[] { Ti });
  78. SSXOT ssxot = new SSXOT(con1, con2);
  79. R = ssxot.runE(R, tupleParam);
  80. timer.stop(M.online_comp);
  81. return R;
  82. }
  83. public void runD(boolean firstTree, int sw, int lBits, int[] tupleParam, byte[] Li, int w) {
  84. if (firstTree)
  85. return;
  86. timer.start(M.offline_comp);
  87. int logSW = (int) Math.ceil(Math.log(sw) / Math.log(2));
  88. GCSignal[] j1ZeroKeys = GCUtil.genEmptyKeys(logSW);
  89. GCSignal[] LiZeroKeys = GCUtil.genEmptyKeys(lBits);
  90. GCSignal[] E_feZeroKeys = GCUtil.genEmptyKeys(sw);
  91. GCSignal[] C_feZeroKeys = GCUtil.genEmptyKeys(sw);
  92. GCSignal[][] E_labelZeroKeys = new GCSignal[sw][];
  93. GCSignal[][] C_labelZeroKeys = new GCSignal[sw][];
  94. for (int i = 0; i < sw; i++) {
  95. E_labelZeroKeys[i] = GCUtil.genEmptyKeys(lBits);
  96. C_labelZeroKeys[i] = GCUtil.genEmptyKeys(lBits);
  97. }
  98. Network channel = new Network(con1, null);
  99. CompEnv<GCSignal> eva = new GCEva(channel, timer, M.offline_read);
  100. GCUpdateRoot<GCSignal> gcur = new GCUpdateRoot<GCSignal>(eva, lBits + 1, sw);
  101. gcur.rootFindDeepestAndEmpty(j1ZeroKeys, LiZeroKeys, E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys,
  102. C_labelZeroKeys);
  103. ((GCEva) eva).receiveLastSetGTT();
  104. eva.setEvaluate();
  105. timer.start(M.offline_read);
  106. byte[][][] outKeyHashes = con1.readTripleByteArrayAndDec();
  107. timer.stop(M.offline_read);
  108. timer.stop(M.offline_comp);
  109. ///////////////////////////////////////////////////////////////////////////////
  110. timer.start(M.online_comp);
  111. // step 1
  112. timer.start(M.online_read);
  113. GCSignal[] j1InputKeys = con1.readGCSignalArrayAndDec();
  114. GCSignal[] LiInputKeys = con1.readGCSignalArrayAndDec();
  115. GCSignal[] E_feInputKeys = con1.readGCSignalArrayAndDec();
  116. GCSignal[][] E_labelInputKeys = con1.readDoubleGCSignalArrayAndDec();
  117. GCSignal[] C_feInputKeys = con2.readGCSignalArrayAndDec();
  118. GCSignal[][] C_labelInputKeys = con2.readDoubleGCSignalArrayAndDec();
  119. timer.stop(M.online_read);
  120. // step 2
  121. GCSignal[][] outKeys = gcur.rootFindDeepestAndEmpty(j1InputKeys, LiInputKeys, E_feInputKeys, C_feInputKeys,
  122. E_labelInputKeys, C_labelInputKeys);
  123. int j1 = GCUtil.evaOutKeys(outKeys[0], outKeyHashes[0]).intValue();
  124. int j2 = GCUtil.evaOutKeys(outKeys[1], outKeyHashes[1]).intValue();
  125. // step 3
  126. int r = Crypto.sr.nextInt(w);
  127. int[] I = new int[E_feInputKeys.length];
  128. for (int i = 0; i < I.length; i++)
  129. I[i] = i;
  130. I[j2] = I.length;
  131. int tmp = I[r];
  132. I[r] = I[j1];
  133. I[j1] = tmp;
  134. // step 4
  135. SSXOT ssxot = new SSXOT(con1, con2);
  136. ssxot.runD(sw + 1, sw, tupleParam, I);
  137. timer.stop(M.online_comp);
  138. }
  139. public Tuple[] runC(boolean firstTree, int[] tupleParam, Tuple[] R, Tuple Ti) {
  140. if (firstTree)
  141. return R;
  142. timer.start(M.offline_comp);
  143. timer.start(M.offline_read);
  144. GCSignal[][] C_feKeyPairs = con1.readDoubleGCSignalArrayAndDec();
  145. GCSignal[][][] C_labelKeyPairs = con1.readTripleGCSignalArrayAndDec();
  146. timer.stop(M.offline_read);
  147. timer.stop(M.offline_comp);
  148. ////////////////////////////////////////////////////////////////////////////
  149. timer.start(M.online_comp);
  150. // step 1
  151. GCSignal[] C_feInputKeys = GCUtil.selectFeKeys(C_feKeyPairs, R);
  152. GCSignal[][] C_labelInputKeys = GCUtil.selectLabelKeys(C_labelKeyPairs, R);
  153. timer.start(M.online_write);
  154. con2.write(online_band, C_feInputKeys);
  155. con2.write(online_band, C_labelInputKeys);
  156. timer.stop(M.online_write);
  157. // step 4
  158. R = ArrayUtils.addAll(R, new Tuple[] { Ti });
  159. SSXOT ssxot = new SSXOT(con1, con2);
  160. R = ssxot.runC(R, tupleParam);
  161. timer.stop(M.online_comp);
  162. return R;
  163. }
  164. @Override
  165. public void run(Party party, Metadata md, Forest[] forest) {
  166. for (int i = 0; i < 100; i++) {
  167. System.out.println("i=" + i);
  168. if (party == Party.Eddie) {
  169. int sw = Crypto.sr.nextInt(15) + 10;
  170. int lBits = Crypto.sr.nextInt(20) + 5;
  171. byte[] Li = Util.nextBytes((lBits + 7) / 8, Crypto.sr);
  172. Tuple[] R = new Tuple[sw];
  173. for (int j = 0; j < sw; j++)
  174. R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
  175. Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
  176. int[] tupleParam = new int[] { 1, 2, (lBits + 7) / 8, 3 };
  177. con1.write(sw);
  178. con1.write(lBits);
  179. con1.write(Li);
  180. con1.write(tupleParam);
  181. con2.write(sw);
  182. con2.write(lBits);
  183. con2.write(tupleParam);
  184. Tuple[] newR = runE(false, sw, lBits, tupleParam, Li, R, Ti);
  185. Tuple[] R_C = con2.readTupleArray();
  186. int cnt = 0;
  187. int[] index = new int[3];
  188. for (int j = 0; j < sw; j++) {
  189. newR[j].setXor(R_C[j]);
  190. if (!R[j].equals(newR[j])) {
  191. index[cnt] = j;
  192. cnt++;
  193. }
  194. }
  195. if (cnt == 1) {
  196. if (newR[index[0]].equals(Ti) && (R[index[0]].getF()[0] & 1) == 0)
  197. System.out.println("UpdateRoot test passed");
  198. else
  199. System.err.println("UpdateRoot test failed 1");
  200. } else if (cnt == 2) {
  201. int u = -1;
  202. for (int k = 0; k < cnt; k++) {
  203. if (newR[index[k]].equals(Ti)) {
  204. u = k;
  205. break;
  206. }
  207. }
  208. if (u == -1)
  209. System.err.println("UpdateRoot test failed 2");
  210. else {
  211. int a1 = index[u];
  212. int a2 = index[1 - u];
  213. if (!R[a1].equals(newR[a2]) || (R[u].getF()[0] & 1) == 1)
  214. System.err.println("UpdateRoot test failed 3");
  215. else
  216. System.out.println("UpdateRoot test passed");
  217. }
  218. } else if (cnt == 3) {
  219. int u = -1;
  220. for (int k = 0; k < cnt; k++) {
  221. if (newR[index[k]].equals(Ti)) {
  222. u = k;
  223. break;
  224. }
  225. }
  226. if (u == -1)
  227. System.err.println("UpdateRoot test failed 4");
  228. else {
  229. int a1, a2;
  230. if (u == 0) {
  231. a1 = 1;
  232. a2 = 2;
  233. } else if (u == 1) {
  234. a1 = 0;
  235. a2 = 2;
  236. } else {
  237. a1 = 0;
  238. a2 = 1;
  239. }
  240. u = index[u];
  241. a1 = index[a1];
  242. a2 = index[a2];
  243. if ((R[u].getF()[0] & 1) == 1)
  244. System.err.println("UpdateRoot test failed 5");
  245. else if (!R[a1].equals(newR[a2]))
  246. System.err.println("UpdateRoot test failed 6");
  247. else if (!R[a1].equals(newR[a2]) || !R[a2].equals(newR[a1]))
  248. System.err.println("UpdateRoot test failed 7");
  249. else
  250. System.out.println("UpdateRoot test passed");
  251. }
  252. } else {
  253. System.err.println("UpdateRoot test failed 8");
  254. }
  255. } else if (party == Party.Debbie) {
  256. int sw = con1.readInt();
  257. int lBits = con1.readInt();
  258. byte[] Li = con1.read();
  259. int[] tupleParam = con1.readIntArray();
  260. runD(false, sw, lBits, tupleParam, Li, md.getW());
  261. } else if (party == Party.Charlie) {
  262. int sw = con1.readInt();
  263. int lBits = con1.readInt();
  264. int[] tupleParam = con1.readIntArray();
  265. Tuple[] R = new Tuple[sw];
  266. for (int j = 0; j < sw; j++)
  267. R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
  268. Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
  269. R = runC(false, tupleParam, R, Ti);
  270. con1.write(R);
  271. } else {
  272. throw new NoSuchPartyException(party + "");
  273. }
  274. }
  275. }
  276. @Override
  277. public void run(Party party, Metadata md, Forest forest) {
  278. }
  279. }