package com.oblivm.backend.circuits.arithmetic; import java.util.Arrays; import com.oblivm.backend.flexsc.CompEnv; import com.oblivm.backend.util.Utils; public class FloatLib<T> implements ArithmeticLib<T> { CompEnv<T> env; IntegerLib<T> lib; public int VLength; public int PLength; public FloatLib(CompEnv<T> e, int VLength, int PLength) { this.env = e; lib = new IntegerLib<>(e); this.VLength = VLength; this.PLength = PLength; } public T[] inputOfAlice(double d) { return env.inputOfAlice(Utils.fromFloat(d, VLength, PLength)); } public T[] inputOfBob(double d) { return env.inputOfBob(Utils.fromFloat(d, VLength, PLength)); } public T[] pack(Representation<T> f) { assert (f.v.length == VLength && f.p.length == PLength) : "pack: not compatiable"; T[] res = env.newTArray(1 + f.v.length + f.p.length); res[0] = f.s; System.arraycopy(f.v, 0, res, 1, f.v.length); System.arraycopy(f.p, 0, res, 1 + f.v.length, f.p.length); return res; } public Representation<T> unpack(T[] data) { assert (data.length == VLength + PLength + 1) : "unpack: not compatiable"; T[] v = Arrays.copyOfRange(data, 1, 1 + VLength); T[] p = Arrays.copyOfRange(data, 1 + VLength, data.length); return new Representation<T>(data[0], v, p); } public static class Representation<T> { public T s; public T[] v; public T[] p; public Representation(T sign, T[] v, T[] p) { this.s = sign; this.p = p; this.v = v; } } public T[] multiply(T[] fa, T[] fb) { Representation<T> a = unpack(fa); Representation<T> b = unpack(fb); T new_s = lib.xor(a.s, b.s); T[] a_multi_b = lib.karatsubaMultiply(a.v, b.v);// length 2*v.length T[] a_add_b = lib.add(a.p, b.p); T toShift = lib.not(a_multi_b[a_multi_b.length - 1]); T[] Shifted = lib.conditionalLeftPublicShift(a_multi_b, 1, toShift); T[] new_v = Arrays.copyOfRange(Shifted, a.v.length, a.v.length * 2); T[] new_p = lib.add(a_add_b, lib.toSignals(a.v.length, a_add_b.length)); T[] decrement = lib.zeros(new_p.length); decrement[0] = toShift; new_p = lib.sub(new_p, decrement); Representation<T> res = new Representation<T>(new_s, new_v, new_p); return pack(res); } public T[] div(T[] fa, T[] fb) { Representation<T> a = unpack(fa); Representation<T> b = unpack(fb); T new_s = lib.xor(a.s, b.s); int length = a.v.length; int newLength = a.v.length * 2; T[] padded_av = lib.padSignal(a.v, newLength); T[] padded_bv = lib.padSignal(b.v, b.v.length + 1); T[] shifted_av = lib.leftPublicShift(padded_av, newLength - length - 1); // must be postive number div. so avoid div(shifted_av, padded_bv); T[] a_div_b = Arrays.copyOf(lib.divInternal(shifted_av, padded_bv), shifted_av.length); T[] leadingzero = lib.leadingZeros(a_div_b); T[] sh = lib.leftPrivateShift(a_div_b, leadingzero); sh = lib.rightPublicShift(sh, newLength - length); T[] new_v = Arrays.copyOf(sh, length); T[] new_p = lib.add(lib.sub(a.p, b.p), lib.toSignals(1, a.p.length)); new_p = lib.sub(lib.padSignal(new_p, leadingzero.length), leadingzero); new_p = lib.padSignedSignal(new_p, a.p.length); Representation<T> res = new Representation<T>(new_s, new_v, new_p); return pack(res); } public T[] publicValue(double d) { boolean[] b = Utils.fromFloat(d, VLength, PLength); T[] res = env.newTArray(PLength + VLength + 1); for (int i = 0; i < b.length; ++i) res[i] = b[i] ? lib.SIGNAL_ONE : lib.SIGNAL_ZERO; return res; } // assuming na = va*2^p, nb = vb*2^(p+pDiff) private T[] addInternal(T sa, T sb, T[] va, T[] vb, T[] p, T[] pDiff) { int temp_length = 2 * VLength + 1; T[] signedVa = lib.padSignal(va, temp_length); T[] signedVb = lib.padSignal(vb, temp_length); signedVb = lib.leftPrivateShift(signedVb, pDiff); signedVa = lib.addSign(signedVa, sa); signedVb = lib.addSign(signedVb, sb); T[] new_v = lib.add(signedVa, signedVb); T new_s = new_v[new_v.length - 1]; new_v = lib.absolute(new_v); T[] leadingzero = lib.leadingZeros(new_v); T[] sh = lib.leftPrivateShift(new_v, leadingzero); sh = lib.rightPublicShift(sh, temp_length - VLength); new_v = Arrays.copyOf(sh, VLength); T[] new_p = lib.sub(lib.padSignal(p, leadingzero.length), leadingzero); new_p = lib.add(new_p, lib.toSignals(temp_length - VLength, new_p.length)); new_p = lib.padSignedSignal(new_p, PLength); Representation<T> res = new Representation<T>(new_s, new_v, new_p); return pack(res); } public T[] add(T[] fa, T[] fb) { T[] va = Arrays.copyOfRange(fa, 1, 1 + VLength); T[] vb = Arrays.copyOfRange(fb, 1, 1 + VLength); T[] pa = Arrays.copyOfRange(fa, 1 + VLength, fa.length); T[] pb = Arrays.copyOfRange(fb, 1 + VLength, fb.length); T[] pDifference = lib.sub(pa, pb); T[] pDiffAbs = lib.absolute(pDifference); T paGreater = lib.not(pDifference[pDifference.length - 1]); T[] pToUse = lib.mux(pa, pb, paGreater); T[] normalCase = addInternal(lib.mux(fa[0], fb[0], paGreater), lib.mux(fb[0], fa[0], paGreater), lib.mux(va, vb, paGreater), lib.mux(vb, va, paGreater), pToUse, pDiffAbs); T underFlowHappen = lib.not(lib.leq(pDiffAbs, lib.toSignals(VLength, pDiffAbs.length))); T[] underFlowResult = lib.mux(fb, fa, paGreater); return lib.mux(normalCase, underFlowResult, underFlowHappen); } // (v*s^p)^(1/2) = public T[] sqrt(T[] fa) { int newLength = VLength + 2 + 1; T[] va = Arrays.copyOfRange(fa, 1, 1 + VLength); T[] pa = Arrays.copyOfRange(fa, 1 + VLength, fa.length); va = lib.padSignal(va, newLength); va = lib.leftPublicShift(va, 1); pa = lib.sub(pa, lib.toSignals(1, PLength)); va = lib.conditionalLeftPublicShift(va, 1, pa[0]); va = lib.sqrt(va); pa = lib.rightPublicShift(pa, 1); pa[pa.length - 1] = pa[pa.length - 2]; T[] leadingzero = lib.leadingZeros(va); T[] sh = lib.leftPrivateShift(va, leadingzero); sh = lib.rightPublicShift(sh, newLength - VLength); T[] new_v = Arrays.copyOf(sh, VLength); pa = lib.sub(pa, lib.padSignal(leadingzero, pa.length)); pa = lib.add(pa, lib.toSignals(newLength - VLength, PLength)); return pack(new Representation<T>(lib.SIGNAL_ZERO, new_v, pa)); } public T[] sub(T[] a, T[] b) { T[] negB = Arrays.copyOf(b, b.length); negB[0] = lib.not(negB[0]); return add(a, negB); } public T leq(T[] a, T[] b) { T[] res = sub(a, b); return lib.not(res[0]); } public T eq(T[] a, T[] b) { return lib.eq(a, b); } @Override public CompEnv<T> getEnv() { return env; } @Override public T[] toSecureInt(T[] a, IntegerLib<T> lib) { // TODO Auto-generated method stub return null; } @Override public T[] toSecureFloat(T[] a, FloatLib<T> lib) { // TODO Auto-generated method stub return null; } @Override public T[] toSecureFixPoint(T[] a, FixedPointLib<T> lib) { // TODO Auto-generated method stub return null; } @Override public double outputToAlice(T[] a) { return Utils.toFloat(env.outputToAlice(a), VLength, PLength); } @Override public int numBits() { return VLength + PLength + 1; } }