Browse Source

Added more CryptoContext member functions

Dmitriy Suponitskiy 11 months ago
parent
commit
611395adbf

+ 52 - 1
src/include/docstrings/cryptocontext_docs.h

@@ -932,6 +932,15 @@ const char* cc_MultiKeySwitchGen_docs = R"pbdoc(
     :rtype: EvalKey
 )pbdoc";
 
+const char* cc_GetEvalAutomorphismKeyMap_docs = R"pbdoc(
+    Get automorphism keys for a specific secret key tag
+
+    :param keyId: key identifier used for private key
+    :type keyId: str
+    :return: EvalKeyMap: map with all automorphism keys.
+    :rtype: EvalKeyMap
+)pbdoc";
+
 // TODO (Oliveira, R.) - Complete the following documentation
 const char* cc_GetEvalSumKeyMap_docs = R"pbdoc(
     Get a map of summation keys (each is composed of several automorphism keys) for a specific secret key tag
@@ -944,6 +953,22 @@ const char* cc_InsertEvalSumKey_docs = R"pbdoc(
     :param evalKeyMap: key map
     :type evalKeyMap: EvalKeyMap
 )pbdoc";
+
+const char* cc_MultiEvalAtIndexKeyGen_docs = R"pbdoc(
+    Threshold FHE: Generates joined rotation keys from the current secret key and prior joined rotation keys
+
+    :param privateKey: secret key share
+    :type privateKey: PrivateKey
+    :param evalKeyMap: a map with prior joined rotation keys
+    :type evalKeyMap: EvalKeyMap
+    :param indexList: a vector of rotation indices
+    :type indexList: List[int32]
+    :param keyId: new key identifier used for resulting evaluation key
+    :type keyId: str
+    :return: EvalKeyMap: new joined rotation keys
+    :rtype: EvalKeyMap
+)pbdoc";
+
 const char* cc_MultiEvalSumKeyGen_docs = R"pbdoc(
     Threshold FHE: Generates joined summation evaluation keys from the current secret share and prior joined summation keys
 
@@ -957,6 +982,32 @@ const char* cc_MultiEvalSumKeyGen_docs = R"pbdoc(
     :rtype: EvalKeyMap
 )pbdoc";
 
+const char* cc_MultiAddEvalAutomorphismKeys_docs = R"pbdoc(
+    Threshold FHE: Adds two prior evaluation key sets for automorphisms
+
+    :param evalKeyMap1: first automorphism key set
+    :type evalKeyMap1: EvalKeyMap
+    :param evalKeyMap2: second automorphism key set
+    :type evalKeyMap2: EvalKeyMap
+    :param keyId: new key identifier used for resulting evaluation key
+    :type keyId: str
+    :return: the new joined key set for summation
+    :rtype: evalKeyMap
+)pbdoc";
+
+const char* cc_MultiAddPubKeys_docs = R"pbdoc(
+    Threshold FHE: Adds two prior public keys
+
+    :param publicKey1: first public key
+    :type publicKey1: PublicKey
+    :param publicKey2: second public key
+    :type publicKey2: PublicKey
+    :param keyId: new key identifier used for the resulting key
+    :type keyId: str
+    :return: the new combined key
+    :rtype: PublicKey
+)pbdoc";
+
 const char* cc_MultiAddEvalKeys_docs = R"pbdoc(
     Threshold FHE: Adds two prior evaluation keys
 
@@ -1279,7 +1330,7 @@ const char* cc_ClearEvalMultKeys_docs = R"pbdoc(
 )pbdoc";
 
 const char* cc_ClearEvalAutomorphismKeys_docs = R"pbdoc(
-    ClearEvalAutomorphismKeys - flush EvalAutomorphismKey cache
+    Flush EvalAutomorphismKey cache
 )pbdoc";
 
 const char* cc_SerializeEvalAutomorphismKey_docs = R"pbdoc(

+ 0 - 1
src/include/pke/cryptocontext_wrapper.h

@@ -58,7 +58,6 @@ Plaintext DecryptWrapper(CryptoContext<DCRTPoly> &self,
                          const PrivateKey<DCRTPoly> privateKey, ConstCiphertext<DCRTPoly> ciphertext);
 Plaintext MultipartyDecryptFusionWrapper(CryptoContext<DCRTPoly>& self,const std::vector<Ciphertext<DCRTPoly>>& partialCiphertextVec);
 
-const std::map<usint, EvalKey<DCRTPoly>> EvalAutomorphismKeyGenWrapper(CryptoContext<DCRTPoly>& self,const PrivateKey<DCRTPoly> privateKey,const std::vector<usint> &indexList);
 const std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> GetEvalSumKeyMapWrapper(CryptoContext<DCRTPoly>& self, const std::string &id);
 const PlaintextModulus GetPlaintextModulusWrapper(CryptoContext<DCRTPoly>& self);
 const double GetModulusWrapper(CryptoContext<DCRTPoly>& self);

+ 33 - 7
src/lib/bindings.cpp

@@ -497,11 +497,34 @@ void bind_crypto_context(py::module &m)
              py::arg("originalPrivateKey"),
              py::arg("newPrivateKey"),
              py::arg("evalKey"))
+        .def("MultiEvalAtIndexKeyGen",
+            [](CryptoContextImpl<DCRTPoly>* self,
+                const PrivateKey<DCRTPoly>& privateKey,
+                std::shared_ptr<std::map<unsigned int, EvalKey<DCRTPoly>>> evalKeyMap,
+                const std::vector<int32_t>& indexList,
+                const std::string& keyId) {
+                 return self->MultiEvalAtIndexKeyGen(privateKey, evalKeyMap, indexList, keyId);
+             },
+             cc_MultiEvalAtIndexKeyGen_docs,
+             py::arg("privateKey"),
+             py::arg("evalKeyMap"),
+             py::arg("indexList"),
+             py::arg("keyId") = "")
         .def("MultiEvalSumKeyGen", &CryptoContextImpl<DCRTPoly>::MultiEvalSumKeyGen,
              cc_MultiEvalSumKeyGen_docs,
              py::arg("privateKey"),
              py::arg("evalKeyMap"),
              py::arg("keyId") = "")
+        .def("MultiAddEvalAutomorphismKeys", &CryptoContextImpl<DCRTPoly>::MultiAddEvalAutomorphismKeys,
+            cc_MultiAddEvalAutomorphismKeys_docs,
+            py::arg("evalKeyMap1"),
+            py::arg("evalKeyMap1"),
+            py::arg("keyId") = "")
+        .def("MultiAddPubKeys", &CryptoContextImpl<DCRTPoly>::MultiAddPubKeys,
+            cc_MultiAddPubKeys_docs,
+            py::arg("publicKey1"),
+            py::arg("publicKey2"),
+            py::arg("keyId") = "")
         .def("MultiAddEvalKeys", &CryptoContextImpl<DCRTPoly>::MultiAddEvalKeys,
              cc_MultiAddEvalKeys_docs,
              py::arg("evalKey1"),
@@ -686,11 +709,12 @@ void bind_crypto_context(py::module &m)
              py::arg("pLWE") = 0,
              py::arg("scaleSign") = 1.0)
         //TODO (Oliveira, R.): Solve pointer handling bug when returning EvalKeyMap objects for the next functions
-        .def("EvalAutomorphismKeyGen", &EvalAutomorphismKeyGenWrapper, 
+        .def("EvalAutomorphismKeyGen",
+            static_cast<std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> (CryptoContextImpl<DCRTPoly>::*)(const PrivateKey<DCRTPoly>, const std::vector<usint>&) const>
+            (&CryptoContextImpl<DCRTPoly>::EvalAutomorphismKeyGen), 
             cc_EvalAutomorphismKeyGen_docs,
             py::arg("privateKey"),
-            py::arg("indexList"),
-            py::return_value_policy::reference_internal)
+            py::arg("indexList"))
         .def("EvalLinearWSumMutable",
             static_cast<lbcrypto::Ciphertext<DCRTPoly> (lbcrypto::CryptoContextImpl<DCRTPoly>::*)(
                 const std::vector<double>&,
@@ -729,11 +753,13 @@ void bind_crypto_context(py::module &m)
             "ClearEvalAutomorphismKeys", []()
             { CryptoContextImpl<DCRTPoly>::ClearEvalAutomorphismKeys(); },
             cc_ClearEvalAutomorphismKeys_docs)
-        .def("GetEvalSumKeyMap", &GetEvalSumKeyMapWrapper,
-            cc_GetEvalSumKeyMap_docs,
+        .def_static("GetEvalAutomorphismKeyMap", &CryptoContextImpl<DCRTPoly>::GetEvalAutomorphismKeyMap,
+            cc_GetEvalAutomorphismKeyMap_docs,
+            py::arg("keyId") = "",
             py::return_value_policy::reference)
-        .def("GetBinCCForSchemeSwitch", &CryptoContextImpl<DCRTPoly>::GetBinCCForSchemeSwitch,
-        		py::return_value_policy::reference_internal)
+        .def("GetEvalSumKeyMap", &GetEvalSumKeyMapWrapper,
+            cc_GetEvalSumKeyMap_docs)
+        .def("GetBinCCForSchemeSwitch", &CryptoContextImpl<DCRTPoly>::GetBinCCForSchemeSwitch)
         .def_static(
             "SerializeEvalMultKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string id = "")
             {

+ 1 - 7
src/lib/pke/cryptocontext_wrapper.cpp

@@ -77,14 +77,8 @@ Plaintext MultipartyDecryptFusionWrapper(CryptoContext<DCRTPoly>& self,const std
     return plaintextDecResult;
 }
 
-const std::map<usint, EvalKey<DCRTPoly>> EvalAutomorphismKeyGenWrapper(CryptoContext<DCRTPoly>& self,const PrivateKey<DCRTPoly> privateKey,const std::vector<usint> &indexList){
-    return *(self->EvalAutomorphismKeyGen(privateKey, indexList));
-}
-
 const std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> GetEvalSumKeyMapWrapper(CryptoContext<DCRTPoly>& self,const std::string &id){
-    auto evalSumKeyMap = 
-        std::make_shared<std::map<usint, EvalKey<DCRTPoly>>>(self->GetEvalSumKeyMap(id));
-    return evalSumKeyMap;
+    return std::make_shared<std::map<usint, EvalKey<DCRTPoly>>>(CryptoContextImpl<DCRTPoly>::GetEvalSumKeyMap(id));;
 }
 
 const PlaintextModulus GetPlaintextModulusWrapper(CryptoContext<DCRTPoly>& self){