Browse Source

Redesign SerialDeserial

Hovsep Papoyan 1 year ago
parent
commit
494c4f9005
8 changed files with 201 additions and 239 deletions
  1. 8 0
      src/Ciphertext.cc
  2. 2 7
      src/Ciphertext.h
  3. 8 0
      src/CryptoContext.cc
  4. 4 15
      src/CryptoContext.h
  5. 8 0
      src/PublicKey.cc
  6. 2 7
      src/PublicKey.h
  7. 134 188
      src/SerialDeserial.cc
  8. 35 22
      src/SerialDeserial.h

+ 8 - 0
src/Ciphertext.cc

@@ -12,6 +12,14 @@ std::shared_ptr<CiphertextImpl> CiphertextDCRTPoly::GetInternal() const noexcept
 {
     return m_ciphertext;
 }
+std::shared_ptr<CiphertextImpl>& CiphertextDCRTPoly::GetRef() noexcept
+{
+    return m_ciphertext;
+}
+const std::shared_ptr<CiphertextImpl>& CiphertextDCRTPoly::GetRef() const noexcept
+{
+    return m_ciphertext;
+}
 
 // Generator functions
 std::unique_ptr<CiphertextDCRTPoly> GenNullCiphertext()

+ 2 - 7
src/Ciphertext.h

@@ -3,8 +3,6 @@
 #include "openfhe/core/lattice/hal/lat-backend.h"
 #include "openfhe/pke/ciphertext-fwd.h"
 
-#include "SerialMode.h"
-
 namespace openfhe
 {
 
@@ -14,11 +12,6 @@ class CiphertextDCRTPoly final
 {
     std::shared_ptr<CiphertextImpl> m_ciphertext;
 public:
-    friend bool SerializeCiphertextToFile(const std::string& ciphertextLocation,
-        const CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
-    friend bool DeserializeCiphertextFromFile(const std::string& ciphertextLocation,
-        CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
-
     CiphertextDCRTPoly() = default;
     CiphertextDCRTPoly(std::shared_ptr<CiphertextImpl>&& ciphertext) noexcept;
     CiphertextDCRTPoly(const CiphertextDCRTPoly&) = delete;
@@ -27,6 +20,8 @@ public:
     CiphertextDCRTPoly& operator=(CiphertextDCRTPoly&&) = delete;
 
     [[nodiscard]] std::shared_ptr<CiphertextImpl> GetInternal() const noexcept;
+    [[nodiscard]] std::shared_ptr<CiphertextImpl>& GetRef() noexcept;
+    [[nodiscard]] const std::shared_ptr<CiphertextImpl>& GetRef() const noexcept;
 };
 
 // Generator functions

+ 8 - 0
src/CryptoContext.cc

@@ -1071,6 +1071,14 @@ std::shared_ptr<CryptoContextImpl> CryptoContextDCRTPoly::GetInternal() const
 {
     return m_cryptoContextImplSharedPtr;
 }
+std::shared_ptr<CryptoContextImpl>& CryptoContextDCRTPoly::GetRef() noexcept
+{
+    return m_cryptoContextImplSharedPtr;
+}
+const std::shared_ptr<CryptoContextImpl>& CryptoContextDCRTPoly::GetRef() const noexcept
+{
+    return m_cryptoContextImplSharedPtr;
+}
 
 // cxx currently does not support static class methods
 void ClearEvalMultKeys()

+ 4 - 15
src/CryptoContext.h

@@ -7,8 +7,6 @@
 
 #include "rust/cxx.h"
 
-#include "SerialMode.h"
-
 namespace lbcrypto
 {
 
@@ -64,17 +62,6 @@ class CryptoContextDCRTPoly final
 {
     std::shared_ptr<CryptoContextImpl> m_cryptoContextImplSharedPtr;
 public:
-    friend bool SerializeCryptoContextToFile(const std::string& ccLocation,
-        const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-    friend bool DeserializeCryptoContextFromFile(const std::string& ccLocation,
-        CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-    friend bool SerializeEvalMultKeyToFile(const std::string& multKeyLocation,
-        const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-    friend bool SerializeEvalSumKeyToFile(const std::string& sumKeyLocation,
-        const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-    friend bool SerializeEvalAutomorphismKeyToFile(const std::string& automorphismKeyLocation,
-        const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-
     CryptoContextDCRTPoly() = default;
     explicit CryptoContextDCRTPoly(const ParamsBFVRNS& params);
     explicit CryptoContextDCRTPoly(const ParamsBGVRNS& params);
@@ -289,11 +276,11 @@ public:
         const uint32_t level /* 0 */) const;
     [[nodiscard]] std::unique_ptr<Plaintext> MakeCKKSPackedPlaintext(
         const std::vector<double>& value, const size_t scaleDeg /* 1 */,
-        const uint32_t level /* 0 */, const DCRTPolyParams& params /* GenNullDCRTPolyParams */,
+        const uint32_t level /* 0 */, const DCRTPolyParams& params /* GenNullDCRTPolyParams() */,
         const uint32_t slots /* 0 */) const;
     [[nodiscard]] std::unique_ptr<Plaintext> MakeCKKSPackedPlaintextByVectorOfComplex(
         const std::vector<ComplexPair>& value, const size_t scaleDeg /* 1 */,
-        const uint32_t level /* 0 */, const DCRTPolyParams& params /* GenNullDCRTPolyParams */,
+        const uint32_t level /* 0 */, const DCRTPolyParams& params /* GenNullDCRTPolyParams() */,
         const uint32_t slots /* 0 */) const;
     [[nodiscard]] std::unique_ptr<std::vector<uint32_t>> FindAutomorphismIndices(
         const std::vector<uint32_t>& idxList) const;
@@ -442,6 +429,8 @@ public:
         const CiphertextDCRTPoly& ciphertext) const;
     [[nodiscard]] std::unique_ptr<DCRTPolyParams> GetElementParams() const;
     [[nodiscard]] std::shared_ptr<CryptoContextImpl> GetInternal() const;
+    [[nodiscard]] std::shared_ptr<CryptoContextImpl>& GetRef() noexcept;
+    [[nodiscard]] const std::shared_ptr<CryptoContextImpl>& GetRef() const noexcept;
 };
 
 // cxx currently does not support static class methods

+ 8 - 0
src/PublicKey.cc

@@ -12,6 +12,14 @@ std::shared_ptr<PublicKeyImpl> PublicKeyDCRTPoly::GetInternal() const noexcept
 {
     return m_publicKey;
 }
+std::shared_ptr<PublicKeyImpl>& PublicKeyDCRTPoly::GetRef() noexcept
+{
+    return m_publicKey;
+}
+const std::shared_ptr<PublicKeyImpl>& PublicKeyDCRTPoly::GetRef() const noexcept
+{
+    return m_publicKey;
+}
 
 // Generator functions
 std::unique_ptr<PublicKeyDCRTPoly> GenNullPublicKey()

+ 2 - 7
src/PublicKey.h

@@ -3,8 +3,6 @@
 #include "openfhe/core/lattice/hal/lat-backend.h"
 #include "openfhe/pke/key/publickey-fwd.h"
 
-#include "SerialMode.h"
-
 namespace openfhe
 {
 
@@ -14,11 +12,6 @@ class PublicKeyDCRTPoly final
 {
     std::shared_ptr<PublicKeyImpl> m_publicKey;
 public:
-    friend bool SerializePublicKeyToFile(const std::string& publicKeyLocation,
-        const PublicKeyDCRTPoly& publicKey, const SerialMode serialMode);
-    friend bool DeserializePublicKeyFromFile(const std::string& publicKeyLocation,
-        PublicKeyDCRTPoly& publicKey, const SerialMode serialMode);
-
     PublicKeyDCRTPoly() = default;
     PublicKeyDCRTPoly(const std::shared_ptr<PublicKeyImpl>& publicKey) noexcept;
     PublicKeyDCRTPoly(const PublicKeyDCRTPoly&) = delete;
@@ -27,6 +20,8 @@ public:
     PublicKeyDCRTPoly& operator=(PublicKeyDCRTPoly&&) = delete;
 
     [[nodiscard]] std::shared_ptr<PublicKeyImpl> GetInternal() const noexcept;
+    [[nodiscard]] std::shared_ptr<PublicKeyImpl>& GetRef() noexcept;
+    [[nodiscard]] const std::shared_ptr<PublicKeyImpl>& GetRef() const noexcept;
 };
 
 // Generator functions

+ 134 - 188
src/SerialDeserial.cc

@@ -9,286 +9,232 @@
 namespace openfhe
 {
 
-bool SerializeCryptoContextToFile(const std::string& ccLocation,
-    const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
+template <typename ST, typename Object>
+[[nodiscard]] bool SerialDeserial(const std::string& location,
+    bool (* const funcPtr) (const std::string&, Object&, const ST&), Object& object)
+{
+    return funcPtr(location, object, ST{});
+}
+template <typename Object>
+[[nodiscard]] bool Serial(const std::string& location, Object& object, const SerialMode serialMode)
 {
     if (serialMode == SerialMode::BINARY)
     {
-        return lbcrypto::Serial::SerializeToFile(ccLocation,
-            cryptoContext.m_cryptoContextImplSharedPtr, lbcrypto::SerType::BINARY);
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, decltype(object.GetRef())>(location,
+            lbcrypto::Serial::SerializeToFile, object.GetRef());
     }
     if (serialMode == SerialMode::JSON)
     {
-        return lbcrypto::Serial::SerializeToFile(ccLocation,
-            cryptoContext.m_cryptoContextImplSharedPtr, lbcrypto::SerType::JSON);
+        return SerialDeserial<lbcrypto::SerType::SERJSON, decltype(object.GetRef())>(location,
+            lbcrypto::Serial::SerializeToFile, object.GetRef());
     }
     return false;
 }
-bool DeserializeCryptoContextFromFile(const std::string& ccLocation,
-    CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
+template <typename Object>
+[[nodiscard]] bool Deserial(const std::string& location, Object& object,
+    const SerialMode serialMode)
 {
     if (serialMode == SerialMode::BINARY)
     {
-        return lbcrypto::Serial::DeserializeFromFile(ccLocation,
-            cryptoContext.m_cryptoContextImplSharedPtr, lbcrypto::SerType::BINARY);
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, decltype(object.GetRef())>(location,
+            lbcrypto::Serial::DeserializeFromFile, object.GetRef());
     }
     if (serialMode == SerialMode::JSON)
     {
-        return lbcrypto::Serial::DeserializeFromFile(ccLocation,
-            cryptoContext.m_cryptoContextImplSharedPtr, lbcrypto::SerType::JSON);
+        return SerialDeserial<lbcrypto::SerType::SERJSON, decltype(object.GetRef())>(location,
+            lbcrypto::Serial::DeserializeFromFile, object.GetRef());
     }
     return false;
 }
-bool SerializeEvalMultKeyToFile(const std::string& multKeyLocation,
-    const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
-{
-    const auto close = [](std::ofstream* const ofs){ if (ofs->is_open()) { ofs->close(); } };
-    const std::unique_ptr<std::ofstream, decltype(close)> ofs(
-        new std::ofstream(multKeyLocation, std::ios::out | std::ios::binary), close);
 
-    if (ofs->is_open())
-    {
-        if (serialMode == SerialMode::BINARY)
-        {
-            return CryptoContextImpl::SerializeEvalMultKey(*ofs,
-                lbcrypto::SerType::BINARY, cryptoContext.m_cryptoContextImplSharedPtr);
-        }
-        if (serialMode == SerialMode::JSON)
-        {
-            return CryptoContextImpl::SerializeEvalMultKey(*ofs,
-                lbcrypto::SerType::JSON, cryptoContext.m_cryptoContextImplSharedPtr);
-        }
-    }
-    return false;
+template <typename ST, typename Stream, typename FStream, typename... Types>
+[[nodiscard]] bool SerialDeserial(const std::string& location,
+    bool (* const funcPtr) (Stream&, const ST&, Types... args), Types... args)
+{
+    const auto close = [](FStream* const fs){ if (fs->is_open()) { fs->close(); } };
+    const std::unique_ptr<FStream, decltype(close)> fs(
+        new FStream(location, std::ios::binary), close);
+    return fs->is_open() ? funcPtr(*fs, ST{}, args...) : false;
 }
-bool SerializeEvalMultKeyByIdToFile(const std::string& multKeyLocation,
-    const SerialMode serialMode, const std::string& id)
-{
-    const auto close = [](std::ofstream* const ofs){ if (ofs->is_open()) { ofs->close(); } };
-    const std::unique_ptr<std::ofstream, decltype(close)> ofs(
-        new std::ofstream(multKeyLocation, std::ios::out | std::ios::binary), close);
 
-    if (ofs->is_open())
-    {
-        if (serialMode == SerialMode::BINARY)
-        {
-            return CryptoContextImpl::SerializeEvalMultKey(*ofs, lbcrypto::SerType::BINARY, id);
-        }
-        if (serialMode == SerialMode::JSON)
-        {
-            return CryptoContextImpl::SerializeEvalMultKey(*ofs, lbcrypto::SerType::JSON, id);
-        }
-    }
-    return false;
+// Ciphertext
+bool DeserializeCiphertextFromFile(const std::string& ciphertextLocation,
+    CiphertextDCRTPoly& ciphertext, const SerialMode serialMode)
+{
+    return Deserial(ciphertextLocation, ciphertext, serialMode);
 }
-bool DeserializeEvalMultKeyFromFile(const std::string& multKeyLocation,
-    const SerialMode serialMode)
+bool SerializeCiphertextToFile(const std::string& ciphertextLocation,
+    const CiphertextDCRTPoly& ciphertext, const SerialMode serialMode)
 {
-    const auto close = [](std::ifstream* const ifs){ if (ifs->is_open()) { ifs->close(); } };
-    const std::unique_ptr<std::ifstream, decltype(close)> ifs(
-        new std::ifstream(multKeyLocation, std::ios::in | std::ios::binary), close);
+    return Serial(ciphertextLocation, ciphertext, serialMode);
+}
 
-    if (ifs->is_open())
-    {
-        if (serialMode == SerialMode::BINARY)
-        {
-            return CryptoContextImpl::DeserializeEvalMultKey(*ifs, lbcrypto::SerType::BINARY);
-        }
-        if (serialMode == SerialMode::JSON)
-        {
-            return CryptoContextImpl::DeserializeEvalMultKey(*ifs, lbcrypto::SerType::JSON);
-        }
-    }
-    return false;
+// CryptoContextDCRTPoly
+bool DeserializeCryptoContextFromFile(const std::string& ccLocation,
+    CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
+{
+    return Deserial(ccLocation, cryptoContext, serialMode);
 }
-bool SerializeEvalSumKeyToFile(const std::string& sumKeyLocation,
+bool SerializeCryptoContextToFile(const std::string& ccLocation,
     const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
 {
-    const auto close = [](std::ofstream* const ofs){ if (ofs->is_open()) { ofs->close(); } };
-    const std::unique_ptr<std::ofstream, decltype(close)> ofs(
-        new std::ofstream(sumKeyLocation, std::ios::out | std::ios::binary), close);
+    return Serial(ccLocation, cryptoContext, serialMode);
+}
 
-    if (ofs->is_open())
+// EvalAutomorphismKey
+bool DeserializeEvalAutomorphismKeyFromFile(const std::string& automorphismKeyLocation,
+    const SerialMode serialMode)
+{
+    if (serialMode == SerialMode::BINARY)
+    {
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, std::istream, std::ifstream>(
+            automorphismKeyLocation, CryptoContextImpl::DeserializeEvalAutomorphismKey);
+    }
+    if (serialMode == SerialMode::JSON)
     {
-        if (serialMode == SerialMode::BINARY)
-        {
-            return CryptoContextImpl::SerializeEvalAutomorphismKey(*ofs, lbcrypto::SerType::BINARY,
-                cryptoContext.m_cryptoContextImplSharedPtr);
-        }
-        if (serialMode == SerialMode::JSON)
-        {
-            return CryptoContextImpl::SerializeEvalAutomorphismKey(*ofs, lbcrypto::SerType::JSON,
-                cryptoContext.m_cryptoContextImplSharedPtr);
-        }
+        return SerialDeserial<lbcrypto::SerType::SERJSON, std::istream, std::ifstream>(
+            automorphismKeyLocation, CryptoContextImpl::DeserializeEvalAutomorphismKey);
     }
     return false;
 }
-bool SerializeEvalSumKeyByIdToFile(const std::string& sumKeyLocation,
+bool SerializeEvalAutomorphismKeyByIdToFile(const std::string& automorphismKeyLocation,
     const SerialMode serialMode, const std::string& id)
 {
-    const auto close = [](std::ofstream* const ofs){ if (ofs->is_open()) { ofs->close(); } };
-    const std::unique_ptr<std::ofstream, decltype(close)> ofs(
-        new std::ofstream(sumKeyLocation, std::ios::out | std::ios::binary), close);
-
-    if (ofs->is_open())
+    if (serialMode == SerialMode::BINARY)
     {
-        if (serialMode == SerialMode::BINARY)
-        {
-            return CryptoContextImpl::SerializeEvalSumKey(*ofs, lbcrypto::SerType::BINARY, id);
-        }
-        if (serialMode == SerialMode::JSON)
-        {
-            return CryptoContextImpl::SerializeEvalSumKey(*ofs, lbcrypto::SerType::JSON, id);
-        }
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
+            automorphismKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey, id);
     }
-    return false;
-}
-bool DeserializeEvalSumKeyFromFile(const std::string& sumKeyLocation, const SerialMode serialMode)
-{
-    const auto close = [](std::ifstream* const ifs){ if (ifs->is_open()) { ifs->close(); } };
-    const std::unique_ptr<std::ifstream, decltype(close)> ifs(
-        new std::ifstream(sumKeyLocation, std::ios::in | std::ios::binary), close);
-
-    if (ifs->is_open())
+    if (serialMode == SerialMode::JSON)
     {
-        if (serialMode == SerialMode::BINARY)
-        {
-            return CryptoContextImpl::DeserializeEvalAutomorphismKey(*ifs,
-                lbcrypto::SerType::BINARY);
-        }
-        if (serialMode == SerialMode::JSON)
-        {
-            return CryptoContextImpl::DeserializeEvalAutomorphismKey(*ifs,
-                lbcrypto::SerType::JSON);
-        }
+        return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
+            automorphismKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey, id);
     }
     return false;
 }
 bool SerializeEvalAutomorphismKeyToFile(const std::string& automorphismKeyLocation,
     const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
 {
-    const auto close = [](std::ofstream* const ofs){ if (ofs->is_open()) { ofs->close(); } };
-    const std::unique_ptr<std::ofstream, decltype(close)> ofs(
-        new std::ofstream(automorphismKeyLocation, std::ios::out | std::ios::binary), close);
-
-    if (ofs->is_open())
+    if (serialMode == SerialMode::BINARY)
+    {
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
+            automorphismKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey,
+            cryptoContext.GetRef());
+    }
+    if (serialMode == SerialMode::JSON)
     {
-        if (serialMode == SerialMode::BINARY)
-        {
-            return CryptoContextImpl::SerializeEvalAutomorphismKey(*ofs, lbcrypto::SerType::BINARY,
-                cryptoContext.m_cryptoContextImplSharedPtr);
-        }
-        if (serialMode == SerialMode::JSON)
-        {
-            return CryptoContextImpl::SerializeEvalAutomorphismKey(*ofs, lbcrypto::SerType::JSON,
-                cryptoContext.m_cryptoContextImplSharedPtr);
-        }
+        return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
+            automorphismKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey,
+            cryptoContext.GetRef());
     }
     return false;
 }
-bool SerializeEvalAutomorphismKeyByIdToFile(const std::string& automorphismKeyLocation,
-    const SerialMode serialMode, const std::string& id)
-{
-    const auto close = [](std::ofstream* const ofs){ if (ofs->is_open()) { ofs->close(); } };
-    const std::unique_ptr<std::ofstream, decltype(close)> ofs(
-        new std::ofstream(automorphismKeyLocation, std::ios::out | std::ios::binary), close);
 
-    if (ofs->is_open())
+// EvalMultKey
+bool DeserializeEvalMultKeyFromFile(const std::string& multKeyLocation,
+    const SerialMode serialMode)
+{
+    if (serialMode == SerialMode::BINARY)
+    {
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, std::istream, std::ifstream>(
+            multKeyLocation, CryptoContextImpl::DeserializeEvalMultKey);
+    }
+    if (serialMode == SerialMode::JSON)
     {
-        if (serialMode == SerialMode::BINARY)
-        {
-            return CryptoContextImpl::SerializeEvalAutomorphismKey(*ofs, lbcrypto::SerType::BINARY,
-                id);
-        }
-        if (serialMode == SerialMode::JSON)
-        {
-            return CryptoContextImpl::SerializeEvalAutomorphismKey(*ofs, lbcrypto::SerType::JSON,
-                id);
-        }
+        return SerialDeserial<lbcrypto::SerType::SERJSON, std::istream, std::ifstream>(
+            multKeyLocation, CryptoContextImpl::DeserializeEvalMultKey);
     }
     return false;
 }
-bool DeserializeEvalAutomorphismKeyFromFile(const std::string& automorphismKeyLocation,
-    const SerialMode serialMode)
+bool SerializeEvalMultKeyByIdToFile(const std::string& multKeyLocation,
+    const SerialMode serialMode, const std::string& id)
 {
-    const auto close = [](std::ifstream* const ifs){ if (ifs->is_open()) { ifs->close(); } };
-    const std::unique_ptr<std::ifstream, decltype(close)> ifs(
-        new std::ifstream(automorphismKeyLocation, std::ios::in | std::ios::binary), close);
-
-    if (ifs->is_open())
+    if (serialMode == SerialMode::BINARY)
+    {
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
+            multKeyLocation, CryptoContextImpl::SerializeEvalMultKey, id);
+    }
+    if (serialMode == SerialMode::JSON)
     {
-        if (serialMode == SerialMode::BINARY)
-        {
-            return CryptoContextImpl::DeserializeEvalAutomorphismKey(*ifs,
-                lbcrypto::SerType::BINARY);
-        }
-        if (serialMode == SerialMode::JSON)
-        {
-            return CryptoContextImpl::DeserializeEvalAutomorphismKey(*ifs,
-                lbcrypto::SerType::JSON);
-        }
+        return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
+            multKeyLocation, CryptoContextImpl::SerializeEvalMultKey, id);
     }
     return false;
 }
-bool SerializePublicKeyToFile(const std::string& publicKeyLocation,
-    const PublicKeyDCRTPoly& publicKey, const SerialMode serialMode)
+bool SerializeEvalMultKeyToFile(const std::string& multKeyLocation,
+    const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
 {
     if (serialMode == SerialMode::BINARY)
     {
-        return lbcrypto::Serial::SerializeToFile(publicKeyLocation,
-            publicKey.m_publicKey, lbcrypto::SerType::BINARY);
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
+            multKeyLocation, CryptoContextImpl::SerializeEvalMultKey, cryptoContext.GetRef());
     }
     if (serialMode == SerialMode::JSON)
     {
-        return lbcrypto::Serial::SerializeToFile(publicKeyLocation,
-            publicKey.m_publicKey, lbcrypto::SerType::JSON);
+        return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
+            multKeyLocation, CryptoContextImpl::SerializeEvalMultKey, cryptoContext.GetRef());
     }
     return false;
 }
-bool DeserializePublicKeyFromFile(const std::string& publicKeyLocation,
-    PublicKeyDCRTPoly& publicKey, const SerialMode serialMode)
+
+// EvalSumKey
+bool DeserializeEvalSumKeyFromFile(const std::string& sumKeyLocation, const SerialMode serialMode)
 {
     if (serialMode == SerialMode::BINARY)
     {
-        return lbcrypto::Serial::DeserializeFromFile(publicKeyLocation,
-            publicKey.m_publicKey, lbcrypto::SerType::BINARY);
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, std::istream, std::ifstream>(
+            sumKeyLocation, CryptoContextImpl::DeserializeEvalAutomorphismKey);
     }
     if (serialMode == SerialMode::JSON)
     {
-        return lbcrypto::Serial::DeserializeFromFile(publicKeyLocation,
-            publicKey.m_publicKey, lbcrypto::SerType::JSON);
+        return SerialDeserial<lbcrypto::SerType::SERJSON, std::istream, std::ifstream>(
+            sumKeyLocation, CryptoContextImpl::DeserializeEvalAutomorphismKey);
     }
     return false;
 }
-bool SerializeCiphertextToFile(const std::string& ciphertextLocation,
-    const CiphertextDCRTPoly& ciphertext, const SerialMode serialMode)
+bool SerializeEvalSumKeyByIdToFile(const std::string& sumKeyLocation,
+    const SerialMode serialMode, const std::string& id)
 {
     if (serialMode == SerialMode::BINARY)
     {
-        return lbcrypto::Serial::SerializeToFile(ciphertextLocation,
-            ciphertext.m_ciphertext, lbcrypto::SerType::BINARY);
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
+            sumKeyLocation, CryptoContextImpl::SerializeEvalSumKey, id);
     }
     if (serialMode == SerialMode::JSON)
     {
-        return lbcrypto::Serial::SerializeToFile(ciphertextLocation,
-            ciphertext.m_ciphertext, lbcrypto::SerType::JSON);
+        return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
+            sumKeyLocation, CryptoContextImpl::SerializeEvalSumKey, id);
     }
     return false;
 }
-bool DeserializeCiphertextFromFile(const std::string& ciphertextLocation,
-    CiphertextDCRTPoly& ciphertext, const SerialMode serialMode)
+bool SerializeEvalSumKeyToFile(const std::string& sumKeyLocation,
+    const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
 {
     if (serialMode == SerialMode::BINARY)
     {
-        return lbcrypto::Serial::DeserializeFromFile(ciphertextLocation,
-            ciphertext.m_ciphertext, lbcrypto::SerType::BINARY);
+        return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
+            sumKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey,
+            cryptoContext.GetRef());
     }
     if (serialMode == SerialMode::JSON)
     {
-        return lbcrypto::Serial::DeserializeFromFile(ciphertextLocation,
-            ciphertext.m_ciphertext, lbcrypto::SerType::JSON);
+        return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
+            sumKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey,
+            cryptoContext.GetRef());
     }
     return false;
 }
 
+// PublicKey
+bool DeserializePublicKeyFromFile(const std::string& publicKeyLocation,
+    PublicKeyDCRTPoly& publicKey, const SerialMode serialMode)
+{
+    return Deserial(publicKeyLocation, publicKey, serialMode);
+}
+bool SerializePublicKeyToFile(const std::string& publicKeyLocation,
+    const PublicKeyDCRTPoly& publicKey, const SerialMode serialMode)
+{
+    return Serial(publicKeyLocation, publicKey, serialMode);
+}
+
 } // openfhe

+ 35 - 22
src/SerialDeserial.h

@@ -11,34 +11,47 @@ class CiphertextDCRTPoly;
 class CryptoContextDCRTPoly;
 class PublicKeyDCRTPoly;
 
-bool SerializeCryptoContextToFile(const std::string& ccLocation,
-    const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-bool DeserializeCryptoContextFromFile(const std::string& ccLocation,
+// Ciphertext
+[[nodiscard]] bool DeserializeCiphertextFromFile(const std::string& ciphertextLocation,
+    CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
+[[nodiscard]] bool SerializeCiphertextToFile(const std::string& ciphertextLocation,
+    const CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
+
+// CryptoContextDCRTPoly
+[[nodiscard]] bool DeserializeCryptoContextFromFile(const std::string& ccLocation,
     CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-bool SerializeEvalMultKeyToFile(const std::string& multKeyLocation,
+[[nodiscard]] bool SerializeCryptoContextToFile(const std::string& ccLocation,
     const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-bool SerializeEvalMultKeyByIdToFile(const std::string& multKeyLocation,
-    const SerialMode serialMode, const std::string& id);
-bool DeserializeEvalMultKeyFromFile(const std::string& multKeyLocation,
+
+// EvalAutomorphismKey
+[[nodiscard]] bool DeserializeEvalMultKeyFromFile(const std::string& multKeyLocation,
     const SerialMode serialMode);
-bool SerializeEvalSumKeyToFile(const std::string& sumKeyLocation,
-    const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-bool SerializeEvalSumKeyByIdToFile(const std::string& sumKeyLocation,
+[[nodiscard]] bool SerializeEvalMultKeyByIdToFile(const std::string& multKeyLocation,
     const SerialMode serialMode, const std::string& id);
-bool DeserializeEvalSumKeyFromFile(const std::string& sumKeyLocation, const SerialMode serialMode);
-bool SerializeEvalAutomorphismKeyToFile(const std::string& automorphismKeyLocation,
+[[nodiscard]] bool SerializeEvalMultKeyToFile(const std::string& multKeyLocation,
     const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
-bool SerializeEvalAutomorphismKeyByIdToFile(const std::string& automorphismKeyLocation,
-    const SerialMode serialMode, const std::string& id);
-bool DeserializeEvalAutomorphismKeyFromFile(const std::string& automorphismKeyLocation,
+
+// EvalMultKey
+[[nodiscard]] bool DeserializeEvalAutomorphismKeyFromFile(
+    const std::string& automorphismKeyLocation, const SerialMode serialMode);
+[[nodiscard]] bool SerializeEvalAutomorphismKeyByIdToFile(
+    const std::string& automorphismKeyLocation, const SerialMode serialMode,
+    const std::string& id);
+[[nodiscard]] bool SerializeEvalAutomorphismKeyToFile(const std::string& automorphismKeyLocation,
+    const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
+
+// EvalSumKey
+[[nodiscard]] bool DeserializeEvalSumKeyFromFile(const std::string& sumKeyLocation,
     const SerialMode serialMode);
-bool SerializePublicKeyToFile(const std::string& publicKeyLocation,
-    const PublicKeyDCRTPoly& publicKey, const SerialMode serialMode);
-bool DeserializePublicKeyFromFile(const std::string& publicKeyLocation,
+[[nodiscard]] bool SerializeEvalSumKeyByIdToFile(const std::string& sumKeyLocation,
+    const SerialMode serialMode, const std::string& id);
+[[nodiscard]] bool SerializeEvalSumKeyToFile(const std::string& sumKeyLocation,
+    const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode);
+
+// PublicKey
+[[nodiscard]] bool DeserializePublicKeyFromFile(const std::string& publicKeyLocation,
     PublicKeyDCRTPoly& publicKey, const SerialMode serialMode);
-bool SerializeCiphertextToFile(const std::string& ciphertextLocation,
-    const CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
-bool DeserializeCiphertextFromFile(const std::string& ciphertextLocation,
-    CiphertextDCRTPoly& ciphertext, const SerialMode serialMode);
+[[nodiscard]] bool SerializePublicKeyToFile(const std::string& publicKeyLocation,
+    const PublicKeyDCRTPoly& publicKey, const SerialMode serialMode);
 
 } // openfhe