Browse Source

Adding second part of the map related functions

Hovsep Papoyan 3 months ago
parent
commit
38e7e713fd

+ 11 - 4
src/AssociativeContainerOfOpaqueTypes.cc

@@ -13,12 +13,19 @@ std::unordered_map<uint32_t, DCRTPoly>& UnorderedMapFromIndexToDCRTPoly::GetInte
 }
 
 MapFromIndexToEvalKey::MapFromIndexToEvalKey(
-    std::map<uint32_t, std::shared_ptr<EvalKeyImpl>> indexToEvalKeyDCRTPolyMap)
-    : m_indexToEvalKeyDCRTPolyMap(std::move(indexToEvalKeyDCRTPolyMap))
+    std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>>
+    sharedPtrToindexToEvalKeyDCRTPolyMap)
+    : m_sharedPtrToindexToEvalKeyDCRTPolyMap(sharedPtrToindexToEvalKeyDCRTPolyMap)
 { }
-const std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>& MapFromIndexToEvalKey::GetInternal() const
+const std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>&
+    MapFromIndexToEvalKey::GetInternalMap() const
 {
-    return m_indexToEvalKeyDCRTPolyMap;
+    return *m_sharedPtrToindexToEvalKeyDCRTPolyMap;
+}
+std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>>
+    MapFromIndexToEvalKey::GetInternal() const
+{
+    return m_sharedPtrToindexToEvalKeyDCRTPolyMap;
 }
 
 } // openfhe

+ 6 - 4
src/AssociativeContainerOfOpaqueTypes.h

@@ -29,11 +29,13 @@ using EvalKeyImpl = lbcrypto::EvalKeyImpl<lbcrypto::DCRTPoly>;
 
 class MapFromIndexToEvalKey final
 {
-    std::map<uint32_t, std::shared_ptr<EvalKeyImpl>> m_indexToEvalKeyDCRTPolyMap;
+    std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>>
+        m_sharedPtrToindexToEvalKeyDCRTPolyMap;
 public:
-    explicit MapFromIndexToEvalKey(std::map<uint32_t,
-        std::shared_ptr<EvalKeyImpl>> indexToEvalKeyDCRTPolyMap);
-    [[nodiscard]] const std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>& GetInternal() const;
+    explicit MapFromIndexToEvalKey(
+        std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>> indexToEvalKeyDCRTPolyMap);
+    [[nodiscard]] const std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>& GetInternalMap() const;
+    [[nodiscard]] std::shared_ptr<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>> GetInternal() const;
 };
 
 } // openfhe

+ 79 - 6
src/CryptoContext.cc

@@ -928,21 +928,80 @@ std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalAutomorphism(
     const MapFromIndexToEvalKey& evalKeyMap) const
 {
     return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalAutomorphism(
-        ciphertext.GetInternal(), i, evalKeyMap.GetInternal()));
+        ciphertext.GetInternal(), i, evalKeyMap.GetInternalMap()));
 }
 std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalSumRows(
     const CiphertextDCRTPoly& ciphertext, const uint32_t rowSize,
     const MapFromIndexToEvalKey& evalSumKeyMap, const uint32_t subringDim) const
 {
     return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalSumRows(
-        ciphertext.GetInternal(), rowSize, evalSumKeyMap.GetInternal(), subringDim));
+        ciphertext.GetInternal(), rowSize, evalSumKeyMap.GetInternalMap(), subringDim));
 }
 std::unique_ptr<CiphertextDCRTPoly> CryptoContextDCRTPoly::EvalSumCols(
     const CiphertextDCRTPoly& ciphertext, const uint32_t rowSize,
     const MapFromIndexToEvalKey& evalSumKeyMap) const
 {
     return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalSumCols(
-        ciphertext.GetInternal(), rowSize, evalSumKeyMap.GetInternal()));
+        ciphertext.GetInternal(), rowSize, evalSumKeyMap.GetInternalMap()));
+}
+std::unique_ptr<MapFromIndexToEvalKey> CryptoContextDCRTPoly::EvalAutomorphismKeyGen(
+    const PrivateKeyDCRTPoly& privateKey, const std::vector<uint32_t>& indexList) const
+{
+    return std::make_unique<MapFromIndexToEvalKey>(
+        m_cryptoContextImplSharedPtr->EvalAutomorphismKeyGen(privateKey.GetInternal(), indexList));
+}
+std::unique_ptr<MapFromIndexToEvalKey> CryptoContextDCRTPoly::EvalSumRowsKeyGen(
+    const PrivateKeyDCRTPoly& privateKey, const PublicKeyDCRTPoly& publicKey,
+    const uint32_t rowSize, const uint32_t subringDim) const
+{
+    return std::make_unique<MapFromIndexToEvalKey>(m_cryptoContextImplSharedPtr->EvalSumRowsKeyGen(
+        privateKey.GetInternal(), publicKey.GetInternal(), rowSize, subringDim));
+}
+std::unique_ptr<MapFromIndexToEvalKey> CryptoContextDCRTPoly::EvalSumColsKeyGen(
+    const PrivateKeyDCRTPoly& privateKey, const PublicKeyDCRTPoly& publicKey) const
+{
+    return std::make_unique<MapFromIndexToEvalKey>(m_cryptoContextImplSharedPtr->EvalSumColsKeyGen(
+        privateKey.GetInternal(), publicKey.GetInternal()));
+}
+std::unique_ptr<MapFromIndexToEvalKey> CryptoContextDCRTPoly::MultiEvalAutomorphismKeyGen(
+    const PrivateKeyDCRTPoly& privateKey, const MapFromIndexToEvalKey& evalKeyMap,
+    const std::vector<uint32_t>& indexList, const std::string& keyId) const
+{
+    return std::make_unique<MapFromIndexToEvalKey>(
+        m_cryptoContextImplSharedPtr->MultiEvalAutomorphismKeyGen(privateKey.GetInternal(),
+        evalKeyMap.GetInternal(), indexList, keyId));
+}
+std::unique_ptr<MapFromIndexToEvalKey> CryptoContextDCRTPoly::MultiEvalAtIndexKeyGen(
+    const PrivateKeyDCRTPoly& privateKey, const MapFromIndexToEvalKey& evalKeyMap,
+    const std::vector<int32_t>& indexList, const std::string& keyId) const
+{
+    return std::make_unique<MapFromIndexToEvalKey>(
+        m_cryptoContextImplSharedPtr->MultiEvalAtIndexKeyGen(privateKey.GetInternal(),
+        evalKeyMap.GetInternal(), indexList, keyId));
+}
+std::unique_ptr<MapFromIndexToEvalKey> CryptoContextDCRTPoly::MultiEvalSumKeyGen(
+    const PrivateKeyDCRTPoly& privateKey, const MapFromIndexToEvalKey& evalKeyMap,
+    const std::string& keyId) const
+{
+    return std::make_unique<MapFromIndexToEvalKey>(
+        m_cryptoContextImplSharedPtr->MultiEvalSumKeyGen(privateKey.GetInternal(),
+        evalKeyMap.GetInternal(), keyId));
+}
+std::unique_ptr<MapFromIndexToEvalKey> CryptoContextDCRTPoly::MultiAddEvalSumKeys(
+    const MapFromIndexToEvalKey& evalKeyMap1, const MapFromIndexToEvalKey& evalKeyMap2,
+    const std::string& keyId) const
+{
+    return std::make_unique<MapFromIndexToEvalKey>(
+        m_cryptoContextImplSharedPtr->MultiAddEvalSumKeys(evalKeyMap1.GetInternal(),
+        evalKeyMap2.GetInternal(), keyId));
+}
+std::unique_ptr<MapFromIndexToEvalKey> CryptoContextDCRTPoly::MultiAddEvalAutomorphismKeys(
+    const MapFromIndexToEvalKey& evalKeyMap1, const MapFromIndexToEvalKey& evalKeyMap2,
+    const std::string& keyId) const
+{
+    return std::make_unique<MapFromIndexToEvalKey>(
+        m_cryptoContextImplSharedPtr->MultiAddEvalAutomorphismKeys(evalKeyMap1.GetInternal(),
+        evalKeyMap2.GetInternal(), keyId));
 }
 std::shared_ptr<CryptoContextImpl> CryptoContextDCRTPoly::GetInternal() const
 {
@@ -1000,12 +1059,26 @@ std::unique_ptr<std::vector<uint32_t>> GetUniqueValues(const std::vector<uint32_
 }
 std::unique_ptr<MapFromIndexToEvalKey> GetEvalAutomorphismKeyMap(const std::string& keyID)
 {
-    return std::make_unique<MapFromIndexToEvalKey>(std::move(
-        CryptoContextImpl::GetEvalAutomorphismKeyMap(keyID)));
+    return std::make_unique<MapFromIndexToEvalKey>(
+        CryptoContextImpl::GetEvalAutomorphismKeyMapPtr(keyID));
 }
 std::unique_ptr<MapFromIndexToEvalKey> GetCopyOfEvalSumKeyMap(const std::string& id)
 {
-    return std::make_unique<MapFromIndexToEvalKey>(CryptoContextImpl::GetEvalSumKeyMap(id));
+    return std::make_unique<MapFromIndexToEvalKey>(
+        std::make_shared<std::map<uint32_t, std::shared_ptr<EvalKeyImpl>>>(
+        CryptoContextImpl::GetEvalSumKeyMap(id)));
+}
+std::unique_ptr<MapFromIndexToEvalKey> GetEvalAutomorphismKeyMapPtr(const std::string& keyID)
+{
+    return std::make_unique<MapFromIndexToEvalKey>(CryptoContextImpl::GetEvalAutomorphismKeyMapPtr(keyID));
+}
+void InsertEvalAutomorphismKey(const MapFromIndexToEvalKey& evalKeyMap, const std::string& keyTag)
+{
+    CryptoContextImpl::InsertEvalAutomorphismKey(evalKeyMap.GetInternal(), keyTag);
+}
+void InsertEvalSumKey(const MapFromIndexToEvalKey& mapToInsert, const std::string& keyTag)
+{
+    CryptoContextImpl::InsertEvalSumKey(mapToInsert.GetInternal(), keyTag);
 }
 
 // Generator functions

+ 29 - 0
src/CryptoContext.h

@@ -390,6 +390,30 @@ public:
     [[nodiscard]] std::unique_ptr<CiphertextDCRTPoly> EvalSumCols(
         const CiphertextDCRTPoly& ciphertext, const uint32_t rowSize,
         const MapFromIndexToEvalKey& evalSumKeyMap) const;
+    [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> EvalAutomorphismKeyGen(
+        const PrivateKeyDCRTPoly& privateKey, const std::vector<uint32_t>& indexList) const;
+    [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> EvalSumRowsKeyGen(
+        const PrivateKeyDCRTPoly& privateKey,
+        const PublicKeyDCRTPoly& publicKey /* GenNullPublicKey() */,
+        const uint32_t rowSize /* 0 */, const uint32_t subringDim /* 0 */) const;
+    [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> EvalSumColsKeyGen(
+        const PrivateKeyDCRTPoly& privateKey,
+        const PublicKeyDCRTPoly& publicKey /* GenNullPublicKey() */) const;
+    [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> MultiEvalAutomorphismKeyGen(
+        const PrivateKeyDCRTPoly& privateKey, const MapFromIndexToEvalKey& evalKeyMap,
+        const std::vector<uint32_t>& indexList, const std::string& keyId /* "" */) const;
+    [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> MultiEvalAtIndexKeyGen(
+        const PrivateKeyDCRTPoly& privateKey, const MapFromIndexToEvalKey& evalKeyMap,
+        const std::vector<int32_t>& indexList, const std::string& keyId /* "" */) const;
+    [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> MultiEvalSumKeyGen(
+        const PrivateKeyDCRTPoly& privateKey, const MapFromIndexToEvalKey& evalKeyMap,
+        const std::string& keyId /* "" */) const;
+    [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> MultiAddEvalSumKeys(
+        const MapFromIndexToEvalKey& evalKeyMap1, const MapFromIndexToEvalKey& evalKeyMap2,
+        const std::string& keyId /* "" */) const;
+    [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> MultiAddEvalAutomorphismKeys(
+        const MapFromIndexToEvalKey& evalKeyMap1, const MapFromIndexToEvalKey& evalKeyMap2,
+        const std::string& keyId /* "" */) const;
     [[nodiscard]] std::shared_ptr<CryptoContextImpl> GetInternal() const;
 };
 
@@ -410,6 +434,11 @@ void ClearEvalAutomorphismKeysByCryptoContext(const CryptoContextDCRTPoly& crypt
 [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> GetEvalAutomorphismKeyMap(
     const std::string& keyID);
 [[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> GetCopyOfEvalSumKeyMap(const std::string& id);
+[[nodiscard]] std::unique_ptr<MapFromIndexToEvalKey> GetEvalAutomorphismKeyMapPtr(
+    const std::string& keyID);
+void InsertEvalAutomorphismKey(const MapFromIndexToEvalKey& evalKeyMap,
+    const std::string& keyTag /* "" */);
+void InsertEvalSumKey(const MapFromIndexToEvalKey& mapToInsert, const std::string& keyTag /* "" */);
 
 // Generator functions
 [[nodiscard]] std::unique_ptr<CryptoContextDCRTPoly> GenNullCryptoContext();

+ 32 - 0
src/lib.rs

@@ -903,6 +903,34 @@ pub mod ffi
                        -> UniquePtr<CiphertextDCRTPoly>;
         fn EvalSumCols(self: &CryptoContextDCRTPoly, ciphertext: &CiphertextDCRTPoly, rowSize: u32,
                        evalSumKeyMap: &MapFromIndexToEvalKey) -> UniquePtr<CiphertextDCRTPoly>;
+        fn EvalAutomorphismKeyGen(self: &CryptoContextDCRTPoly, privateKey: &PrivateKeyDCRTPoly,
+                                  indexList: &CxxVector<u32>) -> UniquePtr<MapFromIndexToEvalKey>;
+        fn EvalSumRowsKeyGen(self: &CryptoContextDCRTPoly, privateKey: &PrivateKeyDCRTPoly,
+                             publicKey: /* GenNullPublicKey() */ &PublicKeyDCRTPoly,
+                             rowSize: /* 0 */ u32, subringDim: /* 0 */ u32)
+                             -> UniquePtr<MapFromIndexToEvalKey>;
+        fn EvalSumColsKeyGen(self: &CryptoContextDCRTPoly, privateKey: &PrivateKeyDCRTPoly,
+                             publicKey: /* GenNullPublicKey() */ &PublicKeyDCRTPoly)
+                             -> UniquePtr<MapFromIndexToEvalKey>;
+        fn MultiEvalAutomorphismKeyGen(self: &CryptoContextDCRTPoly,
+                                       privateKey: &PrivateKeyDCRTPoly,
+                                       evalKeyMap: &MapFromIndexToEvalKey,
+                                       indexList: &CxxVector<u32>, keyId: /* "" */ &CxxString)
+                                       -> UniquePtr<MapFromIndexToEvalKey>;
+        fn MultiEvalAtIndexKeyGen(self: &CryptoContextDCRTPoly, privateKey: &PrivateKeyDCRTPoly,
+                                  evalKeyMap: &MapFromIndexToEvalKey, indexList: &CxxVector<i32>,
+                                  keyId: /* "" */ &CxxString) -> UniquePtr<MapFromIndexToEvalKey>;
+        fn MultiEvalSumKeyGen(self: &CryptoContextDCRTPoly, privateKey: &PrivateKeyDCRTPoly,
+                              evalKeyMap: &MapFromIndexToEvalKey, keyId: /* "" */ &CxxString)
+                              -> UniquePtr<MapFromIndexToEvalKey>;
+        fn MultiAddEvalSumKeys(self: &CryptoContextDCRTPoly, evalKeyMap1: &MapFromIndexToEvalKey,
+                               evalKeyMap2: &MapFromIndexToEvalKey, keyId: /* "" */ &CxxString)
+                               -> UniquePtr<MapFromIndexToEvalKey>;
+        fn MultiAddEvalAutomorphismKeys(self: &CryptoContextDCRTPoly,
+                                        evalKeyMap1: &MapFromIndexToEvalKey,
+                                        evalKeyMap2: &MapFromIndexToEvalKey,
+                                        keyId: /* "" */ &CxxString)
+                                        -> UniquePtr<MapFromIndexToEvalKey>;
 
         // cxx currently does not support static class methods
         fn ClearEvalMultKeys();
@@ -919,6 +947,10 @@ pub mod ffi
                            -> UniquePtr<CxxVector<u32>>;
         fn GetEvalAutomorphismKeyMap(keyID: &CxxString) -> UniquePtr<MapFromIndexToEvalKey>;
         fn GetCopyOfEvalSumKeyMap(id: &CxxString) -> UniquePtr<MapFromIndexToEvalKey>;
+        fn GetEvalAutomorphismKeyMapPtr(keyID: &CxxString) -> UniquePtr<MapFromIndexToEvalKey>;
+        fn InsertEvalAutomorphismKey(evalKeyMap: &MapFromIndexToEvalKey,
+                                     keyTag: /* "" */ &CxxString);
+        fn InsertEvalSumKey(mapToInsert: &MapFromIndexToEvalKey, keyTag: /* "" */ &CxxString);
     }
 
     // Serialize / Deserialize