CircuitLib.java 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. // Copyright (C) 2014 by Xiao Shaun Wang <wangxiao@cs.umd.edu>
  2. package com.oblivm.backend.circuits;
  3. import java.util.Arrays;
  4. import com.oblivm.backend.flexsc.CompEnv;
  5. import com.oblivm.backend.flexsc.Mode;
  6. import com.oblivm.backend.flexsc.Party;
  7. import com.oblivm.backend.gc.GCSignal;
  8. public class CircuitLib<T> {
  9. public CompEnv<T> env;
  10. public final T SIGNAL_ZERO;
  11. public final T SIGNAL_ONE;
  12. public CircuitLib(CompEnv<T> e) {
  13. env = e;
  14. SIGNAL_ZERO = e.ZERO();
  15. SIGNAL_ONE = e.ONE();
  16. }
  17. public T[] toSignals(long a, int width) {
  18. T[] result = env.newTArray(width);
  19. for (int i = 0; i < width; ++i) {
  20. if ((a & 1) == 1)
  21. result[i] = SIGNAL_ONE;
  22. else
  23. result[i] = SIGNAL_ZERO;
  24. a >>= 1;
  25. }
  26. return result;
  27. }
  28. public T[] enforceBits(T[] a, int length) {
  29. if (length > a.length)
  30. return padSignal(a, length);
  31. else
  32. return Arrays.copyOfRange(a, 0, length);
  33. }
  34. public T enforceBits(T a) {
  35. if (a == null)
  36. return SIGNAL_ZERO;
  37. else
  38. return a;
  39. }
  40. public T[] enforceBits(T a, int length) {
  41. T[] ret = env.newTArray(length);
  42. if (a == null)
  43. ret[0] = SIGNAL_ZERO;
  44. else
  45. ret[0] = a;
  46. for (int i = 1; i < ret.length; ++i) {
  47. ret[i] = SIGNAL_ZERO;
  48. }
  49. return ret;
  50. }
  51. public T[] randBools(int length) {
  52. if (env.getMode() == Mode.COUNT) {
  53. return zeros(length);
  54. }
  55. boolean[] res = new boolean[length];
  56. for (int i = 0; i < length; ++i)
  57. res[i] = CompEnv.rnd.nextBoolean();
  58. T[] alice = env.inputOfAlice(res);
  59. T[] bob = env.inputOfBob(res);
  60. T[] resSC = xor(alice, bob);
  61. return resSC;
  62. }
  63. public boolean[] declassifyToAlice(T[] x) {
  64. return env.outputToAlice(x);
  65. }
  66. public boolean[] declassifyToBob(T[] x) {
  67. return env.outputToBob(x);
  68. }
  69. public boolean[] declassifyToBoth2(T[] x) {
  70. if (env.getMode() == Mode.COUNT) {
  71. return new boolean[x.length];
  72. }
  73. boolean[] pos = env.outputToBob(x);
  74. if (env.getParty() == Party.Bob) {
  75. byte[] tmp = new byte[pos.length];
  76. for (int i = 0; i < pos.length; ++i)
  77. tmp[i] = (byte) (pos[i] ? 1 : 0);
  78. env.channel.writeByte(tmp, tmp.length);
  79. env.flush();
  80. } else {
  81. byte tmp[] = env.channel.readBytes(x.length);
  82. pos = new boolean[x.length];
  83. for (int k = 0; k < tmp.length; ++k) {
  84. pos[k] = ((tmp[k] - 1) == 0);
  85. }
  86. }
  87. return pos;
  88. }
  89. public boolean[] declassifyToBoth(T[] x) {
  90. if (env.getMode() == Mode.COUNT) {
  91. return new boolean[x.length];
  92. } else if (env.getMode() == Mode.VERIFY) {
  93. return com.oblivm.backend.util.Utils.tobooleanArray((Boolean[]) x);
  94. } else {
  95. GCSignal[] in = (GCSignal[]) x;
  96. boolean[] pos = new boolean[x.length];
  97. GCSignal tmp;
  98. for (int i = 0; i < x.length; ++i) {
  99. if (in[i].isPublic()) {
  100. pos[i] = in[i].v;
  101. } else {
  102. in[i].send(env.channel);
  103. }
  104. }
  105. env.channel.flush();
  106. for (int i = 0; i < x.length; ++i) {
  107. if (!in[i].isPublic()) {
  108. tmp = GCSignal.receive(env.channel);
  109. if (tmp.equals(in[i]))
  110. pos[i] = false;
  111. else
  112. pos[i] = true;
  113. }
  114. }
  115. return pos;
  116. }
  117. }
  118. // Defaults to 32 bit constants.
  119. public T[] toSignals(int value) {
  120. return toSignals(value, 32);
  121. }
  122. public GCSignal[] toSignals(GCSignal[] value) {
  123. return value;
  124. }
  125. public T[] zeros(int length) {
  126. T[] result = env.newTArray(length);
  127. for (int i = 0; i < length; ++i) {
  128. result[i] = SIGNAL_ZERO;
  129. }
  130. return result;
  131. }
  132. public T[] ones(int length) {
  133. T[] result = env.newTArray(length);
  134. for (int i = 0; i < length; ++i) {
  135. result[i] = SIGNAL_ONE;
  136. }
  137. return result;
  138. }
  139. /*
  140. * Basic logical operations on Signal and Signal[]
  141. */
  142. public T and(T x, T y) {
  143. assert (x != null && y != null) : "CircuitLib.and: bad inputs";
  144. return env.and(x, y);
  145. }
  146. public T[] and(T[] x, T[] y) {
  147. assert (x != null && y != null && x.length == y.length) : "CircuitLib.and[]: bad inputs";
  148. T[] result = env.newTArray(x.length);
  149. for (int i = 0; i < x.length; ++i) {
  150. result[i] = and(x[i], y[i]);
  151. }
  152. return result;
  153. }
  154. public T xor(T x, T y) {
  155. assert (x != null && y != null) : "CircuitLib.xor: bad inputs";
  156. return env.xor(x, y);
  157. }
  158. public T[] xor(T[] x, T[] y) {
  159. assert (x != null && y != null && x.length == y.length) : "CircuitLib.xor[]: bad inputs";
  160. T[] result = env.newTArray(x.length);
  161. for (int i = 0; i < x.length; ++i) {
  162. result[i] = xor(x[i], y[i]);
  163. }
  164. return result;
  165. }
  166. public T not(T x) {
  167. assert (x != null) : "CircuitLib.not: bad input";
  168. return env.xor(x, SIGNAL_ONE);
  169. }
  170. public T[] not(T[] x) {
  171. assert (x != null) : "CircuitLib.not[]: bad input";
  172. T[] result = env.newTArray(x.length);
  173. for (int i = 0; i < x.length; ++i) {
  174. result[i] = not(x[i]);
  175. }
  176. return result;
  177. }
  178. public T or(T x, T y) {
  179. assert (x != null && y != null) : "CircuitLib.or: bad inputs";
  180. return xor(xor(x, y), and(x, y)); // http://stackoverflow.com/a/2443029
  181. }
  182. public T[] or(T[] x, T[] y) {
  183. assert (x != null && y != null && x.length == y.length) : "CircuitLib.or[]: bad inputs";
  184. T[] result = env.newTArray(x.length);
  185. for (int i = 0; i < x.length; ++i) {
  186. result[i] = or(x[i], y[i]);
  187. }
  188. return result;
  189. }
  190. /*
  191. * Output x when c == 0; Otherwise output y.
  192. */
  193. public T mux(T x, T y, T c) {
  194. assert (x != null && y != null && c != null) : "CircuitLib.mux: bad inputs";
  195. T t = xor(x, y);
  196. t = and(t, c);
  197. T ret = xor(t, x);
  198. return ret;
  199. }
  200. public T[] mux(T[] x, T[] y, T c) {
  201. assert (x != null && y != null && x.length == y.length) : "CircuitLib.mux[]: bad inputs";
  202. T[] ret = env.newTArray(x.length);
  203. for (int i = 0; i < x.length; i++)
  204. ret[i] = mux(x[i], y[i], c);
  205. return ret;
  206. }
  207. public T[][] mux(T[][] x, T[][] y, T c) {
  208. assert (x != null && y != null && x.length == y.length) : "CircuitLib.mux[][]: bad inputs";
  209. T[][] ret = env.newTArray(x.length, 1);
  210. for (int i = 0; i < x.length; i++)
  211. ret[i] = mux(x[i], y[i], c);
  212. return ret;
  213. }
  214. public T[][][] mux(T[][][] x, T[][][] y, T c) {
  215. assert (x != null && y != null && x.length == y.length) : "CircuitLib.mux[]: bad inputs";
  216. T[][][] ret = env.newTArray(x.length, 1, 1);
  217. for (int i = 0; i < x.length; i++)
  218. ret[i] = mux(x[i], y[i], c);
  219. return ret;
  220. }
  221. public T[] padSignal(T[] a, int length) {
  222. T[] res = zeros(length);
  223. for (int i = 0; i < a.length && i < length; ++i)
  224. res[i] = a[i];
  225. return res;
  226. }
  227. public T[] padSignedSignal(T[] a, int length) {
  228. T[] res = env.newTArray(length);
  229. for (int i = 0; i < a.length && i < length; ++i)
  230. res[i] = a[i];
  231. for (int i = a.length; i < length; ++i)
  232. res[i] = a[a.length - 1];
  233. return res;
  234. }
  235. public T[] copy(T[] x) {
  236. return Arrays.copyOf(x, x.length);
  237. }
  238. }