Преглед изворни кода

Added bindings for new APIs and constants defined in OpenFHE v1.3.0 (#221)

Co-authored-by: Dmitriy Suponitskiy <dsuponitskiy@dualitytech.com>
dsuponitskiy пре 9 месеци
родитељ
комит
8686a7e425

+ 41 - 0
src/include/docstrings/cryptocontext_docs.h

@@ -1071,6 +1071,47 @@ const char* cc_MultiAddEvalMultKeys_docs = R"pbdoc(
     :rtype: EvalKey
 )pbdoc";
 
+const char* cc_IntBootDecrypt_docs = R"pbdoc(
+    Performs masked decryption for interactive bootstrapping (2-party protocol).
+
+    :param privateKey: Secret key share
+    :type privateKey: PrivateKey
+    :param ciphertext: Input Ciphertext
+    :type ciphertext: Ciphertext
+    :return: Resulting ciphertext
+    :rtype: Ciphertext
+)pbdoc";
+
+const char* cc_IntBootEncrypt_docs = R"pbdoc(
+    Encrypts Client's masked decryption for interactive bootstrapping. Increases ciphertext modulus to allow further computation. Done by Client.
+
+    :param publicKey: Joined public key (Threshold FHE)
+    :type publicKey: PublicKey
+    :param ciphertext: Input Ciphertext
+    :type ciphertext: Ciphertext
+    :return: Resulting ciphertext
+    :rtype: Ciphertext
+)pbdoc";
+
+const char* cc_IntBootAdd_docs = R"pbdoc(
+    Combines encrypted and unencrypted masked decryptions in 2-party interactive bootstrapping. It is the last step in the boostrapping.
+
+    :param ciphertext1: Encrypted masked decryption
+    :type ciphertext1: Ciphertext
+    :param ciphertext2: Unencrypted masked decryption
+    :type ciphertext2: Ciphertext
+    :return: Refreshed ciphertext
+    :rtype: Ciphertext
+)pbdoc";
+
+const char* cc_IntBootAdjustScale_docs = R"pbdoc(
+    Prepares a ciphertext for interactive bootstrapping.
+
+    :param ciphertext: Input ciphertext
+    :type ciphertext: Ciphertext
+    :return: Adjusted ciphertext
+    :rtype: Ciphertext
+)pbdoc";
 
 const char* cc_IntMPBootAdjustScale_docs = R"pbdoc(
     Threshold FHE: Prepare a ciphertext for Multi-Party Interactive Bootstrapping.

+ 3 - 0
src/include/docstrings/cryptoparameters_docs.h

@@ -63,6 +63,9 @@ const char* ccparams_doc = R"doc(
     :ivar int multiHopModSize: size of moduli used for PRE in the provable HRA setting
     :ivar EncryptionTechnique encryptionTechnique: STANDARD or EXTENDED mode for BFV encryption
     :ivar MultiplicationTechnique multiplicationTechnique: multiplication method in BFV: BEHZ, HPS, etc.
+    :ivar CKKSDataType ckksDataType: CKKS data type: real or complex. Noise flooding is only enabled for real values.
+    :ivar uint32_t compositeDegree: parameter to support high-precision CKKS RNS with small word sizes
+    :ivar uint32_t registerWordSize: parameter to support high-precision CKKS RNS with small word sizes
 )doc";
 
 const char* cc_GetScalingFactorReal_docs = R"pbdoc(

+ 76 - 1
src/lib/bindings.cpp

@@ -54,6 +54,13 @@ namespace py = pybind11;
 // disable the PYBIND11 template-based conversion for this type
 PYBIND11_MAKE_OPAQUE(std::map<uint32_t, EvalKey<DCRTPoly>>);
 
+inline std::shared_ptr<CryptoParametersRNS> GetParamsRNSChecked(const CryptoContext<DCRTPoly>& self, const std::string& func) {
+    auto ptr = std::dynamic_pointer_cast<CryptoParametersRNS>(self->GetCryptoParameters());
+    if (!ptr)
+        OPENFHE_THROW("Failed to cast to CryptoParametersRNS in " + func + "()");
+    return ptr;
+}
+
 template <typename T>
 void bind_parameters(py::module &m,const std::string name)
 {
@@ -90,6 +97,9 @@ void bind_parameters(py::module &m,const std::string name)
         .def("GetMultiplicationTechnique", &CCParams<T>::GetMultiplicationTechnique)
         .def("GetPRENumHops", &CCParams<T>::GetPRENumHops)
         .def("GetInteractiveBootCompressionLevel", &CCParams<T>::GetInteractiveBootCompressionLevel)
+        .def("GetCompositeDegree", &CCParams<T>::GetCompositeDegree)
+        .def("GetRegisterWordSize", &CCParams<T>::GetRegisterWordSize)
+        .def("GetCKKSDataType", &CCParams<T>::GetCKKSDataType)
         // setters
         .def("SetPlaintextModulus", &CCParams<T>::SetPlaintextModulus)
         .def("SetDigitSize", &CCParams<T>::SetDigitSize)
@@ -120,6 +130,9 @@ void bind_parameters(py::module &m,const std::string name)
         .def("SetMultiplicationTechnique", &CCParams<T>::SetMultiplicationTechnique)
         .def("SetPRENumHops", &CCParams<T>::SetPRENumHops)
         .def("SetInteractiveBootCompressionLevel", &CCParams<T>::SetInteractiveBootCompressionLevel)
+        .def("SetCompositeDegree", &CCParams<T>::SetCompositeDegree)
+        .def("SetRegisterWordSize", &CCParams<T>::SetRegisterWordSize)
+        .def("SetCKKSDataType", &CCParams<T>::SetCKKSDataType)
         .def("__str__",[](const CCParams<T> &params) {
             std::stringstream stream;
             stream << params;
@@ -155,6 +168,43 @@ void bind_crypto_context(py::module &m)
         .def("GetScalingTechnique",&GetScalingTechniqueWrapper)
         .def("GetDigitSize", &GetDigitSizeWrapper)
         .def("GetCyclotomicOrder", &CryptoContextImpl<DCRTPoly>::GetCyclotomicOrder, cc_GetCyclotomicOrder_docs)
+        .def("GetCKKSDataType", &CryptoContextImpl<DCRTPoly>::GetCKKSDataType)
+        .def("GetNoiseEstimate", [](CryptoContext<DCRTPoly>& self) {
+            return GetParamsRNSChecked(self, "GetNoiseEstimate")->GetNoiseEstimate();
+        })
+        .def("SetNoiseEstimate", [](CryptoContext<DCRTPoly>& self, double noiseEstimate) {
+            GetParamsRNSChecked(self, "SetNoiseEstimate")->SetNoiseEstimate(noiseEstimate);
+        }, py::arg("noiseEstimate"))
+        .def("GetMultiplicativeDepth", [](CryptoContext<DCRTPoly>& self) {
+            return GetParamsRNSChecked(self, "GetMultiplicativeDepth")->GetMultiplicativeDepth();
+        })
+        .def("SetMultiplicativeDepth", [](CryptoContext<DCRTPoly>& self, uint32_t multiplicativeDepth) {
+            GetParamsRNSChecked(self, "SetMultiplicativeDepth")->SetMultiplicativeDepth(multiplicativeDepth);
+        }, py::arg("multiplicativeDepth"))
+        .def("GetEvalAddCount", [](CryptoContext<DCRTPoly>& self) {
+            return GetParamsRNSChecked(self, "GetEvalAddCount")->GetEvalAddCount();
+        })
+        .def("SetEvalAddCount", [](CryptoContext<DCRTPoly>& self, uint32_t evalAddCount) {
+            GetParamsRNSChecked(self, "SetEvalAddCount")->SetEvalAddCount(evalAddCount);
+        }, py::arg("evalAddCount"))
+        .def("GetKeySwitchCount", [](CryptoContext<DCRTPoly>& self) {
+            return GetParamsRNSChecked(self, "GetKeySwitchCount")->GetKeySwitchCount();
+        })
+        .def("SetKeySwitchCount", [](CryptoContext<DCRTPoly>& self, uint32_t keySwitchCount) {
+            GetParamsRNSChecked(self, "SetKeySwitchCount")->SetKeySwitchCount(keySwitchCount);
+        }, py::arg("keySwitchCount"))
+        .def("GetPRENumHops", [](CryptoContext<DCRTPoly>& self) {
+            return GetParamsRNSChecked(self, "GetPRENumHops")->GetPRENumHops();
+        })
+        .def("SetPRENumHops", [](CryptoContext<DCRTPoly>& self, uint32_t PRENumHops) {
+            GetParamsRNSChecked(self, "SetPRENumHops")->SetPRENumHops(PRENumHops);
+        }, py::arg("PRENumHops"))
+        .def("GetRegisterWordSize", [](CryptoContext<DCRTPoly>& self) {
+            return GetParamsRNSChecked(self, "GetRegisterWordSize")->GetRegisterWordSize();
+        })
+        .def("GetCompositeDegree", [](CryptoContext<DCRTPoly>& self) {
+            return GetParamsRNSChecked(self, "GetCompositeDegree")->GetCompositeDegree();
+        })
         .def("Enable", static_cast<void (CryptoContextImpl<DCRTPoly>::*)(PKESchemeFeature)>(&CryptoContextImpl<DCRTPoly>::Enable), cc_Enable_docs,
              py::arg("feature"))
         .def("KeyGen", &CryptoContextImpl<DCRTPoly>::KeyGen, cc_KeyGen_docs)
@@ -572,6 +622,21 @@ void bind_crypto_context(py::module &m)
              py::arg("evalKey1"),
              py::arg("evalKey2"),
              py::arg("keyTag") = "")
+        .def("IntBootDecrypt",&CryptoContextImpl<DCRTPoly>::IntBootDecrypt,
+            cc_IntBootDecrypt_docs,
+            py::arg("privateKey"),
+            py::arg("ciphertext"))
+        .def("IntBootEncrypt",&CryptoContextImpl<DCRTPoly>::IntBootEncrypt,
+            cc_IntBootEncrypt_docs,
+            py::arg("publicKey"),
+            py::arg("ciphertext"))
+        .def("IntBootAdd",&CryptoContextImpl<DCRTPoly>::IntBootAdd,
+            cc_IntBootAdd_docs,
+            py::arg("ciphertext1"),
+            py::arg("ciphertext2"))
+        .def("IntBootAdjustScale",&CryptoContextImpl<DCRTPoly>::IntBootAdjustScale,
+            cc_IntBootAdjustScale_docs,
+            py::arg("ciphertext"))
         .def("IntMPBootAdjustScale",&CryptoContextImpl<DCRTPoly>::IntMPBootAdjustScale,
              cc_IntMPBootAdjustScale_docs,
              py::arg("ciphertext"))
@@ -971,12 +1036,16 @@ void bind_enums_and_constants(py::module &m)
        .value("FLEXIBLEAUTO", ScalingTechnique::FLEXIBLEAUTO)
        .value("FLEXIBLEAUTOEXT", ScalingTechnique::FLEXIBLEAUTOEXT)
        .value("NORESCALE", ScalingTechnique::NORESCALE)
+       .value("COMPOSITESCALINGAUTO", ScalingTechnique::COMPOSITESCALINGAUTO)
+       .value("COMPOSITESCALINGMANUAL", ScalingTechnique::COMPOSITESCALINGMANUAL)
        .value("INVALID_RS_TECHNIQUE", ScalingTechnique::INVALID_RS_TECHNIQUE);
     m.attr("FIXEDMANUAL") = py::cast(ScalingTechnique::FIXEDMANUAL);
     m.attr("FIXEDAUTO") = py::cast(ScalingTechnique::FIXEDAUTO);
     m.attr("FLEXIBLEAUTO") = py::cast(ScalingTechnique::FLEXIBLEAUTO);
     m.attr("FLEXIBLEAUTOEXT") = py::cast(ScalingTechnique::FLEXIBLEAUTOEXT);
     m.attr("NORESCALE") = py::cast(ScalingTechnique::NORESCALE);
+    m.attr("COMPOSITESCALINGAUTO") = py::cast(ScalingTechnique::COMPOSITESCALINGAUTO);
+    m.attr("COMPOSITESCALINGMANUAL") = py::cast(ScalingTechnique::COMPOSITESCALINGMANUAL);
     m.attr("INVALID_RS_TECHNIQUE") = py::cast(ScalingTechnique::INVALID_RS_TECHNIQUE);
 
     // Key Switching Techniques
@@ -1055,7 +1124,13 @@ void bind_enums_and_constants(py::module &m)
         .value("SLACK", COMPRESSION_LEVEL::SLACK);
     m.attr("COMPACT") = py::cast(COMPRESSION_LEVEL::COMPACT);
     m.attr("SLACK") = py::cast(COMPRESSION_LEVEL::SLACK);
-        
+
+    py::enum_<CKKSDataType>(m,"CKKSDataType")
+        .value("REAL", CKKSDataType::REAL)
+        .value("COMPLEX", CKKSDataType::COMPLEX);
+    m.attr("REAL") = py::cast(CKKSDataType::REAL);
+    m.attr("COMPLEX") = py::cast(CKKSDataType::COMPLEX);
+
     /* ---- CORE enums ---- */ 
     // Security Level
     py::enum_<SecurityLevel>(m,"SecurityLevel")