Pārlūkot izejas kodu

Adding 22 new functions

Hovsep Papoyan 9 mēneši atpakaļ
vecāks
revīzija
a7ccaafd2c
3 mainītis faili ar 200 papildinājumiem un 4 dzēšanām
  1. 122 1
      src/bindings.cc
  2. 41 3
      src/bindings.hpp
  3. 37 0
      src/lib.rs

+ 122 - 1
src/bindings.cc

@@ -109,14 +109,55 @@ std::unique_ptr<Plaintext> CryptoContextDCRTPoly::MakeCKKSPackedPlaintextByVecto
     return std::make_unique<Plaintext>(m_cryptoContextImplSharedPtr->MakeCKKSPackedPlaintext(
         v, scaleDeg, level, params, slots));
 }
+void CryptoContextDCRTPoly::SetSchemeId(const SCHEME schemeTag) const
+{
+    m_cryptoContextImplSharedPtr->setSchemeId(schemeTag);
+}
+SCHEME CryptoContextDCRTPoly::GetSchemeId() const
+{
+    return m_cryptoContextImplSharedPtr->getSchemeId();
+}
+size_t CryptoContextDCRTPoly::GetKeyGenLevel() const
+{
+    return m_cryptoContextImplSharedPtr->GetKeyGenLevel();
+}
+void CryptoContextDCRTPoly::SetKeyGenLevel(const size_t level) const
+{
+    m_cryptoContextImplSharedPtr->SetKeyGenLevel(level);
+}
+void CryptoContextDCRTPoly::SetSwkFC(const CiphertextDCRTPoly& FHEWtoCKKSswk) const
+{
+    m_cryptoContextImplSharedPtr->SetSwkFC(FHEWtoCKKSswk.GetInternal());
+}
+void CryptoContextDCRTPoly::EvalCompareSwitchPrecompute(const uint32_t pLWE,
+    const double scaleSign, const bool unit) const
+{
+    m_cryptoContextImplSharedPtr->EvalCompareSwitchPrecompute(pLWE, scaleSign, unit);
+}
+uint32_t CryptoContextDCRTPoly::FindAutomorphismIndex(const uint32_t idx) const
+{
+    return m_cryptoContextImplSharedPtr->FindAutomorphismIndex(idx);
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::GetSwkFC() const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->GetSwkFC());
+}
 void CryptoContextDCRTPoly::Enable(const PKESchemeFeature feature) const
 {
     m_cryptoContextImplSharedPtr->Enable(feature);
 }
+void CryptoContextDCRTPoly::EnableByMask(const uint32_t featureMask) const
+{
+    m_cryptoContextImplSharedPtr->Enable(featureMask);
+}
 std::unique_ptr<KeyPairDCRTPoly> CryptoContextDCRTPoly::KeyGen() const
 {
     return std::make_unique<KeyPairDCRTPoly>(m_cryptoContextImplSharedPtr->KeyGen());
 }
+std::unique_ptr<KeyPairDCRTPoly> CryptoContextDCRTPoly::SparseKeyGen() const
+{
+    return std::make_unique<KeyPairDCRTPoly>(m_cryptoContextImplSharedPtr->SparseKeyGen());
+}
 void CryptoContextDCRTPoly::EvalMultKeyGen(const std::shared_ptr<PrivateKeyImpl> key) const
 {
     m_cryptoContextImplSharedPtr->EvalMultKeyGen(key);
@@ -232,6 +273,42 @@ std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalSum(
     return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalSum(
         ciphertext.GetInternal(), batchSize));
 }
+
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalPolyLinear(
+    const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalPolyLinear(
+        ciphertext.GetInternal(), coefficients));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalPolyPS(
+    const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalPolyPS(
+        ciphertext.GetInternal(), coefficients));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalChebyshevSeriesLinear(
+    const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients,
+    const double a, const double b) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(
+        m_cryptoContextImplSharedPtr->EvalChebyshevSeriesLinear(ciphertext.GetInternal(),
+        coefficients, a, b));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalChebyshevSeriesPS(
+    const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients,
+    const double a, const double b) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(
+        m_cryptoContextImplSharedPtr->EvalChebyshevSeriesPS(ciphertext.GetInternal(), coefficients,
+        a, b));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalDivide(
+    const CiphertextDCRTPoly& ciphertext, const double a, const double b,
+    const uint32_t degree) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalDivide(
+        ciphertext.GetInternal(), a, b, degree));
+}
 std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::IntMPBootAdjustScale(
     const CiphertextDCRTPoly& ciphertext) const
 {
@@ -269,7 +346,7 @@ std::unique_ptr<DecryptResult> CryptoContextDCRTPoly::Decrypt(
 {
     std::shared_ptr<PlaintextImpl> res;
     std::unique_ptr<DecryptResult> result = std::make_unique<DecryptResult>(
-        m_cryptoContextImplSharedPtr->Decrypt(privateKey, ciphertext.GetInternal(), &res));
+    m_cryptoContextImplSharedPtr->Decrypt(privateKey, ciphertext.GetInternal(), &res));
     plaintext = res;
     return result;
 }
@@ -281,6 +358,20 @@ uint32_t CryptoContextDCRTPoly::GetCyclotomicOrder() const
 {
     return m_cryptoContextImplSharedPtr->GetCyclotomicOrder();
 }
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalSin(
+    const CiphertextDCRTPoly& ciphertext,
+    const double a, const double b, const uint32_t degree) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalSin(
+        ciphertext.GetInternal(), a, b, degree));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalCos(
+    const CiphertextDCRTPoly& ciphertext,
+    const double a, const double b, const uint32_t degree) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalCos(
+        ciphertext.GetInternal(), a, b, degree));
+}
 std::unique_ptr<Plaintext> CryptoContextDCRTPoly::MakeCKKSPackedPlaintext(
     const std::vector<double>& value, const size_t scaleDeg, const uint32_t level,
     const std::shared_ptr<DCRTPolyParams> params, const uint32_t slots) const
@@ -294,6 +385,36 @@ std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalPoly(
     return std::make_unique<CiphertextDCRTPoly>(
         m_cryptoContextImplSharedPtr->EvalPoly(ciphertext.GetInternal(), coefficients));
 }
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalNegate(
+    const CiphertextDCRTPoly& ciphertext) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(
+        m_cryptoContextImplSharedPtr->EvalNegate(ciphertext.GetInternal()));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalSquare(
+    const CiphertextDCRTPoly& ciphertext) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(
+        m_cryptoContextImplSharedPtr->EvalSquare(ciphertext.GetInternal()));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalAtIndex(
+    const CiphertextDCRTPoly& ciphertext, const uint32_t index) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(
+        m_cryptoContextImplSharedPtr->EvalAtIndex(ciphertext.GetInternal(), index));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::ComposedEvalMult(
+    const CiphertextDCRTPoly& ciphertext1, const CiphertextDCRTPoly& ciphertext2) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(
+        m_cryptoContextImplSharedPtr->ComposedEvalMult(ciphertext1.GetInternal(), ciphertext2.GetInternal()));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::Relinearize(
+    const CiphertextDCRTPoly& ciphertext) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(
+        m_cryptoContextImplSharedPtr->Relinearize(ciphertext.GetInternal()));
+}
 
 ///////////////////////////////////////////////////////////////////////////////////////////////////
 

+ 41 - 3
src/bindings.hpp

@@ -152,8 +152,19 @@ public:
     CryptoContextDCRTPoly& operator=(const CryptoContextDCRTPoly&) = delete;
     CryptoContextDCRTPoly& operator=(CryptoContextDCRTPoly&&) = delete;
 
+    void SetSchemeId(const SCHEME schemeTag) const;
+    [[nodiscard]] SCHEME GetSchemeId() const;
+    [[nodiscard]] size_t GetKeyGenLevel() const;
+    void SetKeyGenLevel(const size_t level) const;
+    void SetSwkFC(const CiphertextDCRTPoly& FHEWtoCKKSswk) const;
+    void EvalCompareSwitchPrecompute(const uint32_t pLWE /* 0 */, const double scaleSign /* 1.0 */,
+        const bool unit /* false */) const;
+    [[nodiscard]] uint32_t FindAutomorphismIndex(const uint32_t idx) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> GetSwkFC() const;
     void Enable(const PKESchemeFeature feature) const;
+    void EnableByMask(const uint32_t featureMask) const;
     [[nodiscard]] std::unique_ptr<KeyPairDCRTPoly> KeyGen() const;
+    [[nodiscard]] std::unique_ptr<KeyPairDCRTPoly> SparseKeyGen() const;
     void EvalMultKeyGen(const std::shared_ptr<PrivateKeyImpl> key) const;
     void EvalMultKeysGen(const std::shared_ptr<PrivateKeyImpl> key) const;
     void EvalRotateKeyGen(
@@ -162,6 +173,10 @@ public:
     void EvalCKKStoFHEWPrecompute(const double scale /* 1.0 */) const;
     [[nodiscard]] uint32_t GetRingDimension() const;
     [[nodiscard]] uint32_t GetCyclotomicOrder() const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalSin(const CiphertextDCRTPoly& ciphertext,
+        const double a, const double b, const uint32_t degree) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalCos(const CiphertextDCRTPoly& ciphertext,
+        const double a, const double b, const uint32_t degree) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> Encrypt(
         const std::shared_ptr<PublicKeyImpl> publicKey, const Plaintext& plaintext) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalAdd(
@@ -180,8 +195,18 @@ public:
         const CiphertextDCRTPoly& ciphertext, const int32_t index) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalPoly(
         const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalNegate(
+        const CiphertextDCRTPoly& ciphertext) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalSquare(
+        const CiphertextDCRTPoly& ciphertext) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalAtIndex(
+        const CiphertextDCRTPoly& ciphertext, const uint32_t index) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> ComposedEvalMult(
+        const CiphertextDCRTPoly& ciphertext1, const CiphertextDCRTPoly& ciphertext2) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> Relinearize(
+        const CiphertextDCRTPoly& ciphertext) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalChebyshevSeries(
-   	    const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients,
+        const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients,
         const double a, const double b) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalChebyshevFunction(
         rust::Fn<void(const double x, double& ret)> func, const CiphertextDCRTPoly& ciphertext,
@@ -190,13 +215,26 @@ public:
         const CiphertextDCRTPoly& ciphertext, const uint32_t numIterations /* 1 */,
         const uint32_t precision /* 0 */) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> Rescale(
-   	    const CiphertextDCRTPoly& ciphertext) const;
+        const CiphertextDCRTPoly& ciphertext) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> ModReduce(
         const CiphertextDCRTPoly& ciphertext) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalSum(const CiphertextDCRTPoly& ciphertext,
         const uint32_t batchSize) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalPolyLinear(
+        const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalPolyPS(
+        const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalChebyshevSeriesLinear(
+        const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients,
+        const double a, const double b) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalChebyshevSeriesPS(
+        const CiphertextDCRTPoly& ciphertext, const std::vector<double>& coefficients,
+        const double a, const double b) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalDivide(
+        const CiphertextDCRTPoly& ciphertext, const double a, const double b,
+        const uint32_t degree) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> IntMPBootAdjustScale(
-   	    const CiphertextDCRTPoly& ciphertext) const;
+        const CiphertextDCRTPoly& ciphertext) const;
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalLogistic(
         const CiphertextDCRTPoly& ciphertext, const double a, const double b,
         const uint32_t degree) const;

+ 37 - 0
src/lib.rs

@@ -558,6 +558,43 @@ pub mod ffi
                                                     slots: /* 0 */ u32) -> UniquePtr<Plaintext>;
         fn EvalPoly(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
                     coefficients: &CxxVector<f64>) -> UniquePtr<CiphertextDCRTPoly>;
+        fn SetSchemeId(self: &CryptoContextDCRTPoly, schemeTag: SCHEME);
+        fn GetSchemeId(self: &CryptoContextDCRTPoly) -> SCHEME;
+        fn GetKeyGenLevel(self: &CryptoContextDCRTPoly) -> usize;
+        fn SetKeyGenLevel(self: &CryptoContextDCRTPoly, level: usize);
+        fn SetSwkFC(self: &CryptoContextDCRTPoly, FHEWtoCKKSswk: &CiphertextDCRTPoly);
+        fn EvalCompareSwitchPrecompute(self: &CryptoContextDCRTPoly, pLWE: u32, scaleSign: f64,
+                                       unit: bool);
+        fn FindAutomorphismIndex(self: &CryptoContextDCRTPoly, idx: u32) -> u32;
+        fn GetSwkFC(self: &CryptoContextDCRTPoly) -> UniquePtr<CiphertextDCRTPoly>;
+        fn EnableByMask(self: &CryptoContextDCRTPoly, featureMask: u32);
+        fn SparseKeyGen(self: &CryptoContextDCRTPoly) -> UniquePtr<KeyPairDCRTPoly>;
+        fn EvalPolyLinear(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                          coefficients: &CxxVector<f64>)-> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalPolyPS(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                      coefficients: &CxxVector<f64>) -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalChebyshevSeriesLinear(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                                     coefficients: &CxxVector<f64>, a: f64, b: f64)
+                                     -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalChebyshevSeriesPS(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                                 coefficients: &CxxVector<f64>, a: f64, b: f64)
+                                 -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalDivide(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly, a: f64,
+                      b: f64, degree: u32) -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalSin(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly, a: f64, b: f64,
+                   degree: u32) -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalCos(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly, a: f64, b: f64,
+                   degree: u32) -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalNegate(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly)
+                      -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalSquare(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly)
+                      -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalAtIndex(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly, index: u32)
+                       -> UniquePtr<CiphertextDCRTPoly>;
+        fn ComposedEvalMult(self: &CryptoContextDCRTPoly, ciphertext1: &CiphertextDCRTPoly,
+                            ciphertext2: &CiphertextDCRTPoly) -> UniquePtr<CiphertextDCRTPoly>;
+        fn Relinearize(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly)
+                       -> UniquePtr<CiphertextDCRTPoly>;
     }
 
     // Serialize / Deserialize