Browse Source

Add Serialize/Deserialize functions

Thadah D. Denyse 1 year ago
parent
commit
b900fd08c9
1 changed files with 79 additions and 2 deletions
  1. 79 2
      src/lib/pke/serialization.cpp

+ 79 - 2
src/lib/pke/serialization.cpp

@@ -88,6 +88,39 @@ std::tuple<CryptoContext<DCRTPoly>, bool> DeserializeCCWrapper(const std::string
     return std::make_tuple(newob, result);
 }
 
+template <typename T, typename ST>
+std::string SerializeToStringWrapper(const T& obj, const ST& sertype) {
+    std::ostringstream oss;
+    Serial::Serialize<T>(obj, oss, sertype);
+    return oss.str();
+}
+
+template <typename T, typename ST>
+py::bytes SerializeToBytesWrapper(const T& obj, const ST& sertype) {
+    std::ostringstream oss(std::ios::binary);
+    Serial::Serialize<T>(obj, oss, sertype);
+    std::string str = oss.str();
+    return py::bytes(str);
+}
+
+template <typename T, typename ST>
+T DeserializeFromStringWrapper(const std::string& str, const ST& sertype) {
+    T obj;
+    std::istringstream iss(str);
+    Serial::Deserialize<T>(obj, iss, sertype);
+    return obj;
+}
+
+template <typename T, typename ST>
+T DeserializeFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
+    T obj;
+    std::string str(bytes);
+    std::istringstream iss(str, std::ios::binary);
+    Serial::Deserialize<T>(obj, iss, sertype);
+    return obj;
+}
+
+
 void bind_serialization(pybind11::module &m) {
     // Json Serialization
     m.def("SerializeToFile", static_cast<bool (*)(const std::string &, const CryptoContext<DCRTPoly> &, const SerType::SERJSON &)>(&Serial::SerializeToFile<DCRTPoly>),
@@ -110,6 +143,29 @@ void bind_serialization(pybind11::module &m) {
           py::arg("filename"), py::arg("obj"), py::arg("sertype"));
     m.def("DeserializeEvalKey", static_cast<std::tuple<EvalKey<DCRTPoly>,bool> (*)(const std::string&, const SerType::SERJSON&)>(&DeserializeFromFileWrapper<EvalKey<DCRTPoly>, SerType::SERJSON>),
           py::arg("filename"), py::arg("sertype"));
+
+    // JSON Serialization to string
+    m.def("Serialize", &SerializeToStringWrapper<CryptoContext<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromStringWrapper<CryptoContext<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToStringWrapper<PublicKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromStringWrapper<PublicKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToStringWrapper<PrivateKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromStringWrapper<PrivateKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToStringWrapper<Ciphertext<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromStringWrapper<Ciphertext<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToStringWrapper<EvalKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromStringWrapper<EvalKey<DCRTPoly>, SerType::SERJSON>,
+          py::arg("str"), py::arg("sertype"));
+
     // Binary Serialization
     m.def("SerializeToFile", static_cast<bool (*)(const std::string&,const CryptoContext<DCRTPoly>&, const SerType::SERBINARY&)>(&Serial::SerializeToFile<DCRTPoly>),
           py::arg("filename"), py::arg("obj"), py::arg("sertype"));
@@ -130,7 +186,28 @@ void bind_serialization(pybind11::module &m) {
     m.def("SerializeToFile", static_cast<bool (*)(const std::string&, const EvalKey<DCRTPoly>&, const SerType::SERBINARY&)>(&Serial::SerializeToFile<EvalKey<DCRTPoly>>),
           py::arg("filename"), py::arg("obj"), py::arg("sertype"));
     m.def("DeserializeEvalKey", static_cast<std::tuple<EvalKey<DCRTPoly>,bool> (*)(const std::string&, const SerType::SERBINARY&)>(&DeserializeFromFileWrapper<EvalKey<DCRTPoly>, SerType::SERBINARY>),
-            py::arg("filename"), py::arg("sertype"));  
-    
+          py::arg("filename"), py::arg("sertype"));
+
+    // Binary Serialization to bytes
+    m.def("Serialize", &SerializeToBytesWrapper<CryptoContext<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromBytesWrapper<CryptoContext<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToBytesWrapper<PublicKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromBytesWrapper<PublicKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToBytesWrapper<PrivateKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromBytesWrapper<PrivateKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToBytesWrapper<Ciphertext<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromBytesWrapper<Ciphertext<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToBytesWrapper<EvalKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("Deserialize", &DeserializeFromBytesWrapper<EvalKey<DCRTPoly>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
 }