PIREviction.java 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package pir;
  2. import java.math.BigInteger;
  3. import java.util.Arrays;
  4. import com.oblivm.backend.gc.GCSignal;
  5. import communication.Communication;
  6. import gc.GCUtil;
  7. import oram.Bucket;
  8. import oram.Forest;
  9. import oram.Metadata;
  10. import oram.Tree;
  11. import oram.Tuple;
  12. import protocols.PermuteIndex;
  13. import protocols.PermuteTarget;
  14. import protocols.Protocol;
  15. import protocols.SSXOT;
  16. import protocols.struct.Party;
  17. import protocols.struct.PreData;
  18. import util.M;
  19. import util.P;
  20. import util.Timer;
  21. import util.Util;
  22. public class PIREviction extends Protocol {
  23. private int pid = P.EVI;
  24. public PIREviction(Communication con1, Communication con2) {
  25. super(con1, con2);
  26. }
  27. private int[] prepareEviction(int target[], int[] ti, int W) {
  28. int d = ti.length;
  29. int[] evict = new int[W * d];
  30. for (int r = 0; r < d; r++) {
  31. int tupleIndex = r * W + ti[r];
  32. for (int c = 0; c < W; c++) {
  33. int currIndex = r * W + c;
  34. if (currIndex == tupleIndex) {
  35. int targetIndex = target[r] * W + ti[target[r]];
  36. evict[targetIndex] = currIndex;
  37. } else
  38. evict[currIndex] = currIndex;
  39. }
  40. }
  41. return evict;
  42. }
  43. public void runE(PreData predata, boolean firstTree, byte[] Li, Tuple[] originalPath, Tree OTi, Timer timer) {
  44. timer.start(pid, M.online_comp);
  45. if (firstTree) {
  46. OTi.setBucketsOnPath(new BigInteger(1, Li).longValue(), new Bucket[] { new Bucket(originalPath) });
  47. timer.stop(pid, M.online_comp);
  48. return;
  49. }
  50. int d = OTi.getD();
  51. int sw = OTi.getStashSize();
  52. int w = OTi.getW();
  53. Tuple[] pathTuples = new Tuple[d * w];
  54. System.arraycopy(originalPath, 0, pathTuples, 0, w);
  55. System.arraycopy(originalPath, sw, pathTuples, w, (d - 1) * w);
  56. Bucket[] pathBuckets = Bucket.tuplesToBuckets(pathTuples, d, w, w);
  57. GCSignal[] LiInputKeys = GCUtil.revSelectKeys(predata.evict_LiKeyPairs, Li);
  58. GCSignal[][] E_feInputKeys = new GCSignal[d][];
  59. GCSignal[][][] E_labelInputKeys = new GCSignal[d][][];
  60. GCSignal[][] deltaInputKeys = new GCSignal[d][];
  61. for (int i = 0; i < d; i++) {
  62. E_feInputKeys[i] = GCUtil.selectFeKeys(predata.evict_E_feKeyPairs[i], pathBuckets[i].getTuples());
  63. E_labelInputKeys[i] = GCUtil.selectLabelKeys(predata.evict_E_labelKeyPairs[i], pathBuckets[i].getTuples());
  64. deltaInputKeys[i] = GCUtil.revSelectKeys(predata.evict_deltaKeyPairs[i], predata.evict_delta[i]);
  65. }
  66. timer.start(pid, M.online_write);
  67. con1.write(pid, LiInputKeys);
  68. con1.write(pid, E_feInputKeys);
  69. con1.write(pid, E_labelInputKeys);
  70. con1.write(pid, deltaInputKeys);
  71. timer.stop(pid, M.online_write);
  72. PermuteTarget permutetarget = new PermuteTarget(con1, con2);
  73. permutetarget.runE();
  74. PermuteIndex permuteindex = new PermuteIndex(con1, con2);
  75. permuteindex.runE();
  76. int logW = (int) Math.ceil(Math.log(w + 1) / Math.log(2));
  77. int W = (int) Math.pow(2, logW);
  78. for (int i = 0; i < d; i++) {
  79. pathBuckets[i].expand(W);
  80. pathBuckets[i].permute(predata.evict_delta_p[i]);
  81. }
  82. pathBuckets = Util.permute(pathBuckets, predata.evict_pi);
  83. for (int i = 0; i < d; i++) {
  84. pathBuckets[i].permute(predata.evict_rho_p[i]);
  85. }
  86. pathTuples = Bucket.bucketsToTuples(pathBuckets);
  87. SSXOT ssxot = new SSXOT(con1, con2, 1);
  88. pathTuples = ssxot.runE(predata, pathTuples, timer);
  89. pathBuckets = Bucket.tuplesToBuckets(pathTuples, d, W, W);
  90. for (int i = 0; i < d; i++) {
  91. int[] rho_ivs = Util.inversePermutation(predata.evict_rho_p[i]);
  92. pathBuckets[i].permute(rho_ivs);
  93. }
  94. int[] pi_ivs = Util.inversePermutation(predata.evict_pi);
  95. pathBuckets = Util.permute(pathBuckets, pi_ivs);
  96. for (int i = 0; i < d; i++) {
  97. int[] delta_ivs = Util.inversePermutation(predata.evict_delta_p[i]);
  98. pathBuckets[i].permute(delta_ivs);
  99. pathBuckets[i].shrink(w);
  100. }
  101. pathBuckets[0].expand(Arrays.copyOfRange(originalPath, w, sw));
  102. timer.start(pid, M.online_write);
  103. con2.write(pid, pathBuckets);
  104. timer.stop(pid, M.online_write);
  105. timer.start(pid, M.online_read);
  106. con2.readBucketArray(pid);
  107. timer.stop(pid, M.online_read);
  108. // OTi.setBucketsOnPath(new BigInteger(1, Li).longValue(), pathBuckets);
  109. timer.stop(pid, M.online_comp);
  110. }
  111. public void runD(PreData predata, boolean firstTree, byte[] Li, Tree OTi, Timer timer) {
  112. timer.start(pid, M.online_comp);
  113. if (firstTree) {
  114. timer.start(pid, M.online_read);
  115. Tuple[] originalPath = con2.readTupleArray(pid);
  116. timer.stop(pid, M.online_read);
  117. OTi.setBucketsOnPath(new BigInteger(1, Li).longValue(), new Bucket[] { new Bucket(originalPath) });
  118. timer.stop(pid, M.online_comp);
  119. return;
  120. }
  121. timer.start(pid, M.online_read);
  122. GCSignal[] LiInputKeys = con1.readGCSignalArray(pid);
  123. GCSignal[][] E_feInputKeys = con1.readDoubleGCSignalArray(pid);
  124. GCSignal[][][] E_labelInputKeys = con1.readTripleGCSignalArray(pid);
  125. GCSignal[][] deltaInputKeys = con1.readDoubleGCSignalArray(pid);
  126. GCSignal[][] C_feInputKeys = con2.readDoubleGCSignalArray(pid);
  127. GCSignal[][][] C_labelInputKeys = con2.readTripleGCSignalArray(pid);
  128. timer.stop(pid, M.online_read);
  129. int w = OTi.getW();
  130. int logW = (int) Math.ceil(Math.log(w + 1) / Math.log(2));
  131. GCSignal[][][] outKeys = predata.evict_gcroute.routing(LiInputKeys, E_feInputKeys, C_feInputKeys,
  132. E_labelInputKeys, C_labelInputKeys, deltaInputKeys);
  133. byte[][] ti_p = new byte[deltaInputKeys.length][];
  134. for (int i = 0; i < ti_p.length; i++) {
  135. ti_p[i] = Util.padArray(GCUtil.evaOutKeys(outKeys[1][i], predata.evict_tiOutKeyHashes[i]).toByteArray(),
  136. (logW + 7) / 8);
  137. }
  138. PermuteTarget permutetarget = new PermuteTarget(con1, con2);
  139. int[] target_pp = permutetarget.runD(predata, firstTree, outKeys[0], timer);
  140. PermuteIndex permuteindex = new PermuteIndex(con1, con2);
  141. int[] ti_pp = permuteindex.runD(predata, firstTree, ti_p, w, timer);
  142. int W = (int) Math.pow(2, logW);
  143. int[] evict = prepareEviction(target_pp, ti_pp, W);
  144. SSXOT ssxot = new SSXOT(con1, con2, 1);
  145. ssxot.runD(predata, evict, timer);
  146. // timer.start(pid, M.online_read);
  147. // Bucket[] pathBuckets = con2.readBucketArray(pid);
  148. // timer.stop(pid, M.online_read);
  149. // OTi.setBucketsOnPath(new BigInteger(1, Li).longValue(), pathBuckets);
  150. timer.stop(pid, M.online_comp);
  151. }
  152. public void runC(PreData predata, boolean firstTree, Tuple[] originalPath, int d, int sw, int w, Timer timer) {
  153. if (firstTree) {
  154. timer.start(pid, M.online_write);
  155. con2.write(pid, originalPath);
  156. timer.stop(pid, M.online_write);
  157. return;
  158. }
  159. timer.start(pid, M.online_comp);
  160. Tuple[] pathTuples = new Tuple[d * w];
  161. System.arraycopy(originalPath, 0, pathTuples, 0, w);
  162. System.arraycopy(originalPath, sw, pathTuples, w, (d - 1) * w);
  163. Bucket[] pathBuckets = Bucket.tuplesToBuckets(pathTuples, d, w, w);
  164. GCSignal[][] C_feInputKeys = new GCSignal[d][];
  165. GCSignal[][][] C_labelInputKeys = new GCSignal[d][][];
  166. for (int i = 0; i < d; i++) {
  167. C_feInputKeys[i] = GCUtil.selectFeKeys(predata.evict_C_feKeyPairs[i], pathBuckets[i].getTuples());
  168. C_labelInputKeys[i] = GCUtil.selectLabelKeys(predata.evict_C_labelKeyPairs[i], pathBuckets[i].getTuples());
  169. }
  170. timer.start(pid, M.online_write);
  171. con2.write(pid, C_feInputKeys);
  172. con2.write(pid, C_labelInputKeys);
  173. timer.stop(pid, M.online_write);
  174. PermuteTarget permutetarget = new PermuteTarget(con1, con2);
  175. permutetarget.runC(predata, firstTree, timer);
  176. PermuteIndex permuteindex = new PermuteIndex(con1, con2);
  177. permuteindex.runC(predata, firstTree, timer);
  178. int logW = (int) Math.ceil(Math.log(w + 1) / Math.log(2));
  179. int W = (int) Math.pow(2, logW);
  180. for (int i = 0; i < d; i++) {
  181. pathBuckets[i].expand(W);
  182. pathBuckets[i].permute(predata.evict_delta_p[i]);
  183. }
  184. pathBuckets = Util.permute(pathBuckets, predata.evict_pi);
  185. for (int i = 0; i < d; i++) {
  186. pathBuckets[i].permute(predata.evict_rho_p[i]);
  187. }
  188. pathTuples = Bucket.bucketsToTuples(pathBuckets);
  189. SSXOT ssxot = new SSXOT(con1, con2, 1);
  190. pathTuples = ssxot.runC(predata, pathTuples, timer);
  191. pathBuckets = Bucket.tuplesToBuckets(pathTuples, d, W, W);
  192. for (int i = 0; i < d; i++) {
  193. int[] rho_ivs = Util.inversePermutation(predata.evict_rho_p[i]);
  194. pathBuckets[i].permute(rho_ivs);
  195. }
  196. int[] pi_ivs = Util.inversePermutation(predata.evict_pi);
  197. pathBuckets = Util.permute(pathBuckets, pi_ivs);
  198. for (int i = 0; i < d; i++) {
  199. int[] delta_ivs = Util.inversePermutation(predata.evict_delta_p[i]);
  200. pathBuckets[i].permute(delta_ivs);
  201. pathBuckets[i].shrink(w);
  202. }
  203. pathBuckets[0].expand(Arrays.copyOfRange(originalPath, w, sw));
  204. timer.start(pid, M.online_write);
  205. con1.write(pid, pathBuckets);
  206. timer.stop(pid, M.online_write);
  207. timer.start(pid, M.online_read);
  208. con1.readBucketArray(pid);
  209. timer.stop(pid, M.online_read);
  210. timer.stop(pid, M.online_comp);
  211. }
  212. @Override
  213. public void run(Party party, Metadata md, Forest[] forest) {
  214. System.out.println("Use PIRRetrieve to test PIREviction");
  215. }
  216. @Override
  217. public void run(Party party, Metadata md, Forest forest) {
  218. System.out.println("Use Retrieve to test Eviction");
  219. }
  220. }