bindings.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #include <pybind11/pybind11.h>
  2. #include <pybind11/stl.h>
  3. #include <pybind11/iostream.h>
  4. #include <iostream>
  5. #include "openfhe.h"
  6. #include "key/key-ser.h"
  7. #include "bindings.h"
  8. #include "cryptocontext_wrapper.h"
  9. using namespace lbcrypto;
  10. namespace py = pybind11;
  11. void bind_parameters(py::module &m)
  12. {
  13. py::class_<Params>(m, "Params");
  14. py::class_<CCParams<CryptoContextBFVRNS>, Params>(m, "CCParamsBFVRNS")
  15. .def(py::init<>())
  16. // setters
  17. .def("SetPlaintextModulus", &CCParams<CryptoContextBFVRNS>::SetPlaintextModulus)
  18. .def("SetMultiplicativeDepth", &CCParams<CryptoContextBFVRNS>::SetMultiplicativeDepth)
  19. // getters
  20. .def("GetPlaintextModulus", &CCParams<CryptoContextBFVRNS>::GetPlaintextModulus)
  21. .def("GetMultiplicativeDepth", &CCParams<CryptoContextBFVRNS>::GetMultiplicativeDepth);
  22. py::class_<CCParams<CryptoContextBGVRNS>, Params>(m, "CCParamsBGVRNS")
  23. .def(py::init<>())
  24. // setters
  25. .def("SetPlaintextModulus", &CCParams<CryptoContextBGVRNS>::SetPlaintextModulus)
  26. .def("SetMultiplicativeDepth", &CCParams<CryptoContextBGVRNS>::SetMultiplicativeDepth)
  27. // getters
  28. .def("GetPlaintextModulus", &CCParams<CryptoContextBGVRNS>::GetPlaintextModulus)
  29. .def("GetMultiplicativeDepth", &CCParams<CryptoContextBGVRNS>::GetMultiplicativeDepth);
  30. // bind ckks rns params
  31. py::class_<CCParams<CryptoContextCKKSRNS>, Params>(m, "CCParamsCKKSRNS")
  32. .def(py::init<>())
  33. // setters
  34. .def("SetPlaintextModulus", &CCParams<CryptoContextCKKSRNS>::SetPlaintextModulus)
  35. .def("SetMultiplicativeDepth", &CCParams<CryptoContextCKKSRNS>::SetMultiplicativeDepth)
  36. .def("SetScalingModSize", &CCParams<CryptoContextCKKSRNS>::SetScalingModSize)
  37. .def("SetBatchSize", &CCParams<CryptoContextCKKSRNS>::SetBatchSize)
  38. // getters
  39. .def("GetPlaintextModulus", &CCParams<CryptoContextCKKSRNS>::GetPlaintextModulus)
  40. .def("GetMultiplicativeDepth", &CCParams<CryptoContextCKKSRNS>::GetMultiplicativeDepth)
  41. .def("GetScalingModSize", &CCParams<CryptoContextCKKSRNS>::GetScalingModSize)
  42. .def("GetBatchSize", &CCParams<CryptoContextCKKSRNS>::GetBatchSize);
  43. }
  44. void bind_crypto_context(py::module &m)
  45. {
  46. py::class_<CryptoContextImpl<DCRTPoly>, std::shared_ptr<CryptoContextImpl<DCRTPoly>>>(m, "CryptoContext")
  47. .def(py::init<>())
  48. .def("GetKeyGenLevel", &CryptoContextImpl<DCRTPoly>::GetKeyGenLevel)
  49. .def("SetKeyGenLevel", &CryptoContextImpl<DCRTPoly>::SetKeyGenLevel)
  50. .def("GetRingDimension", &CryptoContextImpl<DCRTPoly>::GetRingDimension)
  51. .def("Enable", static_cast<void (CryptoContextImpl<DCRTPoly>::*)(PKESchemeFeature)>(&CryptoContextImpl<DCRTPoly>::Enable), "Enable a feature for the CryptoContext")
  52. .def("KeyGen", &CryptoContextImpl<DCRTPoly>::KeyGen, "Generate a key pair with public and private keys")
  53. .def("EvalMultKeyGen", &CryptoContextImpl<DCRTPoly>::EvalMultKeyGen, "Generate the evaluation key for multiplication")
  54. .def("EvalRotateKeyGen", &CryptoContextImpl<DCRTPoly>::EvalRotateKeyGen, "Generate the evaluation key for rotation",
  55. py::arg("privateKey"), py::arg("indexList"), py::arg("publicKey") = nullptr)
  56. .def("MakePackedPlaintext", &CryptoContextImpl<DCRTPoly>::MakePackedPlaintext, "Make a plaintext from a vector of integers",
  57. py::arg("value"), py::arg("depth") = 1, py::arg("level") = 0)
  58. .def("MakeCKKSPackedPlaintext",&MakeCKKSPackedPlaintextWrapper, "Make a CKKS plaintext from a vector of floats",
  59. py::arg("value"),
  60. py::arg("depth") = static_cast<size_t>(1),
  61. py::arg("level") = static_cast<uint32_t>(0),
  62. py::arg("params") = py::none(),
  63. py::arg("slots") = 0)
  64. .def("EvalRotate", &CryptoContextImpl<DCRTPoly>::EvalRotate, "Rotate a ciphertext")
  65. .def("Encrypt", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(const PublicKey<DCRTPoly>, Plaintext) const>(&CryptoContextImpl<DCRTPoly>::Encrypt),
  66. "Encrypt a plaintext using public key")
  67. .def("EvalAdd", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalAdd), "Add two ciphertexts")
  68. .def("EvalSub", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalSub), "Subtract two ciphertexts")
  69. .def("EvalMult", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalMult), "Multiply two ciphertexts")
  70. .def("EvalMult", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, double) const>(&CryptoContextImpl<DCRTPoly>::EvalMult), "Multiply a ciphertext with a scalar")
  71. .def_static(
  72. "ClearEvalMultKeys", []()
  73. { CryptoContextImpl<DCRTPoly>::ClearEvalMultKeys(); },
  74. "Clear the evaluation keys for multiplication")
  75. .def_static(
  76. "ClearEvalAutomorphismKeys", []()
  77. { CryptoContextImpl<DCRTPoly>::ClearEvalAutomorphismKeys(); },
  78. "Clear the evaluation keys for rotation")
  79. .def_static(
  80. "SerializeEvalMultKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string id = "")
  81. {
  82. std::ofstream outfile(filename,std::ios::out | std::ios::binary);
  83. bool res;
  84. res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<SerType::SERBINARY>(outfile, sertype, id);
  85. outfile.close();
  86. return res; },
  87. py::arg("filename"), py::arg("sertype"), py::arg("id") = "",
  88. "Serialize an evaluation key for multiplication")
  89. .def_static(
  90. "SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string id = "")
  91. {
  92. std::ofstream outfile(filename,std::ios::out | std::ios::binary);
  93. bool res;
  94. res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<SerType::SERBINARY>(outfile, sertype, id);
  95. outfile.close();
  96. return res; },
  97. py::arg("filename"), py::arg("sertype"), py::arg("id") = "", "Serialize evaluation keys for rotation")
  98. .def_static("DeserializeEvalMultKey", [](std::shared_ptr<CryptoContextImpl<DCRTPoly>> &self,const std::string &filename, const SerType::SERBINARY &sertype)
  99. {
  100. std::ifstream emkeys(filename, std::ios::in | std::ios::binary);
  101. if (!emkeys.is_open()) {
  102. std::cerr << "I cannot read serialization from " << filename << std::endl;
  103. }
  104. bool res;
  105. res = self->DeserializeEvalMultKey<SerType::SERBINARY>(emkeys, sertype);
  106. return res; })
  107. .def_static("DeserializeEvalAutomorphismKey", [](std::shared_ptr<CryptoContextImpl<DCRTPoly>> &self,const std::string &filename, const SerType::SERBINARY &sertype)
  108. {
  109. std::ifstream erkeys(filename, std::ios::in | std::ios::binary);
  110. if (!erkeys.is_open()) {
  111. std::cerr << "I cannot read serialization from " << filename << std::endl;
  112. }
  113. bool res;
  114. res = self->DeserializeEvalAutomorphismKey<SerType::SERBINARY>(erkeys, sertype);
  115. return res; });
  116. // Generator Functions
  117. m.def("GenCryptoContext", &GenCryptoContext<CryptoContextBFVRNS>);
  118. m.def("GenCryptoContext", &GenCryptoContext<CryptoContextBGVRNS>);
  119. m.def("GenCryptoContext", &GenCryptoContext<CryptoContextCKKSRNS>);
  120. m.def("ReleaseAllContexts", &CryptoContextFactory<DCRTPoly>::ReleaseAllContexts);
  121. }
  122. void bind_enums_and_constants(py::module &m)
  123. {
  124. // Scheme Types
  125. py::enum_<SCHEME>(m, "SCHEME")
  126. .value("INVALID_SCHEME", SCHEME::INVALID_SCHEME)
  127. .value("CKKSRNS_SCHEME", SCHEME::CKKSRNS_SCHEME)
  128. .value("BFVRNS_SCHEME", SCHEME::BFVRNS_SCHEME)
  129. .value("BGVRNS_SCHEME", SCHEME::BGVRNS_SCHEME);
  130. // PKE Features
  131. py::enum_<PKESchemeFeature>(m, "PKESchemeFeature")
  132. .value("PKE", PKESchemeFeature::PKE)
  133. .value("KEYSWITCH", PKESchemeFeature::KEYSWITCH)
  134. .value("PRE", PKESchemeFeature::PRE)
  135. .value("LEVELEDSHE", PKESchemeFeature::LEVELEDSHE)
  136. .value("ADVANCEDSHE", PKESchemeFeature::ADVANCEDSHE)
  137. .value("MULTIPARTY", PKESchemeFeature::MULTIPARTY)
  138. .value("FHE", PKESchemeFeature::FHE);
  139. // Serialization Types
  140. py::class_<SerType::SERJSON>(m, "SERJSON");
  141. py::class_<SerType::SERBINARY>(m, "SERBINARY");
  142. m.attr("JSON") = py::cast(SerType::JSON);
  143. m.attr("BINARY") = py::cast(SerType::BINARY);
  144. //Parameters Type
  145. using ParmType = typename DCRTPoly::Params;
  146. py::class_<ParmType, std::shared_ptr<ParmType>>(m, "ParmType");
  147. }
  148. void bind_keys(py::module &m)
  149. {
  150. py::class_<PublicKeyImpl<DCRTPoly>, std::shared_ptr<PublicKeyImpl<DCRTPoly>>>(m, "PublicKey")
  151. .def(py::init<>());
  152. py::class_<PrivateKeyImpl<DCRTPoly>, std::shared_ptr<PrivateKeyImpl<DCRTPoly>>>(m, "PrivateKey");
  153. py::class_<KeyPair<DCRTPoly>>(m, "KeyPair")
  154. .def_readwrite("publicKey", &KeyPair<DCRTPoly>::publicKey)
  155. .def_readwrite("secretKey", &KeyPair<DCRTPoly>::secretKey);
  156. }
  157. void bind_encodings(py::module &m)
  158. {
  159. py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>>(m, "Plaintext")
  160. .def("GetScalingFactor", &PlaintextImpl::GetScalingFactor)
  161. .def("SetScalingFactor", &PlaintextImpl::SetScalingFactor)
  162. .def("GetLength", &PlaintextImpl::GetLength)
  163. .def("GetSchemeID", &PlaintextImpl::GetSchemeID)
  164. .def("SetLength", &PlaintextImpl::SetLength)
  165. .def("IsEncoded", &PlaintextImpl::IsEncoded)
  166. .def("GetLogPrecision", &PlaintextImpl::GetLogPrecision)
  167. //.def("GetEncondingParams", &PlaintextImpl::GetEncondingParams)
  168. .def("Encode", &PlaintextImpl::Encode)
  169. .def("Decode", &PlaintextImpl::Decode)
  170. .def("__repr__", [](const PlaintextImpl &p)
  171. {
  172. std::stringstream ss;
  173. ss << "<Plaintext Object: ";
  174. p.PrintValue(ss);
  175. ss << ">";
  176. return ss.str(); })
  177. .def("__str__", [](const PlaintextImpl &p)
  178. {
  179. std::stringstream ss;
  180. p.PrintValue(ss);
  181. return ss.str(); });
  182. }
  183. void bind_ciphertext(py::module &m)
  184. {
  185. py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
  186. .def(py::init<>());
  187. // .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth)
  188. // .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth)
  189. // .def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel)
  190. // .def("SetLevel", &CiphertextImpl<DCRTPoly>::SetLevel)
  191. // .def("GetHopLevel", &CiphertextImpl<DCRTPoly>::GetHopLevel)
  192. // .def("SetHopLevel", &CiphertextImpl<DCRTPoly>::SetHopLevel)
  193. // .def("GetScalingFactor", &CiphertextImpl<DCRTPoly>::GetScalingFactor)
  194. // .def("SetScalingFactor", &CiphertextImpl<DCRTPoly>::SetScalingFactor)
  195. // .def("GetSlots", &CiphertextImpl<DCRTPoly>::GetSlots)
  196. // .def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots);
  197. }
  198. PYBIND11_MODULE(openfhe, m)
  199. {
  200. m.doc() = "Open-Source Fully Homomorphic Encryption Library";
  201. bind_parameters(m);
  202. bind_crypto_context(m);
  203. bind_enums_and_constants(m);
  204. bind_keys(m);
  205. bind_encodings(m);
  206. bind_ciphertext(m);
  207. bind_decryption(m);
  208. bind_serialization(m);
  209. }