Переглянути джерело

Fixes to binding of serialization functions, MultiEvalAtIndexKeyGen() and GetEvalAutomorphismKeyMap().

Dmitriy Suponitskiy 11 місяців тому
батько
коміт
8f1be1da29
2 змінених файлів з 7 додано та 8 видалено
  1. 3 4
      src/lib/bindings.cpp
  2. 4 4
      src/lib/pke/serialization.cpp

+ 3 - 4
src/lib/bindings.cpp

@@ -502,7 +502,7 @@ void bind_crypto_context(py::module &m)
                 const PrivateKey<DCRTPoly>& privateKey,
                 std::shared_ptr<std::map<unsigned int, EvalKey<DCRTPoly>>> evalKeyMap,
                 const std::vector<int32_t>& indexList,
-                const std::string& keyId) {
+                const std::string& keyId = "") {
                  return self->MultiEvalAtIndexKeyGen(privateKey, evalKeyMap, indexList, keyId);
              },
              cc_MultiEvalAtIndexKeyGen_docs,
@@ -758,10 +758,9 @@ void bind_crypto_context(py::module &m)
             "ClearEvalAutomorphismKeys", []()
             { CryptoContextImpl<DCRTPoly>::ClearEvalAutomorphismKeys(); },
             cc_ClearEvalAutomorphismKeys_docs)
-        .def_static("GetEvalAutomorphismKeyMap", &CryptoContextImpl<DCRTPoly>::GetEvalAutomorphismKeyMap,
+        .def_static("GetEvalAutomorphismKeyMap", &CryptoContextImpl<DCRTPoly>::GetEvalAutomorphismKeyMapPtr,
             cc_GetEvalAutomorphismKeyMap_docs,
-            py::arg("keyId") = "",
-            py::return_value_policy::reference)
+            py::arg("keyId") = "")
         .def("GetEvalSumKeyMap", &GetEvalSumKeyMapWrapper,
             cc_GetEvalSumKeyMap_docs)
         .def("GetBinCCForSchemeSwitch", &CryptoContextImpl<DCRTPoly>::GetBinCCForSchemeSwitch)

+ 4 - 4
src/lib/pke/serialization.cpp

@@ -263,11 +263,11 @@ void bind_serialization(pybind11::module &m) {
     m.def("SerializeEvalMultKeyString", &SerializeEvalMultKeyToStringWrapper<SerType::SERJSON>,
           py::arg("sertype"), py::arg("id") = "");
     m.def("DeserializeEvalMultKeyString", &DeserializeEvalMultKeyFromStringWrapper<SerType::SERJSON>,
-          py::arg("sertype"), py::arg("id") = "");
+          py::arg("data"), py::arg("sertype"));
     m.def("SerializeEvalAutomorphismKeyString", &SerializeEvalAutomorphismKeyToStringWrapper<SerType::SERJSON>,
           py::arg("sertype"), py::arg("id") = "");
     m.def("DeserializeEvalAutomorphismKeyString", &DeserializeEvalAutomorphismKeyFromStringWrapper<SerType::SERJSON>,
-          py::arg("sertype"), py::arg("id") = "");
+          py::arg("data"), py::arg("sertype"));
 
     // Binary Serialization
     m.def("SerializeToFile", static_cast<bool (*)(const std::string&,const CryptoContext<DCRTPoly>&, const SerType::SERBINARY&)>(&Serial::SerializeToFile<DCRTPoly>),
@@ -315,9 +315,9 @@ void bind_serialization(pybind11::module &m) {
     m.def("SerializeEvalMultKeyString", &SerializeEvalMultKeyToBytesWrapper<SerType::SERBINARY>,
           py::arg("sertype"), py::arg("id") = "");
     m.def("DeserializeEvalMultKeyString", &DeserializeEvalMultKeyFromBytesWrapper<SerType::SERBINARY>,
-          py::arg("sertype"), py::arg("id") = "");
+          py::arg("data"), py::arg("sertype"));
     m.def("SerializeEvalAutomorphismKeyString", &SerializeEvalAutomorphismKeyToBytesWrapper<SerType::SERBINARY>,
           py::arg("sertype"), py::arg("id") = "");
     m.def("DeserializeEvalAutomorphismKeyString", &DeserializeEvalAutomorphismKeyFromBytesWrapper<SerType::SERBINARY>,
-          py::arg("sertype"), py::arg("id") = "");
+          py::arg("data"), py::arg("sertype"));
 }