Browse Source

autorescale + manual + hybridkeyswitch

Rener Oliveira (Ubuntu WSL) 1 year ago
parent
commit
ccac80b292
3 changed files with 192 additions and 5 deletions
  1. 4 0
      include/pke/cryptocontext_wrapper.h
  2. 18 1
      src/bindings.cpp
  3. 170 4
      src/pke/examples/advanced-real-numbers.py

+ 4 - 0
include/pke/cryptocontext_wrapper.h

@@ -19,4 +19,8 @@ Plaintext MakeCKKSPackedPlaintextWrapper(std::shared_ptr<CryptoContextImpl<DCRTP
             const std::shared_ptr<ParmType> params,
             usint slots);
 
+// std::shared_ptr<std::vector<Element>> EvalFastRotationPrecomputeWrapper(CryptoContext<DCRTPoly> &self,ConstCiphertext<Element> ciphertext) const {
+//     return self->EvalFastRotationPrecompute(ciphertext);
+// }
+
 #endif // OPENFHE_CRYPTOCONTEXT_BINDINGS_H

+ 18 - 1
src/bindings.cpp

@@ -1,5 +1,6 @@
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
+#include <pybind11/operators.h>
 #include <pybind11/iostream.h>
 #include <iostream>
 #include "openfhe.h"
@@ -38,6 +39,10 @@ void bind_parameters(py::module &m)
         .def("SetScalingModSize", &CCParams<CryptoContextCKKSRNS>::SetScalingModSize)
         .def("SetBatchSize", &CCParams<CryptoContextCKKSRNS>::SetBatchSize)
         .def("SetScalingTechnique", &CCParams<CryptoContextCKKSRNS>::SetScalingTechnique)
+        .def("SetNumLargeDigits", &CCParams<CryptoContextCKKSRNS>::SetNumLargeDigits)
+        .def("SetKeySwitchTechnique", &CCParams<CryptoContextCKKSRNS>::SetKeySwitchTechnique)
+        .def("SetFirstModSize", &CCParams<CryptoContextCKKSRNS>::SetFirstModSize)
+        .def("SetDigitSize", &CCParams<CryptoContextCKKSRNS>::SetDigitSize)
         // getters
         .def("GetPlaintextModulus", &CCParams<CryptoContextCKKSRNS>::GetPlaintextModulus)
         .def("GetMultiplicativeDepth", &CCParams<CryptoContextCKKSRNS>::GetMultiplicativeDepth)
@@ -52,6 +57,8 @@ void bind_crypto_context(py::module &m)
         .def(py::init<>())
         .def("GetKeyGenLevel", &CryptoContextImpl<DCRTPoly>::GetKeyGenLevel)
         .def("SetKeyGenLevel", &CryptoContextImpl<DCRTPoly>::SetKeyGenLevel)
+        //.def("GetScheme",&CryptoContextImpl<DCRTPoly>::GetScheme)
+        //.def("GetCryptoParameters", &CryptoContextImpl<DCRTPoly>::GetCryptoParameters)
         .def("GetRingDimension", &CryptoContextImpl<DCRTPoly>::GetRingDimension)
         .def("Enable", static_cast<void (CryptoContextImpl<DCRTPoly>::*)(PKESchemeFeature)>(&CryptoContextImpl<DCRTPoly>::Enable), "Enable a feature for the CryptoContext")
         .def("KeyGen", &CryptoContextImpl<DCRTPoly>::KeyGen, "Generate a key pair with public and private keys")
@@ -67,6 +74,8 @@ void bind_crypto_context(py::module &m)
             py::arg("params") = py::none(),
             py::arg("slots") = 0)
         .def("EvalRotate", &CryptoContextImpl<DCRTPoly>::EvalRotate, "Rotate a ciphertext")
+        //.def("EvalFastRotationPrecompute", &CryptoContextImpl<DCRTPoly>::EvalFastRotationPrecompute, py::return_value_policy::take_ownership)
+        //.def("EvalFastRotation", &CryptoContextImpl<DCRTPoly>::EvalFastRotation)
         .def("Encrypt", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(const PublicKey<DCRTPoly>, Plaintext) const>(&CryptoContextImpl<DCRTPoly>::Encrypt),
              "Encrypt a plaintext using public key")
         .def("EvalAdd", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalAdd), "Add two ciphertexts")
@@ -74,6 +83,7 @@ void bind_crypto_context(py::module &m)
         .def("EvalSub", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalSub), "Subtract two ciphertexts")
         .def("EvalMult", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalMult), "Multiply two ciphertexts")
         .def("EvalMult", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, double) const>(&CryptoContextImpl<DCRTPoly>::EvalMult), "Multiply a ciphertext with a scalar")
+        .def("Rescale", &CryptoContextImpl<DCRTPoly>::Rescale, "Rescale a ciphertext")
         .def_static(
             "ClearEvalMultKeys", []()
             { CryptoContextImpl<DCRTPoly>::ClearEvalMultKeys(); },
@@ -158,6 +168,11 @@ void bind_enums_and_constants(py::module &m)
        .value("FLEXIBLEAUTOEXT", ScalingTechnique::FLEXIBLEAUTOEXT)
        .value("NORESCALE", ScalingTechnique::NORESCALE)
        .value("INVALID_RS_TECHNIQUE", ScalingTechnique::INVALID_RS_TECHNIQUE);
+    // Key Switching Techniques
+    py::enum_<KeySwitchTechnique>(m, "KeySwitchTechnique")
+        .value("INVALID_KS_TECH", KeySwitchTechnique::INVALID_KS_TECH)
+        .value("BV", KeySwitchTechnique::BV)
+        .value("HYBRID", KeySwitchTechnique::HYBRID);
 
     //Parameters Type
     using ParmType = typename DCRTPoly::Params;
@@ -204,7 +219,8 @@ void bind_encodings(py::module &m)
 void bind_ciphertext(py::module &m)
 {
     py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
-        .def(py::init<>());
+        .def(py::init<>())
+        .def(py::self + py::self);
     // .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth)
     // .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth)
     // .def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel)
@@ -217,6 +233,7 @@ void bind_ciphertext(py::module &m)
     // .def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots);
 }
 
+
 PYBIND11_MODULE(openfhe, m)
 {
     m.doc() = "Open-Source Fully Homomorphic Encryption Library";

+ 170 - 4
src/pke/examples/advanced-real-numbers.py

@@ -1,10 +1,11 @@
 from openfhe import *
+import time # to enable TIC-TOC timing measurements
 
-def AutmaticRescaleDemo(scalTech):
+def AutomaticRescaleDemo(scalTech):
     if(scalTech == ScalingTechnique.FLEXIBLEAUTO):
-        print("\n\n\n===== FlexibleAutoDemo =============\n") 
+        print("\n\n\n ===== FlexibleAutoDemo =============\n") 
     else:
-         print("\n\n\n===== FixedAutoDemo =============\n")
+         print("\n\n\n ===== FixedAutoDemo =============\n")
 
     batchSize = 8
     parameters = CCParamsCKKSRNS()
@@ -53,5 +54,170 @@ def AutmaticRescaleDemo(scalTech):
     result.SetLength(batchSize)
     print(f"Result: {result}")
 
+def ManualRescaleDemo(ScalingTechnique):
+    print("\n\n\n ===== FixedManualDemo =============\n")
+    
+    batchSize = 8
+    parameters = CCParamsCKKSRNS()
+    parameters.SetMultiplicativeDepth(5)
+    parameters.SetScalingModSize(50)
+    parameters.SetBatchSize(batchSize)
+
+    cc = GenCryptoContext(parameters)
+
+    print(f"CKKS scheme is using ring dimension {cc.GetRingDimension()}\n")
+    
+    cc.Enable(PKESchemeFeature.PKE)
+    cc.Enable(PKESchemeFeature.KEYSWITCH)
+    cc.Enable(PKESchemeFeature.LEVELEDSHE)
+
+    keys = cc.KeyGen()
+    cc.EvalMultKeyGen(keys.secretKey)
+
+    # Input
+    x = [1.0, 1.01, 1.02, 1.03, 1.04, 1.05, 1.06, 1.07]
+    ptxt = cc.MakeCKKSPackedPlaintext(x)
+
+    print(f"Input x: {ptxt}")
+
+    c = cc.Encrypt(keys.publicKey,ptxt)
+
+    # Computing f(x) = x^18 + x^9 + 1
+    #
+    # Compare the following with the corresponding code
+    # for FLEXIBLEAUTO. Here we need to track the depth of ciphertexts
+    # and call Rescale whenever needed. In this instance it's still
+    # not hard to do so, but this can be quite tedious in other
+    # complicated computations. (e.g. in bootstrapping)
+    #
+    #
+
+    # x^2
+    c2_depth_2 = cc.EvalMult(c, c)
+    c2_depth_1 = cc.Rescale(c2_depth_2)
+    # x^4
+    c4_depth2 = cc.EvalMult(c2_depth_1, c2_depth_1)
+    c4_depth1 = cc.Rescale(c4_depth2)
+    # x^8
+    c8_depth2 = cc.EvalMult(c4_depth1, c4_depth1)
+    c8_depth1 = cc.Rescale(c8_depth2)
+    # x^16
+    c16_depth2 = cc.EvalMult(c8_depth1, c8_depth1)
+    c16_depth1 = cc.Rescale(c16_depth2)
+    # x^9
+    c9_depth2 = cc.EvalMult(c8_depth1, c)
+    # x^18
+    c18_depth2 = cc.EvalMult(c16_depth1, c2_depth_1)
+    # Final result
+    cRes_depth2 = cc.EvalAdd(cc.EvalAdd(c18_depth2, c9_depth2), 1.0)
+    cRes_depth1 = cc.Rescale(cRes_depth2)
+
+    result = Decrypt(cRes_depth1,keys.secretKey)
+    result.SetLength(batchSize)
+    print("x^18 + x^9 + 1 = ", result)
+
+def HybridKeySwitchingDemo1():
+    
+    dnum = 2
+    batchSize = 8
+    parameters = CCParamsCKKSRNS()
+    parameters.SetMultiplicativeDepth(5)
+    parameters.SetScalingModSize(50)
+    parameters.SetBatchSize(batchSize)
+    parameters.SetScalingTechnique(ScalingTechnique.FLEXIBLEAUTO)
+    parameters.SetNumLargeDigits(dnum)
+
+    cc = GenCryptoContext(parameters)
+
+    print(f"CKKS scheme is using ring dimension {cc.GetRingDimension()}\n")
+
+    print(f"- Using HYBRID key switching with {dnum} digits\n")
+
+    cc.Enable(PKESchemeFeature.PKE)
+    cc.Enable(PKESchemeFeature.KEYSWITCH)
+    cc.Enable(PKESchemeFeature.LEVELEDSHE)
+
+    keys = cc.KeyGen()
+    cc.EvalRotateKeyGen(keys.secretKey,[1,-2])
+
+    # Input
+    x = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7]
+    ptxt = cc.MakeCKKSPackedPlaintext(x)
+
+    print(f"Input x: {ptxt}")
+
+    c = cc.Encrypt(keys.publicKey,ptxt)
+
+    t = time.time()
+    cRot1 = cc.EvalRotate(c,1)
+    cRot2 = cc.EvalRotate(cRot1,-2)
+    time2digits = time.time() - t
+
+    result = Decrypt(cRot2,keys.secretKey)
+    result.SetLength(batchSize)
+    print(f"x rotate by -1 = {result}")
+    print(f" - 2 rotations with HYBRID (2 digits) took {time2digits*1000} ms")
+
+
+def HybridKeySwitchingDemo2():
+    print("\n\n\n ===== HybridKeySwitchingDemo2 =============\n")
+    dnum = 3
+    batchSize = 8
+    parameters = CCParamsCKKSRNS()
+    parameters.SetMultiplicativeDepth(5)
+    parameters.SetScalingModSize(50)
+    parameters.SetBatchSize(batchSize)
+    parameters.SetScalingTechnique(ScalingTechnique.FLEXIBLEAUTO)
+    parameters.SetNumLargeDigits(dnum)
+
+    cc = GenCryptoContext(parameters)
+
+    # Compare the ring dimension in this demo to the one in the previous
+    print(f"CKKS scheme is using ring dimension {cc.GetRingDimension()}\n")
+
+    print(f"- Using HYBRID key switching with {dnum} digits\n")
+
+    cc.Enable(PKESchemeFeature.PKE)
+    cc.Enable(PKESchemeFeature.KEYSWITCH)
+    cc.Enable(PKESchemeFeature.LEVELEDSHE)
+
+    keys = cc.KeyGen()
+    cc.EvalRotateKeyGen(keys.secretKey,[1,-2])
+
+    # Input
+    x = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7]
+    ptxt = cc.MakeCKKSPackedPlaintext(x)
+
+    print(f"Input x: {ptxt}")
+
+    c = cc.Encrypt(keys.publicKey,ptxt)
+
+    t = time.time()
+    cRot1 = cc.EvalRotate(c,1)
+    cRot2 = cc.EvalRotate(cRot1,-2)
+    time3digits = time.time() - t
+    # The runtime here is smaller than the previous demo
+
+    result = Decrypt(cRot2,keys.secretKey)
+    result.SetLength(batchSize)
+    print(f"x rotate by -1 = {result}")
+    print(f" - 2 rotations with HYBRID (3 digits) took {time3digits*1000} ms")
+
+def FastRotationDemo1():
+    pass
+
+def FastRotationDemo2():
+    pass
+
+def main():
+    AutomaticRescaleDemo(ScalingTechnique.FLEXIBLEAUTO)
+    AutomaticRescaleDemo(ScalingTechnique.FIXEDAUTO)
+    ManualRescaleDemo(ScalingTechnique.FIXEDMANUAL)
+    HybridKeySwitchingDemo1()
+    HybridKeySwitchingDemo2()
+    FastRotationDemo1()
+    FastRotationDemo2()
+
+if __name__ == "__main__":
+    main()
 
-AutmaticRescaleDemo(ScalingTechnique.FLEXIBLEAUTO)