Browse Source

more eval functions + EvalKeyImpl exposure

Rener Oliveira (Ubuntu WSL) 1 year ago
parent
commit
9b25d90673
1 changed files with 37 additions and 0 deletions
  1. 37 0
      src/bindings.cpp

+ 37 - 0
src/bindings.cpp

@@ -1,10 +1,12 @@
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
+#include <pybind11/stl_bind.h>
 #include <pybind11/complex.h>
 #include <pybind11/functional.h>
 #include <pybind11/operators.h>
 #include <pybind11/iostream.h>
 #include <iostream>
+#include <map>
 #include "openfhe.h"
 #include "key/key-ser.h"
 #include "bindings.h"
@@ -13,6 +15,7 @@
 
 using namespace lbcrypto;
 namespace py = pybind11;
+PYBIND11_MAKE_OPAQUE(std::map<usint, EvalKey<DCRTPoly>>);
 
 template <typename T>
 void bind_parameters(py::module &m,const std::string name)
@@ -100,6 +103,7 @@ void bind_crypto_context(py::module &m)
         .def("Enable", static_cast<void (CryptoContextImpl<DCRTPoly>::*)(PKESchemeFeature)>(&CryptoContextImpl<DCRTPoly>::Enable), "Enable a feature for the CryptoContext")
         .def("KeyGen", &CryptoContextImpl<DCRTPoly>::KeyGen, "Generate a key pair with public and private keys")
         .def("EvalMultKeyGen", &CryptoContextImpl<DCRTPoly>::EvalMultKeyGen, "Generate the evaluation key for multiplication")
+        .def("EvalMultKeysGen", &CryptoContextImpl<DCRTPoly>::EvalMultKeysGen)
         .def("EvalRotateKeyGen", &CryptoContextImpl<DCRTPoly>::EvalRotateKeyGen, "Generate the evaluation key for rotation",
              py::arg("privateKey"), py::arg("indexList"), py::arg("publicKey") = nullptr)
         .def("MakePackedPlaintext", &CryptoContextImpl<DCRTPoly>::MakePackedPlaintext, "Make a plaintext from a vector of integers",
@@ -133,11 +137,35 @@ void bind_crypto_context(py::module &m)
         .def("EvalAddMutable", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(Ciphertext<DCRTPoly>&, Ciphertext<DCRTPoly>&) const>(&CryptoContextImpl<DCRTPoly>::EvalAddMutable))
         .def("EvalAddMutable", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(Ciphertext<DCRTPoly>&, Plaintext) const>(&CryptoContextImpl<DCRTPoly>::EvalAddMutable))
         .def("EvalAddMutable", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(Plaintext, Ciphertext<DCRTPoly>&) const>(&CryptoContextImpl<DCRTPoly>::EvalAddMutable))
+        .def("EvalAddMutableInPlace", &CryptoContextImpl<DCRTPoly>::EvalAddMutableInPlace)
         .def("EvalSub", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalSub), "Subtract two ciphertexts")
         .def("EvalSub", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, double) const>(&CryptoContextImpl<DCRTPoly>::EvalSub), "Subtract double from ciphertext")
         .def("EvalSub", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(double, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalSub), "Subtract ciphertext from double")
+        .def("EvalSub", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstPlaintext) const>(&CryptoContextImpl<DCRTPoly>::EvalSub), "Subtract plaintext from ciphertext")
+        .def("EvalSub", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstPlaintext, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalSub), "Subtract ciphertext from plaintext")
+        .def("EvalSubInPlace", static_cast<void (CryptoContextImpl<DCRTPoly>::*)(Ciphertext<DCRTPoly> &, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalSubInPlace))
+        .def("EvalSubInPlace", static_cast<void (CryptoContextImpl<DCRTPoly>::*)(Ciphertext<DCRTPoly> &, double) const>(&CryptoContextImpl<DCRTPoly>::EvalSubInPlace))
+        .def("EvalSubInPlace", static_cast<void (CryptoContextImpl<DCRTPoly>::*)(double, Ciphertext<DCRTPoly> &) const>(&CryptoContextImpl<DCRTPoly>::EvalSubInPlace))
+        .def("EvalSubMutable", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(Ciphertext<DCRTPoly>&, Ciphertext<DCRTPoly>&) const>(&CryptoContextImpl<DCRTPoly>::EvalSubMutable))
+        .def("EvalSubMutable", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(Ciphertext<DCRTPoly>&, Plaintext) const>(&CryptoContextImpl<DCRTPoly>::EvalSubMutable))
+        .def("EvalSubMutable", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(Plaintext, Ciphertext<DCRTPoly>&) const>(&CryptoContextImpl<DCRTPoly>::EvalSubMutable))
+        .def("EvalSubMutableInPlace", &CryptoContextImpl<DCRTPoly>::EvalSubMutableInPlace)
         .def("EvalMult", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalMult), "Multiply two ciphertexts")
         .def("EvalMult", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, double) const>(&CryptoContextImpl<DCRTPoly>::EvalMult), "Multiply a ciphertext with a scalar")
+        .def("EvalMult", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstCiphertext<DCRTPoly>, ConstPlaintext) const>(&CryptoContextImpl<DCRTPoly>::EvalMult))
+        .def("EvalMult", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(ConstPlaintext, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalMult))
+        .def("EvalMult", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(double, ConstCiphertext<DCRTPoly>) const>(&CryptoContextImpl<DCRTPoly>::EvalMult))
+        .def("EvalMultMutable", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(Ciphertext<DCRTPoly>&, Ciphertext<DCRTPoly>&) const>(&CryptoContextImpl<DCRTPoly>::EvalMultMutable))
+        .def("EvalMultMutable", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(Ciphertext<DCRTPoly>&, Plaintext) const>(&CryptoContextImpl<DCRTPoly>::EvalMultMutable))
+        .def("EvalMultMutable", static_cast<Ciphertext<DCRTPoly> (CryptoContextImpl<DCRTPoly>::*)(Plaintext, Ciphertext<DCRTPoly>&) const>(&CryptoContextImpl<DCRTPoly>::EvalMultMutable))
+        .def("EvalMultMutableInPlace", &CryptoContextImpl<DCRTPoly>::EvalMultMutableInPlace)
+        .def("EvalSquare", &CryptoContextImpl<DCRTPoly>::EvalSquare)
+        .def("EvalSquareMutable", &CryptoContextImpl<DCRTPoly>::EvalSquareMutable)
+        .def("EvalSquareInPlace", &CryptoContextImpl<DCRTPoly>::EvalSquareInPlace)
+        .def("EvalMultNoRelin", &CryptoContextImpl<DCRTPoly>::EvalMultNoRelin)
+        .def("Relinearize", &CryptoContextImpl<DCRTPoly>::Relinearize)
+        .def("RelinearizeInPlace", &CryptoContextImpl<DCRTPoly>::RelinearizeInPlace)
+        .def("EvalMultAndRelinearize", &CryptoContextImpl<DCRTPoly>::EvalMultAndRelinearize)
         .def("EvalNegate",&CryptoContextImpl<DCRTPoly>::EvalNegate)
         .def("EvalNegateInPlace",&CryptoContextImpl<DCRTPoly>::EvalNegateInPlace)
         .def("EvalLogistic", &CryptoContextImpl<DCRTPoly>::EvalLogistic,
@@ -167,6 +195,10 @@ void bind_crypto_context(py::module &m)
             py::arg("ciphertext"),
             py::arg("numIterations") = 1,
             py::arg("precision") = 0)
+        .def("EvalAutomorphismKeyGen", 
+                static_cast<std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> 
+                (CryptoContextImpl<DCRTPoly>::*)(const PrivateKey<DCRTPoly>, const std::vector<usint> &) const>
+                (&CryptoContextImpl<DCRTPoly>::EvalAutomorphismKeyGen))
         .def_static(
             "ClearEvalMultKeys", []()
             { CryptoContextImpl<DCRTPoly>::ClearEvalMultKeys(); },
@@ -374,6 +406,9 @@ void bind_enums_and_constants(py::module &m)
 
     // Params
     py::class_<Params>(m, "Params");
+
+    // EvalKeyMap
+    py::bind_map<std::map<usint, EvalKey<DCRTPoly>>>(m, "EvalKeyMap");
 }
 
 void bind_keys(py::module &m)
@@ -384,6 +419,8 @@ void bind_keys(py::module &m)
     py::class_<KeyPair<DCRTPoly>>(m, "KeyPair")
         .def_readwrite("publicKey", &KeyPair<DCRTPoly>::publicKey)
         .def_readwrite("secretKey", &KeyPair<DCRTPoly>::secretKey);
+    py::class_<EvalKeyImpl<DCRTPoly>, std::shared_ptr<EvalKeyImpl<DCRTPoly>>>(m, "EvalKey")
+        .def(py::init<>());
 }
 
 void bind_encodings(py::module &m)