direct.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #pragma once
  2. #include <memory>
  3. #include <sstream>
  4. #include "rust/cxx.h" // rust::String
  5. #include "openfhe/pke/scheme/gen-cryptocontext-params.h"
  6. #include "openfhe/pke/scheme/bfvrns/gen-cryptocontext-bfvrns-params.h"
  7. #include "openfhe/pke/scheme/bgvrns/gen-cryptocontext-bgvrns-params.h"
  8. #include "openfhe/pke/scheme/ckksrns/gen-cryptocontext-ckksrns-params.h"
  9. #include "openfhe/pke/gen-cryptocontext.h"
  10. #include "openfhe/pke/scheme/bfvrns/gen-cryptocontext-bfvrns.h"
  11. #include "openfhe/pke/scheme/bgvrns/gen-cryptocontext-bgvrns.h"
  12. #include "openfhe/pke/scheme/ckksrns/gen-cryptocontext-ckksrns.h"
  13. #include "openfhe/pke/scheme/scheme-id.h" // enums
  14. #include "openfhe/core/lattice/constants-lattice.h" // enums
  15. #include "openfhe/pke/constants-fwd.h" // enumss
  16. #include "openfhe/core/lattice/stdlatticeparms.h" // enums
  17. #include "openfhe/pke/key/keypair.h"
  18. #include "openfhe/core/utils/inttypes.h"
  19. #include "openfhe/pke/key/keypair.h"
  20. #include "openfhe/pke/key/privatekey.h"
  21. #include "openfhe/pke/key/publickey.h"
  22. #include "openfhe/pke/ciphertext.h"
  23. #include "openfhe/pke/encoding/plaintext.h"
  24. #include "openfhe/pke/schemebase/decrypt-result.h"
  25. namespace openfhe_rs_dev
  26. {
  27. using ParamsBFVRNS = lbcrypto::CCParams<lbcrypto::CryptoContextBFVRNS>;
  28. using ParamsBGVRNS = lbcrypto::CCParams<lbcrypto::CryptoContextBGVRNS>;
  29. using ParamsCKKSRNS = lbcrypto::CCParams<lbcrypto::CryptoContextCKKSRNS>;
  30. using Params = lbcrypto::Params;
  31. using SCHEME = lbcrypto::SCHEME;
  32. using SecretKeyDist = lbcrypto::SecretKeyDist;
  33. using ProxyReEncryptionMode = lbcrypto::ProxyReEncryptionMode;
  34. using MultipartyMode = lbcrypto::MultipartyMode;
  35. using ExecutionMode = lbcrypto::ExecutionMode;
  36. using DecryptionNoiseMode = lbcrypto::DecryptionNoiseMode;
  37. using KeySwitchTechnique = lbcrypto::KeySwitchTechnique;
  38. using ScalingTechnique = lbcrypto::ScalingTechnique;
  39. using SecurityLevel = lbcrypto::SecurityLevel;
  40. using EncryptionTechnique = lbcrypto::EncryptionTechnique;
  41. using MultiplicationTechnique = lbcrypto::MultiplicationTechnique;
  42. using COMPRESSION_LEVEL = lbcrypto::COMPRESSION_LEVEL;
  43. using PKESchemeFeature = lbcrypto::PKESchemeFeature;
  44. using PublicKeyImpl = lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>;
  45. using PrivateKeyImpl = lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>;
  46. using PlaintextImpl = lbcrypto::PlaintextImpl;
  47. using CiphertextImpl = lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>;
  48. using DecryptResult = lbcrypto::DecryptResult;
  49. using DCRTPolyParams = lbcrypto::DCRTPoly::Params;
  50. struct ComplexPair;
  51. using VectorOfComplexNumbers = std::vector<std::complex<double>>;
  52. class KeyPairDCRTPoly final
  53. {
  54. private:
  55. std::shared_ptr<PublicKeyImpl> m_publicKey;
  56. std::shared_ptr<PrivateKeyImpl> m_privateKey;
  57. public:
  58. // TODO: think about all special functions of class
  59. explicit KeyPairDCRTPoly(lbcrypto::KeyPair<lbcrypto::DCRTPoly> keyPair)
  60. : m_publicKey(keyPair.publicKey)
  61. , m_privateKey(keyPair.secretKey)
  62. { }
  63. std::shared_ptr<PublicKeyImpl> GetPublicKey() const
  64. {
  65. return m_publicKey;
  66. }
  67. std::shared_ptr<PrivateKeyImpl> GetPrivateKey() const
  68. {
  69. return m_privateKey;
  70. }
  71. // TODO: implement necessary member functions
  72. };
  73. class Plaintext final
  74. {
  75. private:
  76. std::shared_ptr<lbcrypto::PlaintextImpl> m_plaintext;
  77. public:
  78. // TODO: think about all special functions of class
  79. explicit Plaintext() = default;
  80. explicit Plaintext(std::shared_ptr<lbcrypto::PlaintextImpl> plaintext)
  81. : m_plaintext(plaintext)
  82. { }
  83. Plaintext& operator=(std::shared_ptr<lbcrypto::PlaintextImpl> plaintext)
  84. {
  85. m_plaintext = plaintext;
  86. return *this;
  87. }
  88. std::shared_ptr<lbcrypto::PlaintextImpl> GetPlainText() const
  89. {
  90. return m_plaintext;
  91. }
  92. void SetLength(const size_t newSize) const
  93. {
  94. if (m_plaintext)
  95. {
  96. m_plaintext->SetLength(newSize);
  97. }
  98. }
  99. double GetLogPrecision() const
  100. {
  101. return m_plaintext->GetLogPrecision();
  102. }
  103. rust::String GetString() const
  104. {
  105. if (m_plaintext)
  106. {
  107. std::stringstream stream;
  108. stream << *m_plaintext;
  109. return rust::String(stream.str());
  110. }
  111. return rust::String();
  112. }
  113. // TODO: implement necessary member functions
  114. };
  115. class CiphertextDCRTPoly final
  116. {
  117. private:
  118. std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> m_ciphertext;
  119. public:
  120. // TODO: think about all special functions of class
  121. explicit CiphertextDCRTPoly(std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext)
  122. : m_ciphertext(ciphertext)
  123. { }
  124. std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> GetCipherText() const
  125. {
  126. return m_ciphertext;
  127. }
  128. // TODO: implement necessary member functions
  129. };
  130. class CryptoContextDCRTPoly final
  131. {
  132. private:
  133. std::shared_ptr<lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>> m_cryptoContextImplSharedPtr;
  134. public:
  135. // TODO: think about all special functions of class
  136. explicit CryptoContextDCRTPoly(const ParamsBFVRNS& params)
  137. : m_cryptoContextImplSharedPtr(lbcrypto::GenCryptoContext(params))
  138. { }
  139. explicit CryptoContextDCRTPoly(const ParamsBGVRNS& params)
  140. : m_cryptoContextImplSharedPtr(lbcrypto::GenCryptoContext(params))
  141. { }
  142. explicit CryptoContextDCRTPoly(const ParamsCKKSRNS& params)
  143. : m_cryptoContextImplSharedPtr(lbcrypto::GenCryptoContext(params))
  144. { }
  145. void Enable(const PKESchemeFeature feature) const
  146. {
  147. m_cryptoContextImplSharedPtr->Enable(feature);
  148. }
  149. std::unique_ptr<KeyPairDCRTPoly> KeyGen() const
  150. {
  151. return std::make_unique<KeyPairDCRTPoly>(m_cryptoContextImplSharedPtr->KeyGen());
  152. }
  153. void EvalMultKeyGen(const std::shared_ptr<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>> key) const
  154. {
  155. m_cryptoContextImplSharedPtr->EvalMultKeyGen(key);
  156. }
  157. void EvalRotateKeyGen(const std::shared_ptr<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>> privateKey, const std::vector<int32_t>& indexList,
  158. const std::shared_ptr<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>> publicKey) const // publicKey = nullptr in original. Rust don't support default args.
  159. {
  160. m_cryptoContextImplSharedPtr->EvalRotateKeyGen(privateKey, indexList, publicKey);
  161. }
  162. std::unique_ptr<Plaintext> MakePackedPlaintext(const std::vector<int64_t>& value, const size_t noiseScaleDeg, const uint32_t level) const // noiseScaleDeg = 1, level = 0
  163. {
  164. return std::make_unique<Plaintext>(m_cryptoContextImplSharedPtr->MakePackedPlaintext(value, noiseScaleDeg, level));
  165. }
  166. std::unique_ptr<CiphertextDCRTPoly> Encrypt(const std::shared_ptr<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>> publicKey, std::shared_ptr<lbcrypto::PlaintextImpl> plaintext) const
  167. {
  168. return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->Encrypt(publicKey, plaintext));
  169. }
  170. std::unique_ptr<CiphertextDCRTPoly> EvalAdd(std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext1,
  171. std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext2) const
  172. {
  173. return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalAdd(ciphertext1, ciphertext2));
  174. }
  175. std::unique_ptr<CiphertextDCRTPoly> EvalSub(std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext1,
  176. std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext2) const
  177. {
  178. return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalSub(ciphertext1, ciphertext2));
  179. }
  180. std::unique_ptr<CiphertextDCRTPoly> EvalMult(std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext1,
  181. std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext2) const
  182. {
  183. return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalMult(ciphertext1, ciphertext2));
  184. }
  185. std::unique_ptr<CiphertextDCRTPoly> EvalMultByConst(std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext, const double constant) const
  186. {
  187. return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalMult(ciphertext, constant));
  188. }
  189. std::unique_ptr<CiphertextDCRTPoly> EvalRotate(std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext, const int32_t index) const
  190. {
  191. return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalRotate(ciphertext, index));
  192. }
  193. std::unique_ptr<DecryptResult> Decrypt(const std::shared_ptr<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>> privateKey,
  194. std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext, Plaintext& plaintext) const
  195. {
  196. std::shared_ptr<lbcrypto::PlaintextImpl> res;
  197. std::unique_ptr<DecryptResult> result = std::make_unique<DecryptResult>(m_cryptoContextImplSharedPtr->Decrypt(privateKey, ciphertext, &res));
  198. plaintext = res;
  199. return result;
  200. }
  201. uint32_t GetRingDimension() const
  202. {
  203. return m_cryptoContextImplSharedPtr->GetRingDimension();
  204. }
  205. std::unique_ptr<Plaintext> MakeCKKSPackedPlaintext(const std::vector<double>& value, const size_t scaleDeg, const uint32_t level,
  206. const std::shared_ptr<DCRTPolyParams> params, const uint32_t slots) const
  207. // scaleDeg = 1, level = 0, params = nullptr, slots = 0
  208. {
  209. return std::make_unique<Plaintext>(m_cryptoContextImplSharedPtr->MakeCKKSPackedPlaintext(value, scaleDeg, level, params, slots));
  210. }
  211. std::unique_ptr<Plaintext> MakeCKKSPackedPlaintextByVectorOfComplexNumbers(const std::vector<std::complex<double>>& value, const size_t scaleDeg, const uint32_t level,
  212. const std::shared_ptr<DCRTPolyParams> params, const uint32_t slots) const
  213. // scaleDeg = 1, level = 0, params = nullptr, slots = 0
  214. {
  215. return std::make_unique<Plaintext>(m_cryptoContextImplSharedPtr->MakeCKKSPackedPlaintext(value, scaleDeg, level, params, slots));
  216. }
  217. std::unique_ptr<CiphertextDCRTPoly> EvalPoly(std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext, const std::vector<double>& coefficients) const
  218. {
  219. return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalPoly(ciphertext, coefficients));
  220. }
  221. };
  222. std::unique_ptr<VectorOfComplexNumbers> GenVectorOfComplexNumbers(const std::vector<ComplexPair>& vals);
  223. std::unique_ptr<Params> GetParamsByScheme(const SCHEME scheme);
  224. std::unique_ptr<Params> GetParamsByVectorOfString(const std::vector<std::string>& vals);
  225. std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNS();
  226. std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNSbyVectorOfString(const std::vector<std::string>& vals);
  227. std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNS();
  228. std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNSbyVectorOfString(const std::vector<std::string>& vals);
  229. std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNS();
  230. std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNSbyVectorOfString(const std::vector<std::string>& vals);
  231. std::unique_ptr<Plaintext> GenEmptyPlainText();
  232. std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBFVRNS(const ParamsBFVRNS& params);
  233. std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBGVRNS(const ParamsBGVRNS& params);
  234. std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsCKKSRNS(const ParamsCKKSRNS& params);
  235. } // openfhe_rs_dev