types.hpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. #ifndef __OBLIVDS_TYPES_HPP__
  2. #define __OBLIVDS_TYPES_HPP__
  3. #include <tuple>
  4. #include <vector>
  5. #include <array>
  6. #include <cstdint>
  7. #include <x86intrin.h> // SSE and AVX intrinsics
  8. #include <bsd/stdlib.h> // arc4random_buf
  9. // The number of bits in an MPC secret-shared memory word
  10. #ifndef VALUE_BITS
  11. #define VALUE_BITS 64
  12. #endif
  13. // Values in MPC secret-shared memory are of this type.
  14. // This is the type of the underlying shared value, not the types of the
  15. // shares themselves.
  16. #if VALUE_BITS == 64
  17. using value_t = uint64_t;
  18. #elif VALUE_BITS == 32
  19. using value_t = uint32_t;
  20. #else
  21. #error "Unsupported value of VALUE_BITS"
  22. #endif
  23. // Secret-shared bits are of this type. Note that it is standards
  24. // compliant to treat a bool as an unsigned integer type with values 0
  25. // and 1.
  26. using bit_t = bool;
  27. // Counts of the number of bits in a value are of this type, which must
  28. // be large enough to store the _value_ VALUE_BITS
  29. using nbits_t = uint8_t;
  30. // Convert a number of bits to the number of bytes required to store (or
  31. // more to the point, send) them.
  32. #define BITBYTES(nbits) (((nbits)+7)>>3)
  33. // A mask of this many bits; the test is to prevent 1<<nbits from
  34. // overflowing if nbits == VALUE_BITS
  35. #define MASKBITS(nbits) (((nbits) < VALUE_BITS) ? (value_t(1)<<(nbits))-1 : ~0)
  36. // The type of a register holding an additive share of a value
  37. struct RegAS {
  38. value_t ashare;
  39. RegAS() : ashare(0) {}
  40. inline value_t share() const { return ashare; }
  41. inline void set(value_t s) { ashare = s; }
  42. // Set each side's share to a random value nbits bits long
  43. inline void randomize(size_t nbits = VALUE_BITS) {
  44. value_t mask = MASKBITS(nbits);
  45. arc4random_buf(&ashare, sizeof(ashare));
  46. ashare &= mask;
  47. }
  48. inline RegAS &operator+=(const RegAS &rhs) {
  49. this->ashare += rhs.ashare;
  50. return *this;
  51. }
  52. inline RegAS operator+(const RegAS &rhs) const {
  53. RegAS res = *this;
  54. res += rhs;
  55. return res;
  56. }
  57. inline RegAS &operator-=(const RegAS &rhs) {
  58. this->ashare -= rhs.ashare;
  59. return *this;
  60. }
  61. inline RegAS operator-(const RegAS &rhs) const {
  62. RegAS res = *this;
  63. res -= rhs;
  64. return res;
  65. }
  66. inline RegAS operator-() const {
  67. RegAS res = *this;
  68. res.ashare = -res.ashare;
  69. return res;
  70. }
  71. inline RegAS &operator*=(value_t rhs) {
  72. this->ashare *= rhs;
  73. return *this;
  74. }
  75. inline RegAS operator*(value_t rhs) const {
  76. RegAS res = *this;
  77. res *= rhs;
  78. return res;
  79. }
  80. inline RegAS &operator&=(value_t mask) {
  81. this->ashare &= mask;
  82. return *this;
  83. }
  84. inline RegAS operator&(value_t mask) const {
  85. RegAS res = *this;
  86. res &= mask;
  87. return res;
  88. }
  89. };
  90. inline value_t combine(const RegAS &A, const RegAS &B,
  91. nbits_t nbits = VALUE_BITS) {
  92. value_t mask = ~0;
  93. if (nbits < VALUE_BITS) {
  94. mask = (value_t(1)<<nbits)-1;
  95. }
  96. return (A.ashare + B.ashare) & mask;
  97. }
  98. // The type of a register holding a bit share
  99. struct RegBS {
  100. bit_t bshare;
  101. RegBS() : bshare(0) {}
  102. inline bit_t share() const { return bshare; }
  103. inline void set(bit_t s) { bshare = s; }
  104. // Set each side's share to a random bit
  105. inline void randomize() {
  106. unsigned char randb;
  107. arc4random_buf(&randb, sizeof(randb));
  108. bshare = randb & 1;
  109. }
  110. inline RegBS &operator^=(const RegBS &rhs) {
  111. this->bshare ^= rhs.bshare;
  112. return *this;
  113. }
  114. inline RegBS operator^(const RegBS &rhs) const {
  115. RegBS res = *this;
  116. res ^= rhs;
  117. return res;
  118. }
  119. inline RegBS &operator^=(const bit_t &rhs) {
  120. this->bshare ^= rhs;
  121. return *this;
  122. }
  123. inline RegBS operator^(const bit_t &rhs) const {
  124. RegBS res = *this;
  125. res ^= rhs;
  126. return res;
  127. }
  128. };
  129. // The type of a register holding an XOR share of a value
  130. struct RegXS {
  131. value_t xshare;
  132. RegXS() : xshare(0) {}
  133. RegXS(const RegBS &b) { xshare = b.bshare ? ~0 : 0; }
  134. inline value_t share() const { return xshare; }
  135. inline void set(value_t s) { xshare = s; }
  136. // Set each side's share to a random value nbits bits long
  137. inline void randomize(size_t nbits = VALUE_BITS) {
  138. value_t mask = MASKBITS(nbits);
  139. arc4random_buf(&xshare, sizeof(xshare));
  140. xshare &= mask;
  141. }
  142. // For RegXS, + and * should be interpreted bitwise; that is, + is
  143. // really XOR and * is really AND. - is also XOR (the same as +).
  144. // We also include actual XOR operators for convenience
  145. inline RegXS &operator+=(const RegXS &rhs) {
  146. this->xshare ^= rhs.xshare;
  147. return *this;
  148. }
  149. inline RegXS operator+(const RegXS &rhs) const {
  150. RegXS res = *this;
  151. res += rhs;
  152. return res;
  153. }
  154. inline RegXS &operator-=(const RegXS &rhs) {
  155. this->xshare ^= rhs.xshare;
  156. return *this;
  157. }
  158. inline RegXS operator-(const RegXS &rhs) const {
  159. RegXS res = *this;
  160. res -= rhs;
  161. return res;
  162. }
  163. inline RegXS operator-() const {
  164. RegXS res = *this;
  165. return res;
  166. }
  167. inline RegXS &operator*=(value_t rhs) {
  168. this->xshare &= rhs;
  169. return *this;
  170. }
  171. inline RegXS operator*(value_t rhs) const {
  172. RegXS res = *this;
  173. res *= rhs;
  174. return res;
  175. }
  176. inline RegXS &operator^=(const RegXS &rhs) {
  177. this->xshare ^= rhs.xshare;
  178. return *this;
  179. }
  180. inline RegXS operator^(const RegXS &rhs) const {
  181. RegXS res = *this;
  182. res ^= rhs;
  183. return res;
  184. }
  185. inline RegXS &operator&=(value_t mask) {
  186. this->xshare &= mask;
  187. return *this;
  188. }
  189. inline RegXS operator&(value_t mask) const {
  190. RegXS res = *this;
  191. res &= mask;
  192. return res;
  193. }
  194. // Extract a bit share of bit bitnum of the XOR-shared register
  195. inline RegBS bit(nbits_t bitnum) const {
  196. RegBS bs;
  197. bs.bshare = !!(xshare & (value_t(1)<<bitnum));
  198. return bs;
  199. }
  200. };
  201. inline value_t combine(const RegXS &A, const RegXS &B,
  202. nbits_t nbits = VALUE_BITS) {
  203. value_t mask = ~0;
  204. if (nbits < VALUE_BITS) {
  205. mask = (value_t(1)<<nbits)-1;
  206. }
  207. return (A.xshare ^ B.xshare) & mask;
  208. }
  209. // Some useful operations on tuples, vectors, and arrays of the above
  210. // types
  211. template <typename T>
  212. std::tuple<T,T> operator+=(std::tuple<T,T> &A,
  213. const std::tuple<T,T> &B)
  214. {
  215. std::get<0>(A) += std::get<0>(B);
  216. std::get<1>(A) += std::get<1>(B);
  217. return A;
  218. }
  219. template <typename T>
  220. std::tuple<T,T> operator+=(const std::tuple<T&,T&> &A,
  221. const std::tuple<T,T> &B)
  222. {
  223. std::get<0>(A) += std::get<0>(B);
  224. std::get<1>(A) += std::get<1>(B);
  225. return A;
  226. }
  227. template <typename T>
  228. std::tuple<T,T> operator+(const std::tuple<T,T> &A,
  229. const std::tuple<T,T> &B)
  230. {
  231. auto res = A;
  232. res += B;
  233. return res;
  234. }
  235. template <typename T>
  236. std::tuple<T,T> operator-=(const std::tuple<T&,T&> &A,
  237. const std::tuple<T,T> &B)
  238. {
  239. std::get<0>(A) -= std::get<0>(B);
  240. std::get<1>(A) -= std::get<1>(B);
  241. return A;
  242. }
  243. template <typename T>
  244. std::tuple<T,T> operator-=(std::tuple<T,T> &A,
  245. const std::tuple<T,T> &B)
  246. {
  247. std::get<0>(A) -= std::get<0>(B);
  248. std::get<1>(A) -= std::get<1>(B);
  249. return A;
  250. }
  251. template <typename T>
  252. std::tuple<T,T> operator-(const std::tuple<T,T> &A,
  253. const std::tuple<T,T> &B)
  254. {
  255. auto res = A;
  256. res -= B;
  257. return res;
  258. }
  259. template <typename T>
  260. std::tuple<T,T> operator*=(const std::tuple<T&,T&> &A,
  261. const std::tuple<value_t,value_t> &B)
  262. {
  263. std::get<0>(A) *= std::get<0>(B);
  264. std::get<1>(A) *= std::get<1>(B);
  265. return A;
  266. }
  267. template <typename T>
  268. std::tuple<T,T> operator*=(std::tuple<T,T> &A,
  269. const std::tuple<value_t,value_t> &B)
  270. {
  271. std::get<0>(A) *= std::get<0>(B);
  272. std::get<1>(A) *= std::get<1>(B);
  273. return A;
  274. }
  275. template <typename T>
  276. std::tuple<T,T> operator*(const std::tuple<T,T> &A,
  277. const std::tuple<value_t,value_t> &B)
  278. {
  279. auto res = A;
  280. res *= B;
  281. return res;
  282. }
  283. template <typename T>
  284. inline std::tuple<value_t,value_t> combine(
  285. const std::tuple<T,T> &A, const std::tuple<T,T> &B,
  286. nbits_t nbits = VALUE_BITS) {
  287. return std::make_tuple(
  288. combine(std::get<0>(A), std::get<0>(B), nbits),
  289. combine(std::get<1>(A), std::get<1>(B), nbits));
  290. }
  291. template <typename T>
  292. std::tuple<T,T,T> operator+=(const std::tuple<T&,T&,T&> &A,
  293. const std::tuple<T,T,T> &B)
  294. {
  295. std::get<0>(A) += std::get<0>(B);
  296. std::get<1>(A) += std::get<1>(B);
  297. std::get<2>(A) += std::get<2>(B);
  298. return A;
  299. }
  300. template <typename T>
  301. std::tuple<T,T,T> operator+=(std::tuple<T,T,T> &A,
  302. const std::tuple<T,T,T> &B)
  303. {
  304. std::get<0>(A) += std::get<0>(B);
  305. std::get<1>(A) += std::get<1>(B);
  306. std::get<2>(A) += std::get<2>(B);
  307. return A;
  308. }
  309. template <typename T>
  310. std::tuple<T,T,T> operator+(const std::tuple<T,T,T> &A,
  311. const std::tuple<T,T,T> &B)
  312. {
  313. auto res = A;
  314. res += B;
  315. return res;
  316. }
  317. template <typename T>
  318. std::tuple<T,T,T> operator-=(const std::tuple<T&,T&,T&> &A,
  319. const std::tuple<T,T,T> &B)
  320. {
  321. std::get<0>(A) -= std::get<0>(B);
  322. std::get<1>(A) -= std::get<1>(B);
  323. std::get<2>(A) -= std::get<2>(B);
  324. return A;
  325. }
  326. template <typename T>
  327. std::tuple<T,T,T> operator-=(std::tuple<T,T,T> &A,
  328. const std::tuple<T,T,T> &B)
  329. {
  330. std::get<0>(A) -= std::get<0>(B);
  331. std::get<1>(A) -= std::get<1>(B);
  332. std::get<2>(A) -= std::get<2>(B);
  333. return A;
  334. }
  335. template <typename T>
  336. std::tuple<T,T,T> operator-(const std::tuple<T,T,T> &A,
  337. const std::tuple<T,T,T> &B)
  338. {
  339. auto res = A;
  340. res -= B;
  341. return res;
  342. }
  343. template <typename T>
  344. std::tuple<T,T,T> operator*=(const std::tuple<T&,T&,T&> &A,
  345. const std::tuple<value_t,value_t,value_t> &B)
  346. {
  347. std::get<0>(A) *= std::get<0>(B);
  348. std::get<1>(A) *= std::get<1>(B);
  349. std::get<2>(A) *= std::get<2>(B);
  350. return A;
  351. }
  352. template <typename T>
  353. std::tuple<T,T,T> operator*=(std::tuple<T,T,T> &A,
  354. const std::tuple<value_t,value_t,value_t> &B)
  355. {
  356. std::get<0>(A) *= std::get<0>(B);
  357. std::get<1>(A) *= std::get<1>(B);
  358. std::get<2>(A) *= std::get<2>(B);
  359. return A;
  360. }
  361. template <typename T>
  362. std::tuple<T,T,T> operator*(const std::tuple<T,T,T> &A,
  363. const std::tuple<value_t,value_t,value_t> &B)
  364. {
  365. auto res = A;
  366. res *= B;
  367. return res;
  368. }
  369. inline std::vector<RegAS> operator-(const std::vector<RegAS> &A)
  370. {
  371. std::vector<RegAS> res;
  372. for (const auto &v : A) {
  373. res.push_back(-v);
  374. }
  375. return res;
  376. }
  377. inline std::vector<RegXS> operator-(const std::vector<RegXS> &A)
  378. {
  379. return A;
  380. }
  381. inline std::vector<RegBS> operator-(const std::vector<RegBS> &A)
  382. {
  383. return A;
  384. }
  385. template <size_t N>
  386. inline std::vector<RegAS> operator-(const std::array<RegAS,N> &A)
  387. {
  388. std::vector<RegAS> res;
  389. for (const auto &v : A) {
  390. res.push_back(-v);
  391. }
  392. return res;
  393. }
  394. template <size_t N>
  395. inline std::array<RegXS,N> operator-(const std::array<RegXS,N> &A)
  396. {
  397. return A;
  398. }
  399. template <size_t N>
  400. inline std::array<RegBS,N> operator-(const std::array<RegBS,N> &A)
  401. {
  402. return A;
  403. }
  404. template <typename T>
  405. inline std::tuple<value_t,value_t,value_t> combine(
  406. const std::tuple<T,T,T> &A, const std::tuple<T,T,T> &B,
  407. nbits_t nbits = VALUE_BITS) {
  408. return std::make_tuple(
  409. combine(std::get<0>(A), std::get<0>(B), nbits),
  410. combine(std::get<1>(A), std::get<1>(B), nbits),
  411. combine(std::get<2>(A), std::get<2>(B), nbits));
  412. }
  413. // The _maximum_ number of bits in an MPC address; the actual size of
  414. // the memory will typically be set at runtime, but it cannot exceed
  415. // this value. It is more efficient (in terms of communication) in some
  416. // places for this value to be at most 32.
  417. #ifndef ADDRESS_MAX_BITS
  418. #define ADDRESS_MAX_BITS 32
  419. #endif
  420. // Addresses of MPC secret-shared memory are of this type
  421. #if ADDRESS_MAX_BITS <= 32
  422. using address_t = uint32_t;
  423. #elif ADDRESS_MAX_BITS <= 64
  424. using address_t = uint64_t;
  425. #else
  426. #error "Unsupported value of ADDRESS_MAX_BITS"
  427. #endif
  428. #if ADDRESS_MAX_BITS > VALUE_BITS
  429. #error "VALUE_BITS must be at least as large as ADDRESS_MAX_BITS"
  430. #endif
  431. // A multiplication triple is a triple (X0,Y0,Z0) held by P0 (and
  432. // correspondingly (X1,Y1,Z1) held by P1), with all values random,
  433. // but subject to the relation that X0*Y1 + Y0*X1 = Z0+Z1
  434. using MultTriple = std::tuple<value_t, value_t, value_t>;
  435. // The *Name structs are a way to get strings representing the names of
  436. // the types as would be given to preprocessing to create them in
  437. // advance.
  438. struct MultTripleName { static constexpr const char *name = "t"; };
  439. // A half-triple is (X0,Z0) held by P0 (and correspondingly (Y1,Z1) held
  440. // by P1), with all values random, but subject to the relation that
  441. // X0*Y1 = Z0+Z1
  442. using HalfTriple = std::tuple<value_t, value_t>;
  443. struct HalfTripleName { static constexpr const char *name = "h"; };
  444. // The type of nodes in a DPF. This must be at least as many bits as
  445. // the security parameter, and at least twice as many bits as value_t.
  446. using DPFnode = __m128i;
  447. // A Select triple is a triple of (X0,Y0,Z0) where X0 is a bit and Y0
  448. // and Z0 are DPFnodes held by P0 (and correspondingly (X1,Y1,Z1) held
  449. // by P1), with all values random, but subject to the relation that
  450. // (X0*Y1) ^ (Y0*X1) = Z0^Z1. These are only used while creating RDPFs
  451. // in the preprocessing phase, so we never need to store them. This is
  452. // a struct instead of a tuple for alignment reasons.
  453. struct SelectTriple {
  454. bit_t X;
  455. DPFnode Y, Z;
  456. };
  457. // These are defined in rdpf.hpp, but declared here to avoid cyclic
  458. // header dependencies.
  459. struct RDPFPair;
  460. struct RDPFPairName { static constexpr const char *name = "r"; };
  461. struct RDPFTriple;
  462. struct RDPFTripleName { static constexpr const char *name = "r"; };
  463. struct CDPF;
  464. struct CDPFName { static constexpr const char *name = "c"; };
  465. // We want the I/O (using << and >>) for many classes
  466. // to just be a common thing: write out the bytes
  467. // straight from memory
  468. #define DEFAULT_IO(CLASSNAME) \
  469. template <typename T> \
  470. T& operator>>(T& is, CLASSNAME &x) \
  471. { \
  472. is.read((char *)&x, sizeof(x)); \
  473. return is; \
  474. } \
  475. \
  476. template <typename T> \
  477. T& operator<<(T& os, const CLASSNAME &x) \
  478. { \
  479. os.write((const char *)&x, sizeof(x)); \
  480. return os; \
  481. }
  482. // Default I/O for various types
  483. DEFAULT_IO(DPFnode)
  484. DEFAULT_IO(RegBS)
  485. DEFAULT_IO(RegAS)
  486. DEFAULT_IO(RegXS)
  487. DEFAULT_IO(MultTriple)
  488. DEFAULT_IO(HalfTriple)
  489. // And for pairs and triples
  490. #define DEFAULT_TUPLE_IO(CLASSNAME) \
  491. template <typename T> \
  492. T& operator>>(T& is, std::tuple<CLASSNAME, CLASSNAME> &x) \
  493. { \
  494. is >> std::get<0>(x) >> std::get<1>(x); \
  495. return is; \
  496. } \
  497. \
  498. template <typename T> \
  499. T& operator<<(T& os, const std::tuple<CLASSNAME, CLASSNAME> &x) \
  500. { \
  501. os << std::get<0>(x) << std::get<1>(x); \
  502. return os; \
  503. } \
  504. \
  505. template <typename T> \
  506. T& operator>>(T& is, std::tuple<CLASSNAME, CLASSNAME, CLASSNAME> &x) \
  507. { \
  508. is >> std::get<0>(x) >> std::get<1>(x) >> std::get<2>(x); \
  509. return is; \
  510. } \
  511. \
  512. template <typename T> \
  513. T& operator<<(T& os, const std::tuple<CLASSNAME, CLASSNAME, CLASSNAME> &x) \
  514. { \
  515. os << std::get<0>(x) << std::get<1>(x) << std::get<2>(x); \
  516. return os; \
  517. }
  518. DEFAULT_TUPLE_IO(RegAS)
  519. DEFAULT_TUPLE_IO(RegXS)
  520. enum ProcessingMode {
  521. MODE_ONLINE, // Online mode, after preprocessing has been done
  522. MODE_PREPROCESSING, // Preprocessing mode
  523. MODE_ONLINEONLY // Online-only mode, where all computations are
  524. }; // done online
  525. #endif