IntegerLib.java 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. // Copyright (C) 2014 by Xiao Shaun Wang <wangxiao@cs.umd.edu>, Yan Huang <yhuang@cs.umd.edu> and Kartik Nayak <kartik@cs.umd.edu>
  2. package com.oblivm.backend.circuits.arithmetic;
  3. import java.util.Arrays;
  4. import com.oblivm.backend.circuits.CircuitLib;
  5. import com.oblivm.backend.flexsc.CompEnv;
  6. import com.oblivm.backend.util.Utils;
  7. public class IntegerLib<T> extends CircuitLib<T> implements ArithmeticLib<T> {
  8. public int width;
  9. public IntegerLib(CompEnv<T> e) {
  10. super(e);
  11. width = 32;
  12. }
  13. public IntegerLib(CompEnv<T> e, int width) {
  14. super(e);
  15. this.width = width;
  16. }
  17. static final int S = 0;
  18. static final int COUT = 1;
  19. public T[] publicValue(double v) {
  20. int intv = (int) v;
  21. return toSignals(intv, width);
  22. }
  23. // full 1-bit adder
  24. protected T[] add(T x, T y, T cin) {
  25. T[] res = env.newTArray(2);
  26. T t1 = xor(x, cin);
  27. T t2 = xor(y, cin);
  28. res[S] = xor(x, t2);
  29. t1 = and(t1, t2);
  30. res[COUT] = xor(cin, t1);
  31. return res;
  32. }
  33. // full n-bit adder
  34. public T[] addFull(T[] x, T[] y, boolean cin) {
  35. assert (x != null && y != null && x.length == y.length) : "add: bad inputs.";
  36. T[] res = env.newTArray(x.length + 1);
  37. T[] t = add(x[0], y[0], env.newT(cin));
  38. res[0] = t[S];
  39. for (int i = 0; i < x.length - 1; i++) {
  40. t = add(x[i + 1], y[i + 1], t[COUT]);
  41. res[i + 1] = t[S];
  42. }
  43. res[res.length - 1] = t[COUT];
  44. return res;
  45. }
  46. public T[] add(T[] x, T[] y, boolean cin) {
  47. return Arrays.copyOf(addFull(x, y, cin), x.length);
  48. }
  49. public T[] add(T[] x, T[] y) {
  50. return add(x, y, false);
  51. }
  52. public T[] sub(T x, T y) throws Exception {
  53. T[] ax = env.newTArray(2);
  54. ax[1] = SIGNAL_ZERO;
  55. ax[0] = x;
  56. T[] ay = env.newTArray(2);
  57. ay[1] = SIGNAL_ZERO;
  58. ay[0] = y;
  59. return sub(x, y);
  60. }
  61. public T[] sub(T[] x, T[] y) {
  62. assert (x != null && y != null && x.length == y.length) : "sub: bad inputs.";
  63. return add(x, not(y), true);
  64. }
  65. public T[] incrementByOne(T[] x) {
  66. T[] one = zeros(x.length);
  67. one[0] = SIGNAL_ONE;
  68. return add(x, one);
  69. }
  70. public T[] decrementByOne(T[] x) {
  71. T[] one = zeros(x.length);
  72. one[0] = SIGNAL_ONE;
  73. return sub(x, one);
  74. }
  75. public T[] conditionalIncreament(T[] x, T flag) {
  76. T[] one = zeros(x.length);
  77. one[0] = mux(SIGNAL_ZERO, SIGNAL_ONE, flag);
  78. return add(x, one);
  79. }
  80. public T[] conditionalDecrement(T[] x, T flag) {
  81. T[] one = zeros(x.length);
  82. one[0] = mux(SIGNAL_ZERO, SIGNAL_ONE, flag);
  83. return sub(x, one);
  84. }
  85. public T geq(T[] x, T[] y) {
  86. assert (x.length == y.length) : "bad input";
  87. T[] result = sub(x, y);
  88. return not(result[result.length - 1]);
  89. }
  90. public T leq(T[] x, T[] y) {
  91. return geq(y, x);
  92. }
  93. public T[] multiply(T[] x, T[] y) {
  94. return Arrays.copyOf(multiplyInternal(x, y), x.length);// res;
  95. }
  96. // This multiplication does not truncate the length of x and y
  97. public T[] multiplyFull(T[] x, T[] y) {
  98. return multiplyInternal(x, y);
  99. }
  100. private T[] multiplyInternal(T[] x, T[] y) {
  101. // return karatsubaMultiply(x,y);
  102. assert (x != null && y != null) : "multiply: bad inputs";
  103. T[] res = zeros(x.length + y.length);
  104. T[] zero = zeros(x.length);
  105. T[] toAdd = mux(zero, x, y[0]);
  106. System.arraycopy(toAdd, 0, res, 0, toAdd.length);
  107. for (int i = 1; i < y.length; ++i) {
  108. toAdd = Arrays.copyOfRange(res, i, i + x.length);
  109. toAdd = add(toAdd, mux(zero, x, y[i]), false);
  110. System.arraycopy(toAdd, 0, res, i, toAdd.length);
  111. }
  112. return res;
  113. }
  114. public T[] absolute(T[] x) {
  115. T reachedOneSignal = SIGNAL_ZERO;
  116. T[] result = zeros(x.length);
  117. for (int i = 0; i < x.length; ++i) {
  118. T comp = eq(SIGNAL_ONE, x[i]);
  119. result[i] = xor(x[i], reachedOneSignal);
  120. reachedOneSignal = or(reachedOneSignal, comp);
  121. }
  122. return mux(x, result, x[x.length - 1]);
  123. }
  124. public T[] div(T[] x, T[] y) {
  125. T[] absoluteX = absolute(x);
  126. T[] absoluteY = absolute(y);
  127. T[] PA = divInternal(absoluteX, absoluteY);
  128. return addSign(Arrays.copyOf(PA, x.length), xor(x[x.length - 1], y[y.length - 1]));
  129. }
  130. // Restoring Division Algorithm
  131. public T[] divInternal(T[] x, T[] y) {
  132. T[] PA = zeros(x.length + y.length);
  133. T[] B = y;
  134. System.arraycopy(x, 0, PA, 0, x.length);
  135. for (int i = 0; i < x.length; ++i) {
  136. PA = leftShift(PA);
  137. T[] tempP = sub(Arrays.copyOfRange(PA, x.length, PA.length), B);
  138. PA[0] = not(tempP[tempP.length - 1]);
  139. System.arraycopy(mux(tempP, Arrays.copyOfRange(PA, x.length, PA.length), tempP[tempP.length - 1]), 0, PA,
  140. x.length, y.length);
  141. }
  142. return PA;
  143. }
  144. public T[] mod(T[] x, T[] y) {
  145. T Xneg = x[x.length - 1];
  146. T[] absoluteX = absolute(x);
  147. T[] absoluteY = absolute(y);
  148. T[] PA = divInternal(absoluteX, absoluteY);
  149. T[] res = Arrays.copyOfRange(PA, y.length, PA.length);
  150. return mux(res, sub(toSignals(0, res.length), res), Xneg);
  151. }
  152. public T[] addSign(T[] x, T sign) {
  153. T[] reachedOneSignal = zeros(x.length);
  154. T[] result = env.newTArray(x.length);
  155. for (int i = 0; i < x.length - 1; ++i) {
  156. reachedOneSignal[i + 1] = or(reachedOneSignal[i], x[i]);
  157. result[i] = xor(x[i], reachedOneSignal[i]);
  158. }
  159. result[x.length - 1] = xor(x[x.length - 1], reachedOneSignal[x.length - 1]);
  160. return mux(x, result, sign);
  161. }
  162. public T[] commonPrefix(T[] x, T[] y) {
  163. assert (x != null && y != null) : "multiply: bad inputs";
  164. T[] result = xor(x, y);
  165. for (int i = x.length - 2; i >= 0; --i) {
  166. result[i] = or(result[i], result[i + 1]);
  167. }
  168. return result;
  169. }
  170. public T[] leadingZeros(T[] x) {
  171. assert (x != null) : "leading zeros: bad inputs";
  172. T[] result = Arrays.copyOf(x, x.length);
  173. for (int i = result.length - 2; i >= 0; --i) {
  174. result[i] = or(result[i], result[i + 1]);
  175. }
  176. return numberOfOnes(not(result));
  177. }
  178. public T[] lengthOfCommenPrefix(T[] x, T[] y) {
  179. assert (x != null) : "lengthOfCommenPrefix : bad inputs";
  180. return leadingZeros(xor(x, y));
  181. }
  182. /*
  183. * Integer manipulation
  184. */
  185. public T[] leftShift(T[] x) {
  186. assert (x != null) : "leftShift: bad inputs";
  187. return leftPublicShift(x, 1);
  188. }
  189. public T[] rightShift(T[] x) {
  190. assert (x != null) : "rightShift: bad inputs";
  191. return rightPublicShift(x, 1);
  192. }
  193. public T[] leftPublicShift(T[] x, int s) {
  194. assert (x != null && s < x.length) : "leftshift: bad inputs";
  195. T res[] = env.newTArray(x.length);
  196. System.arraycopy(zeros(s), 0, res, 0, s);
  197. System.arraycopy(x, 0, res, s, x.length - s);
  198. return res;
  199. }
  200. public T[] rightPublicShift(T[] x, int s) {
  201. assert (x != null && s < x.length) : "rightshift: bad inputs";
  202. T[] res = env.newTArray(x.length);
  203. System.arraycopy(x, s, res, 0, x.length - s);
  204. System.arraycopy(zeros(s), 0, res, x.length - s, s);
  205. return res;
  206. }
  207. public T[] conditionalLeftPublicShift(T[] x, int s, T sign) {
  208. assert (x != null && s < x.length) : "leftshift: bad inputs";
  209. T[] res = env.newTArray(x.length);
  210. System.arraycopy(mux(Arrays.copyOfRange(x, 0, s), zeros(s), sign), 0, res, 0, s);
  211. System.arraycopy(mux(Arrays.copyOfRange(x, s, x.length), Arrays.copyOfRange(x, 0, x.length), sign), 0, res, s,
  212. x.length - s);
  213. return res;
  214. }
  215. public T[] conditionalRightPublicShift(T[] x, int s, T sign) {
  216. assert (x != null && s < x.length) : "rightshift: bad inputs";
  217. T res[] = env.newTArray(x.length);
  218. System.arraycopy(mux(Arrays.copyOfRange(x, 0, x.length - s), Arrays.copyOfRange(x, s, x.length), sign), 0, res,
  219. 0, x.length - s);
  220. System.arraycopy(mux(Arrays.copyOfRange(x, x.length - s, x.length), zeros(s), sign), 0, res, x.length - s, s);
  221. return res;
  222. }
  223. public T[] leftPrivateShift(T[] x, T[] lengthToShift) {
  224. T[] res = Arrays.copyOf(x, x.length);
  225. for (int i = 0; ((1 << i) < x.length) && i < lengthToShift.length; ++i)
  226. res = conditionalLeftPublicShift(res, (1 << i), lengthToShift[i]);
  227. T clear = SIGNAL_ZERO;
  228. for (int i = 0; i < lengthToShift.length; ++i) {
  229. if ((1 << i) >= x.length)
  230. clear = or(clear, lengthToShift[i]);
  231. }
  232. return mux(res, zeros(x.length), clear);
  233. }
  234. public T[] rightPrivateShift(T[] x, T[] lengthToShift) {
  235. T[] res = Arrays.copyOf(x, x.length);
  236. for (int i = 0; ((1 << i) < x.length) && i < lengthToShift.length; ++i)
  237. res = conditionalRightPublicShift(res, (1 << i), lengthToShift[i]);
  238. T clear = SIGNAL_ZERO;
  239. for (int i = 0; i < lengthToShift.length; ++i) {
  240. if ((1 << i) >= x.length)
  241. clear = or(clear, lengthToShift[i]);
  242. }
  243. return mux(res, zeros(x.length), clear);
  244. }
  245. T compare(T x, T y, T cin) {
  246. T t1 = xor(x, cin);
  247. T t2 = xor(y, cin);
  248. t1 = and(t1, t2);
  249. return xor(x, t1);
  250. }
  251. public T compare(T[] x, T[] y) {
  252. assert (x != null && y != null && x.length == y.length) : "compare: bad inputs.";
  253. T t = env.newT(false);
  254. for (int i = 0; i < x.length; i++) {
  255. t = compare(x[i], y[i], t);
  256. }
  257. return t;
  258. }
  259. public T eq(T x, T y) {
  260. assert (x != null && y != null) : "CircuitLib.eq: bad inputs";
  261. return not(xor(x, y));
  262. }
  263. public T eq(T[] x, T[] y) {
  264. assert (x != null && y != null && x.length == y.length) : "CircuitLib.eq[]: bad inputs.";
  265. T res = env.newT(true);
  266. for (int i = 0; i < x.length; i++) {
  267. T t = eq(x[i], y[i]);
  268. res = env.and(res, t);
  269. }
  270. return res;
  271. }
  272. public T[] twosComplement(T[] x) {
  273. T reachOne = SIGNAL_ZERO;
  274. T[] result = env.newTArray(x.length);
  275. for (int i = 0; i < x.length; ++i) {
  276. result[i] = xor(x[i], reachOne);
  277. reachOne = or(reachOne, x[i]);
  278. }
  279. return result;
  280. }
  281. public T[] hammingDistance(T[] x, T[] y) {
  282. T[] a = xor(x, y);
  283. return numberOfOnes(a);
  284. }
  285. public T[] numberOfOnes(T[] t) {
  286. if (t.length == 0) {
  287. T[] res = env.newTArray(1);
  288. res[0] = SIGNAL_ZERO;
  289. return res;
  290. }
  291. if (t.length == 1) {
  292. return t;
  293. } else {
  294. int length = 1;
  295. int w = 1;
  296. while (length <= t.length) {
  297. length <<= 1;
  298. w++;
  299. }
  300. length >>= 1;
  301. T[] res1 = numberOfOnesN(Arrays.copyOfRange(t, 0, length));
  302. T[] res2 = numberOfOnes(Arrays.copyOfRange(t, length, t.length));
  303. return add(padSignal(res1, w), padSignal(res2, w));
  304. }
  305. }
  306. public T[] numberOfOnesN(T[] res) {
  307. if (res.length == 1)
  308. return res;
  309. T[] left = numberOfOnesN(Arrays.copyOfRange(res, 0, res.length / 2));
  310. T[] right = numberOfOnesN(Arrays.copyOfRange(res, res.length / 2, res.length));
  311. return unSignedAdd(left, right);
  312. }
  313. public T[] unSignedAdd(T[] x, T[] y) {
  314. assert (x != null && y != null && x.length == y.length) : "add: bad inputs.";
  315. T[] res = env.newTArray(x.length + 1);
  316. T[] t = add(x[0], y[0], env.newT(false));
  317. res[0] = t[S];
  318. for (int i = 0; i < x.length - 1; i++) {
  319. t = add(x[i + 1], y[i + 1], t[COUT]);
  320. res[i + 1] = t[S];
  321. }
  322. res[res.length - 1] = t[COUT];
  323. return res;
  324. }
  325. public T[] unSignedMultiply(T[] x, T[] y) {
  326. assert (x != null && y != null) : "multiply: bad inputs";
  327. T[] res = zeros(x.length + y.length);
  328. T[] zero = zeros(x.length);
  329. T[] toAdd = mux(zero, x, y[0]);
  330. System.arraycopy(toAdd, 0, res, 0, toAdd.length);
  331. for (int i = 1; i < y.length; ++i) {
  332. toAdd = Arrays.copyOfRange(res, i, i + x.length);
  333. toAdd = unSignedAdd(toAdd, mux(zero, x, y[i]));
  334. System.arraycopy(toAdd, 0, res, i, toAdd.length);
  335. }
  336. return res;
  337. }
  338. public T[] karatsubaMultiply(T[] x, T[] y) {
  339. if (x.length <= 18)
  340. return unSignedMultiply(x, y);
  341. int length = (x.length + y.length);
  342. T[] xlo = Arrays.copyOfRange(x, 0, x.length / 2);
  343. T[] xhi = Arrays.copyOfRange(x, x.length / 2, x.length);
  344. T[] ylo = Arrays.copyOfRange(y, 0, y.length / 2);
  345. T[] yhi = Arrays.copyOfRange(y, y.length / 2, y.length);
  346. int nextlength = Math.max(x.length / 2, x.length - x.length / 2);
  347. xlo = padSignal(xlo, nextlength);
  348. xhi = padSignal(xhi, nextlength);
  349. ylo = padSignal(ylo, nextlength);
  350. yhi = padSignal(yhi, nextlength);
  351. T[] z0 = karatsubaMultiply(xlo, ylo);
  352. T[] z2 = karatsubaMultiply(xhi, yhi);
  353. T[] z1 = sub(padSignal(karatsubaMultiply(unSignedAdd(xlo, xhi), unSignedAdd(ylo, yhi)), 2 * nextlength + 2),
  354. padSignal(unSignedAdd(padSignal(z2, 2 * nextlength), padSignal(z0, 2 * nextlength)),
  355. 2 * nextlength + 2));
  356. z1 = padSignal(z1, length);
  357. z1 = leftPublicShift(z1, x.length / 2);
  358. T[] z0Pad = padSignal(z0, length);
  359. T[] z2Pad = padSignal(z2, length);
  360. z2Pad = leftPublicShift(z2Pad, 2 * (x.length / 2));
  361. return add(add(z0Pad, z1), z2Pad);
  362. }
  363. public T[] min(T[] x, T[] y) {
  364. T leq = leq(x, y);
  365. return mux(y, x, leq);
  366. }
  367. public T[] sqrt(T[] a) {
  368. int newLength = a.length;
  369. if (newLength % 2 == 1)
  370. newLength++;
  371. T[] x = padSignal(a, newLength);
  372. T[] rem = zeros(x.length);
  373. T[] root = zeros(x.length);
  374. for (int i = 0; i < x.length / 2; i++) {
  375. root = leftShift(root);
  376. rem = add(leftPublicShift(rem, 2), rightPublicShift(x, x.length - 2));
  377. x = leftPublicShift(x, 2);
  378. T[] oldRoot = root;
  379. root = copy(root);
  380. root[0] = SIGNAL_ONE;
  381. T[] remMinusRoot = sub(rem, root);
  382. T isRootSmaller = not(remMinusRoot[remMinusRoot.length - 1]);
  383. rem = mux(rem, remMinusRoot, isRootSmaller);
  384. root = mux(oldRoot, incrementByOne(root), isRootSmaller);
  385. }
  386. return padSignal(rightShift(root), a.length);
  387. }
  388. public T[] inputOfAlice(double d) {
  389. return env.inputOfAlice(Utils.fromLong((long) d, width));
  390. }
  391. public T[] inputOfBob(double d) {
  392. return env.inputOfBob(Utils.fromLong((long) d, width));
  393. }
  394. @Override
  395. public CompEnv<T> getEnv() {
  396. return env;
  397. }
  398. @Override
  399. public T[] toSecureInt(T[] a, IntegerLib<T> lib) {
  400. return a;
  401. }
  402. // not fully implemented, more cases to consider
  403. @Override
  404. public T[] toSecureFloat(T[] a, FloatLib<T> lib) {
  405. T[] v = padSignal(a, lib.VLength);
  406. T[] p = leadingZeros(v);
  407. v = leftPrivateShift(v, p);
  408. p = padSignal(p, lib.PLength);
  409. p = sub(zeros(p.length), p);
  410. return lib.pack(new FloatLib.Representation<T>(SIGNAL_ZERO, v, p));
  411. }
  412. @Override
  413. public T[] toSecureFixPoint(T[] a, FixedPointLib<T> lib) {
  414. return leftPublicShift(padSignal(a, lib.width), lib.offset);
  415. }
  416. @Override
  417. public double outputToAlice(T[] a) {
  418. return Utils.toInt(env.outputToAlice(a));
  419. }
  420. @Override
  421. public int numBits() {
  422. return width;
  423. }
  424. }