123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542 |
- // Copyright (C) 2014 by Xiao Shaun Wang <wangxiao@cs.umd.edu>, Yan Huang <yhuang@cs.umd.edu> and Kartik Nayak <kartik@cs.umd.edu>
- package com.oblivm.backend.circuits.arithmetic;
- import java.util.Arrays;
- import com.oblivm.backend.circuits.CircuitLib;
- import com.oblivm.backend.flexsc.CompEnv;
- import com.oblivm.backend.util.Utils;
- public class IntegerLib<T> extends CircuitLib<T> implements ArithmeticLib<T> {
- public int width;
- public IntegerLib(CompEnv<T> e) {
- super(e);
- width = 32;
- }
- public IntegerLib(CompEnv<T> e, int width) {
- super(e);
- this.width = width;
- }
- static final int S = 0;
- static final int COUT = 1;
- public T[] publicValue(double v) {
- int intv = (int) v;
- return toSignals(intv, width);
- }
- // full 1-bit adder
- protected T[] add(T x, T y, T cin) {
- T[] res = env.newTArray(2);
- T t1 = xor(x, cin);
- T t2 = xor(y, cin);
- res[S] = xor(x, t2);
- t1 = and(t1, t2);
- res[COUT] = xor(cin, t1);
- return res;
- }
- // full n-bit adder
- public T[] addFull(T[] x, T[] y, boolean cin) {
- assert (x != null && y != null && x.length == y.length) : "add: bad inputs.";
- T[] res = env.newTArray(x.length + 1);
- T[] t = add(x[0], y[0], env.newT(cin));
- res[0] = t[S];
- for (int i = 0; i < x.length - 1; i++) {
- t = add(x[i + 1], y[i + 1], t[COUT]);
- res[i + 1] = t[S];
- }
- res[res.length - 1] = t[COUT];
- return res;
- }
- public T[] addFull(T[] x, T[] y) {
- return addFull(x, y, false);
- }
- public T[] add(T[] x, T[] y, boolean cin) {
- return Arrays.copyOf(addFull(x, y, cin), x.length);
- }
- public T[] add(T[] x, T[] y) {
- return add(x, y, false);
- }
- public T[] sub(T x, T y) throws Exception {
- T[] ax = env.newTArray(2);
- ax[1] = SIGNAL_ZERO;
- ax[0] = x;
- T[] ay = env.newTArray(2);
- ay[1] = SIGNAL_ZERO;
- ay[0] = y;
- return sub(x, y);
- }
- public T[] sub(T[] x, T[] y) {
- assert (x != null && y != null && x.length == y.length) : "sub: bad inputs.";
- return add(x, not(y), true);
- }
- public T[] incrementByOne(T[] x) {
- T[] one = zeros(x.length);
- one[0] = SIGNAL_ONE;
- return add(x, one);
- }
- public T[] decrementByOne(T[] x) {
- T[] one = zeros(x.length);
- one[0] = SIGNAL_ONE;
- return sub(x, one);
- }
- public T[] conditionalIncreament(T[] x, T flag) {
- T[] one = zeros(x.length);
- one[0] = mux(SIGNAL_ZERO, SIGNAL_ONE, flag);
- return add(x, one);
- }
- public T[] conditionalDecrement(T[] x, T flag) {
- T[] one = zeros(x.length);
- one[0] = mux(SIGNAL_ZERO, SIGNAL_ONE, flag);
- return sub(x, one);
- }
- public T less(T[] x, T[] y) {
- assert (x.length == y.length) : "bad input";
- T[] result = sub(x, y);
- return result[result.length - 1];
- }
- public T greater(T[] x, T[] y) {
- return less(y, x);
- }
- public T geq(T[] x, T[] y) {
- assert (x.length == y.length) : "bad input";
- T[] result = sub(x, y);
- return not(result[result.length - 1]);
- }
- public T leq(T[] x, T[] y) {
- return geq(y, x);
- }
- public T[] multiply(T[] x, T[] y) {
- return Arrays.copyOf(multiplyInternal(x, y), x.length);// res;
- }
- // This multiplication does not truncate the length of x and y
- public T[] multiplyFull(T[] x, T[] y) {
- return multiplyInternal(x, y);
- }
- private T[] multiplyInternal(T[] x, T[] y) {
- // return karatsubaMultiply(x,y);
- assert (x != null && y != null) : "multiply: bad inputs";
- T[] res = zeros(x.length + y.length);
- T[] zero = zeros(x.length);
- T[] toAdd = mux(zero, x, y[0]);
- System.arraycopy(toAdd, 0, res, 0, toAdd.length);
- for (int i = 1; i < y.length; ++i) {
- toAdd = Arrays.copyOfRange(res, i, i + x.length);
- toAdd = add(toAdd, mux(zero, x, y[i]), false);
- System.arraycopy(toAdd, 0, res, i, toAdd.length);
- }
- return res;
- }
- public T[] absolute(T[] x) {
- T reachedOneSignal = SIGNAL_ZERO;
- T[] result = zeros(x.length);
- for (int i = 0; i < x.length; ++i) {
- T comp = eq(SIGNAL_ONE, x[i]);
- result[i] = xor(x[i], reachedOneSignal);
- reachedOneSignal = or(reachedOneSignal, comp);
- }
- return mux(x, result, x[x.length - 1]);
- }
- public T[] div(T[] x, T[] y) {
- T[] absoluteX = absolute(x);
- T[] absoluteY = absolute(y);
- T[] PA = divInternal(absoluteX, absoluteY);
- return addSign(Arrays.copyOf(PA, x.length), xor(x[x.length - 1], y[y.length - 1]));
- }
- // Restoring Division Algorithm
- public T[] divInternal(T[] x, T[] y) {
- T[] PA = zeros(x.length + y.length);
- T[] B = y;
- System.arraycopy(x, 0, PA, 0, x.length);
- for (int i = 0; i < x.length; ++i) {
- PA = leftShift(PA);
- T[] tempP = sub(Arrays.copyOfRange(PA, x.length, PA.length), B);
- PA[0] = not(tempP[tempP.length - 1]);
- System.arraycopy(mux(tempP, Arrays.copyOfRange(PA, x.length, PA.length), tempP[tempP.length - 1]), 0, PA,
- x.length, y.length);
- }
- return PA;
- }
- public T[] mod(T[] x, T[] y) {
- T Xneg = x[x.length - 1];
- T[] absoluteX = absolute(x);
- T[] absoluteY = absolute(y);
- T[] PA = divInternal(absoluteX, absoluteY);
- T[] res = Arrays.copyOfRange(PA, y.length, PA.length);
- return mux(res, sub(toSignals(0, res.length), res), Xneg);
- }
- public T[] addSign(T[] x, T sign) {
- T[] reachedOneSignal = zeros(x.length);
- T[] result = env.newTArray(x.length);
- for (int i = 0; i < x.length - 1; ++i) {
- reachedOneSignal[i + 1] = or(reachedOneSignal[i], x[i]);
- result[i] = xor(x[i], reachedOneSignal[i]);
- }
- result[x.length - 1] = xor(x[x.length - 1], reachedOneSignal[x.length - 1]);
- return mux(x, result, sign);
- }
- public T[] commonPrefix(T[] x, T[] y) {
- assert (x != null && y != null) : "multiply: bad inputs";
- T[] result = xor(x, y);
- for (int i = x.length - 2; i >= 0; --i) {
- result[i] = or(result[i], result[i + 1]);
- }
- return result;
- }
- public T[] leadingZeros(T[] x) {
- assert (x != null) : "leading zeros: bad inputs";
- T[] result = Arrays.copyOf(x, x.length);
- for (int i = result.length - 2; i >= 0; --i) {
- result[i] = or(result[i], result[i + 1]);
- }
- return numberOfOnes(not(result));
- }
- public T[] lengthOfCommenPrefix(T[] x, T[] y) {
- assert (x != null) : "lengthOfCommenPrefix : bad inputs";
- return leadingZeros(xor(x, y));
- }
- /*
- * Integer manipulation
- */
- public T[] leftShift(T[] x) {
- assert (x != null) : "leftShift: bad inputs";
- return leftPublicShift(x, 1);
- }
- public T[] rightShift(T[] x) {
- assert (x != null) : "rightShift: bad inputs";
- return rightPublicShift(x, 1);
- }
- public T[] leftPublicShift(T[] x, int s) {
- assert (x != null && s < x.length) : "leftshift: bad inputs";
- T res[] = env.newTArray(x.length);
- System.arraycopy(zeros(s), 0, res, 0, s);
- System.arraycopy(x, 0, res, s, x.length - s);
- return res;
- }
- public T[] rightPublicShift(T[] x, int s) {
- assert (x != null && s < x.length) : "rightshift: bad inputs";
- T[] res = env.newTArray(x.length);
- System.arraycopy(x, s, res, 0, x.length - s);
- System.arraycopy(zeros(s), 0, res, x.length - s, s);
- return res;
- }
- public T[] conditionalLeftPublicShift(T[] x, int s, T sign) {
- assert (x != null && s < x.length) : "leftshift: bad inputs";
- T[] res = env.newTArray(x.length);
- System.arraycopy(mux(Arrays.copyOfRange(x, 0, s), zeros(s), sign), 0, res, 0, s);
- System.arraycopy(mux(Arrays.copyOfRange(x, s, x.length), Arrays.copyOfRange(x, 0, x.length), sign), 0, res, s,
- x.length - s);
- return res;
- }
- public T[] conditionalRightPublicShift(T[] x, int s, T sign) {
- assert (x != null && s < x.length) : "rightshift: bad inputs";
- T res[] = env.newTArray(x.length);
- System.arraycopy(mux(Arrays.copyOfRange(x, 0, x.length - s), Arrays.copyOfRange(x, s, x.length), sign), 0, res,
- 0, x.length - s);
- System.arraycopy(mux(Arrays.copyOfRange(x, x.length - s, x.length), zeros(s), sign), 0, res, x.length - s, s);
- return res;
- }
- public T[] leftPrivateShift(T[] x, T[] lengthToShift) {
- T[] res = Arrays.copyOf(x, x.length);
- for (int i = 0; ((1 << i) < x.length) && i < lengthToShift.length; ++i)
- res = conditionalLeftPublicShift(res, (1 << i), lengthToShift[i]);
- T clear = SIGNAL_ZERO;
- for (int i = 0; i < lengthToShift.length; ++i) {
- if ((1 << i) >= x.length)
- clear = or(clear, lengthToShift[i]);
- }
- return mux(res, zeros(x.length), clear);
- }
- public T[] rightPrivateShift(T[] x, T[] lengthToShift) {
- T[] res = Arrays.copyOf(x, x.length);
- for (int i = 0; ((1 << i) < x.length) && i < lengthToShift.length; ++i)
- res = conditionalRightPublicShift(res, (1 << i), lengthToShift[i]);
- T clear = SIGNAL_ZERO;
- for (int i = 0; i < lengthToShift.length; ++i) {
- if ((1 << i) >= x.length)
- clear = or(clear, lengthToShift[i]);
- }
- return mux(res, zeros(x.length), clear);
- }
- T compare(T x, T y, T cin) {
- T t1 = xor(x, cin);
- T t2 = xor(y, cin);
- t1 = and(t1, t2);
- return xor(x, t1);
- }
- public T compare(T[] x, T[] y) {
- assert (x != null && y != null && x.length == y.length) : "compare: bad inputs.";
- T t = env.newT(false);
- for (int i = 0; i < x.length; i++) {
- t = compare(x[i], y[i], t);
- }
- return t;
- }
- public T eq(T x, T y) {
- assert (x != null && y != null) : "CircuitLib.eq: bad inputs";
- return not(xor(x, y));
- }
- public T eq(T[] x, T[] y) {
- assert (x != null && y != null && x.length == y.length) : "CircuitLib.eq[]: bad inputs.";
- T res = env.newT(true);
- for (int i = 0; i < x.length; i++) {
- T t = eq(x[i], y[i]);
- res = env.and(res, t);
- }
- return res;
- }
- public T[] twosComplement(T[] x) {
- T reachOne = SIGNAL_ZERO;
- T[] result = env.newTArray(x.length);
- for (int i = 0; i < x.length; ++i) {
- result[i] = xor(x[i], reachOne);
- reachOne = or(reachOne, x[i]);
- }
- return result;
- }
- public T[] hammingDistance(T[] x, T[] y) {
- T[] a = xor(x, y);
- return numberOfOnes(a);
- }
- public T[] numberOfOnes(T[] t) {
- if (t.length == 0) {
- // T[] res = env.newTArray(1);
- // res[0] = SIGNAL_ZERO;
- // return res;
- return zeros(2);
- }
- if (t.length == 1) {
- // return t;
- return padSignal(t, 2);
- } else {
- int length = 1;
- int w = 1;
- while (length <= t.length) {
- length <<= 1;
- w++;
- }
- length >>= 1;
- T[] res1 = numberOfOnesN(Arrays.copyOfRange(t, 0, length));
- T[] res2 = numberOfOnes(Arrays.copyOfRange(t, length, t.length));
- return add(padSignal(res1, w), padSignal(res2, w));
- }
- }
- public T[] numberOfOnesN(T[] res) {
- if (res.length == 1)
- return res;
- T[] left = numberOfOnesN(Arrays.copyOfRange(res, 0, res.length / 2));
- T[] right = numberOfOnesN(Arrays.copyOfRange(res, res.length / 2, res.length));
- return unSignedAdd(left, right);
- }
- public T[] unSignedAdd(T[] x, T[] y) {
- assert (x != null && y != null && x.length == y.length) : "add: bad inputs.";
- T[] res = env.newTArray(x.length + 1);
- T[] t = add(x[0], y[0], env.newT(false));
- res[0] = t[S];
- for (int i = 0; i < x.length - 1; i++) {
- t = add(x[i + 1], y[i + 1], t[COUT]);
- res[i + 1] = t[S];
- }
- res[res.length - 1] = t[COUT];
- return res;
- }
- public T[] unSignedMultiply(T[] x, T[] y) {
- assert (x != null && y != null) : "multiply: bad inputs";
- T[] res = zeros(x.length + y.length);
- T[] zero = zeros(x.length);
- T[] toAdd = mux(zero, x, y[0]);
- System.arraycopy(toAdd, 0, res, 0, toAdd.length);
- for (int i = 1; i < y.length; ++i) {
- toAdd = Arrays.copyOfRange(res, i, i + x.length);
- toAdd = unSignedAdd(toAdd, mux(zero, x, y[i]));
- System.arraycopy(toAdd, 0, res, i, toAdd.length);
- }
- return res;
- }
- public T[] karatsubaMultiply(T[] x, T[] y) {
- if (x.length <= 18)
- return unSignedMultiply(x, y);
- int length = (x.length + y.length);
- T[] xlo = Arrays.copyOfRange(x, 0, x.length / 2);
- T[] xhi = Arrays.copyOfRange(x, x.length / 2, x.length);
- T[] ylo = Arrays.copyOfRange(y, 0, y.length / 2);
- T[] yhi = Arrays.copyOfRange(y, y.length / 2, y.length);
- int nextlength = Math.max(x.length / 2, x.length - x.length / 2);
- xlo = padSignal(xlo, nextlength);
- xhi = padSignal(xhi, nextlength);
- ylo = padSignal(ylo, nextlength);
- yhi = padSignal(yhi, nextlength);
- T[] z0 = karatsubaMultiply(xlo, ylo);
- T[] z2 = karatsubaMultiply(xhi, yhi);
- T[] z1 = sub(padSignal(karatsubaMultiply(unSignedAdd(xlo, xhi), unSignedAdd(ylo, yhi)), 2 * nextlength + 2),
- padSignal(unSignedAdd(padSignal(z2, 2 * nextlength), padSignal(z0, 2 * nextlength)),
- 2 * nextlength + 2));
- z1 = padSignal(z1, length);
- z1 = leftPublicShift(z1, x.length / 2);
- T[] z0Pad = padSignal(z0, length);
- T[] z2Pad = padSignal(z2, length);
- z2Pad = leftPublicShift(z2Pad, 2 * (x.length / 2));
- return add(add(z0Pad, z1), z2Pad);
- }
- public T[] min(T[] x, T[] y) {
- T leq = leq(x, y);
- return mux(y, x, leq);
- }
- public T[] sqrt(T[] a) {
- int newLength = a.length;
- if (newLength % 2 == 1)
- newLength++;
- T[] x = padSignal(a, newLength);
- T[] rem = zeros(x.length);
- T[] root = zeros(x.length);
- for (int i = 0; i < x.length / 2; i++) {
- root = leftShift(root);
- rem = add(leftPublicShift(rem, 2), rightPublicShift(x, x.length - 2));
- x = leftPublicShift(x, 2);
- T[] oldRoot = root;
- root = copy(root);
- root[0] = SIGNAL_ONE;
- T[] remMinusRoot = sub(rem, root);
- T isRootSmaller = not(remMinusRoot[remMinusRoot.length - 1]);
- rem = mux(rem, remMinusRoot, isRootSmaller);
- root = mux(oldRoot, incrementByOne(root), isRootSmaller);
- }
- return padSignal(rightShift(root), a.length);
- }
- public T[] inputOfAlice(double d) {
- return env.inputOfAlice(Utils.fromLong((long) d, width));
- }
- public T[] inputOfBob(double d) {
- return env.inputOfBob(Utils.fromLong((long) d, width));
- }
- @Override
- public CompEnv<T> getEnv() {
- return env;
- }
- @Override
- public T[] toSecureInt(T[] a, IntegerLib<T> lib) {
- return a;
- }
- // not fully implemented, more cases to consider
- @Override
- public T[] toSecureFloat(T[] a, FloatLib<T> lib) {
- T[] v = padSignal(a, lib.VLength);
- T[] p = leadingZeros(v);
- v = leftPrivateShift(v, p);
- p = padSignal(p, lib.PLength);
- p = sub(zeros(p.length), p);
- return lib.pack(new FloatLib.Representation<T>(SIGNAL_ZERO, v, p));
- }
- @Override
- public T[] toSecureFixPoint(T[] a, FixedPointLib<T> lib) {
- return leftPublicShift(padSignal(a, lib.width), lib.offset);
- }
- @Override
- public double outputToAlice(T[] a) {
- return Utils.toInt(env.outputToAlice(a));
- }
- @Override
- public int numBits() {
- return width;
- }
- }
|