Explorar el Código

Added examples/pke/interactive-bootstrapping.py, added symbols to bindings and got rid of a warning from CMakeLists.txt (#228)

Co-authored-by: Dmitriy Suponitskiy <dsuponitskiy@dualitytech.com>
dsuponitskiy hace 7 meses
padre
commit
aec8efef74
Se han modificado 3 ficheros con 290 adiciones y 22 borrados
  1. 1 0
      CMakeLists.txt
  2. 245 0
      examples/pke/interactive-bootstrapping.py
  3. 44 22
      src/lib/bindings.cpp

+ 1 - 0
CMakeLists.txt

@@ -16,6 +16,7 @@ if(APPLE)
 endif()
 
 find_package(OpenFHE 1.3.0 REQUIRED)
+set(PYBIND11_FINDPYTHON ON)
 find_package(pybind11 REQUIRED)
 
 # "CMAKE_INTERPROCEDURAL_OPTIMIZATION ON" (ON is the default value) causes link failure. see

+ 245 - 0
examples/pke/interactive-bootstrapping.py

@@ -0,0 +1,245 @@
+from openfhe import *
+
+
+def main():
+    # the scaling technigue can be changed to FIXEDMANUAL, FIXEDAUTO, or FLEXIBLEAUTOEXT
+    ThresholdFHE(FLEXIBLEAUTO)
+    Chebyshev(FLEXIBLEAUTO)
+
+def ThresholdFHE(scaleTech):
+    # if scaleTech not in [FIXEDMANUAL, FIXEDAUTO, FLEXIBLEAUTOEXT]:
+    #     errMsg = "ERROR: Scaling technique is not supported!"
+    #     raise Exception(errMsg)
+
+    print(f"Threshold FHE example with Scaling Technique {scaleTech}")
+
+    parameters = CCParamsCKKSRNS()
+    # 1 extra level needs to be added for FIXED* modes (2 extra levels for FLEXIBLE* modes) to the multiplicative depth
+    # to support 2-party interactive bootstrapping
+    depth = 7
+    parameters.SetMultiplicativeDepth(depth)
+    parameters.SetScalingModSize(50)
+    parameters.SetBatchSize(16)
+    parameters.SetScalingTechnique(scaleTech)
+
+    cc = GenCryptoContext(parameters)
+    cc.Enable(PKE)
+    cc.Enable(LEVELEDSHE)
+    cc.Enable(ADVANCEDSHE)
+    cc.Enable(MULTIPARTY)
+
+    #############################################################
+    # Perform Key Generation Operation
+    #############################################################
+
+    print("Running key generation (used for source data)...")
+    print("Round 1 (party A) started.")
+
+    kp1 = cc.KeyGen()
+    evalMultKey = cc.KeySwitchGen(kp1.secretKey, kp1.secretKey)
+
+    print("Round 1 of key generation completed.")
+    #############################################################
+    print("Round 2 (party B) started.")
+    print("Joint public key for (s_a + s_b) is generated...")
+    kp2 = cc.MultipartyKeyGen(kp1.publicKey)
+
+    input = [-0.9, -0.8, -0.6, -0.4, -0.2, 0., 0.2, 0.4, 0.6, 0.8, 0.9]
+
+    # This plaintext only has 3 RNS limbs, the minimum needed to perform 2-party interactive bootstrapping for FLEXIBLEAUTO
+    plaintext1 = cc.MakeCKKSPackedPlaintext(input, 1, depth - 2)
+    ciphertext1 = cc.Encrypt(kp2.publicKey, plaintext1)
+
+    # INTERACTIVE BOOTSTRAPPING STARTS
+
+    # under the hood it reduces to two towers
+    ciphertext1 = cc.IntBootAdjustScale(ciphertext1)
+    print("IntBootAdjustScale Succeeded")
+
+    # masked decryption on the server: c0 = b + a*s0
+    ciphertextOutput1 = cc.IntBootDecrypt(kp1.secretKey, ciphertext1)
+    print("IntBootDecrypt on Server Succeeded")
+
+    ciphertext2 = ciphertext1.Clone()
+    ciphertext2.SetElements([ciphertext2.GetElements()[1]])
+
+    # masked decryption on the client: c1 = a*s1
+    ciphertextOutput2 = cc.IntBootDecrypt(kp2.secretKey, ciphertext2)
+    print("IntBootDecrypt on Client Succeeded")
+
+    # Encryption of masked decryption c1 = a*s1
+    ciphertextOutput2 = cc.IntBootEncrypt(kp2.publicKey, ciphertextOutput2)
+    print("IntBootEncrypt on Client Succeeded")
+
+    # Compute Enc(c1) + c0
+    ciphertextOutput = cc.IntBootAdd(ciphertextOutput2, ciphertextOutput1)
+    print("IntBootAdd on Server Succeeded")
+
+    # INTERACTIVE BOOTSTRAPPING ENDS
+
+    # distributed decryption
+    ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextOutput], kp1.secretKey)
+    ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
+
+    partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
+    plaintextMultiparty = cc.MultipartyDecryptFusion(partialCiphertextVec)
+
+    plaintextMultiparty.SetLength(len(input))
+
+    print(f"Original plaintext \n\t {plaintext1.GetCKKSPackedValue()}")
+    print(f"Result after bootstrapping \n\t {plaintextMultiparty.GetCKKSPackedValue()}")
+
+def Chebyshev(scaleTech):
+#     if scaleTech not in [FIXEDMANUAL, FIXEDAUTO, FLEXIBLEAUTOEXT]:
+#         errMsg = "ERROR: Scaling technique is not supported!"
+#         raise Exception(errMsg)
+
+    print(f"Threshold FHE example with Scaling Technique {scaleTech}")
+    
+    parameters = CCParamsCKKSRNS()
+    # 1 extra level needs to be added for FIXED* modes (2 extra levels for FLEXIBLE* modes) to the multiplicative depth
+    # to support 2-party interactive bootstrapping
+    parameters.SetMultiplicativeDepth(8)
+    parameters.SetScalingModSize(50)
+    parameters.SetBatchSize(16)
+    parameters.SetScalingTechnique(scaleTech)
+
+    cc = GenCryptoContext(parameters)
+    # enable features that you wish to use
+    cc.Enable(PKE)
+    cc.Enable(LEVELEDSHE)
+    cc.Enable(ADVANCEDSHE)
+    cc.Enable(MULTIPARTY)
+
+    ############################################################
+    # Perform Key Generation Operation
+    ############################################################
+
+    print("Running key generation (used for source data)...")
+    print("Round 1 (party A) started.")
+
+    kp1 = cc.KeyGen()
+
+    evalMultKey = cc.KeySwitchGen(kp1.secretKey, kp1.secretKey)
+    cc.EvalSumKeyGen(kp1.secretKey)
+    evalSumKeys = cc.GetEvalSumKeyMap(kp1.secretKey.GetKeyTag())
+
+    print("Round 1 of key generation completed.")
+    ############################################################
+    print("Round 2 (party B) started.")
+    print("Joint public key for (s_a + s_b) is generated...")
+    kp2 = cc.MultipartyKeyGen(kp1.publicKey)
+
+    evalMultKey2 = cc.MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey)
+
+    print("Joint evaluation multiplication key for (s_a + s_b) is generated...")
+    evalMultAB = cc.MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey.GetKeyTag())
+
+    print("Joint evaluation multiplication key (s_a + s_b) is transformed into s_b*(s_a + s_b)...")
+    evalMultBAB = cc.MultiMultEvalKey(kp2.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
+
+    evalSumKeysB = cc.MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey.GetKeyTag())
+
+    print("Joint evaluation summation key for (s_a + s_b) is generated...")
+    evalSumKeysJoin = cc.MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey.GetKeyTag())
+
+    cc.InsertEvalSumKey(evalSumKeysJoin)
+
+    print("Round 2 of key generation completed.")
+
+    print("Round 3 (party A) started.")
+    print("Joint key (s_a + s_b) is transformed into s_a*(s_a + s_b)...")
+    evalMultAAB = cc.MultiMultEvalKey(kp1.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
+
+    print("Computing the final evaluation multiplication key for (s_a + s_b)*(s_a + s_b)...")
+    evalMultFinal = cc.MultiAddEvalMultKeys(evalMultAAB, evalMultBAB, evalMultAB.GetKeyTag())
+
+    cc.InsertEvalMultKey([evalMultFinal])
+
+    print("Round 3 of key generation completed.")
+
+    input = [-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0]
+
+    coefficients = [1.0, 0.558971, 0.0, -0.0943712, 0.0, 0.0215023, 0.0, -0.00505348, 0.0, 0.00119324,
+                    0.0, -0.000281928, 0.0, 0.0000664347, 0.0, -0.0000148709]
+
+    a = -4
+    b = 4
+
+    plaintext1 = cc.MakeCKKSPackedPlaintext(input)
+
+    ciphertext1 = cc.Encrypt(kp2.publicKey, plaintext1)
+
+    # The Chebyshev series interpolation requires 6 levels
+    ciphertext1 = cc.EvalChebyshevSeries(ciphertext1, coefficients, a, b)
+    print("Ran Chebyshev interpolation")
+
+    # INTERACTIVE BOOTSTRAPPING STARTS
+
+    ciphertext1 = cc.IntBootAdjustScale(ciphertext1)
+    print("IntBootAdjustScale Succeeded")
+
+    # masked decryption on the server: c0 = b + a*s0
+    ciphertextOutput1 = cc.IntBootDecrypt(kp1.secretKey, ciphertext1)
+    print("IntBootDecrypt on Server Succeeded")
+
+    ciphertext2 = ciphertext1.Clone()
+    ciphertext2.SetElements([ciphertext2.GetElements()[1]])
+
+    # masked decryption on the client: c1 = a*s1
+    ciphertextOutput2 = cc.IntBootDecrypt(kp2.secretKey, ciphertext2)
+    print("IntBootDecrypt on Client Succeeded")
+
+    # Encryption of masked decryption c1 = a*s1
+    ciphertextOutput2 = cc.IntBootEncrypt(kp2.publicKey, ciphertextOutput2)
+    print("IntBootEncrypt on Client Succeeded")
+
+    # Compute Enc(c1) + c0
+    ciphertextOutput = cc.IntBootAdd(ciphertextOutput2, ciphertextOutput1)
+    print("IntBootAdd on Server Succeeded")
+
+    # INTERACTIVE BOOTSTRAPPING ENDS
+
+    # distributed decryption
+
+    ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextOutput], kp1.secretKey)
+
+    ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
+
+    partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
+    plaintextMultiparty = cc.MultipartyDecryptFusion(partialCiphertextVec)
+
+    plaintextMultiparty.SetLength(len(input))
+
+    print(f"\n Original Plaintext #1: \n {plaintext1}")
+
+    print(f"\n Results of evaluating the polynomial with coefficients {coefficients} \n")
+    print(f"\n Ciphertext result: {plaintextMultiparty}")
+
+    print("\n Plaintext result: ( 0.0179885, 0.0474289, 0.119205, 0.268936, 0.5, 0.731064, 0.880795, 0.952571, 0.982011 ) \n")
+
+    print("\n Exact result: ( 0.0179862, 0.0474259, 0.119203, 0.268941, 0.5, 0.731059, 0.880797, 0.952574, 0.982014 ) \n")
+
+    print("\n Another round of Chebyshev interpolation after interactive bootstrapping: \n")
+
+    ciphertextOutput = cc.EvalChebyshevSeries(ciphertextOutput, coefficients, a, b)
+    print("Ran Chebyshev interpolation")
+
+    # distributed decryption
+
+    ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextOutput], kp1.secretKey)
+
+    ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
+
+    partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
+    plaintextMultiparty = cc.MultipartyDecryptFusion(partialCiphertextVec)
+
+    plaintextMultiparty.SetLength(len(input))
+
+    print(f"\n Ciphertext result: {plaintextMultiparty}")
+
+    print("\n Plaintext result: ( 0.504497, 0.511855, 0.529766, 0.566832, 0.622459, 0.675039, 0.706987, 0.721632, 0.727508 )")
+
+
+if __name__ == "__main__":
+    main()

+ 44 - 22
src/lib/bindings.cpp

@@ -61,6 +61,10 @@ inline std::shared_ptr<CryptoParametersRNS> GetParamsRNSChecked(const CryptoCont
     return ptr;
 }
 
+void bind_DCRTPoly(py::module &m) {
+  py::class_<DCRTPoly>(m, "DCRTPoly").def(py::init<>());
+}
+
 template <typename T>
 void bind_parameters(py::module &m,const std::string name)
 {
@@ -1312,28 +1316,45 @@ void bind_encodings(py::module &m)
 
 void bind_ciphertext(py::module &m)
 {
-    py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
-        .def(py::init<>())
-        .def("__add__", [](const Ciphertext<DCRTPoly> &a, const Ciphertext<DCRTPoly> &b)
-             {return a + b; },py::is_operator(),pybind11::keep_alive<0, 1>())
-       // .def(py::self + py::self);
-    // .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth)
-    // .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth)
-     .def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel,
-        ctx_GetLevel_docs)
-     .def("SetLevel", &CiphertextImpl<DCRTPoly>::SetLevel,
-        ctx_SetLevel_docs,
-        py::arg("level"))
-     .def("Clone", &CiphertextImpl<DCRTPoly>::Clone)
-     .def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs)
-    // .def("GetHopLevel", &CiphertextImpl<DCRTPoly>::GetHopLevel)
-    // .def("SetHopLevel", &CiphertextImpl<DCRTPoly>::SetHopLevel)
-    // .def("GetScalingFactor", &CiphertextImpl<DCRTPoly>::GetScalingFactor)
-    // .def("SetScalingFactor", &CiphertextImpl<DCRTPoly>::SetScalingFactor)
-     .def("GetSlots", &CiphertextImpl<DCRTPoly>::GetSlots)
-     .def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots)
-     .def("GetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::GetNoiseScaleDeg)
-     .def("SetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::SetNoiseScaleDeg);
+  py::class_<CiphertextImpl<DCRTPoly>,
+             std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
+      .def(py::init<>())
+      .def(
+          "__add__",
+          [](const Ciphertext<DCRTPoly> &a, const Ciphertext<DCRTPoly> &b) {
+            return a + b;
+          },
+          py::is_operator(), pybind11::keep_alive<0, 1>())
+      // .def(py::self + py::self);
+      // .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth)
+      // .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth)
+      .def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel, ctx_GetLevel_docs)
+      .def("SetLevel", &CiphertextImpl<DCRTPoly>::SetLevel, ctx_SetLevel_docs,
+           py::arg("level"))
+      .def("Clone", &CiphertextImpl<DCRTPoly>::Clone)
+      .def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs)
+      // .def("GetHopLevel", &CiphertextImpl<DCRTPoly>::GetHopLevel)
+      // .def("SetHopLevel", &CiphertextImpl<DCRTPoly>::SetHopLevel)
+      // .def("GetScalingFactor", &CiphertextImpl<DCRTPoly>::GetScalingFactor)
+      // .def("SetScalingFactor", &CiphertextImpl<DCRTPoly>::SetScalingFactor)
+      .def("GetSlots", &CiphertextImpl<DCRTPoly>::GetSlots)
+      .def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots)
+      .def("GetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::GetNoiseScaleDeg)
+      .def("SetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::SetNoiseScaleDeg)
+      .def("GetElements", [](const CiphertextImpl<DCRTPoly>& self) -> const std::vector<DCRTPoly> & {
+            return self.GetElements();
+          },
+          py::return_value_policy::reference_internal)
+      .def("GetElementsMutable", [](CiphertextImpl<DCRTPoly>& self) -> std::vector<DCRTPoly> & {
+            return self.GetElements();
+          },
+          py::return_value_policy::reference_internal)
+      .def("SetElements", [](CiphertextImpl<DCRTPoly>& self, const std::vector<DCRTPoly> &elems) {
+             self.SetElements(elems);
+           })
+      .def("SetElementsMove", [](CiphertextImpl<DCRTPoly>& self, std::vector<DCRTPoly> &&elems) {
+             self.SetElements(std::move(elems));
+           });
 }
 
 void bind_schemes(py::module &m){
@@ -1400,6 +1421,7 @@ PYBIND11_MODULE(openfhe, m)
 {
     m.doc() = "Open-Source Fully Homomorphic Encryption Library";
     // binfhe library
+    bind_DCRTPoly(m);
     bind_binfhe_enums(m);
     bind_binfhe_context(m);
     bind_binfhe_keys(m);