소스 검색

Merge pull request #199 from openfheorg/187-fixes-for-openfhe

Updates required for the new OpenFHE release v1.3.0
pascoec 11 달 전
부모
커밋
eee65ddc33
1개의 변경된 파일84개의 추가작업 그리고 5개의 파일을 삭제
  1. 84 5
      src/lib/bindings.cpp

+ 84 - 5
src/lib/bindings.cpp

@@ -1050,16 +1050,95 @@ void bind_keys(py::module &m)
         .def_readwrite("secretKey", &KeyPair<DCRTPoly>::secretKey)
         .def("good", &KeyPair<DCRTPoly>::good,kp_good_docs);
     py::class_<EvalKeyImpl<DCRTPoly>, std::shared_ptr<EvalKeyImpl<DCRTPoly>>>(m, "EvalKey")
-    .def(py::init<>())
+        .def(py::init<>())
         .def("GetKeyTag", &EvalKeyImpl<DCRTPoly>::GetKeyTag)
         .def("SetKeyTag", &EvalKeyImpl<DCRTPoly>::SetKeyTag);
     py::class_<std::map<usint, EvalKey<DCRTPoly>>, std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>>>(m, "EvalKeyMap")
         .def(py::init<>());
 }
 
+// PlaintextImpl is an abstract class, so we should use a helper (trampoline) class
+class PlaintextImpl_helper : public PlaintextImpl
+{
+public:
+    using PlaintextImpl::PlaintextImpl; // inherited constructors
+
+    // the PlaintextImpl virtual functions' overrides
+    bool Encode() override {
+        PYBIND11_OVERRIDE_PURE(
+            bool,          // return type
+            PlaintextImpl, // parent class
+            Encode         // function name
+                           // no arguments
+        );
+    }
+    bool Decode() override {
+        PYBIND11_OVERRIDE_PURE(
+            bool,          // return type
+            PlaintextImpl, // parent class
+            Decode         // function name
+                           // no arguments
+        );
+    }
+    bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override {
+        PYBIND11_OVERRIDE(
+            bool,          // return type
+            PlaintextImpl, // parent class
+            Decode,        // function name
+            depth, scalingFactor, scalTech, executionMode // arguments
+        );
+    }
+    size_t GetLength() const override {
+        PYBIND11_OVERRIDE_PURE(
+            size_t,        // return type
+            PlaintextImpl, // parent class
+            GetLength      // function name
+                           // no arguments
+        );
+    }
+    void SetLength(size_t newSize) override {
+        PYBIND11_OVERRIDE(
+            void,          // return type
+            PlaintextImpl, // parent class
+            SetLength,     // function name
+            newSize        // arguments
+        );
+    }
+    double GetLogError() const override {
+        PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogError);
+    }
+    double GetLogPrecision() const override {
+        PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogPrecision);
+    }
+    const std::string& GetStringValue() const override {
+        PYBIND11_OVERRIDE(const std::string&, PlaintextImpl, GetStringValue);
+    }
+    const std::vector<int64_t>& GetCoefPackedValue() const override {
+        PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetCoefPackedValue);
+    }
+    const std::vector<int64_t>& GetPackedValue() const override {
+        PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetPackedValue);
+    }
+    const std::vector<std::complex<double>>& GetCKKSPackedValue() const override {
+        PYBIND11_OVERRIDE(const std::vector<std::complex<double>>&, PlaintextImpl, GetCKKSPackedValue);
+    }
+    std::vector<double> GetRealPackedValue() const override {
+        PYBIND11_OVERRIDE(std::vector<double>, PlaintextImpl, GetRealPackedValue);
+    }
+    void SetStringValue(const std::string& str) override {
+        PYBIND11_OVERRIDE(void, PlaintextImpl, SetStringValue, str);
+    }
+    void SetIntVectorValue(const std::vector<int64_t>& vec) override {
+        PYBIND11_OVERRIDE(void, PlaintextImpl, SetIntVectorValue, vec);
+    }
+    std::string GetFormattedValues(int64_t precision) const override {
+        PYBIND11_OVERRIDE(std::string, PlaintextImpl, GetFormattedValues, precision);
+    }
+};
+
 void bind_encodings(py::module &m)
 {
-    py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>>(m, "Plaintext")
+    py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>, PlaintextImpl_helper>(m, "Plaintext")
         .def("GetScalingFactor", &PlaintextImpl::GetScalingFactor,
             ptx_GetScalingFactor_docs)
         .def("SetScalingFactor", &PlaintextImpl::SetScalingFactor,
@@ -1069,8 +1148,6 @@ void bind_encodings(py::module &m)
             ptx_GetSchemeID_docs)
         .def("GetLength", &PlaintextImpl::GetLength,
             ptx_GetLength_docs)
-        .def("GetSchemeID", &PlaintextImpl::GetSchemeID,
-            ptx_GetSchemeID_docs)
         .def("SetLength", &PlaintextImpl::SetLength,
             ptx_SetLength_docs,
             py::arg("newSize"))
@@ -1080,7 +1157,9 @@ void bind_encodings(py::module &m)
             ptx_GetLogPrecision_docs)
         .def("Encode", &PlaintextImpl::Encode,
             ptx_Encode_docs)
-        .def("Decode", &PlaintextImpl::Decode,
+        .def("Decode", py::overload_cast<>(&PlaintextImpl::Decode),
+            ptx_Decode_docs)
+        .def("Decode", py::overload_cast<size_t, double, ScalingTechnique, ExecutionMode>(&PlaintextImpl::Decode),
             ptx_Decode_docs)
         .def("LowBound", &PlaintextImpl::LowBound,
             ptx_LowBound_docs)