bindings.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. #pragma once
  2. #include "openfhe/pke/scheme/ckksrns/gen-cryptocontext-ckksrns.h"
  3. #include "openfhe/pke/scheme/bfvrns/gen-cryptocontext-bfvrns.h"
  4. #include "openfhe/pke/scheme/bgvrns/gen-cryptocontext-bgvrns.h"
  5. #include "rust/cxx.h" // rust::String
  6. enum SerialMode
  7. {
  8. BINARY = 0,
  9. JSON = 1,
  10. };
  11. namespace openfhe
  12. {
  13. using ParamsBFVRNS = lbcrypto::CCParams<lbcrypto::CryptoContextBFVRNS>;
  14. using ParamsBGVRNS = lbcrypto::CCParams<lbcrypto::CryptoContextBGVRNS>;
  15. using ParamsCKKSRNS = lbcrypto::CCParams<lbcrypto::CryptoContextCKKSRNS>;
  16. using Params = lbcrypto::Params;
  17. using SCHEME = lbcrypto::SCHEME;
  18. using SecretKeyDist = lbcrypto::SecretKeyDist;
  19. using ProxyReEncryptionMode = lbcrypto::ProxyReEncryptionMode;
  20. using MultipartyMode = lbcrypto::MultipartyMode;
  21. using ExecutionMode = lbcrypto::ExecutionMode;
  22. using DecryptionNoiseMode = lbcrypto::DecryptionNoiseMode;
  23. using KeySwitchTechnique = lbcrypto::KeySwitchTechnique;
  24. using ScalingTechnique = lbcrypto::ScalingTechnique;
  25. using SecurityLevel = lbcrypto::SecurityLevel;
  26. using EncryptionTechnique = lbcrypto::EncryptionTechnique;
  27. using MultiplicationTechnique = lbcrypto::MultiplicationTechnique;
  28. using COMPRESSION_LEVEL = lbcrypto::COMPRESSION_LEVEL;
  29. using PKESchemeFeature = lbcrypto::PKESchemeFeature;
  30. using PublicKeyImpl = lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>;
  31. using PrivateKeyImpl = lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>;
  32. using DecryptResult = lbcrypto::DecryptResult;
  33. using DCRTPolyParams = lbcrypto::DCRTPoly::Params;
  34. using ::SerialMode;
  35. struct ComplexPair;
  36. using Complex = std::complex<double>;
  37. struct SharedComplex;
  38. // not used in the Rust side
  39. using PlaintextImpl = lbcrypto::PlaintextImpl;
  40. // not used in the Rust side
  41. using CiphertextImpl = lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>;
  42. // not used in the Rust side
  43. using CryptoContextImpl = lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>;
  44. // not used in the Rust side
  45. using KeyPair = lbcrypto::KeyPair<lbcrypto::DCRTPoly>;
  46. ///////////////////////////////////////////////////////////////////////////////////////////////////
  47. class PublicKeyDCRTPoly final
  48. {
  49. std::shared_ptr<PublicKeyImpl> m_publicKey;
  50. public:
  51. friend bool SerializePublicKeyToFile(const std::string& publicKeyLocation,
  52. const PublicKeyDCRTPoly& publicKey, const SerialMode serialMode);
  53. friend bool DeserializePublicKeyFromFile(const std::string& publicKeyLocation,
  54. PublicKeyDCRTPoly& publicKey, const SerialMode serialMode);
  55. explicit PublicKeyDCRTPoly();
  56. PublicKeyDCRTPoly(const PublicKeyDCRTPoly&) = delete;
  57. PublicKeyDCRTPoly(PublicKeyDCRTPoly&&) = delete;
  58. PublicKeyDCRTPoly& operator=(const PublicKeyDCRTPoly&) = delete;
  59. PublicKeyDCRTPoly& operator=(PublicKeyDCRTPoly&&) = delete;
  60. [[nodiscard]] std::shared_ptr<PublicKeyImpl> GetInternal() const;
  61. };
  62. ///////////////////////////////////////////////////////////////////////////////////////////////////
  63. class KeyPairDCRTPoly final
  64. {
  65. std::shared_ptr<PublicKeyImpl> m_publicKey;
  66. std::shared_ptr<PrivateKeyImpl> m_privateKey;
  67. public:
  68. explicit KeyPairDCRTPoly(KeyPair keyPair);
  69. KeyPairDCRTPoly(const KeyPairDCRTPoly&) = delete;
  70. KeyPairDCRTPoly(KeyPairDCRTPoly&&) = delete;
  71. KeyPairDCRTPoly& operator=(const KeyPairDCRTPoly&) = delete;
  72. KeyPairDCRTPoly& operator=(KeyPairDCRTPoly&&) = delete;
  73. [[nodiscard]] std::shared_ptr<PublicKeyImpl> GetPublicKey() const;
  74. [[nodiscard]] std::shared_ptr<PrivateKeyImpl> GetPrivateKey() const;
  75. };
  76. ///////////////////////////////////////////////////////////////////////////////////////////////////
  77. class Plaintext final
  78. {
  79. std::shared_ptr<PlaintextImpl> m_plaintext;
  80. public:
  81. explicit Plaintext() = default;
  82. explicit Plaintext(std::shared_ptr<PlaintextImpl> plaintext);
  83. Plaintext(const Plaintext&) = delete;
  84. Plaintext(Plaintext&&) = delete;
  85. Plaintext& operator=(const Plaintext&) = delete;
  86. Plaintext& operator=(Plaintext&&) = delete;
  87. Plaintext& operator=(std::shared_ptr<PlaintextImpl> plaintext);
  88. [[nodiscard]] std::shared_ptr<PlaintextImpl> GetInternal() const;
  89. void SetLength(const size_t newSize) const;
  90. [[nodiscard]] double GetLogPrecision() const;
  91. [[nodiscard]] rust::String GetString() const;
  92. };
  93. ///////////////////////////////////////////////////////////////////////////////////////////////////
  94. class CiphertextDCRTPoly final
  95. {
  96. std::shared_ptr<CiphertextImpl> m_ciphertext;
  97. public:
  98. friend bool SerializeCiphertextToFile(const std::string& ciphertextLocation,
  99. const CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
  100. friend bool DeserializeCiphertextFromFile(const std::string& ciphertextLocation,
  101. CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
  102. explicit CiphertextDCRTPoly();
  103. explicit CiphertextDCRTPoly(std::shared_ptr<CiphertextImpl> ciphertext);
  104. CiphertextDCRTPoly(const CiphertextDCRTPoly&) = delete;
  105. CiphertextDCRTPoly(CiphertextDCRTPoly&&) = delete;
  106. CiphertextDCRTPoly& operator=(const CiphertextDCRTPoly&) = delete;
  107. CiphertextDCRTPoly& operator=(CiphertextDCRTPoly&&) = delete;
  108. [[nodiscard]] std::shared_ptr<CiphertextImpl> GetInternal() const;
  109. };
  110. ///////////////////////////////////////////////////////////////////////////////////////////////////
  111. class CryptoContextDCRTPoly final
  112. {
  113. std::shared_ptr<CryptoContextImpl> m_cryptoContextImplSharedPtr;
  114. public:
  115. friend bool SerializeCryptoContextToFile(const std::string& ccLocation,
  116. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  117. friend bool DeserializeCryptoContextFromFile(const std::string& ccLocation,
  118. CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  119. friend bool SerializeEvalMultKeyToFile(const std::string& multKeyLocation,
  120. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  121. friend bool SerializeEvalSumKeyToFile(const std::string& sumKeyLocation,
  122. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  123. friend bool SerializeEvalAutomorphismKeyToFile(const std::string& automorphismKeyLocation,
  124. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  125. explicit CryptoContextDCRTPoly() = default;
  126. explicit CryptoContextDCRTPoly(const ParamsBFVRNS& params);
  127. explicit CryptoContextDCRTPoly(const ParamsBGVRNS& params);
  128. explicit CryptoContextDCRTPoly(const ParamsCKKSRNS& params);
  129. CryptoContextDCRTPoly(const CryptoContextDCRTPoly&) = delete;
  130. CryptoContextDCRTPoly(CryptoContextDCRTPoly&&) = delete;
  131. CryptoContextDCRTPoly& operator=(const CryptoContextDCRTPoly&) = delete;
  132. CryptoContextDCRTPoly& operator=(CryptoContextDCRTPoly&&) = delete;
  133. void Enable(const PKESchemeFeature feature) const;
  134. [[nodiscard]] std::unique_ptr<KeyPairDCRTPoly> KeyGen() const;
  135. void EvalMultKeyGen(const std::shared_ptr<PrivateKeyImpl> key) const;
  136. void EvalMultKeysGen(const std::shared_ptr<PrivateKeyImpl> key) const;
  137. void EvalRotateKeyGen(
  138. const std::shared_ptr<PrivateKeyImpl> privateKey, const std::vector<int32_t>& indexList,
  139. const std::shared_ptr<PublicKeyImpl> publicKey /* nullptr */) const;
  140. void EvalCKKStoFHEWPrecompute(const double scale /* 1.0 */) const;
  141. [[nodiscard]] uint32_t GetRingDimension() const;
  142. [[nodiscard]] uint32_t GetCyclotomicOrder() const;
  143. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> Encrypt(
  144. const std::shared_ptr<PublicKeyImpl> publicKey, const Plaintext& plaintext) const;
  145. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalAdd(
  146. const CiphertextDCRTPoly& ciphertext1, const CiphertextDCRTPoly& ciphertext2) const;
  147. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalSub(
  148. const CiphertextDCRTPoly& ciphertext1, const CiphertextDCRTPoly& ciphertext2) const;
  149. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalMult(
  150. const CiphertextDCRTPoly& ciphertext1, const CiphertextDCRTPoly& ciphertext2) const;
  151. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalMultNoRelin(
  152. const CiphertextDCRTPoly& ciphertext1, const CiphertextDCRTPoly& ciphertext2) const;
  153. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalMultAndRelinearize(
  154. const CiphertextDCRTPoly& ciphertext1, const CiphertextDCRTPoly& ciphertext2) const;
  155. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalMultByConst(
  156. const CiphertextDCRTPoly& ciphertext, const double constant) const;
  157. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalRotate(
  158. const CiphertextDCRTPoly& ciphertext, const int32_t index) const;
  159. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalPoly(
  160. const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients) const;
  161. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalChebyshevSeries(
  162. const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients,
  163. const double a, const double b) const;
  164. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalBootstrap(
  165. const CiphertextDCRTPoly& ciphertext, const uint32_t numIterations /* 1 */,
  166. const uint32_t precision /* 0 */) const;
  167. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> Rescale(
  168. const CiphertextDCRTPoly& ciphertext) const;
  169. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> ModReduce(
  170. const CiphertextDCRTPoly& ciphertext) const;
  171. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalSum(const CiphertextDCRTPoly& ciphertext,
  172. const uint32_t batchSize) const;
  173. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> IntMPBootAdjustScale(
  174. const CiphertextDCRTPoly& ciphertext) const;
  175. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> IntMPBootRandomElementGen(
  176. const PublicKeyDCRTPoly& publicKey) const;
  177. void EvalBootstrapSetup(const std::vector<uint32_t>& levelBudget /* {5, 4} */,
  178. const std::vector<uint32_t>& dim1 /* {0, 0} */, const uint32_t slots /* 0 */,
  179. const uint32_t correctionFactor /* 0 */, const bool precompute /* true */) const;
  180. void EvalBootstrapKeyGen(const std::shared_ptr<PrivateKeyImpl> privateKey,
  181. const uint32_t slots) const;
  182. [[nodiscard]] std::unique_ptr<DecryptResult> Decrypt(
  183. const std::shared_ptr<PrivateKeyImpl> privateKey, const CiphertextDCRTPoly& ciphertext,
  184. Plaintext& plaintext) const;
  185. [[nodiscard]] std::unique_ptr<Plaintext> MakePackedPlaintext(
  186. const std::vector<int64_t>& value, const size_t noiseScaleDeg /* 1 */,
  187. const uint32_t level /* 0 */) const;
  188. [[nodiscard]] std::unique_ptr<Plaintext> MakeCKKSPackedPlaintext(
  189. const std::vector<double>& value, const size_t scaleDeg /* 1 */,
  190. const uint32_t level /* 0 */, const std::shared_ptr<DCRTPolyParams> params /* nullptr */,
  191. const uint32_t slots /* 0 */) const;
  192. [[nodiscard]] std::unique_ptr<Plaintext> MakeCKKSPackedPlaintextByVectorOfComplex(
  193. const std::vector<SharedComplex>& value, const size_t scaleDeg /* 1 */,
  194. const uint32_t level /* 0 */, const std::shared_ptr<DCRTPolyParams> params /* nullptr */,
  195. const uint32_t slots /* 0 */) const;
  196. };
  197. ///////////////////////////////////////////////////////////////////////////////////////////////////
  198. [[nodiscard]] std::unique_ptr<std::vector<SharedComplex>> GenVectorOfComplex(
  199. const std::vector<ComplexPair>& vals);
  200. [[nodiscard]] std::unique_ptr<Params> GetParamsByScheme(const SCHEME scheme);
  201. [[nodiscard]] std::unique_ptr<Params> GetParamsByVectorOfString(
  202. const std::vector<std::string>& vals);
  203. [[nodiscard]] std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNS();
  204. [[nodiscard]] std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNSbyVectorOfString(
  205. const std::vector<std::string>& vals);
  206. [[nodiscard]] std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNS();
  207. [[nodiscard]] std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNSbyVectorOfString(
  208. const std::vector<std::string>& vals);
  209. [[nodiscard]] std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNS();
  210. [[nodiscard]] std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNSbyVectorOfString(
  211. const std::vector<std::string>& vals);
  212. [[nodiscard]] std::unique_ptr<Plaintext> GenEmptyPlainText();
  213. [[nodiscard]] std::unique_ptr<CryptoContextDCRTPoly> GenEmptyCryptoContext();
  214. [[nodiscard]] std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBFVRNS(
  215. const ParamsBFVRNS& params);
  216. [[nodiscard]] std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBGVRNS(
  217. const ParamsBGVRNS& params);
  218. [[nodiscard]] std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsCKKSRNS(
  219. const ParamsCKKSRNS& params);
  220. [[nodiscard]] std::unique_ptr<PublicKeyDCRTPoly> GenDefaultConstructedPublicKey();
  221. [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> GenDefaultConstructedCiphertext();
  222. ///////////////////////////////////////////////////////////////////////////////////////////////////
  223. bool SerializeCryptoContextToFile(const std::string& ccLocation,
  224. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  225. bool DeserializeCryptoContextFromFile(const std::string& ccLocation,
  226. CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  227. bool SerializeEvalMultKeyToFile(const std::string& multKeyLocation,
  228. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  229. bool SerializeEvalMultKeyByIdToFile(const std::string& multKeyLocation,
  230. const SerialMode serialMode, const std::string& id);
  231. bool DeserializeEvalMultKeyFromFile(const std::string& multKeyLocation,
  232. const SerialMode serialMode);
  233. bool SerializeEvalSumKeyToFile(const std::string& sumKeyLocation,
  234. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  235. bool SerializeEvalSumKeyByIdToFile(const std::string& sumKeyLocation,
  236. const SerialMode serialMode, const std::string& id);
  237. bool DeserializeEvalSumKeyFromFile(const std::string& sumKeyLocation, const SerialMode serialMode);
  238. bool SerializeEvalAutomorphismKeyToFile(const std::string& automorphismKeyLocation,
  239. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
  240. bool SerializeEvalAutomorphismKeyByIdToFile(const std::string& automorphismKeyLocation,
  241. const SerialMode serialMode, const std::string& id);
  242. bool DeserializeEvalAutomorphismKeyFromFile(const std::string& automorphismKeyLocation,
  243. const SerialMode serialMode);
  244. bool SerializeCiphertextToFile(const std::string& ciphertextLocation,
  245. const CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
  246. bool DeserializeCiphertextFromFile(const std::string& ciphertextLocation,
  247. CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
  248. bool SerializePublicKeyToFile(const std::string& publicKeyLocation,
  249. const PublicKeyDCRTPoly& publicKey, const SerialMode serialMode);
  250. bool DeserializePublicKeyFromFile(const std::string& publicKeyLocation,
  251. PublicKeyDCRTPoly& publicKey, const SerialMode serialMode);
  252. } // openfhe