Pārlūkot izejas kodu

Added GetEvalMultKeyVector() to binding

Dmitriy Suponitskiy 10 mēneši atpakaļ
vecāks
revīzija
1a934e51aa

+ 9 - 0
src/include/docstrings/cryptocontext_docs.h

@@ -935,6 +935,15 @@ const char* cc_MultiKeySwitchGen_docs = R"pbdoc(
     :rtype: EvalKey
 )pbdoc";
 
+const char* cc_GetEvalMultKeyVector_docs = R"pbdoc(
+    Get relinearization keys for a specific secret key tag
+
+    :param keyId: key identifier used for private key
+    :type keyId: str
+    :return: EvalKeyVector: vector with all relinearization keys.
+    :rtype: EvalKeyVector
+)pbdoc";
+
 const char* cc_GetEvalAutomorphismKeyMap_docs = R"pbdoc(
     Get automorphism keys for a specific secret key tag
 

+ 19 - 21
src/lib/bindings.cpp

@@ -782,6 +782,9 @@ void bind_crypto_context(py::module &m)
         .def("FindAutomorphismIndices", &CryptoContextImpl<DCRTPoly>::FindAutomorphismIndices,
             cc_FindAutomorphismIndices_docs,
             py::arg("idxList"))
+        .def("GetEvalSumKeyMap", &GetEvalSumKeyMapWrapper,
+            cc_GetEvalSumKeyMap_docs)
+        .def("GetBinCCForSchemeSwitch", &CryptoContextImpl<DCRTPoly>::GetBinCCForSchemeSwitch)
         .def_static(
             "InsertEvalSumKey", &CryptoContextImpl<DCRTPoly>::InsertEvalSumKey,
             cc_InsertEvalSumKey_docs,
@@ -801,18 +804,21 @@ void bind_crypto_context(py::module &m)
             "ClearEvalAutomorphismKeys", []()
             { CryptoContextImpl<DCRTPoly>::ClearEvalAutomorphismKeys(); },
             cc_ClearEvalAutomorphismKeys_docs)
+        // it is safer to return by value instead of by reference (GetEvalMultKeyVector returns a const reference to std::vector)
+        .def_static("GetEvalMultKeyVector",
+            [](const std::string& keyId) {
+                return CryptoContextImpl<DCRTPoly>::GetEvalMultKeyVector(keyId);
+            },
+            cc_GetEvalMultKeyVector_docs,
+            py::arg("keyId") = "")
         .def_static("GetEvalAutomorphismKeyMap", &CryptoContextImpl<DCRTPoly>::GetEvalAutomorphismKeyMapPtr,
             cc_GetEvalAutomorphismKeyMap_docs,
             py::arg("keyId") = "")
-        .def("GetEvalSumKeyMap", &GetEvalSumKeyMapWrapper,
-            cc_GetEvalSumKeyMap_docs)
-        .def("GetBinCCForSchemeSwitch", &CryptoContextImpl<DCRTPoly>::GetBinCCForSchemeSwitch)
         .def_static(
             "SerializeEvalMultKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string id = "")
             {
                 std::ofstream outfile(filename,std::ios::out | std::ios::binary);
-                bool res;
-                res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<SerType::SERBINARY>(outfile, sertype, id);
+                bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<SerType::SERBINARY>(outfile, sertype, id);
                 outfile.close();
                 return res; },
             cc_SerializeEvalMultKey_docs,
@@ -821,8 +827,7 @@ void bind_crypto_context(py::module &m)
             "SerializeEvalMultKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string id = "")
             {
                 std::ofstream outfile(filename,std::ios::out | std::ios::binary);
-                bool res;
-                res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<SerType::SERJSON>(outfile, sertype, id);
+                bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<SerType::SERJSON>(outfile, sertype, id);
                 outfile.close();
                 return res; },
             cc_SerializeEvalMultKey_docs,
@@ -831,8 +836,7 @@ void bind_crypto_context(py::module &m)
             "SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string id = "")
             {
                 std::ofstream outfile(filename,std::ios::out | std::ios::binary);
-                bool res;
-                res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<SerType::SERBINARY>(outfile, sertype, id);
+                bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<SerType::SERBINARY>(outfile, sertype, id);
                 outfile.close();
                 return res; },
             cc_SerializeEvalAutomorphismKey_docs,
@@ -841,8 +845,7 @@ void bind_crypto_context(py::module &m)
             "SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string id = "")
             {
                 std::ofstream outfile(filename,std::ios::out | std::ios::binary);
-                bool res;
-                res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<SerType::SERJSON>(outfile, sertype, id);
+                bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<SerType::SERJSON>(outfile, sertype, id);
                 outfile.close();
                 return res; },
             cc_SerializeEvalAutomorphismKey_docs,
@@ -854,10 +857,8 @@ void bind_crypto_context(py::module &m)
                          if (!emkeys.is_open()) {
                             std::cerr << "I cannot read serialization from " << filename << std::endl;
                          }
-                        bool res;
-                        res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<SerType::SERBINARY>(emkeys, sertype);
-                        return res; 
-                        
+                        bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<SerType::SERBINARY>(emkeys, sertype);
+                        return res;
                         },
                         cc_DeserializeEvalMultKey_docs,
                         py::arg("filename"), py::arg("sertype"))
@@ -868,8 +869,7 @@ void bind_crypto_context(py::module &m)
                          if (!emkeys.is_open()) {
                             std::cerr << "I cannot read serialization from " << filename << std::endl;
                          }
-                        bool res;
-                        res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<SerType::SERJSON>(emkeys, sertype);
+                        bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<SerType::SERJSON>(emkeys, sertype);
                         return res; },
                         cc_DeserializeEvalMultKey_docs,
                         py::arg("filename"), py::arg("sertype"))
@@ -880,8 +880,7 @@ void bind_crypto_context(py::module &m)
                          if (!erkeys.is_open()) {
                             std::cerr << "I cannot read serialization from " << filename << std::endl;
                          }
-                        bool res;
-                        res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<SerType::SERBINARY>(erkeys, sertype);
+                        bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<SerType::SERBINARY>(erkeys, sertype);
                         return res; },
                         cc_DeserializeEvalAutomorphismKey_docs,
                         py::arg("filename"), py::arg("sertype"))
@@ -892,8 +891,7 @@ void bind_crypto_context(py::module &m)
                          if (!erkeys.is_open()) {
                             std::cerr << "I cannot read serialization from " << filename << std::endl;
                          }
-                        bool res;
-                        res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<SerType::SERJSON>(erkeys, sertype);
+                        bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<SerType::SERJSON>(erkeys, sertype);
                         return res; },
                         cc_DeserializeEvalAutomorphismKey_docs,
                         py::arg("filename"), py::arg("sertype"));

+ 4 - 0
src/lib/pke/serialization.cpp

@@ -342,6 +342,10 @@ void bind_serialization(pybind11::module &m) {
           py::arg("obj"), py::arg("sertype"));
     m.def("DeserializeEvalKeyMapString", &DeserializeFromBytesWrapper<std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>>, SerType::SERBINARY>,
           py::arg("str"), py::arg("sertype"));
+    m.def("Serialize", &SerializeToBytesWrapper<std::vector<EvalKey<DCRTPoly>>, SerType::SERBINARY>,
+          py::arg("obj"), py::arg("sertype"));
+    m.def("DeserializeEvalKeyMapVectorString", &DeserializeFromBytesWrapper<std::vector<EvalKey<DCRTPoly>>, SerType::SERBINARY>,
+          py::arg("str"), py::arg("sertype"));
 
     m.def("SerializeEvalMultKeyString", &SerializeEvalMultKeyToBytesWrapper<SerType::SERBINARY>,
           py::arg("sertype"), py::arg("id") = "");