PIRCOT.java 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package pir;
  2. import communication.Communication;
  3. import crypto.Crypto;
  4. import exceptions.NoSuchPartyException;
  5. import exceptions.SSCOTException;
  6. import oram.Forest;
  7. import oram.Metadata;
  8. import precomputation.PrePIRCOT;
  9. import protocols.Protocol;
  10. import protocols.struct.OutSSCOT;
  11. import protocols.struct.Party;
  12. import protocols.struct.PreData;
  13. import util.M;
  14. import util.P;
  15. import util.Timer;
  16. import util.Util;
  17. public class PIRCOT extends Protocol {
  18. private int pid = P.COT;
  19. public PIRCOT(Communication con1, Communication con2) {
  20. super(con1, con2);
  21. }
  22. public void runE(PreData predata, byte[][] a, Timer timer) {
  23. timer.start(pid, M.online_comp);
  24. // step 1
  25. int n = a.length;
  26. byte[][] x = predata.sscot_r;
  27. byte[][] v = new byte[n][];
  28. for (int i = 0; i < n; i++) {
  29. for (int j = 0; j < a[i].length; j++)
  30. x[i][j] = (byte) (predata.sscot_r[i][j] ^ a[i][j]);
  31. v[i] = predata.sscot_F_kprime.compute(x[i]);
  32. }
  33. timer.start(pid, M.online_write);
  34. con2.write(pid, v);
  35. timer.stop(pid, M.online_write);
  36. timer.stop(pid, M.online_comp);
  37. }
  38. public void runD(PreData predata, byte[][] b, Timer timer) {
  39. timer.start(pid, M.online_comp);
  40. // step 2
  41. int n = b.length;
  42. byte[][] y = predata.sscot_r;
  43. byte[][] w = new byte[n][];
  44. for (int i = 0; i < n; i++) {
  45. for (int j = 0; j < b[i].length; j++)
  46. y[i][j] = (byte) (predata.sscot_r[i][j] ^ b[i][j]);
  47. w[i] = predata.sscot_F_kprime.compute(y[i]);
  48. }
  49. timer.start(pid, M.online_write);
  50. con2.write(pid, w);
  51. timer.stop(pid, M.online_write);
  52. timer.stop(pid, M.online_comp);
  53. }
  54. public OutSSCOT runC(Timer timer) {
  55. timer.start(pid, M.online_comp);
  56. // step 1
  57. timer.start(pid, M.online_read);
  58. byte[][] v = con1.readDoubleByteArray(pid);
  59. // step 2
  60. byte[][] w = con2.readDoubleByteArray(pid);
  61. timer.stop(pid, M.online_read);
  62. // step 3
  63. int n = v.length;
  64. OutSSCOT output = null;
  65. int invariant = 0;
  66. for (int i = 0; i < n; i++) {
  67. if (Util.equal(v[i], w[i])) {
  68. output = new OutSSCOT(i, null);
  69. invariant++;
  70. }
  71. }
  72. if (invariant != 1)
  73. throw new SSCOTException("Invariant error: " + invariant);
  74. timer.stop(pid, M.online_comp);
  75. return output;
  76. }
  77. // for testing correctness
  78. @Override
  79. public void run(Party party, Metadata md, Forest forest) {
  80. Timer timer = new Timer();
  81. for (int j = 0; j < 100; j++) {
  82. int n = 100;
  83. int FN = 5;
  84. byte[][] a = new byte[n][FN];
  85. byte[][] b = new byte[n][FN];
  86. for (int i = 0; i < n; i++) {
  87. Crypto.sr.nextBytes(a[i]);
  88. Crypto.sr.nextBytes(b[i]);
  89. while (Util.equal(a[i], b[i]))
  90. Crypto.sr.nextBytes(b[i]);
  91. }
  92. int index = Crypto.sr.nextInt(n);
  93. b[index] = a[index].clone();
  94. PreData predata = new PreData();
  95. PrePIRCOT presscot = new PrePIRCOT(con1, con2);
  96. if (party == Party.Eddie) {
  97. con1.write(b);
  98. con2.write(index);
  99. presscot.runE(predata, n, timer);
  100. runE(predata, a, timer);
  101. } else if (party == Party.Debbie) {
  102. b = con1.readDoubleByteArray();
  103. presscot.runD(predata, timer);
  104. runD(predata, b, timer);
  105. } else if (party == Party.Charlie) {
  106. index = con1.readInt();
  107. presscot.runC();
  108. OutSSCOT output = runC(timer);
  109. if (output.t == index)
  110. System.out.println("PIRCOT test passed");
  111. else
  112. System.err.println("PIRCOT test failed");
  113. } else {
  114. throw new NoSuchPartyException(party + "");
  115. }
  116. }
  117. // timer.print();
  118. }
  119. }