FloatLib.java 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. package com.oblivm.backend.circuits.arithmetic;
  2. import java.util.Arrays;
  3. import com.oblivm.backend.flexsc.CompEnv;
  4. import com.oblivm.backend.util.Utils;
  5. public class FloatLib<T> implements ArithmeticLib<T> {
  6. CompEnv<T> env;
  7. IntegerLib<T> lib;
  8. public int VLength;
  9. public int PLength;
  10. public FloatLib(CompEnv<T> e, int VLength, int PLength) {
  11. this.env = e;
  12. lib = new IntegerLib<>(e);
  13. this.VLength = VLength;
  14. this.PLength = PLength;
  15. }
  16. public T[] inputOfAlice(double d) {
  17. return env.inputOfAlice(Utils.fromFloat(d, VLength, PLength));
  18. }
  19. public T[] inputOfBob(double d) {
  20. return env.inputOfBob(Utils.fromFloat(d, VLength, PLength));
  21. }
  22. public T[] pack(Representation<T> f) {
  23. assert (f.v.length == VLength && f.p.length == PLength) : "pack: not compatiable";
  24. T[] res = env.newTArray(1 + f.v.length + f.p.length);
  25. res[0] = f.s;
  26. System.arraycopy(f.v, 0, res, 1, f.v.length);
  27. System.arraycopy(f.p, 0, res, 1 + f.v.length, f.p.length);
  28. return res;
  29. }
  30. public Representation<T> unpack(T[] data) {
  31. assert (data.length == VLength + PLength + 1) : "unpack: not compatiable";
  32. T[] v = Arrays.copyOfRange(data, 1, 1 + VLength);
  33. T[] p = Arrays.copyOfRange(data, 1 + VLength, data.length);
  34. return new Representation<T>(data[0], v, p);
  35. }
  36. public static class Representation<T> {
  37. public T s;
  38. public T[] v;
  39. public T[] p;
  40. public Representation(T sign, T[] v, T[] p) {
  41. this.s = sign;
  42. this.p = p;
  43. this.v = v;
  44. }
  45. }
  46. public T[] multiply(T[] fa, T[] fb) {
  47. Representation<T> a = unpack(fa);
  48. Representation<T> b = unpack(fb);
  49. T new_s = lib.xor(a.s, b.s);
  50. T[] a_multi_b = lib.karatsubaMultiply(a.v, b.v);// length 2*v.length
  51. T[] a_add_b = lib.add(a.p, b.p);
  52. T toShift = lib.not(a_multi_b[a_multi_b.length - 1]);
  53. T[] Shifted = lib.conditionalLeftPublicShift(a_multi_b, 1, toShift);
  54. T[] new_v = Arrays.copyOfRange(Shifted, a.v.length, a.v.length * 2);
  55. T[] new_p = lib.add(a_add_b, lib.toSignals(a.v.length, a_add_b.length));
  56. T[] decrement = lib.zeros(new_p.length);
  57. decrement[0] = toShift;
  58. new_p = lib.sub(new_p, decrement);
  59. Representation<T> res = new Representation<T>(new_s, new_v, new_p);
  60. return pack(res);
  61. }
  62. public T[] div(T[] fa, T[] fb) {
  63. Representation<T> a = unpack(fa);
  64. Representation<T> b = unpack(fb);
  65. T new_s = lib.xor(a.s, b.s);
  66. int length = a.v.length;
  67. int newLength = a.v.length * 2;
  68. T[] padded_av = lib.padSignal(a.v, newLength);
  69. T[] padded_bv = lib.padSignal(b.v, b.v.length + 1);
  70. T[] shifted_av = lib.leftPublicShift(padded_av, newLength - length - 1);
  71. // must be postive number div. so avoid div(shifted_av, padded_bv);
  72. T[] a_div_b = Arrays.copyOf(lib.divInternal(shifted_av, padded_bv), shifted_av.length);
  73. T[] leadingzero = lib.leadingZeros(a_div_b);
  74. T[] sh = lib.leftPrivateShift(a_div_b, leadingzero);
  75. sh = lib.rightPublicShift(sh, newLength - length);
  76. T[] new_v = Arrays.copyOf(sh, length);
  77. T[] new_p = lib.add(lib.sub(a.p, b.p), lib.toSignals(1, a.p.length));
  78. new_p = lib.sub(lib.padSignal(new_p, leadingzero.length), leadingzero);
  79. new_p = lib.padSignedSignal(new_p, a.p.length);
  80. Representation<T> res = new Representation<T>(new_s, new_v, new_p);
  81. return pack(res);
  82. }
  83. public T[] publicValue(double d) {
  84. boolean[] b = Utils.fromFloat(d, VLength, PLength);
  85. T[] res = env.newTArray(PLength + VLength + 1);
  86. for (int i = 0; i < b.length; ++i)
  87. res[i] = b[i] ? lib.SIGNAL_ONE : lib.SIGNAL_ZERO;
  88. return res;
  89. }
  90. // assuming na = va*2^p, nb = vb*2^(p+pDiff)
  91. private T[] addInternal(T sa, T sb, T[] va, T[] vb, T[] p, T[] pDiff) {
  92. int temp_length = 2 * VLength + 1;
  93. T[] signedVa = lib.padSignal(va, temp_length);
  94. T[] signedVb = lib.padSignal(vb, temp_length);
  95. signedVb = lib.leftPrivateShift(signedVb, pDiff);
  96. signedVa = lib.addSign(signedVa, sa);
  97. signedVb = lib.addSign(signedVb, sb);
  98. T[] new_v = lib.add(signedVa, signedVb);
  99. T new_s = new_v[new_v.length - 1];
  100. new_v = lib.absolute(new_v);
  101. T[] leadingzero = lib.leadingZeros(new_v);
  102. T[] sh = lib.leftPrivateShift(new_v, leadingzero);
  103. sh = lib.rightPublicShift(sh, temp_length - VLength);
  104. new_v = Arrays.copyOf(sh, VLength);
  105. T[] new_p = lib.sub(lib.padSignal(p, leadingzero.length), leadingzero);
  106. new_p = lib.add(new_p, lib.toSignals(temp_length - VLength, new_p.length));
  107. new_p = lib.padSignedSignal(new_p, PLength);
  108. Representation<T> res = new Representation<T>(new_s, new_v, new_p);
  109. return pack(res);
  110. }
  111. public T[] add(T[] fa, T[] fb) {
  112. T[] va = Arrays.copyOfRange(fa, 1, 1 + VLength);
  113. T[] vb = Arrays.copyOfRange(fb, 1, 1 + VLength);
  114. T[] pa = Arrays.copyOfRange(fa, 1 + VLength, fa.length);
  115. T[] pb = Arrays.copyOfRange(fb, 1 + VLength, fb.length);
  116. T[] pDifference = lib.sub(pa, pb);
  117. T[] pDiffAbs = lib.absolute(pDifference);
  118. T paGreater = lib.not(pDifference[pDifference.length - 1]);
  119. T[] pToUse = lib.mux(pa, pb, paGreater);
  120. T[] normalCase = addInternal(lib.mux(fa[0], fb[0], paGreater), lib.mux(fb[0], fa[0], paGreater),
  121. lib.mux(va, vb, paGreater), lib.mux(vb, va, paGreater), pToUse, pDiffAbs);
  122. T underFlowHappen = lib.not(lib.leq(pDiffAbs, lib.toSignals(VLength, pDiffAbs.length)));
  123. T[] underFlowResult = lib.mux(fb, fa, paGreater);
  124. return lib.mux(normalCase, underFlowResult, underFlowHappen);
  125. }
  126. // (v*s^p)^(1/2) =
  127. public T[] sqrt(T[] fa) {
  128. int newLength = VLength + 2 + 1;
  129. T[] va = Arrays.copyOfRange(fa, 1, 1 + VLength);
  130. T[] pa = Arrays.copyOfRange(fa, 1 + VLength, fa.length);
  131. va = lib.padSignal(va, newLength);
  132. va = lib.leftPublicShift(va, 1);
  133. pa = lib.sub(pa, lib.toSignals(1, PLength));
  134. va = lib.conditionalLeftPublicShift(va, 1, pa[0]);
  135. va = lib.sqrt(va);
  136. pa = lib.rightPublicShift(pa, 1);
  137. pa[pa.length - 1] = pa[pa.length - 2];
  138. T[] leadingzero = lib.leadingZeros(va);
  139. T[] sh = lib.leftPrivateShift(va, leadingzero);
  140. sh = lib.rightPublicShift(sh, newLength - VLength);
  141. T[] new_v = Arrays.copyOf(sh, VLength);
  142. pa = lib.sub(pa, lib.padSignal(leadingzero, pa.length));
  143. pa = lib.add(pa, lib.toSignals(newLength - VLength, PLength));
  144. return pack(new Representation<T>(lib.SIGNAL_ZERO, new_v, pa));
  145. }
  146. public T[] sub(T[] a, T[] b) {
  147. T[] negB = Arrays.copyOf(b, b.length);
  148. negB[0] = lib.not(negB[0]);
  149. return add(a, negB);
  150. }
  151. public T leq(T[] a, T[] b) {
  152. T[] res = sub(a, b);
  153. return lib.not(res[0]);
  154. }
  155. public T eq(T[] a, T[] b) {
  156. return lib.eq(a, b);
  157. }
  158. @Override
  159. public CompEnv<T> getEnv() {
  160. return env;
  161. }
  162. @Override
  163. public T[] toSecureInt(T[] a, IntegerLib<T> lib) {
  164. // TODO Auto-generated method stub
  165. return null;
  166. }
  167. @Override
  168. public T[] toSecureFloat(T[] a, FloatLib<T> lib) {
  169. // TODO Auto-generated method stub
  170. return null;
  171. }
  172. @Override
  173. public T[] toSecureFixPoint(T[] a, FixedPointLib<T> lib) {
  174. // TODO Auto-generated method stub
  175. return null;
  176. }
  177. @Override
  178. public double outputToAlice(T[] a) {
  179. return Utils.toFloat(env.outputToAlice(a), VLength, PLength);
  180. }
  181. @Override
  182. public int numBits() {
  183. return VLength + PLength + 1;
  184. }
  185. }