Browse Source

Adding new functions

Hovsep Papoyan 8 months ago
parent
commit
efc0829927
6 changed files with 193 additions and 0 deletions
  1. 3 0
      build.rs
  2. 81 0
      src/CryptoContext.cc
  3. 33 0
      src/CryptoContext.h
  4. 19 0
      src/EvalKey.cc
  5. 25 0
      src/EvalKey.h
  6. 32 0
      src/lib.rs

+ 3 - 0
build.rs

@@ -8,6 +8,7 @@ fn main()
         .file("src/Plaintext.cc")
         .file("src/PublicKey.cc")
         .file("src/SerialDeserial.cc")
+        .file("src/EvalKey.cc")
         .include("/usr/local/include/openfhe")
         .include("/usr/local/include/openfhe/third-party/include")
         .include("/usr/local/include/openfhe/core")
@@ -39,6 +40,8 @@ fn main()
     println!("cargo::rerun-if-changed=src/PublicKey.cc");
     println!("cargo::rerun-if-changed=src/SerialDeserial.h");
     println!("cargo::rerun-if-changed=src/SerialDeserial.cc");
+    println!("cargo::rerun-if-changed=src/EvalKey.h");
+    println!("cargo::rerun-if-changed=src/EvalKey.cc");
 
     // linking openFHE
     println!("cargo::rustc-link-arg=-L/usr/local/lib");

+ 81 - 0
src/CryptoContext.cc

@@ -11,6 +11,7 @@
 #include "KeyPair.h"
 #include "Plaintext.h"
 #include "PublicKey.h"
+#include "EvalKey.h"
 
 namespace openfhe
 {
@@ -671,6 +672,86 @@ std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalInnerProductByPla
     return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalInnerProduct(
         ciphertext.GetInternal(), plaintext.GetInternal(), batchSize));
 }
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::KeySwitch(
+    const CiphertextDCRTPoly& ciphertext, const EvalKeyDCRTPoly& evalKey) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->KeySwitch(
+        ciphertext.GetInternal(), evalKey.GetInternal()));
+}
+void CryptoContextDCRTPoly::KeySwitchInPlace(const CiphertextDCRTPoly& ciphertext,
+    const EvalKeyDCRTPoly& evalKey) const
+{
+    std::shared_ptr<CiphertextImpl> c = ciphertext.GetInternal();
+    m_cryptoContextImplSharedPtr->KeySwitchInPlace(c, evalKey.GetInternal());
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::LevelReduce(
+    const CiphertextDCRTPoly& ciphertext, const EvalKeyDCRTPoly& evalKey,
+    const size_t levels) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->LevelReduce(
+        ciphertext.GetInternal(), evalKey.GetInternal(), levels));
+}
+void CryptoContextDCRTPoly::LevelReduceInPlace(const CiphertextDCRTPoly& ciphertext,
+    const EvalKeyDCRTPoly& evalKey, const size_t levels) const
+{
+    std::shared_ptr<CiphertextImpl> c = ciphertext.GetInternal();
+    m_cryptoContextImplSharedPtr->LevelReduceInPlace(c, evalKey.GetInternal(), levels);
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::ReEncrypt(
+    const CiphertextDCRTPoly& ciphertext, const EvalKeyDCRTPoly& evalKey,
+    const std::shared_ptr<PublicKeyImpl> publicKey) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->ReEncrypt(
+        ciphertext.GetInternal(), evalKey.GetInternal(), publicKey));
+}
+
+std::unique_ptr<EvalKeyDCRTPoly> CryptoContextDCRTPoly::KeySwitchGen(
+    const std::shared_ptr<PrivateKeyImpl> oldPrivateKey,
+    const std::shared_ptr<PrivateKeyImpl> newPrivateKey) const
+{
+    return std::make_unique<EvalKeyDCRTPoly>(m_cryptoContextImplSharedPtr->KeySwitchGen(
+        oldPrivateKey, newPrivateKey));
+}
+std::unique_ptr<EvalKeyDCRTPoly> CryptoContextDCRTPoly::ReKeyGen(
+    const std::shared_ptr<PrivateKeyImpl> oldPrivateKey,
+    const std::shared_ptr<PublicKeyImpl> newPublicKey) const
+{
+    return std::make_unique<EvalKeyDCRTPoly>(m_cryptoContextImplSharedPtr->ReKeyGen(
+        oldPrivateKey, newPublicKey));
+}
+std::unique_ptr<EvalKeyDCRTPoly> CryptoContextDCRTPoly::MultiKeySwitchGen(
+    const std::shared_ptr<PrivateKeyImpl> originalPrivateKey,
+    const std::shared_ptr<PrivateKeyImpl> newPrivateKey, const EvalKeyDCRTPoly& evalKey) const
+{
+    return std::make_unique<EvalKeyDCRTPoly>(m_cryptoContextImplSharedPtr->MultiKeySwitchGen(
+        originalPrivateKey, newPrivateKey, evalKey.GetInternal()));
+}
+std::unique_ptr<EvalKeyDCRTPoly> CryptoContextDCRTPoly::MultiAddEvalKeys(
+    const EvalKeyDCRTPoly& evalKey1, const EvalKeyDCRTPoly& evalKey2,
+    const std::string& keyId /* "" */) const
+{
+    return std::make_unique<EvalKeyDCRTPoly>(m_cryptoContextImplSharedPtr->MultiAddEvalKeys(
+        evalKey1.GetInternal(), evalKey2.GetInternal(), keyId));
+}
+std::unique_ptr<EvalKeyDCRTPoly> CryptoContextDCRTPoly::MultiMultEvalKey(
+    const std::shared_ptr<PrivateKeyImpl> privateKey, const EvalKeyDCRTPoly& evalKey,
+    const std::string& keyId /* "" */) const
+{
+    return std::make_unique<EvalKeyDCRTPoly>(m_cryptoContextImplSharedPtr->MultiMultEvalKey(
+        privateKey, evalKey.GetInternal(), keyId));
+}
+std::unique_ptr<EvalKeyDCRTPoly> CryptoContextDCRTPoly::MultiAddEvalMultKeys(
+    const EvalKeyDCRTPoly& evalKey1, const EvalKeyDCRTPoly& evalKey2,
+    const std::string& keyId /* "" */) const
+{
+    return std::make_unique<EvalKeyDCRTPoly>(m_cryptoContextImplSharedPtr->MultiAddEvalMultKeys(
+        evalKey1.GetInternal(), evalKey2.GetInternal(), keyId));
+}
+void CryptoContextDCRTPoly::EvalSumKeyGen(const std::shared_ptr<PrivateKeyImpl> privateKey,
+    const std::shared_ptr<PublicKeyImpl> publicKey /* nullptr */) const
+{
+    m_cryptoContextImplSharedPtr->EvalSumKeyGen(privateKey, publicKey);
+}
 std::shared_ptr<CryptoContextImpl> CryptoContextDCRTPoly::GetInternal() const
 {
     return m_cryptoContextImplSharedPtr;

+ 33 - 0
src/CryptoContext.h

@@ -33,6 +33,7 @@ class KeyPairDCRTPoly;
 class PublicKeyDCRTPoly;
 class Plaintext;
 class CiphertextDCRTPoly;
+class EvalKeyDCRTPoly;
 
 using SCHEME = lbcrypto::SCHEME;
 using PKESchemeFeature = lbcrypto::PKESchemeFeature;
@@ -287,6 +288,38 @@ public:
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalInnerProductByPlaintext(
         const CiphertextDCRTPoly& ciphertext, const Plaintext& plaintext,
         const uint32_t batchSize) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> KeySwitch(
+        const CiphertextDCRTPoly& ciphertext, const EvalKeyDCRTPoly& evalKey) const;
+    void KeySwitchInPlace(const CiphertextDCRTPoly& ciphertext,
+        const EvalKeyDCRTPoly& evalKey) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> LevelReduce(
+        const CiphertextDCRTPoly& ciphertext, const EvalKeyDCRTPoly& evalKey,
+        const size_t levels /* 1 */) const;
+    void LevelReduceInPlace(const CiphertextDCRTPoly& ciphertext, const EvalKeyDCRTPoly& evalKey,
+        const size_t levels /* 1 */) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> ReEncrypt(
+        const CiphertextDCRTPoly& ciphertext, const EvalKeyDCRTPoly& evalKey,
+        const std::shared_ptr<PublicKeyImpl> publicKey /* nullptr */) const;
+    [[nodiscard]] std::unique_ptr<EvalKeyDCRTPoly> KeySwitchGen(
+        const std::shared_ptr<PrivateKeyImpl> oldPrivateKey,
+        const std::shared_ptr<PrivateKeyImpl> newPrivateKey) const;
+    [[nodiscard]] std::unique_ptr<EvalKeyDCRTPoly> ReKeyGen(
+        const std::shared_ptr<PrivateKeyImpl> oldPrivateKey,
+        const std::shared_ptr<PublicKeyImpl> newPublicKey) const;
+    [[nodiscard]] std::unique_ptr<EvalKeyDCRTPoly> MultiKeySwitchGen(
+        const std::shared_ptr<PrivateKeyImpl> originalPrivateKey,
+        const std::shared_ptr<PrivateKeyImpl> newPrivateKey, const EvalKeyDCRTPoly& evalKey) const;
+    [[nodiscard]] std::unique_ptr<EvalKeyDCRTPoly> MultiAddEvalKeys(
+        const EvalKeyDCRTPoly& evalKey1, const EvalKeyDCRTPoly& evalKey2,
+        const std::string& keyId /* "" */) const;
+    [[nodiscard]] std::unique_ptr<EvalKeyDCRTPoly> MultiMultEvalKey(
+        const std::shared_ptr<PrivateKeyImpl> privateKey, const EvalKeyDCRTPoly& evalKey,
+        const std::string& keyId /* "" */) const;
+    [[nodiscard]] std::unique_ptr<EvalKeyDCRTPoly> MultiAddEvalMultKeys(
+        const EvalKeyDCRTPoly& evalKey1, const EvalKeyDCRTPoly& evalKey2,
+        const std::string& keyId /* "" */) const;
+    void EvalSumKeyGen(const std::shared_ptr<PrivateKeyImpl> privateKey,
+        const std::shared_ptr<PublicKeyImpl> publicKey /* nullptr */) const;
     [[nodiscard]] std::shared_ptr<CryptoContextImpl> GetInternal() const;
 };
 // cxx currently does not support static class methods

+ 19 - 0
src/EvalKey.cc

@@ -0,0 +1,19 @@
+#include "EvalKey.h"
+
+#include "openfhe/pke/key/evalkey.h"
+
+namespace openfhe
+{
+
+EvalKeyDCRTPoly::EvalKeyDCRTPoly()
+    : m_evalKey(std::make_shared<EvalKeyImpl>())
+{ }
+EvalKeyDCRTPoly::EvalKeyDCRTPoly(const std::shared_ptr<EvalKeyImpl>& evalKey)
+    : m_evalKey(evalKey)
+{ }
+std::shared_ptr<EvalKeyImpl> EvalKeyDCRTPoly::GetInternal() const
+{
+    return m_evalKey;
+}
+
+} // openfhe

+ 25 - 0
src/EvalKey.h

@@ -0,0 +1,25 @@
+#pragma once
+
+#include "openfhe/core/lattice/hal/lat-backend.h"
+#include "openfhe/pke/key/evalkey-fwd.h"
+
+namespace openfhe
+{
+
+using EvalKeyImpl = lbcrypto::EvalKeyImpl<lbcrypto::DCRTPoly>;
+
+class EvalKeyDCRTPoly final
+{
+    std::shared_ptr<EvalKeyImpl> m_evalKey;
+public:
+    explicit EvalKeyDCRTPoly();
+    explicit EvalKeyDCRTPoly(const std::shared_ptr<EvalKeyImpl>& evalKey);
+    EvalKeyDCRTPoly(const EvalKeyDCRTPoly&) = delete;
+    EvalKeyDCRTPoly(EvalKeyDCRTPoly&&) = delete;
+    EvalKeyDCRTPoly& operator=(const EvalKeyDCRTPoly&) = delete;
+    EvalKeyDCRTPoly& operator=(EvalKeyDCRTPoly&&) = delete;
+
+    [[nodiscard]] std::shared_ptr<EvalKeyImpl> GetInternal() const;
+};
+
+} // openfhe

+ 32 - 0
src/lib.rs

@@ -153,6 +153,7 @@ pub mod ffi
         include!("openfhe/src/Plaintext.h");
         include!("openfhe/src/PublicKey.h");
         include!("openfhe/src/SerialDeserial.h");
+        include!("openfhe/src/EvalKey.h");
 
         type COMPRESSION_LEVEL;
         type DecryptionNoiseMode;
@@ -183,6 +184,7 @@ pub mod ffi
         type PrivateKeyImpl;
         type PublicKeyDCRTPoly;
         type PublicKeyImpl;
+        type EvalKeyDCRTPoly;
     }
 
     // Params
@@ -794,6 +796,36 @@ pub mod ffi
         fn EvalInnerProductByPlaintext(self: &CryptoContextDCRTPoly,
                                        ciphertext: &CiphertextDCRTPoly, plaintext: &Plaintext,
                                        batchSize: u32) -> UniquePtr<CiphertextDCRTPoly>;
+        fn KeySwitch(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                     evalKey: &EvalKeyDCRTPoly) -> UniquePtr<CiphertextDCRTPoly>;
+        fn KeySwitchInPlace(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                            evalKey: &EvalKeyDCRTPoly);
+        fn LevelReduce(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                       evalKey: &EvalKeyDCRTPoly, levels: /* 1 */ usize)
+                       -> UniquePtr<CiphertextDCRTPoly>;
+        fn LevelReduceInPlace(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                              evalKey: &EvalKeyDCRTPoly, levels: /* 1 */ usize);
+        fn ReEncrypt(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                     evalKey: &EvalKeyDCRTPoly, publicKey: /* null() */ SharedPtr<PublicKeyImpl>)
+                     -> UniquePtr<CiphertextDCRTPoly>;
+        fn KeySwitchGen(self: &CryptoContextDCRTPoly, oldPrivateKey: SharedPtr<PrivateKeyImpl>,
+                        newPrivateKey: SharedPtr<PrivateKeyImpl>) -> UniquePtr<EvalKeyDCRTPoly>;
+        fn ReKeyGen(self: &CryptoContextDCRTPoly, oldPrivateKey: SharedPtr<PrivateKeyImpl>,
+                    newPublicKey: SharedPtr<PublicKeyImpl>) -> UniquePtr<EvalKeyDCRTPoly>;
+        fn MultiKeySwitchGen(self: &CryptoContextDCRTPoly, originalPrivateKey: SharedPtr<PrivateKeyImpl>,
+                             newPrivateKey: SharedPtr<PrivateKeyImpl>, evalKey: &EvalKeyDCRTPoly)
+                             -> UniquePtr<EvalKeyDCRTPoly>;
+        fn MultiAddEvalKeys(self: &CryptoContextDCRTPoly, evalKey1: &EvalKeyDCRTPoly,
+                            evalKey2: &EvalKeyDCRTPoly, keyId: /* "" */ &CxxString)
+                            -> UniquePtr<EvalKeyDCRTPoly>;
+        fn MultiMultEvalKey(self: &CryptoContextDCRTPoly, privateKey: SharedPtr<PrivateKeyImpl>,
+                            evalKey: &EvalKeyDCRTPoly, keyId: &CxxString /* "" */)
+                            -> UniquePtr<EvalKeyDCRTPoly>;
+        fn MultiAddEvalMultKeys(self: &CryptoContextDCRTPoly, evalKey1: &EvalKeyDCRTPoly,
+                                evalKey2: &EvalKeyDCRTPoly, keyId: /* "" */ &CxxString)
+                                -> UniquePtr<EvalKeyDCRTPoly>;
+        fn EvalSumKeyGen(self: &CryptoContextDCRTPoly, privateKey: SharedPtr<PrivateKeyImpl>,
+                         publicKey: /* null() */ SharedPtr<PublicKeyImpl>);
 
         // cxx currently does not support static class methods
         fn ClearEvalMultKeys();