SSCOT.java 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. package protocols;
  2. import communication.Communication;
  3. import crypto.Crypto;
  4. import crypto.PRG;
  5. import exceptions.NoSuchPartyException;
  6. import exceptions.SSCOTException;
  7. import oram.Forest;
  8. import oram.Metadata;
  9. import protocols.precomputation.PreSSCOT;
  10. import protocols.struct.OutSSCOT;
  11. import protocols.struct.Party;
  12. import protocols.struct.PreData;
  13. import util.M;
  14. import util.P;
  15. import util.Timer;
  16. import util.Util;
  17. public class SSCOT extends Protocol {
  18. private int pid = P.COT;
  19. public SSCOT(Communication con1, Communication con2) {
  20. super(con1, con2);
  21. }
  22. public void runE(PreData predata, byte[][] m, byte[][] a, Timer timer) {
  23. timer.start(pid, M.online_comp);
  24. // step 1
  25. int n = m.length;
  26. int l = m[0].length * 8;
  27. byte[][] x = predata.sscot_r;
  28. byte[][] e = new byte[n][];
  29. byte[][] v = new byte[n][];
  30. PRG G = new PRG(l);
  31. for (int i = 0; i < n; i++) {
  32. for (int j = 0; j < a[i].length; j++)
  33. x[i][j] = (byte) (predata.sscot_r[i][j] ^ a[i][j]);
  34. e[i] = Util.xor(G.compute(predata.sscot_F_k.compute(x[i])), m[i]);
  35. v[i] = predata.sscot_F_kprime.compute(x[i]);
  36. }
  37. timer.start(pid, M.online_write);
  38. con2.write(pid, e);
  39. con2.write(pid, v);
  40. timer.stop(pid, M.online_write);
  41. timer.stop(pid, M.online_comp);
  42. }
  43. public void runD(PreData predata, byte[][] b, Timer timer) {
  44. timer.start(pid, M.online_comp);
  45. // step 2
  46. int n = b.length;
  47. byte[][] y = predata.sscot_r;
  48. byte[][] p = new byte[n][];
  49. byte[][] w = new byte[n][];
  50. for (int i = 0; i < n; i++) {
  51. for (int j = 0; j < b[i].length; j++)
  52. y[i][j] = (byte) (predata.sscot_r[i][j] ^ b[i][j]);
  53. p[i] = predata.sscot_F_k.compute(y[i]);
  54. w[i] = predata.sscot_F_kprime.compute(y[i]);
  55. }
  56. timer.start(pid, M.online_write);
  57. con2.write(pid, p);
  58. con2.write(pid, w);
  59. timer.stop(pid, M.online_write);
  60. timer.stop(pid, M.online_comp);
  61. }
  62. public OutSSCOT runC(Timer timer) {
  63. timer.start(pid, M.online_comp);
  64. // step 1
  65. timer.start(pid, M.online_read);
  66. byte[][] e = con1.readDoubleByteArray();
  67. byte[][] v = con1.readDoubleByteArray();
  68. // step 2
  69. byte[][] p = con2.readDoubleByteArray();
  70. byte[][] w = con2.readDoubleByteArray();
  71. timer.stop(pid, M.online_read);
  72. // step 3
  73. int n = e.length;
  74. int l = e[0].length * 8;
  75. PRG G = new PRG(l);
  76. OutSSCOT output = null;
  77. int invariant = 0;
  78. for (int i = 0; i < n; i++) {
  79. if (Util.equal(v[i], w[i])) {
  80. byte[] m = Util.xor(e[i], G.compute(p[i]));
  81. output = new OutSSCOT(i, m);
  82. invariant++;
  83. }
  84. }
  85. if (invariant != 1)
  86. throw new SSCOTException("Invariant error: " + invariant);
  87. timer.stop(pid, M.online_comp);
  88. return output;
  89. }
  90. // for testing correctness
  91. @Override
  92. public void run(Party party, Metadata md, Forest forest) {
  93. Timer timer = new Timer();
  94. for (int j = 0; j < 100; j++) {
  95. int n = 100;
  96. int A = 32;
  97. int FN = 5;
  98. byte[][] m = new byte[n][A];
  99. byte[][] a = new byte[n][FN];
  100. byte[][] b = new byte[n][FN];
  101. for (int i = 0; i < n; i++) {
  102. Crypto.sr.nextBytes(m[i]);
  103. Crypto.sr.nextBytes(a[i]);
  104. Crypto.sr.nextBytes(b[i]);
  105. while (Util.equal(a[i], b[i]))
  106. Crypto.sr.nextBytes(b[i]);
  107. }
  108. int index = Crypto.sr.nextInt(n);
  109. b[index] = a[index].clone();
  110. PreData predata = new PreData();
  111. PreSSCOT presscot = new PreSSCOT(con1, con2);
  112. if (party == Party.Eddie) {
  113. con1.write(b);
  114. con2.write(m);
  115. con2.write(index);
  116. presscot.runE(predata, n, timer);
  117. runE(predata, m, a, timer);
  118. } else if (party == Party.Debbie) {
  119. b = con1.readDoubleByteArray();
  120. presscot.runD(predata, timer);
  121. runD(predata, b, timer);
  122. } else if (party == Party.Charlie) {
  123. m = con1.readDoubleByteArray();
  124. index = con1.readInt();
  125. presscot.runC();
  126. OutSSCOT output = runC(timer);
  127. if (output.t == index && Util.equal(output.m_t, m[index]))
  128. System.out.println("SSCOT test passed");
  129. else
  130. System.err.println("SSCOT test failed");
  131. } else {
  132. throw new NoSuchPartyException(party + "");
  133. }
  134. }
  135. // timer.print();
  136. }
  137. }