IntegerLib.java 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. // Copyright (C) 2014 by Xiao Shaun Wang <wangxiao@cs.umd.edu>, Yan Huang <yhuang@cs.umd.edu> and Kartik Nayak <kartik@cs.umd.edu>
  2. package com.oblivm.backend.circuits.arithmetic;
  3. import java.util.Arrays;
  4. import com.oblivm.backend.circuits.CircuitLib;
  5. import com.oblivm.backend.flexsc.CompEnv;
  6. import com.oblivm.backend.util.Utils;
  7. public class IntegerLib<T> extends CircuitLib<T> implements ArithmeticLib<T> {
  8. public int width;
  9. public IntegerLib(CompEnv<T> e) {
  10. super(e);
  11. width = 32;
  12. }
  13. public IntegerLib(CompEnv<T> e, int width) {
  14. super(e);
  15. this.width = width;
  16. }
  17. static final int S = 0;
  18. static final int COUT = 1;
  19. public T[] publicValue(double v) {
  20. int intv = (int) v;
  21. return toSignals(intv, width);
  22. }
  23. // full 1-bit adder
  24. protected T[] add(T x, T y, T cin) {
  25. T[] res = env.newTArray(2);
  26. T t1 = xor(x, cin);
  27. T t2 = xor(y, cin);
  28. res[S] = xor(x, t2);
  29. t1 = and(t1, t2);
  30. res[COUT] = xor(cin, t1);
  31. return res;
  32. }
  33. // full n-bit adder
  34. public T[] addFull(T[] x, T[] y, boolean cin) {
  35. assert (x != null && y != null && x.length == y.length) : "add: bad inputs.";
  36. T[] res = env.newTArray(x.length + 1);
  37. T[] t = add(x[0], y[0], env.newT(cin));
  38. res[0] = t[S];
  39. for (int i = 0; i < x.length - 1; i++) {
  40. t = add(x[i + 1], y[i + 1], t[COUT]);
  41. res[i + 1] = t[S];
  42. }
  43. res[res.length - 1] = t[COUT];
  44. return res;
  45. }
  46. public T[] add(T[] x, T[] y, boolean cin) {
  47. return Arrays.copyOf(addFull(x, y, cin), x.length);
  48. }
  49. public T[] add(T[] x, T[] y) {
  50. return add(x, y, false);
  51. }
  52. public T[] sub(T x, T y) throws Exception {
  53. T[] ax = env.newTArray(2);
  54. ax[1] = SIGNAL_ZERO;
  55. ax[0] = x;
  56. T[] ay = env.newTArray(2);
  57. ay[1] = SIGNAL_ZERO;
  58. ay[0] = y;
  59. return sub(x, y);
  60. }
  61. public T[] sub(T[] x, T[] y) {
  62. assert (x != null && y != null && x.length == y.length) : "sub: bad inputs.";
  63. return add(x, not(y), true);
  64. }
  65. public T[] incrementByOne(T[] x) {
  66. T[] one = zeros(x.length);
  67. one[0] = SIGNAL_ONE;
  68. return add(x, one);
  69. }
  70. public T[] decrementByOne(T[] x) {
  71. T[] one = zeros(x.length);
  72. one[0] = SIGNAL_ONE;
  73. return sub(x, one);
  74. }
  75. public T[] conditionalIncreament(T[] x, T flag) {
  76. T[] one = zeros(x.length);
  77. one[0] = mux(SIGNAL_ZERO, SIGNAL_ONE, flag);
  78. return add(x, one);
  79. }
  80. public T[] conditionalDecrement(T[] x, T flag) {
  81. T[] one = zeros(x.length);
  82. one[0] = mux(SIGNAL_ZERO, SIGNAL_ONE, flag);
  83. return sub(x, one);
  84. }
  85. public T geq(T[] x, T[] y) {
  86. assert (x.length == y.length) : "bad input";
  87. T[] result = sub(x, y);
  88. return not(result[result.length - 1]);
  89. }
  90. public T leq(T[] x, T[] y) {
  91. return geq(y, x);
  92. }
  93. public T greater(T[] x, T[] y) {
  94. return not(leq(x, y));
  95. }
  96. public T less(T[] x, T[] y) {
  97. return not(geq(x, y));
  98. }
  99. public T[] multiply(T[] x, T[] y) {
  100. return Arrays.copyOf(multiplyInternal(x, y), x.length);// res;
  101. }
  102. // This multiplication does not truncate the length of x and y
  103. public T[] multiplyFull(T[] x, T[] y) {
  104. return multiplyInternal(x, y);
  105. }
  106. private T[] multiplyInternal(T[] x, T[] y) {
  107. // return karatsubaMultiply(x,y);
  108. assert (x != null && y != null) : "multiply: bad inputs";
  109. T[] res = zeros(x.length + y.length);
  110. T[] zero = zeros(x.length);
  111. T[] toAdd = mux(zero, x, y[0]);
  112. System.arraycopy(toAdd, 0, res, 0, toAdd.length);
  113. for (int i = 1; i < y.length; ++i) {
  114. toAdd = Arrays.copyOfRange(res, i, i + x.length);
  115. toAdd = add(toAdd, mux(zero, x, y[i]), false);
  116. System.arraycopy(toAdd, 0, res, i, toAdd.length);
  117. }
  118. return res;
  119. }
  120. public T[] absolute(T[] x) {
  121. T reachedOneSignal = SIGNAL_ZERO;
  122. T[] result = zeros(x.length);
  123. for (int i = 0; i < x.length; ++i) {
  124. T comp = eq(SIGNAL_ONE, x[i]);
  125. result[i] = xor(x[i], reachedOneSignal);
  126. reachedOneSignal = or(reachedOneSignal, comp);
  127. }
  128. return mux(x, result, x[x.length - 1]);
  129. }
  130. public T[] div(T[] x, T[] y) {
  131. T[] absoluteX = absolute(x);
  132. T[] absoluteY = absolute(y);
  133. T[] PA = divInternal(absoluteX, absoluteY);
  134. return addSign(Arrays.copyOf(PA, x.length), xor(x[x.length - 1], y[y.length - 1]));
  135. }
  136. // Restoring Division Algorithm
  137. public T[] divInternal(T[] x, T[] y) {
  138. T[] PA = zeros(x.length + y.length);
  139. T[] B = y;
  140. System.arraycopy(x, 0, PA, 0, x.length);
  141. for (int i = 0; i < x.length; ++i) {
  142. PA = leftShift(PA);
  143. T[] tempP = sub(Arrays.copyOfRange(PA, x.length, PA.length), B);
  144. PA[0] = not(tempP[tempP.length - 1]);
  145. System.arraycopy(mux(tempP, Arrays.copyOfRange(PA, x.length, PA.length), tempP[tempP.length - 1]), 0, PA,
  146. x.length, y.length);
  147. }
  148. return PA;
  149. }
  150. public T[] mod(T[] x, T[] y) {
  151. T Xneg = x[x.length - 1];
  152. T[] absoluteX = absolute(x);
  153. T[] absoluteY = absolute(y);
  154. T[] PA = divInternal(absoluteX, absoluteY);
  155. T[] res = Arrays.copyOfRange(PA, y.length, PA.length);
  156. return mux(res, sub(toSignals(0, res.length), res), Xneg);
  157. }
  158. public T[] addSign(T[] x, T sign) {
  159. T[] reachedOneSignal = zeros(x.length);
  160. T[] result = env.newTArray(x.length);
  161. for (int i = 0; i < x.length - 1; ++i) {
  162. reachedOneSignal[i + 1] = or(reachedOneSignal[i], x[i]);
  163. result[i] = xor(x[i], reachedOneSignal[i]);
  164. }
  165. result[x.length - 1] = xor(x[x.length - 1], reachedOneSignal[x.length - 1]);
  166. return mux(x, result, sign);
  167. }
  168. public T[] commonPrefix(T[] x, T[] y) {
  169. assert (x != null && y != null) : "multiply: bad inputs";
  170. T[] result = xor(x, y);
  171. for (int i = x.length - 2; i >= 0; --i) {
  172. result[i] = or(result[i], result[i + 1]);
  173. }
  174. return result;
  175. }
  176. public T[] leadingZeros(T[] x) {
  177. assert (x != null) : "leading zeros: bad inputs";
  178. T[] result = Arrays.copyOf(x, x.length);
  179. for (int i = result.length - 2; i >= 0; --i) {
  180. result[i] = or(result[i], result[i + 1]);
  181. }
  182. return numberOfOnes(not(result));
  183. }
  184. public T[] lengthOfCommenPrefix(T[] x, T[] y) {
  185. assert (x != null) : "lengthOfCommenPrefix : bad inputs";
  186. return leadingZeros(xor(x, y));
  187. }
  188. /*
  189. * Integer manipulation
  190. */
  191. public T[] leftShift(T[] x) {
  192. assert (x != null) : "leftShift: bad inputs";
  193. return leftPublicShift(x, 1);
  194. }
  195. public T[] rightShift(T[] x) {
  196. assert (x != null) : "rightShift: bad inputs";
  197. return rightPublicShift(x, 1);
  198. }
  199. public T[] leftPublicShift(T[] x, int s) {
  200. assert (x != null && s < x.length) : "leftshift: bad inputs";
  201. T res[] = env.newTArray(x.length);
  202. System.arraycopy(zeros(s), 0, res, 0, s);
  203. System.arraycopy(x, 0, res, s, x.length - s);
  204. return res;
  205. }
  206. public T[] rightPublicShift(T[] x, int s) {
  207. assert (x != null && s < x.length) : "rightshift: bad inputs";
  208. T[] res = env.newTArray(x.length);
  209. System.arraycopy(x, s, res, 0, x.length - s);
  210. System.arraycopy(zeros(s), 0, res, x.length - s, s);
  211. return res;
  212. }
  213. public T[] conditionalLeftPublicShift(T[] x, int s, T sign) {
  214. assert (x != null && s < x.length) : "leftshift: bad inputs";
  215. T[] res = env.newTArray(x.length);
  216. System.arraycopy(mux(Arrays.copyOfRange(x, 0, s), zeros(s), sign), 0, res, 0, s);
  217. System.arraycopy(mux(Arrays.copyOfRange(x, s, x.length), Arrays.copyOfRange(x, 0, x.length), sign), 0, res, s,
  218. x.length - s);
  219. return res;
  220. }
  221. public T[] conditionalRightPublicShift(T[] x, int s, T sign) {
  222. assert (x != null && s < x.length) : "rightshift: bad inputs";
  223. T res[] = env.newTArray(x.length);
  224. System.arraycopy(mux(Arrays.copyOfRange(x, 0, x.length - s), Arrays.copyOfRange(x, s, x.length), sign), 0, res,
  225. 0, x.length - s);
  226. System.arraycopy(mux(Arrays.copyOfRange(x, x.length - s, x.length), zeros(s), sign), 0, res, x.length - s, s);
  227. return res;
  228. }
  229. public T[] leftPrivateShift(T[] x, T[] lengthToShift) {
  230. T[] res = Arrays.copyOf(x, x.length);
  231. for (int i = 0; ((1 << i) < x.length) && i < lengthToShift.length; ++i)
  232. res = conditionalLeftPublicShift(res, (1 << i), lengthToShift[i]);
  233. T clear = SIGNAL_ZERO;
  234. for (int i = 0; i < lengthToShift.length; ++i) {
  235. if ((1 << i) >= x.length)
  236. clear = or(clear, lengthToShift[i]);
  237. }
  238. return mux(res, zeros(x.length), clear);
  239. }
  240. public T[] rightPrivateShift(T[] x, T[] lengthToShift) {
  241. T[] res = Arrays.copyOf(x, x.length);
  242. for (int i = 0; ((1 << i) < x.length) && i < lengthToShift.length; ++i)
  243. res = conditionalRightPublicShift(res, (1 << i), lengthToShift[i]);
  244. T clear = SIGNAL_ZERO;
  245. for (int i = 0; i < lengthToShift.length; ++i) {
  246. if ((1 << i) >= x.length)
  247. clear = or(clear, lengthToShift[i]);
  248. }
  249. return mux(res, zeros(x.length), clear);
  250. }
  251. T compare(T x, T y, T cin) {
  252. T t1 = xor(x, cin);
  253. T t2 = xor(y, cin);
  254. t1 = and(t1, t2);
  255. return xor(x, t1);
  256. }
  257. public T compare(T[] x, T[] y) {
  258. assert (x != null && y != null && x.length == y.length) : "compare: bad inputs.";
  259. T t = env.newT(false);
  260. for (int i = 0; i < x.length; i++) {
  261. t = compare(x[i], y[i], t);
  262. }
  263. return t;
  264. }
  265. public T eq(T x, T y) {
  266. assert (x != null && y != null) : "CircuitLib.eq: bad inputs";
  267. return not(xor(x, y));
  268. }
  269. public T eq(T[] x, T[] y) {
  270. assert (x != null && y != null && x.length == y.length) : "CircuitLib.eq[]: bad inputs.";
  271. T res = env.newT(true);
  272. for (int i = 0; i < x.length; i++) {
  273. T t = eq(x[i], y[i]);
  274. res = env.and(res, t);
  275. }
  276. return res;
  277. }
  278. public T[] twosComplement(T[] x) {
  279. T reachOne = SIGNAL_ZERO;
  280. T[] result = env.newTArray(x.length);
  281. for (int i = 0; i < x.length; ++i) {
  282. result[i] = xor(x[i], reachOne);
  283. reachOne = or(reachOne, x[i]);
  284. }
  285. return result;
  286. }
  287. public T[] hammingDistance(T[] x, T[] y) {
  288. T[] a = xor(x, y);
  289. return numberOfOnes(a);
  290. }
  291. public T[] numberOfOnes(T[] t) {
  292. if (t.length == 0) {
  293. T[] res = env.newTArray(1);
  294. res[0] = SIGNAL_ZERO;
  295. return res;
  296. }
  297. if (t.length == 1) {
  298. return t;
  299. } else {
  300. int length = 1;
  301. int w = 1;
  302. while (length <= t.length) {
  303. length <<= 1;
  304. w++;
  305. }
  306. length >>= 1;
  307. T[] res1 = numberOfOnesN(Arrays.copyOfRange(t, 0, length));
  308. T[] res2 = numberOfOnes(Arrays.copyOfRange(t, length, t.length));
  309. return add(padSignal(res1, w), padSignal(res2, w));
  310. }
  311. }
  312. public T[] numberOfOnesN(T[] res) {
  313. if (res.length == 1)
  314. return res;
  315. T[] left = numberOfOnesN(Arrays.copyOfRange(res, 0, res.length / 2));
  316. T[] right = numberOfOnesN(Arrays.copyOfRange(res, res.length / 2, res.length));
  317. return unSignedAdd(left, right);
  318. }
  319. public T[] unSignedAdd(T[] x, T[] y) {
  320. assert (x != null && y != null && x.length == y.length) : "add: bad inputs.";
  321. T[] res = env.newTArray(x.length + 1);
  322. T[] t = add(x[0], y[0], env.newT(false));
  323. res[0] = t[S];
  324. for (int i = 0; i < x.length - 1; i++) {
  325. t = add(x[i + 1], y[i + 1], t[COUT]);
  326. res[i + 1] = t[S];
  327. }
  328. res[res.length - 1] = t[COUT];
  329. return res;
  330. }
  331. public T[] unSignedMultiply(T[] x, T[] y) {
  332. assert (x != null && y != null) : "multiply: bad inputs";
  333. T[] res = zeros(x.length + y.length);
  334. T[] zero = zeros(x.length);
  335. T[] toAdd = mux(zero, x, y[0]);
  336. System.arraycopy(toAdd, 0, res, 0, toAdd.length);
  337. for (int i = 1; i < y.length; ++i) {
  338. toAdd = Arrays.copyOfRange(res, i, i + x.length);
  339. toAdd = unSignedAdd(toAdd, mux(zero, x, y[i]));
  340. System.arraycopy(toAdd, 0, res, i, toAdd.length);
  341. }
  342. return res;
  343. }
  344. public T[] karatsubaMultiply(T[] x, T[] y) {
  345. if (x.length <= 18)
  346. return unSignedMultiply(x, y);
  347. int length = (x.length + y.length);
  348. T[] xlo = Arrays.copyOfRange(x, 0, x.length / 2);
  349. T[] xhi = Arrays.copyOfRange(x, x.length / 2, x.length);
  350. T[] ylo = Arrays.copyOfRange(y, 0, y.length / 2);
  351. T[] yhi = Arrays.copyOfRange(y, y.length / 2, y.length);
  352. int nextlength = Math.max(x.length / 2, x.length - x.length / 2);
  353. xlo = padSignal(xlo, nextlength);
  354. xhi = padSignal(xhi, nextlength);
  355. ylo = padSignal(ylo, nextlength);
  356. yhi = padSignal(yhi, nextlength);
  357. T[] z0 = karatsubaMultiply(xlo, ylo);
  358. T[] z2 = karatsubaMultiply(xhi, yhi);
  359. T[] z1 = sub(padSignal(karatsubaMultiply(unSignedAdd(xlo, xhi), unSignedAdd(ylo, yhi)), 2 * nextlength + 2),
  360. padSignal(unSignedAdd(padSignal(z2, 2 * nextlength), padSignal(z0, 2 * nextlength)),
  361. 2 * nextlength + 2));
  362. z1 = padSignal(z1, length);
  363. z1 = leftPublicShift(z1, x.length / 2);
  364. T[] z0Pad = padSignal(z0, length);
  365. T[] z2Pad = padSignal(z2, length);
  366. z2Pad = leftPublicShift(z2Pad, 2 * (x.length / 2));
  367. return add(add(z0Pad, z1), z2Pad);
  368. }
  369. public T[] min(T[] x, T[] y) {
  370. T leq = leq(x, y);
  371. return mux(y, x, leq);
  372. }
  373. public T[] sqrt(T[] a) {
  374. int newLength = a.length;
  375. if (newLength % 2 == 1)
  376. newLength++;
  377. T[] x = padSignal(a, newLength);
  378. T[] rem = zeros(x.length);
  379. T[] root = zeros(x.length);
  380. for (int i = 0; i < x.length / 2; i++) {
  381. root = leftShift(root);
  382. rem = add(leftPublicShift(rem, 2), rightPublicShift(x, x.length - 2));
  383. x = leftPublicShift(x, 2);
  384. T[] oldRoot = root;
  385. root = copy(root);
  386. root[0] = SIGNAL_ONE;
  387. T[] remMinusRoot = sub(rem, root);
  388. T isRootSmaller = not(remMinusRoot[remMinusRoot.length - 1]);
  389. rem = mux(rem, remMinusRoot, isRootSmaller);
  390. root = mux(oldRoot, incrementByOne(root), isRootSmaller);
  391. }
  392. return padSignal(rightShift(root), a.length);
  393. }
  394. public T[] inputOfAlice(double d) {
  395. return env.inputOfAlice(Utils.fromLong((long) d, width));
  396. }
  397. public T[] inputOfBob(double d) {
  398. return env.inputOfBob(Utils.fromLong((long) d, width));
  399. }
  400. @Override
  401. public CompEnv<T> getEnv() {
  402. return env;
  403. }
  404. @Override
  405. public T[] toSecureInt(T[] a, IntegerLib<T> lib) {
  406. return a;
  407. }
  408. // not fully implemented, more cases to consider
  409. @Override
  410. public T[] toSecureFloat(T[] a, FloatLib<T> lib) {
  411. T[] v = padSignal(a, lib.VLength);
  412. T[] p = leadingZeros(v);
  413. v = leftPrivateShift(v, p);
  414. p = padSignal(p, lib.PLength);
  415. p = sub(zeros(p.length), p);
  416. return lib.pack(new FloatLib.Representation<T>(SIGNAL_ZERO, v, p));
  417. }
  418. @Override
  419. public T[] toSecureFixPoint(T[] a, FixedPointLib<T> lib) {
  420. return leftPublicShift(padSignal(a, lib.width), lib.offset);
  421. }
  422. @Override
  423. public double outputToAlice(T[] a) {
  424. return Utils.toInt(env.outputToAlice(a));
  425. }
  426. @Override
  427. public int numBits() {
  428. return width;
  429. }
  430. }