|
|
@@ -35,13 +35,9 @@
|
|
|
|
|
|
#include "openfhe.h"
|
|
|
// header files needed for serialization
|
|
|
-#include "metadata-ser.h"
|
|
|
#include "ciphertext-ser.h"
|
|
|
#include "cryptocontext-ser.h"
|
|
|
#include "key/key-ser.h"
|
|
|
-#include "scheme/bfvrns/bfvrns-ser.h"
|
|
|
-#include "scheme/bgvrns/bgvrns-ser.h"
|
|
|
-#include "scheme/ckksrns/ckksrns-ser.h"
|
|
|
|
|
|
using namespace lbcrypto;
|
|
|
namespace py = pybind11;
|
|
|
@@ -51,8 +47,7 @@ PYBIND11_MAKE_OPAQUE(std::map<uint32_t, EvalKey<DCRTPoly>>);
|
|
|
|
|
|
|
|
|
template <typename ST>
|
|
|
-bool SerializeEvalMultKeyWrapper(const std::string &filename, const ST &sertype, std::string id)
|
|
|
-{
|
|
|
+bool SerializeEvalMultKeyWrapper(const std::string& filename, const ST& sertype, std::string id) {
|
|
|
std::ofstream outfile(filename, std::ios::out | std::ios::binary);
|
|
|
bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<ST>(outfile, sertype, id);
|
|
|
outfile.close();
|
|
|
@@ -60,8 +55,7 @@ bool SerializeEvalMultKeyWrapper(const std::string &filename, const ST &sertype,
|
|
|
}
|
|
|
|
|
|
template <typename ST>
|
|
|
-bool SerializeEvalAutomorphismKeyWrapper(const std::string& filename, const ST& sertype, std::string id)
|
|
|
-{
|
|
|
+bool SerializeEvalAutomorphismKeyWrapper(const std::string& filename, const ST& sertype, std::string id) {
|
|
|
std::ofstream outfile(filename, std::ios::out | std::ios::binary);
|
|
|
bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<ST>(outfile, sertype, id);
|
|
|
outfile.close();
|
|
|
@@ -69,15 +63,14 @@ bool SerializeEvalAutomorphismKeyWrapper(const std::string& filename, const ST&
|
|
|
}
|
|
|
|
|
|
template <typename ST>
|
|
|
-bool DeserializeEvalMultKeyWrapper(const std::string &filename, const ST &sertype)
|
|
|
-{
|
|
|
+bool DeserializeEvalMultKeyWrapper(const std::string& filename, const ST& sertype) {
|
|
|
std::ifstream emkeys(filename, std::ios::in | std::ios::binary);
|
|
|
- if (!emkeys.is_open())
|
|
|
- {
|
|
|
+ if (!emkeys.is_open()) {
|
|
|
std::cerr << "I cannot read serialization from " << filename << std::endl;
|
|
|
}
|
|
|
bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<ST>(emkeys, sertype);
|
|
|
- return res; }
|
|
|
+ return res;
|
|
|
+}
|
|
|
|
|
|
template <typename T, typename ST>
|
|
|
std::tuple<T, bool> DeserializeFromFileWrapper(const std::string& filename, const ST& sertype) {
|
|
|
@@ -101,10 +94,14 @@ std::string SerializeToStringWrapper(const T& obj, const ST& sertype) {
|
|
|
|
|
|
template <typename T, typename ST>
|
|
|
py::bytes SerializeToBytesWrapper(const T& obj, const ST& sertype) {
|
|
|
- std::ostringstream oss(std::ios::binary);
|
|
|
+ // let strbuf be dynamically allocated as we may be dealing with large keys
|
|
|
+ auto strbuf = std::make_unique<std::stringbuf>(std::ios::out | std::ios::binary);
|
|
|
+ std::ostream oss(strbuf.get());
|
|
|
+
|
|
|
Serial::Serialize<T>(obj, oss, sertype);
|
|
|
- std::string str = oss.str();
|
|
|
- return py::bytes(str);
|
|
|
+
|
|
|
+ const std::string& str = strbuf->str();
|
|
|
+ return py::bytes(str.data(), str.size());
|
|
|
}
|
|
|
|
|
|
template <typename T, typename ST>
|
|
|
@@ -125,18 +122,20 @@ CryptoContext<DCRTPoly> DeserializeCCFromStringWrapper(const std::string& str, c
|
|
|
|
|
|
template <typename T, typename ST>
|
|
|
T DeserializeFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
|
|
|
- T obj;
|
|
|
- std::string str(bytes);
|
|
|
+ std::string str{static_cast<std::string>(bytes)};
|
|
|
std::istringstream iss(str, std::ios::binary);
|
|
|
+
|
|
|
+ T obj;
|
|
|
Serial::Deserialize<T>(obj, iss, sertype);
|
|
|
return obj;
|
|
|
}
|
|
|
|
|
|
template <typename ST>
|
|
|
CryptoContext<DCRTPoly> DeserializeCCFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
|
|
|
- CryptoContext<DCRTPoly> obj;
|
|
|
- std::string str(bytes);
|
|
|
+ std::string str{static_cast<std::string>(bytes)};
|
|
|
std::istringstream iss(str, std::ios::binary);
|
|
|
+
|
|
|
+ CryptoContext<DCRTPoly> obj;
|
|
|
Serial::Deserialize<DCRTPoly>(obj, iss, sertype);
|
|
|
return obj;
|
|
|
}
|
|
|
@@ -153,15 +152,17 @@ std::string SerializeEvalMultKeyToStringWrapper(const ST& sertype, const std::st
|
|
|
|
|
|
template <typename ST>
|
|
|
py::bytes SerializeEvalMultKeyToBytesWrapper(const ST& sertype, const std::string& id) {
|
|
|
- std::ostringstream oss(std::ios::binary);
|
|
|
- bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey(oss, sertype, id);
|
|
|
- if (!res) {
|
|
|
+ // let strbuf be dynamically allocated as we may be dealing with large keys
|
|
|
+ auto strbuf = std::make_unique<std::stringbuf>(std::ios::out | std::ios::binary);
|
|
|
+ std::ostream oss(strbuf.get());
|
|
|
+
|
|
|
+ if (!CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey(oss, sertype, id)) {
|
|
|
throw std::runtime_error("Failed to serialize EvalMultKey");
|
|
|
}
|
|
|
- std::string str = oss.str();
|
|
|
- return py::bytes(str);
|
|
|
-}
|
|
|
|
|
|
+ const std::string& str = strbuf->str();
|
|
|
+ return py::bytes(str.data(), str.size());
|
|
|
+}
|
|
|
|
|
|
template <typename ST>
|
|
|
std::string SerializeEvalAutomorphismKeyToStringWrapper(const ST& sertype, const std::string& id) {
|
|
|
@@ -173,15 +174,18 @@ std::string SerializeEvalAutomorphismKeyToStringWrapper(const ST& sertype, const
|
|
|
return oss.str();
|
|
|
}
|
|
|
|
|
|
-
|
|
|
template <typename ST>
|
|
|
py::bytes SerializeEvalAutomorphismKeyToBytesWrapper(const ST& sertype, const std::string& id) {
|
|
|
- std::ostringstream oss(std::ios::binary);
|
|
|
- bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey(oss, sertype, id);
|
|
|
- if (!res) {
|
|
|
+ // let strbuf be dynamically allocated as we may be dealing with large keys
|
|
|
+ auto strbuf = std::make_unique<std::stringbuf>(std::ios::out | std::ios::binary);
|
|
|
+ std::ostream oss(strbuf.get());
|
|
|
+
|
|
|
+ if (!CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey(oss, sertype, id)) {
|
|
|
throw std::runtime_error("Failed to serialize EvalAutomorphismKey");
|
|
|
}
|
|
|
- return oss.str();
|
|
|
+
|
|
|
+ const std::string& str = strbuf->str();
|
|
|
+ return py::bytes(str.data(), str.size());
|
|
|
}
|
|
|
|
|
|
template <typename ST>
|
|
|
@@ -194,11 +198,11 @@ void DeserializeEvalMultKeyFromStringWrapper(const std::string& data, const ST&
|
|
|
}
|
|
|
|
|
|
template <typename ST>
|
|
|
-void DeserializeEvalMultKeyFromBytesWrapper(const std::string& data, const ST& sertype) {
|
|
|
- std::string str(data);
|
|
|
+void DeserializeEvalMultKeyFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
|
|
|
+ std::string str{static_cast<std::string>(bytes)};
|
|
|
std::istringstream iss(str, std::ios::binary);
|
|
|
- bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<ST>(iss, sertype);
|
|
|
- if (!res) {
|
|
|
+
|
|
|
+ if (!CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<ST>(iss, sertype)) {
|
|
|
throw std::runtime_error("Failed to deserialize EvalMultKey");
|
|
|
}
|
|
|
}
|
|
|
@@ -214,11 +218,11 @@ void DeserializeEvalAutomorphismKeyFromStringWrapper(const std::string& data, co
|
|
|
}
|
|
|
|
|
|
template <typename ST>
|
|
|
-void DeserializeEvalAutomorphismKeyFromBytesWrapper(const std::string& data, const ST& sertype) {
|
|
|
- std::string str(data);
|
|
|
+void DeserializeEvalAutomorphismKeyFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
|
|
|
+ std::string str{static_cast<std::string>(bytes)};
|
|
|
std::istringstream iss(str, std::ios::binary);
|
|
|
- bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<ST>(iss, sertype);
|
|
|
- if (!res) {
|
|
|
+
|
|
|
+ if (!CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<ST>(iss, sertype)) {
|
|
|
throw std::runtime_error("Failed to deserialize EvalAutomorphismKey");
|
|
|
}
|
|
|
}
|
|
|
@@ -272,17 +276,19 @@ void bind_serialization(pybind11::module &m) {
|
|
|
m.def("DeserializeEvalKeyString", &DeserializeFromStringWrapper<EvalKey<DCRTPoly>, SerType::SERJSON>,
|
|
|
py::arg("str"), py::arg("sertype"));
|
|
|
m.def("Serialize", &SerializeToBytesWrapper<std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>>, SerType::SERJSON>,
|
|
|
- py::arg("obj"), py::arg("sertype"));
|
|
|
+ py::arg("obj"), py::arg("sertype"));
|
|
|
m.def("DeserializeEvalKeyMapString", &DeserializeFromBytesWrapper<std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>>, SerType::SERJSON>,
|
|
|
- py::arg("str"), py::arg("sertype"));
|
|
|
+ py::arg("str"), py::arg("sertype"));
|
|
|
|
|
|
m.def("SerializeEvalMultKeyString", &SerializeEvalMultKeyToStringWrapper<SerType::SERJSON>,
|
|
|
py::arg("sertype"), py::arg("id") = "");
|
|
|
- m.def("DeserializeEvalMultKeyString", &DeserializeEvalMultKeyFromStringWrapper<SerType::SERJSON>,
|
|
|
+ m.def("DeserializeEvalMultKeyString",
|
|
|
+ static_cast<void (*)(const std::string&, const SerType::SERJSON&)>(&DeserializeEvalMultKeyFromStringWrapper<SerType::SERJSON>),
|
|
|
py::arg("data"), py::arg("sertype"));
|
|
|
m.def("SerializeEvalAutomorphismKeyString", &SerializeEvalAutomorphismKeyToStringWrapper<SerType::SERJSON>,
|
|
|
py::arg("sertype"), py::arg("id") = "");
|
|
|
- m.def("DeserializeEvalAutomorphismKeyString", &DeserializeEvalAutomorphismKeyFromStringWrapper<SerType::SERJSON>,
|
|
|
+ m.def("DeserializeEvalAutomorphismKeyString",
|
|
|
+ static_cast<void (*)(const std::string&, const SerType::SERJSON&)>(&DeserializeEvalAutomorphismKeyFromStringWrapper<SerType::SERJSON>),
|
|
|
py::arg("data"), py::arg("sertype"));
|
|
|
|
|
|
// Binary Serialization
|
|
|
@@ -333,16 +339,18 @@ void bind_serialization(pybind11::module &m) {
|
|
|
m.def("DeserializeEvalKeyString", &DeserializeFromBytesWrapper<EvalKey<DCRTPoly>, SerType::SERBINARY>,
|
|
|
py::arg("str"), py::arg("sertype"));
|
|
|
m.def("Serialize", &SerializeToBytesWrapper<std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>>, SerType::SERBINARY>,
|
|
|
- py::arg("obj"), py::arg("sertype"));
|
|
|
+ py::arg("obj"), py::arg("sertype"));
|
|
|
m.def("DeserializeEvalKeyMapString", &DeserializeFromBytesWrapper<std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>>, SerType::SERBINARY>,
|
|
|
- py::arg("str"), py::arg("sertype"));
|
|
|
+ py::arg("str"), py::arg("sertype"));
|
|
|
|
|
|
m.def("SerializeEvalMultKeyString", &SerializeEvalMultKeyToBytesWrapper<SerType::SERBINARY>,
|
|
|
py::arg("sertype"), py::arg("id") = "");
|
|
|
- m.def("DeserializeEvalMultKeyString", &DeserializeEvalMultKeyFromBytesWrapper<SerType::SERBINARY>,
|
|
|
- py::arg("data"), py::arg("sertype"));
|
|
|
+ m.def("DeserializeEvalMultKeyString",
|
|
|
+ static_cast<void (*)(const py::bytes&, const SerType::SERBINARY&)>(&DeserializeEvalMultKeyFromBytesWrapper<SerType::SERBINARY>),
|
|
|
+ py::arg("bytes"), py::arg("sertype"));
|
|
|
m.def("SerializeEvalAutomorphismKeyString", &SerializeEvalAutomorphismKeyToBytesWrapper<SerType::SERBINARY>,
|
|
|
py::arg("sertype"), py::arg("id") = "");
|
|
|
- m.def("DeserializeEvalAutomorphismKeyString", &DeserializeEvalAutomorphismKeyFromBytesWrapper<SerType::SERBINARY>,
|
|
|
- py::arg("data"), py::arg("sertype"));
|
|
|
+ m.def("DeserializeEvalAutomorphismKeyString",
|
|
|
+ static_cast<void (*)(const py::bytes&, const SerType::SERBINARY&)>(&DeserializeEvalAutomorphismKeyFromBytesWrapper<SerType::SERBINARY>),
|
|
|
+ py::arg("bytes"), py::arg("sertype"));
|
|
|
}
|