UpdateRoot.java 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. package protocols;
  2. import java.math.BigInteger;
  3. import org.apache.commons.lang3.ArrayUtils;
  4. import com.oblivm.backend.flexsc.CompEnv;
  5. import com.oblivm.backend.gc.GCSignal;
  6. import com.oblivm.backend.gc.regular.GCEva;
  7. import com.oblivm.backend.gc.regular.GCGen;
  8. import com.oblivm.backend.network.Network;
  9. import communication.Communication;
  10. import crypto.Crypto;
  11. import exceptions.NoSuchPartyException;
  12. import gc.GCUpdateRoot;
  13. import gc.GCUtil;
  14. import oram.Forest;
  15. import oram.Metadata;
  16. import oram.Tuple;
  17. import struct.Party;
  18. import subprotocols.SSXOT;
  19. import util.M;
  20. import util.P;
  21. import util.Util;
  22. public class UpdateRoot extends Protocol {
  23. int pid = P.UR;
  24. public UpdateRoot(Communication con1, Communication con2) {
  25. super(con1, con2);
  26. online_band = all.online_band[pid];
  27. offline_band = all.offline_band[pid];
  28. timer = all.timer[pid];
  29. }
  30. public Tuple[] runE(boolean firstTree, int sw, int lBits, int[] tupleParam, byte[] Li, Tuple[] R, Tuple Ti) {
  31. if (firstTree)
  32. return R;
  33. timer.start(M.offline_comp);
  34. int sLogW = (int) Math.ceil(Math.log(sw) / Math.log(2));
  35. GCSignal[][] j1KeyPairs = GCUtil.genKeyPairs(sLogW);
  36. GCSignal[][] LiKeyPairs = GCUtil.genKeyPairs(lBits);
  37. GCSignal[][] E_feKeyPairs = GCUtil.genKeyPairs(sw);
  38. GCSignal[][] C_feKeyPairs = GCUtil.genKeyPairs(sw);
  39. GCSignal[] j1ZeroKeys = GCUtil.getZeroKeys(j1KeyPairs);
  40. GCSignal[] LiZeroKeys = GCUtil.getZeroKeys(LiKeyPairs);
  41. GCSignal[] E_feZeroKeys = GCUtil.getZeroKeys(E_feKeyPairs);
  42. GCSignal[] C_feZeroKeys = GCUtil.getZeroKeys(C_feKeyPairs);
  43. GCSignal[][][] E_labelKeyPairs = new GCSignal[sw][][];
  44. GCSignal[][][] C_labelKeyPairs = new GCSignal[sw][][];
  45. GCSignal[][] E_labelZeroKeys = new GCSignal[sw][];
  46. GCSignal[][] C_labelZeroKeys = new GCSignal[sw][];
  47. for (int i = 0; i < sw; i++) {
  48. E_labelKeyPairs[i] = GCUtil.genKeyPairs(lBits);
  49. C_labelKeyPairs[i] = GCUtil.genKeyPairs(lBits);
  50. E_labelZeroKeys[i] = GCUtil.getZeroKeys(E_labelKeyPairs[i]);
  51. C_labelZeroKeys[i] = GCUtil.getZeroKeys(C_labelKeyPairs[i]);
  52. }
  53. Network channel = new Network(null, con1);
  54. CompEnv<GCSignal> gen = new GCGen(channel, timer, offline_band, M.offline_write);
  55. GCSignal[][] outZeroKeys = new GCUpdateRoot<GCSignal>(gen, lBits + 1, sw).rootFindDeepestAndEmpty(j1ZeroKeys,
  56. LiZeroKeys, E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys, C_labelZeroKeys);
  57. ((GCGen) gen).sendLastSetGTT();
  58. byte[][][] outKeyHashes = new byte[outZeroKeys.length][][];
  59. for (int i = 0; i < outZeroKeys.length; i++)
  60. outKeyHashes[i] = GCUtil.genOutKeyHashes(outZeroKeys[i]);
  61. timer.start(M.offline_write);
  62. con2.write(offline_band, C_feKeyPairs);
  63. con2.write(offline_band, C_labelKeyPairs);
  64. con1.write(offline_band, outKeyHashes);
  65. timer.stop(M.offline_write);
  66. timer.stop(M.offline_comp);
  67. //////////////////////////////////////////////////////////////////////////////////
  68. timer.start(M.online_comp);
  69. // step 1
  70. int j1 = Crypto.sr.nextInt(R.length);
  71. GCSignal[] j1InputKeys = GCUtil.revSelectKeys(j1KeyPairs, BigInteger.valueOf(j1).toByteArray());
  72. GCSignal[] LiInputKeys = GCUtil.revSelectKeys(LiKeyPairs, Li);
  73. GCSignal[] E_feInputKeys = GCUtil.selectFeKeys(E_feKeyPairs, R);
  74. GCSignal[][] E_labelInputKeys = GCUtil.selectLabelKeys(E_labelKeyPairs, R);
  75. timer.start(M.online_write);
  76. con1.write(online_band, j1InputKeys);
  77. con1.write(online_band, LiInputKeys);
  78. con1.write(online_band, E_feInputKeys);
  79. con1.write(online_band, E_labelInputKeys);
  80. timer.stop(M.online_write);
  81. // step 4
  82. R = ArrayUtils.addAll(R, new Tuple[] { Ti });
  83. SSXOT ssxot = new SSXOT(con1, con2);
  84. R = ssxot.runE(R, tupleParam);
  85. timer.stop(M.online_comp);
  86. return R;
  87. }
  88. public void runD(boolean firstTree, int sw, int lBits, int[] tupleParam, byte[] Li, int w) {
  89. if (firstTree)
  90. return;
  91. timer.start(M.offline_comp);
  92. int logSW = (int) Math.ceil(Math.log(sw) / Math.log(2));
  93. GCSignal[] j1ZeroKeys = GCUtil.genEmptyKeys(logSW);
  94. GCSignal[] LiZeroKeys = GCUtil.genEmptyKeys(lBits);
  95. GCSignal[] E_feZeroKeys = GCUtil.genEmptyKeys(sw);
  96. GCSignal[] C_feZeroKeys = GCUtil.genEmptyKeys(sw);
  97. GCSignal[][] E_labelZeroKeys = new GCSignal[sw][];
  98. GCSignal[][] C_labelZeroKeys = new GCSignal[sw][];
  99. for (int i = 0; i < sw; i++) {
  100. E_labelZeroKeys[i] = GCUtil.genEmptyKeys(lBits);
  101. C_labelZeroKeys[i] = GCUtil.genEmptyKeys(lBits);
  102. }
  103. Network channel = new Network(con1, null);
  104. CompEnv<GCSignal> eva = new GCEva(channel, timer, M.offline_read);
  105. GCUpdateRoot<GCSignal> gcur = new GCUpdateRoot<GCSignal>(eva, lBits + 1, sw);
  106. gcur.rootFindDeepestAndEmpty(j1ZeroKeys, LiZeroKeys, E_feZeroKeys, C_feZeroKeys, E_labelZeroKeys,
  107. C_labelZeroKeys);
  108. ((GCEva) eva).receiveLastSetGTT();
  109. eva.setEvaluate();
  110. timer.start(M.offline_read);
  111. byte[][][] outKeyHashes = con1.readTripleByteArrayAndDec();
  112. timer.stop(M.offline_read);
  113. timer.stop(M.offline_comp);
  114. ///////////////////////////////////////////////////////////////////////////////
  115. timer.start(M.online_comp);
  116. // step 1
  117. timer.start(M.online_read);
  118. GCSignal[] j1InputKeys = con1.readGCSignalArrayAndDec();
  119. GCSignal[] LiInputKeys = con1.readGCSignalArrayAndDec();
  120. GCSignal[] E_feInputKeys = con1.readGCSignalArrayAndDec();
  121. GCSignal[][] E_labelInputKeys = con1.readDoubleGCSignalArrayAndDec();
  122. GCSignal[] C_feInputKeys = con2.readGCSignalArrayAndDec();
  123. GCSignal[][] C_labelInputKeys = con2.readDoubleGCSignalArrayAndDec();
  124. timer.stop(M.online_read);
  125. // step 2
  126. GCSignal[][] outKeys = gcur.rootFindDeepestAndEmpty(j1InputKeys, LiInputKeys, E_feInputKeys, C_feInputKeys,
  127. E_labelInputKeys, C_labelInputKeys);
  128. int j1 = GCUtil.evaOutKeys(outKeys[0], outKeyHashes[0]).intValue();
  129. int j2 = GCUtil.evaOutKeys(outKeys[1], outKeyHashes[1]).intValue();
  130. // step 3
  131. int r = Crypto.sr.nextInt(w);
  132. int[] I = new int[E_feInputKeys.length];
  133. for (int i = 0; i < I.length; i++)
  134. I[i] = i;
  135. I[j2] = I.length;
  136. int tmp = I[r];
  137. I[r] = I[j1];
  138. I[j1] = tmp;
  139. // step 4
  140. SSXOT ssxot = new SSXOT(con1, con2);
  141. ssxot.runD(sw + 1, sw, tupleParam, I);
  142. timer.stop(M.online_comp);
  143. }
  144. public Tuple[] runC(boolean firstTree, int[] tupleParam, Tuple[] R, Tuple Ti) {
  145. if (firstTree)
  146. return R;
  147. timer.start(M.offline_comp);
  148. timer.start(M.offline_read);
  149. GCSignal[][] C_feKeyPairs = con1.readDoubleGCSignalArrayAndDec();
  150. GCSignal[][][] C_labelKeyPairs = con1.readTripleGCSignalArrayAndDec();
  151. timer.stop(M.offline_read);
  152. timer.stop(M.offline_comp);
  153. ////////////////////////////////////////////////////////////////////////////
  154. timer.start(M.online_comp);
  155. // step 1
  156. GCSignal[] C_feInputKeys = GCUtil.selectFeKeys(C_feKeyPairs, R);
  157. GCSignal[][] C_labelInputKeys = GCUtil.selectLabelKeys(C_labelKeyPairs, R);
  158. timer.start(M.online_write);
  159. con2.write(online_band, C_feInputKeys);
  160. con2.write(online_band, C_labelInputKeys);
  161. timer.stop(M.online_write);
  162. // step 4
  163. R = ArrayUtils.addAll(R, new Tuple[] { Ti });
  164. SSXOT ssxot = new SSXOT(con1, con2);
  165. R = ssxot.runC(R, tupleParam);
  166. timer.stop(M.online_comp);
  167. return R;
  168. }
  169. @Override
  170. public void run(Party party, Metadata md, Forest[] forest) {
  171. for (int i = 0; i < 100; i++) {
  172. System.out.println("i=" + i);
  173. if (party == Party.Eddie) {
  174. int sw = Crypto.sr.nextInt(15) + 10;
  175. int lBits = Crypto.sr.nextInt(20) + 5;
  176. byte[] Li = Util.nextBytes((lBits + 7) / 8, Crypto.sr);
  177. Tuple[] R = new Tuple[sw];
  178. for (int j = 0; j < sw; j++)
  179. R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
  180. Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, Crypto.sr);
  181. int[] tupleParam = new int[] { 1, 2, (lBits + 7) / 8, 3 };
  182. con1.write(sw);
  183. con1.write(lBits);
  184. con1.write(Li);
  185. con1.write(tupleParam);
  186. con2.write(sw);
  187. con2.write(lBits);
  188. con2.write(tupleParam);
  189. Tuple[] newR = runE(false, sw, lBits, tupleParam, Li, R, Ti);
  190. Tuple[] R_C = con2.readTupleArray();
  191. int cnt = 0;
  192. int[] index = new int[3];
  193. for (int j = 0; j < sw; j++) {
  194. newR[j].setXor(R_C[j]);
  195. if (!R[j].equals(newR[j])) {
  196. index[cnt] = j;
  197. cnt++;
  198. }
  199. }
  200. if (cnt == 1) {
  201. if (newR[index[0]].equals(Ti) && (R[index[0]].getF()[0] & 1) == 0)
  202. System.out.println("UpdateRoot test passed");
  203. else
  204. System.err.println("UpdateRoot test failed 1");
  205. } else if (cnt == 2) {
  206. int u = -1;
  207. for (int k = 0; k < cnt; k++) {
  208. if (newR[index[k]].equals(Ti)) {
  209. u = k;
  210. break;
  211. }
  212. }
  213. if (u == -1)
  214. System.err.println("UpdateRoot test failed 2");
  215. else {
  216. int a1 = index[u];
  217. int a2 = index[1 - u];
  218. if (!R[a1].equals(newR[a2]) || (R[u].getF()[0] & 1) == 1)
  219. System.err.println("UpdateRoot test failed 3");
  220. else
  221. System.out.println("UpdateRoot test passed");
  222. }
  223. } else if (cnt == 3) {
  224. int u = -1;
  225. for (int k = 0; k < cnt; k++) {
  226. if (newR[index[k]].equals(Ti)) {
  227. u = k;
  228. break;
  229. }
  230. }
  231. if (u == -1)
  232. System.err.println("UpdateRoot test failed 4");
  233. else {
  234. int a1, a2;
  235. if (u == 0) {
  236. a1 = 1;
  237. a2 = 2;
  238. } else if (u == 1) {
  239. a1 = 0;
  240. a2 = 2;
  241. } else {
  242. a1 = 0;
  243. a2 = 1;
  244. }
  245. u = index[u];
  246. a1 = index[a1];
  247. a2 = index[a2];
  248. if ((R[u].getF()[0] & 1) == 1)
  249. System.err.println("UpdateRoot test failed 5");
  250. else if (!R[a1].equals(newR[a2]))
  251. System.err.println("UpdateRoot test failed 6");
  252. else if (!R[a1].equals(newR[a2]) || !R[a2].equals(newR[a1]))
  253. System.err.println("UpdateRoot test failed 7");
  254. else
  255. System.out.println("UpdateRoot test passed");
  256. }
  257. } else {
  258. System.err.println("UpdateRoot test failed 8");
  259. }
  260. } else if (party == Party.Debbie) {
  261. int sw = con1.readInt();
  262. int lBits = con1.readInt();
  263. byte[] Li = con1.read();
  264. int[] tupleParam = con1.readIntArray();
  265. runD(false, sw, lBits, tupleParam, Li, md.getW());
  266. } else if (party == Party.Charlie) {
  267. int sw = con1.readInt();
  268. int lBits = con1.readInt();
  269. int[] tupleParam = con1.readIntArray();
  270. Tuple[] R = new Tuple[sw];
  271. for (int j = 0; j < sw; j++)
  272. R[j] = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
  273. Tuple Ti = new Tuple(1, 2, (lBits + 7) / 8, 3, null);
  274. R = runC(false, tupleParam, R, Ti);
  275. con1.write(R);
  276. } else {
  277. throw new NoSuchPartyException(party + "");
  278. }
  279. }
  280. }
  281. @Override
  282. public void run(Party party, Metadata md, Forest forest) {
  283. }
  284. }