瀏覽代碼

Merge pull request #155 from openfheorg/td-issue-145

#145 Add Serialize/Deserialize functions
Thadah D. Denyse 1 年之前
父節點
當前提交
b7ee5b0255
共有 3 個文件被更改,包括 147 次插入2 次删除
  1. 13 0
      src/include/pke/serialization.h
  2. 79 2
      src/lib/pke/serialization.cpp
  3. 55 0
      tests/test_serial_cc.py

+ 13 - 0
src/include/pke/serialization.h

@@ -33,6 +33,7 @@
 #include "bindings.h"
 
 using namespace lbcrypto;
+namespace py = pybind11;
 
 template <typename ST>
 bool SerializeEvalMultKeyWrapper(const std::string& filename, const ST& sertype, std::string id);
@@ -43,4 +44,16 @@ bool SerializeEvalAutomorphismKeyWrapper(const std::string& filename, const ST&
 template <typename ST>
 bool DeserializeEvalMultKeyWrapper(const std::string& filename, const ST& sertype);
 
+template <typename T, typename ST>
+std::string SerializeToStringWrapper(const T& obj, const ST& sertype);
+
+template <typename T, typename ST>
+py::bytes SerializeToBytesWrapper(const T& obj, const ST& sertype);
+
+template <typename T, typename ST>
+T DeserializeFromStringWrapper(const std::string& str, const ST& sertype);
+
+template <typename T, typename ST>
+T DeserializeFromBytesWrapper(const py::bytes& bytes, const ST& sertype);
+
 #endif // OPENFHE_SERIALIZATION_BINDINGS_H

+ 79 - 2
src/lib/pke/serialization.cpp

@@ -88,6 +88,39 @@ std::tuple<CryptoContext<DCRTPoly>, bool> DeserializeCCWrapper(const std::string
     return std::make_tuple(newob, result);
 }
 
+template <typename T, typename ST>
+std::string SerializeToStringWrapper(const T& obj, const ST& sertype) {
+    std::ostringstream oss;
+    Serial::Serialize<T>(obj, oss, sertype);
+    return oss.str();
+}
+
+template <typename T, typename ST>
+py::bytes SerializeToBytesWrapper(const T& obj, const ST& sertype) {
+    std::ostringstream oss(std::ios::binary);
+    Serial::Serialize<T>(obj, oss, sertype);
+    std::string str = oss.str();
+    return py::bytes(str);
+}
+
+template <typename T, typename ST>
+T DeserializeFromStringWrapper(const std::string& str, const ST& sertype) {
+    T obj;
+    std::istringstream iss(str);
+    Serial::Deserialize<T>(obj, iss, sertype);
+    return obj;
+}
+
+template <typename T, typename ST>
+T DeserializeFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
+    T obj;
+    std::string str(bytes);
+    std::istringstream iss(str, std::ios::binary);
+    Serial::Deserialize<T>(obj, iss, sertype);
+    return obj;
+}
+
+
 void bind_serialization(pybind11::module &m) {
     // Json Serialization
     m.def("SerializeToFile", static_cast<bool (*)(const std::string &, const CryptoContext<DCRTPoly> &, const SerType::SERJSON &)>(&Serial::SerializeToFile<DCRTPoly>),
@@ -110,6 +143,29 @@ void bind_serialization(pybind11::module &m) {
           py::arg("filename"), py::arg("obj"), py::arg("sertype"));
     m.def("DeserializeEvalKey", static_cast<std::tuple<EvalKey<DCRTPoly>,bool> (*)(const std::string&, const SerType::SERJSON&)>(&DeserializeFromFileWrapper<EvalKey<DCRTPoly>, SerType::SERJSON>),
           py::arg("filename"), py::arg("sertype"));
+
+    // JSON Serialization to string
+    m.def("Serialize", &SerializeToStringWrapper<CryptoContext<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializeCryptoContextString", &DeserializeFromStringWrapper<CryptoContext<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToStringWrapper<PublicKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializePublicKeyString", &DeserializeFromStringWrapper<PublicKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToStringWrapper<PrivateKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializePrivateKeyString", &DeserializeFromStringWrapper<PrivateKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToStringWrapper<Ciphertext<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializeCiphertextString", &DeserializeFromStringWrapper<Ciphertext<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToStringWrapper<EvalKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializeEvalKeyString", &DeserializeFromStringWrapper<EvalKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+
     // Binary Serialization
     m.def("SerializeToFile", static_cast<bool (*)(const std::string&,const CryptoContext<DCRTPoly>&, const SerType::SERBINARY&)>(&Serial::SerializeToFile<DCRTPoly>),
           py::arg("filename"), py::arg("obj"), py::arg("sertype"));
@@ -130,7 +186,28 @@ void bind_serialization(pybind11::module &m) {
     m.def("SerializeToFile", static_cast<bool (*)(const std::string&, const EvalKey<DCRTPoly>&, const SerType::SERBINARY&)>(&Serial::SerializeToFile<EvalKey<DCRTPoly>>),
           py::arg("filename"), py::arg("obj"), py::arg("sertype"));
     m.def("DeserializeEvalKey", static_cast<std::tuple<EvalKey<DCRTPoly>,bool> (*)(const std::string&, const SerType::SERBINARY&)>(&DeserializeFromFileWrapper<EvalKey<DCRTPoly>, SerType::SERBINARY>),
-            py::arg("filename"), py::arg("sertype"));  
-    
+          py::arg("filename"), py::arg("sertype"));
+
+    // Binary Serialization to bytes
+    m.def("Serialize", &SerializeToBytesWrapper<CryptoContext<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializeCryptoContextString", &DeserializeFromBytesWrapper<CryptoContext<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToBytesWrapper<PublicKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializePublicKeyString", &DeserializeFromBytesWrapper<PublicKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToBytesWrapper<PrivateKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializePrivateKeyString", &DeserializeFromBytesWrapper<PrivateKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToBytesWrapper<Ciphertext<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializeCiphertextString", &DeserializeFromBytesWrapper<Ciphertext<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToBytesWrapper<EvalKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializeEvalKeyString", &DeserializeFromBytesWrapper<EvalKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
 }
 

+ 55 - 0
tests/test_serial_cc.py

@@ -1,4 +1,5 @@
 import logging
+import pytest
 
 import openfhe as fhe
 
@@ -37,3 +38,57 @@ def test_serial_cryptocontext(tmp_path):
     assert isinstance(ct1, fhe.Ciphertext)
     LOGGER.debug("Cryptocontext deserializes to %s %s", success, ct1)
     assert fhe.SerializeToFile(str(tmp_path / "ciphertext12.json"), ct1, fhe.JSON)
+
+
+@pytest.mark.parametrize("mode", [fhe.JSON, fhe.BINARY])
+def test_serial_cryptocontext_str(mode):
+    parameters = fhe.CCParamsBFVRNS()
+    parameters.SetPlaintextModulus(65537)
+    parameters.SetMultiplicativeDepth(2)
+
+    cryptoContext = fhe.GenCryptoContext(parameters)
+    cryptoContext.Enable(fhe.PKESchemeFeature.PKE)
+    cryptoContext.Enable(fhe.PKESchemeFeature.PRE)
+
+    keypair = cryptoContext.KeyGen()
+    vectorOfInts = list(range(12))
+    plaintext = cryptoContext.MakePackedPlaintext(vectorOfInts)
+    ciphertext = cryptoContext.Encrypt(keypair.publicKey, plaintext)
+    evalKey = cryptoContext.ReKeyGen(keypair.secretKey, keypair.publicKey)
+
+
+    cryptoContext_ser = fhe.Serialize(cryptoContext, mode)
+    LOGGER.debug("The cryptocontext has been serialized.")
+    publickey_ser = fhe.Serialize(keypair.publicKey, mode)
+    LOGGER.debug("The public key has been serialized.")
+    secretkey_ser = fhe.Serialize(keypair.secretKey, mode)
+    LOGGER.debug("The private key has been serialized.")
+    ciphertext_ser = fhe.Serialize(ciphertext, mode)
+    LOGGER.debug("The ciphertext has been serialized.")
+    evalKey_ser = fhe.Serialize(evalKey, mode)
+    LOGGER.debug("The evaluation key has been serialized.")
+
+
+    cryptoContext.ClearEvalMultKeys()
+    cryptoContext.ClearEvalAutomorphismKeys()
+    fhe.ReleaseAllContexts()
+
+    cc = fhe.DeserializeCryptoContextString(cryptoContext_ser, mode)
+    assert isinstance(cc, fhe.CryptoContext)
+    LOGGER.debug("The cryptocontext has been deserialized.")
+
+    pk = fhe.DeserializePublicKeyString(publickey_ser, mode)
+    assert isinstance(pk, fhe.PublicKey)
+    LOGGER.debug("The public key has been deserialized.")
+
+    sk = fhe.DeserializePrivateKeyString(secretkey_ser, mode)
+    assert isinstance(sk, fhe.PrivateKey)
+    LOGGER.debug("The private key has been deserialized.")
+
+    ct = fhe.DeserializeCiphertextString(ciphertext_ser, mode)
+    assert isinstance(ct, fhe.Ciphertext)
+    LOGGER.debug("The ciphertext has been reserialized.")
+
+    ek = fhe.DeserializeEvalKeyString(evalKey_ser, mode)
+    assert isinstance(ek, fhe.EvalKey)
+    LOGGER.debug("The evaluation key has been deserialized.")