Browse Source

Adding rotation related functions

Hovsep Papoyan 6 months ago
parent
commit
2400a4e0cf

+ 1 - 1
src/AssociativeContainerOfOpaqueTypes.cc

@@ -13,7 +13,7 @@ std::unordered_map<uint32_t, DCRTPoly>& UnorderedMapFromIndexToDCRTPoly::GetInte
 }
 
 MapFromIndexToEvalKey::MapFromIndexToEvalKey(
-    std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>>
+    const std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>>
     sharedPtrToindexToEvalKeyDCRTPolyMap)
     : m_sharedPtrToindexToEvalKeyDCRTPolyMap(sharedPtrToindexToEvalKeyDCRTPolyMap)
 { }

+ 4 - 2
src/AssociativeContainerOfOpaqueTypes.h

@@ -33,9 +33,11 @@ class MapFromIndexToEvalKey final
         m_sharedPtrToindexToEvalKeyDCRTPolyMap;
 public:
     explicit MapFromIndexToEvalKey(
-        std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>> indexToEvalKeyDCRTPolyMap);
+        const std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>>
+        indexToEvalKeyDCRTPolyMap);
     [[nodiscard]] const std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>& GetInternalMap() const;
-    [[nodiscard]] std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>> GetInternal() const;
+    [[nodiscard]] std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>>
+	    GetInternal() const;
 };
 
 } // openfhe

+ 20 - 0
src/CryptoContext.cc

@@ -1003,6 +1003,26 @@ std::unique_ptr<MapFromIndexToEvalKey> CryptoContextDCRTPoly::MultiAddEvalAutomo
         m_cryptoContextImplSharedPtr->MultiAddEvalAutomorphismKeys(evalKeyMap1.GetInternal(),
         evalKeyMap2.GetInternal(), keyId));
 }
+std::unique_ptr<VectorOfDCRTPoly> CryptoContextDCRTPoly::EvalFastRotationPrecompute(
+    const CiphertextDCRTPoly& ciphertext) const
+{
+    return std::make_unique<VectorOfDCRTPoly>(
+        m_cryptoContextImplSharedPtr->EvalFastRotationPrecompute(ciphertext.GetInternal()));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalFastRotation(
+    const CiphertextDCRTPoly& ciphertext, const uint32_t index, const uint32_t m,
+    const VectorOfDCRTPoly& digits) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalFastRotation(
+        ciphertext.GetInternal(), index, m, digits.GetInternal()));
+}
+std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalFastRotationExt(
+    const CiphertextDCRTPoly& ciphertext, const uint32_t index, const VectorOfDCRTPoly& digits,
+    const bool addFirst) const
+{
+    return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalFastRotationExt(
+        ciphertext.GetInternal(), index, digits.GetInternal(), addFirst));
+}
 std::shared_ptr<CryptoContextImpl> CryptoContextDCRTPoly::GetInternal() const
 {
     return m_cryptoContextImplSharedPtr;

+ 9 - 0
src/CryptoContext.h

@@ -39,6 +39,7 @@ class PrivateKeyDCRTPoly;
 class PublicKeyDCRTPoly;
 class UnorderedMapFromIndexToDCRTPoly;
 class VectorOfCiphertexts;
+class VectorOfDCRTPoly;
 class VectorOfPrivateKeys;
 
 using SCHEME = lbcrypto::SCHEME;
@@ -414,6 +415,14 @@ public:
     [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> MultiAddEvalAutomorphismKeys(
         const MapFromIndexToEvalKey& evalKeyMap1, const MapFromIndexToEvalKey& evalKeyMap2,
         const std::string& keyId /* "" */) const;
+    [[nodiscard]] std::unique_ptr<VectorOfDCRTPoly> EvalFastRotationPrecompute(
+        const CiphertextDCRTPoly& ciphertext) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalFastRotation(
+        const CiphertextDCRTPoly& ciphertext, const uint32_t index, const uint32_t m,
+        const VectorOfDCRTPoly& digits) const;
+    [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalFastRotationExt(
+        const CiphertextDCRTPoly& ciphertext, const uint32_t index, const VectorOfDCRTPoly& digits,
+        const bool addFirst) const;
     [[nodiscard]] std::shared_ptr<CryptoContextImpl> GetInternal() const;
 };
 

+ 8 - 0
src/SequenceContainerOfOpaqueTypes.cc

@@ -23,4 +23,12 @@ const std::vector<std::shared_ptr<PrivateKeyImpl>>& VectorOfPrivateKeys::GetInte
     return m_privateKeys;
 }
 
+VectorOfDCRTPoly::VectorOfDCRTPoly(const std::shared_ptr<std::vector<lbcrypto::DCRTPoly>> elements)
+    : m_elements(elements)
+{ }
+std::shared_ptr<std::vector<lbcrypto::DCRTPoly>> VectorOfDCRTPoly::GetInternal() const
+{
+    return m_elements;
+}
+
 } // openfhe

+ 8 - 0
src/SequenceContainerOfOpaqueTypes.h

@@ -32,4 +32,12 @@ public:
     [[nodiscard]] const std::vector<std::shared_ptr<PrivateKeyImpl>>& GetInternal() const;
 };
 
+class VectorOfDCRTPoly final
+{
+    std::shared_ptr<std::vector<lbcrypto::DCRTPoly>> m_elements;
+public:
+    explicit VectorOfDCRTPoly(const std::shared_ptr<std::vector<lbcrypto::DCRTPoly>> elements);
+    [[nodiscard]] std::shared_ptr<std::vector<lbcrypto::DCRTPoly>> GetInternal() const;
+};
+
 } // openfhe

+ 11 - 0
src/lib.rs

@@ -192,6 +192,7 @@ pub mod ffi
         type MapFromIndexToEvalKey;
         type UnorderedMapFromIndexToDCRTPoly;
         type VectorOfCiphertexts;
+        type VectorOfDCRTPoly;
         type VectorOfPrivateKeys;
     }
 
@@ -931,6 +932,16 @@ pub mod ffi
                                         evalKeyMap2: &MapFromIndexToEvalKey,
                                         keyId: /* "" */ &CxxString)
                                         -> UniquePtr<MapFromIndexToEvalKey>;
+        fn EvalFastRotationPrecompute(self: &CryptoContextDCRTPoly,
+                                      ciphertext: &CiphertextDCRTPoly)
+                                      -> UniquePtr<VectorOfDCRTPoly>;
+        fn EvalFastRotation(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly,
+                            index: u32, m: u32, digits: &VectorOfDCRTPoly)
+                            -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalFastRotationExt(self: &CryptoContextDCRTPoly,
+                               ciphertext: &CiphertextDCRTPoly, index: u32,
+                               digits: &VectorOfDCRTPoly, addFirst: bool)
+                               -> UniquePtr<CiphertextDCRTPoly>;
 
         // cxx currently does not support static class methods
         fn ClearEvalMultKeys();