IntegerLib.java 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. // Copyright (C) 2014 by Xiao Shaun Wang <wangxiao@cs.umd.edu>, Yan Huang <yhuang@cs.umd.edu> and Kartik Nayak <kartik@cs.umd.edu>
  2. package com.oblivm.backend.circuits.arithmetic;
  3. import java.util.Arrays;
  4. import com.oblivm.backend.circuits.CircuitLib;
  5. import com.oblivm.backend.flexsc.CompEnv;
  6. import com.oblivm.backend.util.Utils;
  7. public class IntegerLib<T> extends CircuitLib<T> implements ArithmeticLib<T> {
  8. public int width;
  9. public IntegerLib(CompEnv<T> e) {
  10. super(e);
  11. width = 32;
  12. }
  13. public IntegerLib(CompEnv<T> e, int width) {
  14. super(e);
  15. this.width = width;
  16. }
  17. static final int S = 0;
  18. static final int COUT = 1;
  19. public T[] publicValue(double v) {
  20. int intv = (int) v;
  21. return toSignals(intv, width);
  22. }
  23. // full 1-bit adder
  24. protected T[] add(T x, T y, T cin) {
  25. T[] res = env.newTArray(2);
  26. T t1 = xor(x, cin);
  27. T t2 = xor(y, cin);
  28. res[S] = xor(x, t2);
  29. t1 = and(t1, t2);
  30. res[COUT] = xor(cin, t1);
  31. return res;
  32. }
  33. // full n-bit adder
  34. public T[] addFull(T[] x, T[] y, boolean cin) {
  35. assert (x != null && y != null && x.length == y.length) : "add: bad inputs.";
  36. T[] res = env.newTArray(x.length + 1);
  37. T[] t = add(x[0], y[0], env.newT(cin));
  38. res[0] = t[S];
  39. for (int i = 0; i < x.length - 1; i++) {
  40. t = add(x[i + 1], y[i + 1], t[COUT]);
  41. res[i + 1] = t[S];
  42. }
  43. res[res.length - 1] = t[COUT];
  44. return res;
  45. }
  46. public T[] addFull(T[] x, T[] y) {
  47. return addFull(x, y, false);
  48. }
  49. public T[] add(T[] x, T[] y, boolean cin) {
  50. return Arrays.copyOf(addFull(x, y, cin), x.length);
  51. }
  52. public T[] add(T[] x, T[] y) {
  53. return add(x, y, false);
  54. }
  55. public T[] sub(T x, T y) throws Exception {
  56. T[] ax = env.newTArray(2);
  57. ax[1] = SIGNAL_ZERO;
  58. ax[0] = x;
  59. T[] ay = env.newTArray(2);
  60. ay[1] = SIGNAL_ZERO;
  61. ay[0] = y;
  62. return sub(x, y);
  63. }
  64. public T[] sub(T[] x, T[] y) {
  65. assert (x != null && y != null && x.length == y.length) : "sub: bad inputs.";
  66. return add(x, not(y), true);
  67. }
  68. public T[] incrementByOne(T[] x) {
  69. T[] one = zeros(x.length);
  70. one[0] = SIGNAL_ONE;
  71. return add(x, one);
  72. }
  73. public T[] decrementByOne(T[] x) {
  74. T[] one = zeros(x.length);
  75. one[0] = SIGNAL_ONE;
  76. return sub(x, one);
  77. }
  78. public T[] conditionalIncreament(T[] x, T flag) {
  79. T[] one = zeros(x.length);
  80. one[0] = mux(SIGNAL_ZERO, SIGNAL_ONE, flag);
  81. return add(x, one);
  82. }
  83. public T[] conditionalDecrement(T[] x, T flag) {
  84. T[] one = zeros(x.length);
  85. one[0] = mux(SIGNAL_ZERO, SIGNAL_ONE, flag);
  86. return sub(x, one);
  87. }
  88. public T less(T[] x, T[] y) {
  89. assert (x.length == y.length) : "bad input";
  90. T[] result = sub(x, y);
  91. return result[result.length - 1];
  92. }
  93. public T greater(T[] x, T[] y) {
  94. return less(y, x);
  95. }
  96. public T geq(T[] x, T[] y) {
  97. assert (x.length == y.length) : "bad input";
  98. T[] result = sub(x, y);
  99. return not(result[result.length - 1]);
  100. }
  101. public T leq(T[] x, T[] y) {
  102. return geq(y, x);
  103. }
  104. public T[] multiply(T[] x, T[] y) {
  105. return Arrays.copyOf(multiplyInternal(x, y), x.length);// res;
  106. }
  107. // This multiplication does not truncate the length of x and y
  108. public T[] multiplyFull(T[] x, T[] y) {
  109. return multiplyInternal(x, y);
  110. }
  111. private T[] multiplyInternal(T[] x, T[] y) {
  112. // return karatsubaMultiply(x,y);
  113. assert (x != null && y != null) : "multiply: bad inputs";
  114. T[] res = zeros(x.length + y.length);
  115. T[] zero = zeros(x.length);
  116. T[] toAdd = mux(zero, x, y[0]);
  117. System.arraycopy(toAdd, 0, res, 0, toAdd.length);
  118. for (int i = 1; i < y.length; ++i) {
  119. toAdd = Arrays.copyOfRange(res, i, i + x.length);
  120. toAdd = add(toAdd, mux(zero, x, y[i]), false);
  121. System.arraycopy(toAdd, 0, res, i, toAdd.length);
  122. }
  123. return res;
  124. }
  125. public T[] absolute(T[] x) {
  126. T reachedOneSignal = SIGNAL_ZERO;
  127. T[] result = zeros(x.length);
  128. for (int i = 0; i < x.length; ++i) {
  129. T comp = eq(SIGNAL_ONE, x[i]);
  130. result[i] = xor(x[i], reachedOneSignal);
  131. reachedOneSignal = or(reachedOneSignal, comp);
  132. }
  133. return mux(x, result, x[x.length - 1]);
  134. }
  135. public T[] div(T[] x, T[] y) {
  136. T[] absoluteX = absolute(x);
  137. T[] absoluteY = absolute(y);
  138. T[] PA = divInternal(absoluteX, absoluteY);
  139. return addSign(Arrays.copyOf(PA, x.length), xor(x[x.length - 1], y[y.length - 1]));
  140. }
  141. // Restoring Division Algorithm
  142. public T[] divInternal(T[] x, T[] y) {
  143. T[] PA = zeros(x.length + y.length);
  144. T[] B = y;
  145. System.arraycopy(x, 0, PA, 0, x.length);
  146. for (int i = 0; i < x.length; ++i) {
  147. PA = leftShift(PA);
  148. T[] tempP = sub(Arrays.copyOfRange(PA, x.length, PA.length), B);
  149. PA[0] = not(tempP[tempP.length - 1]);
  150. System.arraycopy(mux(tempP, Arrays.copyOfRange(PA, x.length, PA.length), tempP[tempP.length - 1]), 0, PA,
  151. x.length, y.length);
  152. }
  153. return PA;
  154. }
  155. public T[] mod(T[] x, T[] y) {
  156. T Xneg = x[x.length - 1];
  157. T[] absoluteX = absolute(x);
  158. T[] absoluteY = absolute(y);
  159. T[] PA = divInternal(absoluteX, absoluteY);
  160. T[] res = Arrays.copyOfRange(PA, y.length, PA.length);
  161. return mux(res, sub(toSignals(0, res.length), res), Xneg);
  162. }
  163. public T[] addSign(T[] x, T sign) {
  164. T[] reachedOneSignal = zeros(x.length);
  165. T[] result = env.newTArray(x.length);
  166. for (int i = 0; i < x.length - 1; ++i) {
  167. reachedOneSignal[i + 1] = or(reachedOneSignal[i], x[i]);
  168. result[i] = xor(x[i], reachedOneSignal[i]);
  169. }
  170. result[x.length - 1] = xor(x[x.length - 1], reachedOneSignal[x.length - 1]);
  171. return mux(x, result, sign);
  172. }
  173. public T[] commonPrefix(T[] x, T[] y) {
  174. assert (x != null && y != null) : "multiply: bad inputs";
  175. T[] result = xor(x, y);
  176. for (int i = x.length - 2; i >= 0; --i) {
  177. result[i] = or(result[i], result[i + 1]);
  178. }
  179. return result;
  180. }
  181. public T[] leadingZeros(T[] x) {
  182. assert (x != null) : "leading zeros: bad inputs";
  183. T[] result = Arrays.copyOf(x, x.length);
  184. for (int i = result.length - 2; i >= 0; --i) {
  185. result[i] = or(result[i], result[i + 1]);
  186. }
  187. return numberOfOnes(not(result));
  188. }
  189. public T[] lengthOfCommenPrefix(T[] x, T[] y) {
  190. assert (x != null) : "lengthOfCommenPrefix : bad inputs";
  191. return leadingZeros(xor(x, y));
  192. }
  193. /*
  194. * Integer manipulation
  195. */
  196. public T[] leftShift(T[] x) {
  197. assert (x != null) : "leftShift: bad inputs";
  198. return leftPublicShift(x, 1);
  199. }
  200. public T[] rightShift(T[] x) {
  201. assert (x != null) : "rightShift: bad inputs";
  202. return rightPublicShift(x, 1);
  203. }
  204. public T[] leftPublicShift(T[] x, int s) {
  205. assert (x != null && s < x.length) : "leftshift: bad inputs";
  206. T res[] = env.newTArray(x.length);
  207. System.arraycopy(zeros(s), 0, res, 0, s);
  208. System.arraycopy(x, 0, res, s, x.length - s);
  209. return res;
  210. }
  211. public T[] rightPublicShift(T[] x, int s) {
  212. assert (x != null && s < x.length) : "rightshift: bad inputs";
  213. T[] res = env.newTArray(x.length);
  214. System.arraycopy(x, s, res, 0, x.length - s);
  215. System.arraycopy(zeros(s), 0, res, x.length - s, s);
  216. return res;
  217. }
  218. public T[] conditionalLeftPublicShift(T[] x, int s, T sign) {
  219. assert (x != null && s < x.length) : "leftshift: bad inputs";
  220. T[] res = env.newTArray(x.length);
  221. System.arraycopy(mux(Arrays.copyOfRange(x, 0, s), zeros(s), sign), 0, res, 0, s);
  222. System.arraycopy(mux(Arrays.copyOfRange(x, s, x.length), Arrays.copyOfRange(x, 0, x.length), sign), 0, res, s,
  223. x.length - s);
  224. return res;
  225. }
  226. public T[] conditionalRightPublicShift(T[] x, int s, T sign) {
  227. assert (x != null && s < x.length) : "rightshift: bad inputs";
  228. T res[] = env.newTArray(x.length);
  229. System.arraycopy(mux(Arrays.copyOfRange(x, 0, x.length - s), Arrays.copyOfRange(x, s, x.length), sign), 0, res,
  230. 0, x.length - s);
  231. System.arraycopy(mux(Arrays.copyOfRange(x, x.length - s, x.length), zeros(s), sign), 0, res, x.length - s, s);
  232. return res;
  233. }
  234. public T[] leftPrivateShift(T[] x, T[] lengthToShift) {
  235. T[] res = Arrays.copyOf(x, x.length);
  236. for (int i = 0; ((1 << i) < x.length) && i < lengthToShift.length; ++i)
  237. res = conditionalLeftPublicShift(res, (1 << i), lengthToShift[i]);
  238. T clear = SIGNAL_ZERO;
  239. for (int i = 0; i < lengthToShift.length; ++i) {
  240. if ((1 << i) >= x.length)
  241. clear = or(clear, lengthToShift[i]);
  242. }
  243. return mux(res, zeros(x.length), clear);
  244. }
  245. public T[] rightPrivateShift(T[] x, T[] lengthToShift) {
  246. T[] res = Arrays.copyOf(x, x.length);
  247. for (int i = 0; ((1 << i) < x.length) && i < lengthToShift.length; ++i)
  248. res = conditionalRightPublicShift(res, (1 << i), lengthToShift[i]);
  249. T clear = SIGNAL_ZERO;
  250. for (int i = 0; i < lengthToShift.length; ++i) {
  251. if ((1 << i) >= x.length)
  252. clear = or(clear, lengthToShift[i]);
  253. }
  254. return mux(res, zeros(x.length), clear);
  255. }
  256. T compare(T x, T y, T cin) {
  257. T t1 = xor(x, cin);
  258. T t2 = xor(y, cin);
  259. t1 = and(t1, t2);
  260. return xor(x, t1);
  261. }
  262. public T compare(T[] x, T[] y) {
  263. assert (x != null && y != null && x.length == y.length) : "compare: bad inputs.";
  264. T t = env.newT(false);
  265. for (int i = 0; i < x.length; i++) {
  266. t = compare(x[i], y[i], t);
  267. }
  268. return t;
  269. }
  270. public T eq(T x, T y) {
  271. assert (x != null && y != null) : "CircuitLib.eq: bad inputs";
  272. return not(xor(x, y));
  273. }
  274. public T eq(T[] x, T[] y) {
  275. assert (x != null && y != null && x.length == y.length) : "CircuitLib.eq[]: bad inputs.";
  276. T res = env.newT(true);
  277. for (int i = 0; i < x.length; i++) {
  278. T t = eq(x[i], y[i]);
  279. res = env.and(res, t);
  280. }
  281. return res;
  282. }
  283. public T[] twosComplement(T[] x) {
  284. T reachOne = SIGNAL_ZERO;
  285. T[] result = env.newTArray(x.length);
  286. for (int i = 0; i < x.length; ++i) {
  287. result[i] = xor(x[i], reachOne);
  288. reachOne = or(reachOne, x[i]);
  289. }
  290. return result;
  291. }
  292. public T[] hammingDistance(T[] x, T[] y) {
  293. T[] a = xor(x, y);
  294. return numberOfOnes(a);
  295. }
  296. public T[] numberOfOnes(T[] t) {
  297. if (t.length == 0) {
  298. // T[] res = env.newTArray(1);
  299. // res[0] = SIGNAL_ZERO;
  300. // return res;
  301. return zeros(2);
  302. }
  303. if (t.length == 1) {
  304. // return t;
  305. return padSignal(t, 2);
  306. } else {
  307. int length = 1;
  308. int w = 1;
  309. while (length <= t.length) {
  310. length <<= 1;
  311. w++;
  312. }
  313. length >>= 1;
  314. T[] res1 = numberOfOnesN(Arrays.copyOfRange(t, 0, length));
  315. T[] res2 = numberOfOnes(Arrays.copyOfRange(t, length, t.length));
  316. return add(padSignal(res1, w), padSignal(res2, w));
  317. }
  318. }
  319. public T[] numberOfOnesN(T[] res) {
  320. if (res.length == 1)
  321. return res;
  322. T[] left = numberOfOnesN(Arrays.copyOfRange(res, 0, res.length / 2));
  323. T[] right = numberOfOnesN(Arrays.copyOfRange(res, res.length / 2, res.length));
  324. return unSignedAdd(left, right);
  325. }
  326. public T[] unSignedAdd(T[] x, T[] y) {
  327. assert (x != null && y != null && x.length == y.length) : "add: bad inputs.";
  328. T[] res = env.newTArray(x.length + 1);
  329. T[] t = add(x[0], y[0], env.newT(false));
  330. res[0] = t[S];
  331. for (int i = 0; i < x.length - 1; i++) {
  332. t = add(x[i + 1], y[i + 1], t[COUT]);
  333. res[i + 1] = t[S];
  334. }
  335. res[res.length - 1] = t[COUT];
  336. return res;
  337. }
  338. public T[] unSignedMultiply(T[] x, T[] y) {
  339. assert (x != null && y != null) : "multiply: bad inputs";
  340. T[] res = zeros(x.length + y.length);
  341. T[] zero = zeros(x.length);
  342. T[] toAdd = mux(zero, x, y[0]);
  343. System.arraycopy(toAdd, 0, res, 0, toAdd.length);
  344. for (int i = 1; i < y.length; ++i) {
  345. toAdd = Arrays.copyOfRange(res, i, i + x.length);
  346. toAdd = unSignedAdd(toAdd, mux(zero, x, y[i]));
  347. System.arraycopy(toAdd, 0, res, i, toAdd.length);
  348. }
  349. return res;
  350. }
  351. public T[] karatsubaMultiply(T[] x, T[] y) {
  352. if (x.length <= 18)
  353. return unSignedMultiply(x, y);
  354. int length = (x.length + y.length);
  355. T[] xlo = Arrays.copyOfRange(x, 0, x.length / 2);
  356. T[] xhi = Arrays.copyOfRange(x, x.length / 2, x.length);
  357. T[] ylo = Arrays.copyOfRange(y, 0, y.length / 2);
  358. T[] yhi = Arrays.copyOfRange(y, y.length / 2, y.length);
  359. int nextlength = Math.max(x.length / 2, x.length - x.length / 2);
  360. xlo = padSignal(xlo, nextlength);
  361. xhi = padSignal(xhi, nextlength);
  362. ylo = padSignal(ylo, nextlength);
  363. yhi = padSignal(yhi, nextlength);
  364. T[] z0 = karatsubaMultiply(xlo, ylo);
  365. T[] z2 = karatsubaMultiply(xhi, yhi);
  366. T[] z1 = sub(padSignal(karatsubaMultiply(unSignedAdd(xlo, xhi), unSignedAdd(ylo, yhi)), 2 * nextlength + 2),
  367. padSignal(unSignedAdd(padSignal(z2, 2 * nextlength), padSignal(z0, 2 * nextlength)),
  368. 2 * nextlength + 2));
  369. z1 = padSignal(z1, length);
  370. z1 = leftPublicShift(z1, x.length / 2);
  371. T[] z0Pad = padSignal(z0, length);
  372. T[] z2Pad = padSignal(z2, length);
  373. z2Pad = leftPublicShift(z2Pad, 2 * (x.length / 2));
  374. return add(add(z0Pad, z1), z2Pad);
  375. }
  376. public T[] min(T[] x, T[] y) {
  377. T leq = leq(x, y);
  378. return mux(y, x, leq);
  379. }
  380. public T[] sqrt(T[] a) {
  381. int newLength = a.length;
  382. if (newLength % 2 == 1)
  383. newLength++;
  384. T[] x = padSignal(a, newLength);
  385. T[] rem = zeros(x.length);
  386. T[] root = zeros(x.length);
  387. for (int i = 0; i < x.length / 2; i++) {
  388. root = leftShift(root);
  389. rem = add(leftPublicShift(rem, 2), rightPublicShift(x, x.length - 2));
  390. x = leftPublicShift(x, 2);
  391. T[] oldRoot = root;
  392. root = copy(root);
  393. root[0] = SIGNAL_ONE;
  394. T[] remMinusRoot = sub(rem, root);
  395. T isRootSmaller = not(remMinusRoot[remMinusRoot.length - 1]);
  396. rem = mux(rem, remMinusRoot, isRootSmaller);
  397. root = mux(oldRoot, incrementByOne(root), isRootSmaller);
  398. }
  399. return padSignal(rightShift(root), a.length);
  400. }
  401. public T[] inputOfAlice(double d) {
  402. return env.inputOfAlice(Utils.fromLong((long) d, width));
  403. }
  404. public T[] inputOfBob(double d) {
  405. return env.inputOfBob(Utils.fromLong((long) d, width));
  406. }
  407. @Override
  408. public CompEnv<T> getEnv() {
  409. return env;
  410. }
  411. @Override
  412. public T[] toSecureInt(T[] a, IntegerLib<T> lib) {
  413. return a;
  414. }
  415. // not fully implemented, more cases to consider
  416. @Override
  417. public T[] toSecureFloat(T[] a, FloatLib<T> lib) {
  418. T[] v = padSignal(a, lib.VLength);
  419. T[] p = leadingZeros(v);
  420. v = leftPrivateShift(v, p);
  421. p = padSignal(p, lib.PLength);
  422. p = sub(zeros(p.length), p);
  423. return lib.pack(new FloatLib.Representation<T>(SIGNAL_ZERO, v, p));
  424. }
  425. @Override
  426. public T[] toSecureFixPoint(T[] a, FixedPointLib<T> lib) {
  427. return leftPublicShift(padSignal(a, lib.width), lib.offset);
  428. }
  429. @Override
  430. public double outputToAlice(T[] a) {
  431. return Utils.toInt(env.outputToAlice(a));
  432. }
  433. @Override
  434. public int numBits() {
  435. return width;
  436. }
  437. }