Browse Source

EvalFastRotation Wrappers

Rener Oliveira (Ubuntu WSL) 1 year ago
parent
commit
2ffe6f9b05

+ 8 - 3
include/pke/cryptocontext_wrapper.h

@@ -19,8 +19,13 @@ Plaintext MakeCKKSPackedPlaintextWrapper(std::shared_ptr<CryptoContextImpl<DCRTP
             const std::shared_ptr<ParmType> params,
             usint slots);
 
-// std::shared_ptr<std::vector<Element>> EvalFastRotationPrecomputeWrapper(CryptoContext<DCRTPoly> &self,ConstCiphertext<Element> ciphertext) const {
-//     return self->EvalFastRotationPrecompute(ciphertext);
-// }
+Ciphertext<DCRTPoly> EvalFastRotationPrecomputeWrapper(CryptoContext<DCRTPoly>& self,
+                                                        ConstCiphertext<DCRTPoly> ciphertext);
+
+Ciphertext<DCRTPoly> EvalFastRotationWrapper(CryptoContext<DCRTPoly>& self,
+                                            ConstCiphertext<DCRTPoly> ciphertext,
+                                              const usint index,
+                                              const usint m,
+                                              ConstCiphertext<DCRTPoly> digits);
 
 #endif // OPENFHE_CRYPTOCONTEXT_BINDINGS_H

+ 5 - 4
src/bindings.cpp

@@ -74,8 +74,8 @@ void bind_crypto_context(py::module &m)
             py::arg("params") = py::none(),
             py::arg("slots") = 0)
         .def("EvalRotate", &CryptoContextImpl<DCRTPoly>::EvalRotate, "Rotate a ciphertext")
-        //.def("EvalFastRotationPrecompute", &CryptoContextImpl<DCRTPoly>::EvalFastRotationPrecompute, py::return_value_policy::take_ownership)
-        //.def("EvalFastRotation", &CryptoContextImpl<DCRTPoly>::EvalFastRotation)
+        .def("EvalFastRotationPrecompute", &EvalFastRotationPrecomputeWrapper)
+        .def("EvalFastRotation", &EvalFastRotationWrapper)
         .def("Encrypt", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(const PublicKey<DCRTPoly>, Plaintext) const>(&CryptoContextImpl<DCRTPoly>::Encrypt),
              "Encrypt a plaintext using public key")
         .def("EvalAdd", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalAdd), "Add two ciphertexts")
@@ -220,7 +220,9 @@ void bind_ciphertext(py::module &m)
 {
     py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
         .def(py::init<>())
-        .def(py::self + py::self);
+        .def("__add__", [](const Ciphertext<DCRTPoly> &a, const Ciphertext<DCRTPoly> &b)
+             {return a + b; },py::is_operator(),pybind11::keep_alive<0, 1>());
+       // .def(py::self + py::self);
     // .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth)
     // .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth)
     // .def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel)
@@ -233,7 +235,6 @@ void bind_ciphertext(py::module &m)
     // .def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots);
 }
 
-
 PYBIND11_MODULE(openfhe, m)
 {
     m.doc() = "Open-Source Fully Homomorphic Encryption Library";

+ 16 - 1
src/pke/cryptocontext_wrapper.cpp

@@ -21,4 +21,19 @@ Plaintext MakeCKKSPackedPlaintextWrapper(std::shared_ptr<CryptoContextImpl<DCRTP
                 std::vector<std::complex<double>> complexValue(value.size());
                 std::transform(value.begin(), value.end(), complexValue.begin(),
                        [](float da) { return std::complex<double>(da); });
-                return self->MakeCKKSPackedPlaintext(complexValue, depth, level, params, slots); }
+                return self->MakeCKKSPackedPlaintext(complexValue, depth, level, params, slots); }
+
+Ciphertext<DCRTPoly> EvalFastRotationPrecomputeWrapper(CryptoContext<DCRTPoly> &self,ConstCiphertext<DCRTPoly> ciphertext) {
+    std::shared_ptr<std::vector<DCRTPoly>> precomp = self->EvalFastRotationPrecompute(ciphertext);
+    std::vector<DCRTPoly> elements = *(precomp.get());
+    CiphertextImpl<DCRTPoly> cipherdigits = CiphertextImpl<DCRTPoly>(self);
+    std::shared_ptr<CiphertextImpl<DCRTPoly>> cipherdigitsPtr = std::make_shared<CiphertextImpl<DCRTPoly>>(cipherdigits);
+    cipherdigitsPtr->SetElements(elements);
+    return cipherdigitsPtr;
+}
+Ciphertext<DCRTPoly> EvalFastRotationWrapper(CryptoContext<DCRTPoly>& self,ConstCiphertext<DCRTPoly> ciphertext, const usint index, const usint m,ConstCiphertext<DCRTPoly> digits) {
+    
+        std::vector<DCRTPoly> digitsElements = digits->GetElements();
+        std::shared_ptr<std::vector<DCRTPoly>> digitsElementsPtr = std::make_shared<std::vector<DCRTPoly>>(digitsElements);
+        return self->EvalFastRotation(ciphertext, index, m, digitsElementsPtr);
+    }

+ 153 - 2
src/pke/examples/advanced-real-numbers.py

@@ -204,10 +204,161 @@ def HybridKeySwitchingDemo2():
     print(f" - 2 rotations with HYBRID (3 digits) took {time3digits*1000} ms")
 
 def FastRotationDemo1():
-    pass
+    print("\n\n\n ===== FastRotationDemo1 =============\n")
+    batchSize = 8
+    parameters = CCParamsCKKSRNS()
+    parameters.SetMultiplicativeDepth(5)
+    parameters.SetScalingModSize(50)
+    parameters.SetBatchSize(batchSize)
+
+    cc = GenCryptoContext(parameters)
+
+    N = cc.GetRingDimension()
+    print(f"CKKS scheme is using ring dimension {N}\n")
+
+    cc.Enable(PKESchemeFeature.PKE)
+    cc.Enable(PKESchemeFeature.KEYSWITCH)
+    cc.Enable(PKESchemeFeature.LEVELEDSHE)
+
+    keys = cc.KeyGen()
+    cc.EvalRotateKeyGen(keys.secretKey,[1,2,3,4,5,6,7])
+
+    # Input
+    x = [0, 0, 0, 0, 0, 0, 0, 1]
+    ptxt = cc.MakeCKKSPackedPlaintext(x)
+
+    print(f"Input x: {ptxt}")
+
+    c = cc.Encrypt(keys.publicKey,ptxt)
+
+    # First, we perform 7 regular (non-hoisted) rotations
+    # and measure the runtime
+    t = time.time()
+    cRot1 = cc.EvalRotate(c,1)
+    cRot2 = cc.EvalRotate(c,2)
+    cRot3 = cc.EvalRotate(c,3)
+    cRot4 = cc.EvalRotate(c,4)
+    cRot5 = cc.EvalRotate(c,5)
+    cRot6 = cc.EvalRotate(c,6)
+    cRot7 = cc.EvalRotate(c,7)
+    timeNoHoisting = time.time() - t
+
+    cResNoHoist = c + cRot1 + cRot2 + cRot3 + cRot4 + cRot5 + cRot6 + cRot7
+
+    # M is the cyclotomic order and we need it to call EvalFastRotation
+    M = 2*N
+
+    # Then, we perform 7 rotations with hoisting.
+    t = time.time()
+    cPrecomp = cc.EvalFastRotationPrecompute(c)
+    cRot1 = cc.EvalFastRotation(c, 1, M, cPrecomp)
+    cRot2 = cc.EvalFastRotation(c, 2, M, cPrecomp)
+    cRot3 = cc.EvalFastRotation(c, 3, M, cPrecomp)
+    cRot4 = cc.EvalFastRotation(c, 4, M, cPrecomp)
+    cRot5 = cc.EvalFastRotation(c, 5, M, cPrecomp)
+    cRot6 = cc.EvalFastRotation(c, 6, M, cPrecomp)
+    cRot7 = cc.EvalFastRotation(c, 7, M, cPrecomp)
+    timeHoisting = time.time() - t
+    # The time with hoisting should be faster than without hoisting.
+
+    cResHoist = c + cRot1 + cRot2 + cRot3 + cRot4 + cRot5 + cRot6 + cRot7
+    
+    result = Decrypt(cResNoHoist,keys.secretKey)
+    result.SetLength(batchSize)
+    print(f"Result without hoisting: {result}")
+    print(f" - 7 rotations without hoisting took {timeNoHoisting*1000} ms")
+
+    
+    result = Decrypt(cResHoist,keys.secretKey)
+    result.SetLength(batchSize)
+    print(f"Result with hoisting: {result}")
+    print(f" - 7 rotations with hoisting took {timeHoisting*1000} ms")
+
+
+
 
 def FastRotationDemo2():
-    pass
+    print("\n\n\n ===== FastRotationDemo2 =============\n")
+
+    digitSize = 3
+    batchSize = 8
+
+    parameters = CCParamsCKKSRNS()
+    parameters.SetMultiplicativeDepth(1)
+    parameters.SetScalingModSize(50)
+    parameters.SetBatchSize(batchSize)
+    parameters.SetScalingTechnique(ScalingTechnique.FLEXIBLEAUTO)
+    parameters.SetKeySwitchTechnique(KeySwitchTechnique.BV)
+    parameters.SetFirstModSize(60)
+    parameters.SetDigitSize(digitSize)
+
+    cc = GenCryptoContext(parameters)
+
+    N = cc.GetRingDimension()
+    print(f"CKKS scheme is using ring dimension {N}\n")
+
+    cc.Enable(PKESchemeFeature.PKE)
+    cc.Enable(PKESchemeFeature.KEYSWITCH)
+    cc.Enable(PKESchemeFeature.LEVELEDSHE)
+
+    keys = cc.KeyGen()
+    cc.EvalRotateKeyGen(keys.secretKey,[1,2,3,4,5,6,7])
+
+    # Input
+    x = [0, 0, 0, 0, 0, 0, 0, 1]
+    ptxt = cc.MakeCKKSPackedPlaintext(x)
+
+    print(f"Input x: {ptxt}")
+
+    c = cc.Encrypt(keys.publicKey,ptxt)
+
+    # First, we perform 7 regular (non-hoisted) rotations
+    # and measure the runtime
+    t = time.time()
+    cRot1 = cc.EvalRotate(c,1)
+    cRot2 = cc.EvalRotate(c,2)
+    cRot3 = cc.EvalRotate(c,3)
+    cRot4 = cc.EvalRotate(c,4)
+    cRot5 = cc.EvalRotate(c,5)
+    cRot6 = cc.EvalRotate(c,6)
+    cRot7 = cc.EvalRotate(c,7)
+    timeNoHoisting = time.time() - t
+
+    cResNoHoist = c + cRot1 + cRot2 + cRot3 + cRot4 + cRot5 + cRot6 + cRot7
+
+    # M is the cyclotomic order and we need it to call EvalFastRotation
+    M = 2*N
+
+    # Then, we perform 7 rotations with hoisting.
+    t = time.time()
+    cPrecomp = cc.EvalFastRotationPrecompute(c)
+    cRot1 = cc.EvalFastRotation(c, 1, M, cPrecomp)
+    cRot2 = cc.EvalFastRotation(c, 2, M, cPrecomp)
+    cRot3 = cc.EvalFastRotation(c, 3, M, cPrecomp)
+    cRot4 = cc.EvalFastRotation(c, 4, M, cPrecomp)
+    cRot5 = cc.EvalFastRotation(c, 5, M, cPrecomp)
+    cRot6 = cc.EvalFastRotation(c, 6, M, cPrecomp)
+    cRot7 = cc.EvalFastRotation(c, 7, M, cPrecomp)
+    timeHoisting = time.time() - t
+    # The time with hoisting should be faster than without hoisting.
+    # Also, the benefits from hoisting should be more pronounced in this
+    # case because we're using BV. Of course, we also observe less
+    # accurate results than when using HYBRID, because of using
+    # digitSize = 10 (Users can decrease digitSize to see the accuracy
+    # increase, and performance decrease).
+
+    cResHoist = c + cRot1 + cRot2 + cRot3 + cRot4 + cRot5 + cRot6 + cRot7
+
+    result = Decrypt(cResNoHoist,keys.secretKey)
+    result.SetLength(batchSize)
+    print(f"Result without hoisting: {result}")
+    print(f" - 7 rotations without hoisting took {timeNoHoisting*1000} ms")
+
+    result = Decrypt(cResHoist,keys.secretKey)
+    result.SetLength(batchSize)
+    print(f"Result with hoisting: {result}")
+    print(f" - 7 rotations with hoisting took {timeHoisting*1000} ms")
+
 
 def main():
     AutomaticRescaleDemo(ScalingTechnique.FLEXIBLEAUTO)