Browse Source

directly exposing SerTypes + SerializeToFile

Rener Oliveira (Ubuntu WSL) 2 years ago
parent
commit
eb643b4f8b
4 changed files with 11 additions and 44 deletions
  1. 1 1
      include/bindings.h
  2. 7 3
      src/bindings.cpp
  3. 1 1
      src/pke/examples/simple-integers-serial.py
  4. 2 39
      src/pke/serialization.cpp

+ 1 - 1
include/bindings.h

@@ -5,7 +5,7 @@
 
 void bind_parameters(pybind11::module &m);
 void bind_crypto_context(pybind11::module &m);
-void bind_enums(pybind11::module &m);
+void bind_enums_and_constants(pybind11::module &m);
 void bind_keys(pybind11::module &m);
 void bind_encodings(pybind11::module &m);
 void bind_ciphertext(pybind11::module &m);

+ 7 - 3
src/bindings.cpp

@@ -45,7 +45,7 @@ void bind_crypto_context(py::module &m){
 
 
 
-void bind_enums(py::module &m){
+void bind_enums_and_constants(py::module &m){
     // Scheme Types
     py::enum_<SCHEME>(m, "SCHEME")
             .value("INVALID_SCHEME", SCHEME::INVALID_SCHEME)
@@ -61,6 +61,11 @@ void bind_enums(py::module &m){
             .value("ADVANCEDSHE", PKESchemeFeature::ADVANCEDSHE)
             .value("MULTIPARTY", PKESchemeFeature::MULTIPARTY)
             .value("FHE", PKESchemeFeature::FHE);
+    // Serialization Types
+    py::class_<SerType::SERJSON >(m, "SERJSON");
+    py::class_<SerType::SERBINARY>(m, "SERBINARY");
+    m.attr("JSON") = py::cast(SerType::JSON);
+    m.attr("BINARY") = py::cast(SerType::BINARY);
 }
 
 void bind_keys(py::module &m){
@@ -117,11 +122,10 @@ PYBIND11_MODULE(openfhe, m) {
     m.doc() = "Open-Source Fully Homomorphic Encryption Library";
     bind_parameters(m);
     bind_crypto_context(m);
-    bind_enums(m);
+    bind_enums_and_constants(m);
     bind_keys(m);
     bind_encodings(m);
     bind_ciphertext(m);
     bind_decryption(m);
     bind_serialization(m);
-
 }

+ 1 - 1
src/pke/examples/simple-integers-serial.py

@@ -17,7 +17,7 @@ cryptoContext.Enable(PKESchemeFeature.PKE)
 cryptoContext.Enable(PKESchemeFeature.KEYSWITCH)
 cryptoContext.Enable(PKESchemeFeature.LEVELEDSHE)
 
-if not SerializeToFile(datafolder + "/cryptocontext.txt", cryptoContext, 'binary'):
+if not SerializeToFile(datafolder + "/cryptocontext.txt", cryptoContext, JSON):
    raise Exception("Error writing serialization of the crypto context to cryptocontext.txt")
 
 

+ 2 - 39
src/pke/serialization.cpp

@@ -2,52 +2,15 @@
 #include <pybind11/stl.h>
 #include <openfhe/pke/openfhe.h>
 #include <openfhe/pke/scheme/bfvrns/bfvrns-ser.h>
+#include <openfhe/pke/cryptocontext-ser.h>
 #include "bindings.h"
 #include "serialization.h"
 
 using namespace lbcrypto;
 namespace py = pybind11;
 
-template <typename T>
-bool SerializeToFileImpl(const std::string& filename, const T& obj, const std::string& sertype_str) {
-    // call the appropriate serialization function based on the string
-    if (sertype_str == "binary") {
-        return Serial::SerializeToFile(filename, obj, SerType::BINARY);
-    } else if (sertype_str == "json") {
-        return Serial::SerializeToFile(filename, obj, SerType::JSON);
-    }else {
-        OPENFHE_THROW(serialize_error,"Serialization type not supported, use 'json' or 'binary'");
-    }
-    
-    // switch (sertype_str)
-    // {
-    // case "json":
-    //     return Serial::SerializeToFile(filename, obj, SerType::JSON);
-    //     break;
-    
-    // case "binary":
-    //     return Serial::SerializeToFile(filename, obj, SerType::BINARY);
-    //     break;
-    
-    // default:
-        
-    // }
-}
-
-bool SerializeToFileInterface(const std::string& filename, const CryptoContext<DCRTPoly>& obj, const std::string& sertype_str) {
-    return SerializeToFileImpl(filename, obj, sertype_str);
-}
-
-bool SerializeToFileInterface(const std::string& filename, const PublicKey<DCRTPoly>& obj, const std::string& sertype_str) {
-    return SerializeToFileImpl(filename, obj, sertype_str);
-}
-
-bool SerializeToFileInterface(const std::string& filename, const PrivateKey<DCRTPoly>& obj, const std::string& sertype_str) {
-    return SerializeToFileImpl(filename, obj, sertype_str);
-}
-
 void bind_serialization(pybind11::module &m) {
-    m.def("SerializeToFile", static_cast<bool (*)(const std::string&, const CryptoContext<DCRTPoly>&, const std::string&)>(&SerializeToFileInterface), py::arg("filename"), py::arg("obj"), py::arg("sertype_str")="binary");
+    m.def("SerializeToFile", static_cast<bool (*)(const std::string&, const CryptoContext<DCRTPoly>&, const SerType::SERJSON&)>(&Serial::SerializeToFile<DCRTPoly>), py::arg("filename"), py::arg("obj"), py::arg("sertype"));
     
 }