Access.java 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. package protocols;
  2. import java.math.BigInteger;
  3. import java.util.Arrays;
  4. import org.apache.commons.lang3.ArrayUtils;
  5. import communication.Communication;
  6. import crypto.Crypto;
  7. import exceptions.AccessException;
  8. import exceptions.NoSuchPartyException;
  9. import oram.Bucket;
  10. import oram.Forest;
  11. import oram.Global;
  12. import oram.Metadata;
  13. import oram.Tree;
  14. import oram.Tuple;
  15. import protocols.precomputation.PreAccess;
  16. import protocols.struct.OutAccess;
  17. import protocols.struct.OutSSCOT;
  18. import protocols.struct.OutSSIOT;
  19. import protocols.struct.Party;
  20. import protocols.struct.PreData;
  21. import util.M;
  22. import util.P;
  23. import util.Timer;
  24. import util.Util;
  25. public class Access extends Protocol {
  26. private int pid = P.ACC;
  27. public Access(Communication con1, Communication con2) {
  28. super(con1, con2);
  29. }
  30. public OutAccess runE(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer timer) {
  31. timer.start(pid, M.online_comp);
  32. // step 0: get Li from C
  33. byte[] Li = new byte[0];
  34. timer.start(pid, M.online_read);
  35. if (OTi.getTreeIndex() > 0)
  36. Li = con2.read(pid);
  37. timer.stop(pid, M.online_read);
  38. // step 1
  39. Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  40. Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
  41. for (int i = 0; i < pathTuples.length; i++)
  42. pathTuples[i].setXor(predata.access_p[i]);
  43. pathTuples = Util.permute(pathTuples, predata.access_sigma);
  44. // step 3
  45. byte[] y = null;
  46. if (OTi.getTreeIndex() == 0)
  47. y = pathTuples[0].getA();
  48. else if (OTi.getTreeIndex() < OTi.getH() - 1)
  49. y = Util.nextBytes(OTi.getABytes(), Crypto.sr);
  50. else
  51. y = new byte[OTi.getABytes()];
  52. if (OTi.getTreeIndex() > 0) {
  53. byte[][] a = new byte[pathTuples.length][];
  54. byte[][] m = new byte[pathTuples.length][];
  55. for (int i = 0; i < pathTuples.length; i++) {
  56. m[i] = Util.xor(pathTuples[i].getA(), y);
  57. a[i] = ArrayUtils.addAll(pathTuples[i].getF(), pathTuples[i].getN());
  58. for (int j = 0; j < Ni.length; j++)
  59. a[i][a[i].length - 1 - j] ^= Ni[Ni.length - 1 - j];
  60. }
  61. SSCOT sscot = new SSCOT(con1, con2);
  62. sscot.runE(predata, m, a, timer);
  63. }
  64. // step 4
  65. if (OTi.getTreeIndex() < OTi.getH() - 1) {
  66. int ySegBytes = y.length / OTi.getTwoTauPow();
  67. byte[][] y_array = new byte[OTi.getTwoTauPow()][];
  68. for (int i = 0; i < OTi.getTwoTauPow(); i++)
  69. y_array[i] = Arrays.copyOfRange(y, i * ySegBytes, (i + 1) * ySegBytes);
  70. SSIOT ssiot = new SSIOT(con1, con2);
  71. ssiot.runE(predata, y_array, Nip1_pr, timer);
  72. }
  73. // step 5
  74. Tuple Ti = null;
  75. if (OTi.getTreeIndex() == 0)
  76. Ti = pathTuples[0];
  77. else
  78. Ti = new Tuple(new byte[1], Ni, Li, y);
  79. OutAccess outaccess = new OutAccess(Li, null, null, null, null, Ti, pathTuples);
  80. timer.stop(pid, M.online_comp);
  81. return outaccess;
  82. }
  83. public byte[] runD(PreData predata, Tree OTi, byte[] Ni, byte[] Nip1_pr, Timer timer) {
  84. timer.start(pid, M.online_comp);
  85. // step 0: get Li from C
  86. byte[] Li = new byte[0];
  87. timer.start(pid, M.online_read);
  88. if (OTi.getTreeIndex() > 0)
  89. Li = con2.read(pid);
  90. timer.stop(pid, M.online_read);
  91. // step 1
  92. Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  93. Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
  94. for (int i = 0; i < pathTuples.length; i++)
  95. pathTuples[i].setXor(predata.access_p[i]);
  96. pathTuples = Util.permute(pathTuples, predata.access_sigma);
  97. // step 2
  98. timer.start(pid, M.online_write);
  99. con2.write(pid, pathTuples);
  100. con2.write(pid, Ni);
  101. timer.stop(pid, M.online_write);
  102. // step 3
  103. if (OTi.getTreeIndex() > 0) {
  104. byte[][] b = new byte[pathTuples.length][];
  105. for (int i = 0; i < pathTuples.length; i++) {
  106. b[i] = ArrayUtils.addAll(pathTuples[i].getF(), pathTuples[i].getN());
  107. b[i][0] ^= 1;
  108. for (int j = 0; j < Ni.length; j++)
  109. b[i][b[i].length - 1 - j] ^= Ni[Ni.length - 1 - j];
  110. }
  111. SSCOT sscot = new SSCOT(con1, con2);
  112. sscot.runD(predata, b, timer);
  113. }
  114. // step 4
  115. if (OTi.getTreeIndex() < OTi.getH() - 1) {
  116. SSIOT ssiot = new SSIOT(con1, con2);
  117. ssiot.runD(predata, Nip1_pr, timer);
  118. }
  119. timer.stop(pid, M.online_comp);
  120. return Li;
  121. }
  122. public OutAccess runC(Metadata md, int treeIndex, byte[] Li, Timer timer) {
  123. timer.start(pid, M.online_comp);
  124. // step 0: send Li to E and D
  125. timer.start(pid, M.online_write);
  126. if (treeIndex > 0) {
  127. con1.write(pid, Li);
  128. con2.write(pid, Li);
  129. }
  130. timer.stop(pid, M.online_write);
  131. // step 2
  132. timer.start(pid, M.online_read);
  133. Tuple[] pathTuples = con2.readTupleArray(pid);
  134. byte[] Ni = con2.read(pid);
  135. timer.stop(pid, M.online_read);
  136. // step 3
  137. int j1 = 0;
  138. byte[] z = null;
  139. if (treeIndex == 0) {
  140. z = pathTuples[0].getA();
  141. } else {
  142. SSCOT sscot = new SSCOT(con1, con2);
  143. OutSSCOT je = sscot.runC(timer);
  144. j1 = je.t;
  145. byte[] d = pathTuples[j1].getA();
  146. z = Util.xor(je.m_t, d);
  147. }
  148. // step 4
  149. int j2 = 0;
  150. byte[] Lip1 = null;
  151. if (treeIndex < md.getNumTrees() - 1) {
  152. SSIOT ssiot = new SSIOT(con1, con2);
  153. OutSSIOT jy = ssiot.runC(timer);
  154. // step 5
  155. j2 = jy.t;
  156. int lSegBytes = md.getABytesOfTree(treeIndex) / md.getTwoTauPow();
  157. byte[] z_j2 = Arrays.copyOfRange(z, j2 * lSegBytes, (j2 + 1) * lSegBytes);
  158. Lip1 = Util.xor(jy.m_t, z_j2);
  159. }
  160. Tuple Ti = null;
  161. if (treeIndex == 0) {
  162. Ti = pathTuples[0];
  163. } else {
  164. Ti = new Tuple(new byte[] { 1 }, Ni, new byte[md.getLBytesOfTree(treeIndex)], z);
  165. pathTuples[j1].getF()[0] = (byte) (1 - pathTuples[j1].getF()[0]);
  166. Crypto.sr.nextBytes(pathTuples[j1].getN());
  167. Crypto.sr.nextBytes(pathTuples[j1].getL());
  168. Crypto.sr.nextBytes(pathTuples[j1].getA());
  169. }
  170. OutAccess outaccess = new OutAccess(Li, Lip1, Ti, pathTuples, j2, null, null);
  171. timer.stop(pid, M.online_comp);
  172. return outaccess;
  173. }
  174. public OutAccess runE2(Tree OTi, Timer timer) {
  175. timer.start(pid, M.online_comp);
  176. // step 0: get Li from C
  177. byte[] Li = new byte[0];
  178. timer.start(pid, M.online_read);
  179. if (OTi.getTreeIndex() > 0)
  180. Li = con2.read(pid);
  181. timer.stop(pid, M.online_read);
  182. // step 1
  183. Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  184. Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
  185. // step 5
  186. Tuple Ti = null;
  187. if (OTi.getTreeIndex() == 0)
  188. Ti = pathTuples[0];
  189. else {
  190. Ti = new Tuple(1, OTi.getNBytes(), OTi.getLBytes(), OTi.getABytes(), Crypto.sr);
  191. Ti.setF(new byte[1]);
  192. }
  193. OutAccess outaccess = new OutAccess(Li, null, null, null, null, Ti, pathTuples);
  194. timer.stop(pid, M.online_comp);
  195. return outaccess;
  196. }
  197. public byte[] runD2(Tree OTi, Timer timer) {
  198. timer.start(pid, M.online_comp);
  199. // step 0: get Li from C
  200. byte[] Li = new byte[0];
  201. timer.start(pid, M.online_read);
  202. if (OTi.getTreeIndex() > 0)
  203. Li = con2.read(pid);
  204. timer.stop(pid, M.online_read);
  205. // step 1
  206. Bucket[] pathBuckets = OTi.getBucketsOnPath(Li);
  207. Tuple[] pathTuples = Bucket.bucketsToTuples(pathBuckets);
  208. // step 2
  209. timer.start(pid, M.online_write);
  210. con2.write(pid, pathTuples);
  211. timer.stop(pid, M.online_write);
  212. timer.stop(pid, M.online_comp);
  213. return Li;
  214. }
  215. public OutAccess runC2(Metadata md, int treeIndex, byte[] Li, Timer timer) {
  216. timer.start(pid, M.online_comp);
  217. // step 0: send Li to E and D
  218. timer.start(pid, M.online_write);
  219. if (treeIndex > 0) {
  220. con1.write(pid, Li);
  221. con2.write(pid, Li);
  222. }
  223. timer.stop(pid, M.online_write);
  224. // step 2
  225. timer.start(pid, M.online_read);
  226. Tuple[] pathTuples = con2.readTupleArray(pid);
  227. timer.stop(pid, M.online_read);
  228. // step 5
  229. Tuple Ti = null;
  230. if (treeIndex == 0) {
  231. Ti = pathTuples[0];
  232. } else {
  233. Ti = new Tuple(1, md.getNBytesOfTree(treeIndex), md.getLBytesOfTree(treeIndex),
  234. md.getABytesOfTree(treeIndex), Crypto.sr);
  235. Ti.setF(new byte[1]);
  236. }
  237. OutAccess outaccess = new OutAccess(Li, null, Ti, pathTuples, null, null, null);
  238. timer.stop(pid, M.online_comp);
  239. return outaccess;
  240. }
  241. // for testing correctness
  242. @Override
  243. public void run(Party party, Metadata md, Forest forest) {
  244. int records = 5;
  245. int repeat = 5;
  246. int tau = md.getTau();
  247. int numTrees = md.getNumTrees();
  248. long numInsert = md.getNumInsertRecords();
  249. int addrBits = md.getAddrBits();
  250. Timer timer = new Timer();
  251. sanityCheck();
  252. System.out.println();
  253. for (int i = 0; i < records; i++) {
  254. long N = Global.cheat ? 0 : Util.nextLong(numInsert, Crypto.sr);
  255. for (int j = 0; j < repeat; j++) {
  256. System.out.println("Test: " + i + " " + j);
  257. System.out.println("N=" + BigInteger.valueOf(N).toString(2));
  258. byte[] Li = new byte[0];
  259. for (int ti = 0; ti < numTrees; ti++) {
  260. long Ni_value = Util.getSubBits(N, addrBits, addrBits - md.getNBitsOfTree(ti));
  261. long Nip1_pr_value = Util.getSubBits(N, addrBits - md.getNBitsOfTree(ti),
  262. Math.max(addrBits - md.getNBitsOfTree(ti) - tau, 0));
  263. byte[] Ni = Util.longToBytes(Ni_value, md.getNBytesOfTree(ti));
  264. byte[] Nip1_pr = Util.longToBytes(Nip1_pr_value, (tau + 7) / 8);
  265. PreData predata = new PreData();
  266. PreAccess preaccess = new PreAccess(con1, con2);
  267. if (party == Party.Eddie) {
  268. Tree OTi = forest.getTree(ti);
  269. int numTuples = (OTi.getD() - 1) * OTi.getW() + OTi.getStashSize();
  270. int[] tupleParam = new int[] { ti == 0 ? 0 : 1, md.getNBytesOfTree(ti), md.getLBytesOfTree(ti),
  271. md.getABytesOfTree(ti) };
  272. preaccess.runE(predata, md.getTwoTauPow(), numTuples, tupleParam, timer);
  273. byte[] sE_Ni = Util.nextBytes(Ni.length, Crypto.sr);
  274. byte[] sD_Ni = Util.xor(Ni, sE_Ni);
  275. con1.write(sD_Ni);
  276. byte[] sE_Nip1_pr = Util.nextBytes(Nip1_pr.length, Crypto.sr);
  277. byte[] sD_Nip1_pr = Util.xor(Nip1_pr, sE_Nip1_pr);
  278. con1.write(sD_Nip1_pr);
  279. runE(predata, OTi, sE_Ni, sE_Nip1_pr, timer);
  280. if (ti == numTrees - 1)
  281. con2.write(N);
  282. } else if (party == Party.Debbie) {
  283. Tree OTi = forest.getTree(ti);
  284. preaccess.runD(predata, timer);
  285. byte[] sD_Ni = con1.read();
  286. byte[] sD_Nip1_pr = con1.read();
  287. runD(predata, OTi, sD_Ni, sD_Nip1_pr, timer);
  288. } else if (party == Party.Charlie) {
  289. preaccess.runC(timer);
  290. System.out.println("L" + ti + "=" + new BigInteger(1, Li).toString(2));
  291. OutAccess outaccess = runC(md, ti, Li, timer);
  292. Li = outaccess.C_Lip1;
  293. if (ti == numTrees - 1) {
  294. N = con1.readLong();
  295. long data = new BigInteger(1, outaccess.C_Ti.getA()).longValue();
  296. if (N == data) {
  297. System.out.println("Access passed");
  298. System.out.println();
  299. } else {
  300. throw new AccessException("Access failed");
  301. }
  302. }
  303. } else {
  304. throw new NoSuchPartyException(party + "");
  305. }
  306. }
  307. }
  308. }
  309. // timer.print();
  310. }
  311. }