Browse Source

Added a trampoline class PlaintextImpl_helper to override PlaintextImpl's virtual functions

Dmitriy Suponitskiy 11 months ago
parent
commit
5c3b5487bc
1 changed files with 80 additions and 1 deletions
  1. 80 1
      src/lib/bindings.cpp

+ 80 - 1
src/lib/bindings.cpp

@@ -1057,9 +1057,88 @@ void bind_keys(py::module &m)
         .def(py::init<>());
 }
 
+// PlaintextImpl is an abstract class, so we should use a helper (trampoline) class
+class PlaintextImpl_helper : public PlaintextImpl
+{
+public:
+    using PlaintextImpl::PlaintextImpl; // inherited constructors
+
+    // the PlaintextImpl virtual functions' overrides
+    bool Encode() override {
+        PYBIND11_OVERRIDE_PURE(
+            bool,          // return type
+            PlaintextImpl, // parent class
+            Encode         // function name
+                           // no arguments
+        );
+    }
+    bool Decode() override {
+        PYBIND11_OVERRIDE_PURE(
+            bool,          // return type
+            PlaintextImpl, // parent class
+            Decode         // function name
+                           // no arguments
+        );
+    }
+    bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override {
+        PYBIND11_OVERRIDE(
+            bool,          // return type
+            PlaintextImpl, // parent class
+            Decode,        // function name
+            depth, scalingFactor, scalTech, executionMode // arguments
+        );
+    }
+    size_t GetLength() const override {
+        PYBIND11_OVERRIDE_PURE(
+            size_t,        // return type
+            PlaintextImpl, // parent class
+            GetLength      // function name
+                           // no arguments
+        );
+    }
+    void SetLength(size_t newSize) override {
+        PYBIND11_OVERRIDE(
+            void,          // return type
+            PlaintextImpl, // parent class
+            SetLength,     // function name
+            newSize        // arguments
+        );
+    }
+    double GetLogError() const override {
+        PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogError);
+    }
+    double GetLogPrecision() const override {
+        PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogPrecision);
+    }
+    const std::string& GetStringValue() const override {
+        PYBIND11_OVERRIDE(const std::string&, PlaintextImpl, GetStringValue);
+    }
+    const std::vector<int64_t>& GetCoefPackedValue() const override {
+        PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetCoefPackedValue);
+    }
+    const std::vector<int64_t>& GetPackedValue() const override {
+        PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetPackedValue);
+    }
+    const std::vector<std::complex<double>>& GetCKKSPackedValue() const override {
+        PYBIND11_OVERRIDE(const std::vector<std::complex<double>>&, PlaintextImpl, GetCKKSPackedValue);
+    }
+    std::vector<double> GetRealPackedValue() const override {
+        PYBIND11_OVERRIDE(std::vector<double>, PlaintextImpl, GetRealPackedValue);
+    }
+    void SetStringValue(const std::string& str) override {
+        PYBIND11_OVERRIDE(void, PlaintextImpl, SetStringValue, str);
+    }
+    void SetIntVectorValue(const std::vector<int64_t>& vec) override {
+        PYBIND11_OVERRIDE(void, PlaintextImpl, SetIntVectorValue, vec);
+    }
+    std::string GetFormattedValues(int64_t precision) const override {
+        PYBIND11_OVERRIDE(std::string, PlaintextImpl, GetFormattedValues, precision);
+    }
+};
+
 void bind_encodings(py::module &m)
 {
-    py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>>(m, "Plaintext")
+    py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>, PlaintextImpl_helper>(m, "Plaintext")
         .def("GetScalingFactor", &PlaintextImpl::GetScalingFactor,
             ptx_GetScalingFactor_docs)
         .def("SetScalingFactor", &PlaintextImpl::SetScalingFactor,