Selaa lähdekoodia

Additional bindings for python (#241)

Co-authored-by: Dmitriy Suponitskiy <dsuponitskiy@dualitytech.com>
dsuponitskiy 7 kuukautta sitten
vanhempi
commit
25f3454019
1 muutettua tiedostoa jossa 39 lisäystä ja 16 poistoa
  1. 39 16
      src/lib/bindings.cpp

+ 39 - 16
src/lib/bindings.cpp

@@ -65,8 +65,7 @@ void bind_DCRTPoly(py::module &m) {
 }
 
 template <typename T>
-void bind_parameters(py::module &m, const std::string name)
-{
+void bind_parameters(py::module &m, const std::string name) {
     py::class_<CCParams<T>>(m, name.c_str())
         .def(py::init<>())
         // getters
@@ -224,8 +223,7 @@ void bind_crypto_context_templates(py::class_<CryptoContextImpl<DCRTPoly>, std::
     ;
 }
 
-void bind_crypto_context(py::module &m)
-{
+void bind_crypto_context(py::module &m) {
     //Parameters Type
     // TODO (Oliveira): If we expose Poly's and ParmType, this block will go somewhere else
     using ParmType = typename DCRTPoly::Params;
@@ -485,6 +483,16 @@ void bind_crypto_context(py::module &m)
             py::arg("plaintext"),
             py::arg("ciphertext"),
             py::doc(""))  // TODO (dsuponit): replace this with an actual docstring
+        .def("EvalAddInPlace",
+            py::overload_cast<Ciphertext<DCRTPoly>&, double>(&CryptoContextImpl<DCRTPoly>::EvalAddInPlace, py::const_),
+            py::arg("ciphertext"),
+            py::arg("scalar"),
+            py::doc(""))  // TODO (dsuponit): replace this with an actual docstring
+        .def("EvalAddInPlace",
+            py::overload_cast<double, Ciphertext<DCRTPoly>&>(&CryptoContextImpl<DCRTPoly>::EvalAddInPlace, py::const_),
+            py::arg("scalar"),
+            py::arg("ciphertext"),
+            py::doc(""))  // TODO (dsuponit): replace this with an actual docstring
         .def("EvalAddMutable",
             py::overload_cast<Ciphertext<DCRTPoly>&, Ciphertext<DCRTPoly>&>(&CryptoContextImpl<DCRTPoly>::EvalAddMutable, py::const_),
             py::arg("ciphertext1"),
@@ -543,6 +551,18 @@ void bind_crypto_context(py::module &m)
             py::arg("scalar"),
             py::arg("ciphertext"),
             py::doc(""))  // TODO (dsuponit): replace this with an actual docstring
+        .def("EvalSubInPlace",
+            py::overload_cast<Ciphertext<DCRTPoly>&, ConstPlaintext&>(
+                &CryptoContextImpl<DCRTPoly>::EvalSubInPlace, py::const_),
+            py::arg("ciphertext"),
+            py::arg("plaintext"),
+            py::doc(""))  // TODO (dsuponit): replace this with an actual docstring
+        .def("EvalSubInPlace",
+            py::overload_cast<Plaintext&, Ciphertext<DCRTPoly>&>(
+                &CryptoContextImpl<DCRTPoly>::EvalSubInPlace, py::const_),
+            py::arg("plaintext"),
+            py::arg("ciphertext"),
+            py::doc(""))  // TODO (dsuponit): replace this with an actual docstring
         .def("EvalSubMutable",
             py::overload_cast<Ciphertext<DCRTPoly>&, Ciphertext<DCRTPoly>&>(&CryptoContextImpl<DCRTPoly>::EvalSubMutable, py::const_),
             py::arg("ciphertext1"),
@@ -1125,7 +1145,7 @@ void bind_crypto_context(py::module &m)
     m.def("ClearEvalMultKeys", static_cast<void (*)()>(&CryptoContextImpl<DCRTPoly>::ClearEvalMultKeys));
 }
 
-int get_native_int(){
+int get_native_int() {
     #if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)
         return 128;
     #elif NATIVEINT == 32
@@ -1135,8 +1155,7 @@ int get_native_int(){
     #endif
 }
 
-void bind_enums_and_constants(py::module &m)
-{
+void bind_enums_and_constants(py::module &m) {
     /* ---- PKE enums ---- */ 
     // Scheme Types
     py::enum_<SCHEME>(m, "SCHEME")
@@ -1298,8 +1317,7 @@ void bind_enums_and_constants(py::module &m)
     m.def("get_native_int", &get_native_int);
 }
 
-void bind_keys(py::module &m)
-{
+void bind_keys(py::module &m) {
     py::class_<PublicKeyImpl<DCRTPoly>, std::shared_ptr<PublicKeyImpl<DCRTPoly>>>(m, "PublicKey")
         .def(py::init<>())
         .def("GetKeyTag", &PublicKeyImpl<DCRTPoly>::GetKeyTag)
@@ -1322,8 +1340,7 @@ void bind_keys(py::module &m)
 }
 
 // PlaintextImpl is an abstract class, so we should use a helper (trampoline) class
-class PlaintextImpl_helper : public PlaintextImpl
-{
+class PlaintextImpl_helper : public PlaintextImpl {
 public:
     using PlaintextImpl::PlaintextImpl; // inherited constructors
 
@@ -1400,8 +1417,7 @@ public:
     }
 };
 
-void bind_encodings(py::module &m)
-{
+void bind_encodings(py::module &m) {
     py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>, PlaintextImpl_helper>(m, "Plaintext")
         .def("GetScalingFactor", &PlaintextImpl::GetScalingFactor, ptx_GetScalingFactor_docs)
         .def("SetScalingFactor", &PlaintextImpl::SetScalingFactor,
@@ -1497,7 +1513,7 @@ void bind_ciphertext(py::module &m) {
             });
 }
 
-void bind_schemes(py::module &m){
+void bind_schemes(py::module &m) {
     // Bind schemes specific functionalities like bootstrapping functions and multiparty
     py::class_<FHECKKSRNS>(m, "FHECKKSRNS")
         .def(py::init<>())
@@ -1513,8 +1529,7 @@ void bind_schemes(py::module &m){
         ;
 }
 
-void bind_sch_swch_params(py::module &m)
-{
+void bind_sch_swch_params(py::module &m) {
     py::class_<SchSwchParams>(m, "SchSwchParams")
         .def(py::init<>())
         .def("GetSecurityLevelCKKS", &SchSwchParams::GetSecurityLevelCKKS)
@@ -1562,6 +1577,13 @@ void bind_sch_swch_params(py::module &m)
             });
 }
 
+void bind_utils(py::module& m) {
+    m.def("EnablePrecomputeCRTTablesAfterDeserializaton", &lbcrypto::EnablePrecomputeCRTTablesAfterDeserializaton,
+          py::doc("Enable CRT precomputation after deserialization"));
+    m.def("DisablePrecomputeCRTTablesAfterDeserializaton", &lbcrypto::DisablePrecomputeCRTTablesAfterDeserializaton,
+          py::doc("Disable CRT precomputation after deserialization"));
+}
+
 PYBIND11_MODULE(openfhe, m) {
     // sequence of function calls matters
     m.doc() = "Open-Source Fully Homomorphic Encryption Library";
@@ -1583,4 +1605,5 @@ PYBIND11_MODULE(openfhe, m) {
     bind_serialization(m);
     bind_schemes(m);
     bind_sch_swch_params(m);
+    bind_utils(m);
 }