Przeglądaj źródła

Changes required by openfhe-numpy (#230)

* Fixes for openfhe-numpy

* Exposed API functions for openfhe-numpy

* Small bug fix and formatting

* Exposed more APIs

* Changes to __init__.py and cryptocontext_wrapper.h

* Changes to __init__.py v2

* Changes to binding

* Cleanup

* Corrected find_package(Python...) and added a new variable OPENFHE_REQUIRED_VERSION to specify the OpenFHE version externally"

* Minor change for OPENFHE_REQUIRED_VERSION

---------

Co-authored-by: Dmitriy Suponitskiy <dsuponitskiy@dualitytech.com>
dsuponitskiy 8 miesięcy temu
rodzic
commit
b5595ebf9c

+ 16 - 11
CMakeLists.txt

@@ -8,6 +8,14 @@ set(OPENFHE_PYTHON_VERSION_PATCH 0)
 set(OPENFHE_PYTHON_VERSION_TWEAK 0)
 set(OPENFHE_PYTHON_VERSION ${OPENFHE_PYTHON_VERSION_MAJOR}.${OPENFHE_PYTHON_VERSION_MINOR}.${OPENFHE_PYTHON_VERSION_PATCH}.${OPENFHE_PYTHON_VERSION_TWEAK})
 
+# OpenFHE version can be specified externally (-DOPENFHE_REQUIRED_VERSION=1.3.0)
+if(NOT DEFINED OPENFHE_REQUIRED_VERSION)
+    set(OPENFHE_REQUIRED_VERSION "1.3.0" CACHE STRING "Required OpenFHE version")
+else()
+    # User provided OPENFHE_REQUIRED_VERSION via -D
+    message(STATUS "Using user-specified OpenFHE version: ${OPENFHE_REQUIRED_VERSION}")
+endif()
+
 set(CMAKE_CXX_STANDARD 17)
 option( BUILD_STATIC "Set to ON to include static versions of the library" OFF)
 
@@ -15,7 +23,9 @@ if(APPLE)
     set(CMAKE_CXX_VISIBILITY_PRESET default)
 endif()
 
-find_package(OpenFHE 1.3.0 REQUIRED)
+find_package(OpenFHE ${OPENFHE_REQUIRED_VERSION} REQUIRED)
+message(STATUS "Building with OpenFHE version: ${OPENFHE_REQUIRED_VERSION}")
+
 set(PYBIND11_FINDPYTHON ON)
 find_package(pybind11 REQUIRED)
 
@@ -66,20 +76,13 @@ pybind11_add_module(openfhe
 ### Python installation 
 # Allow the user to specify the path to Python executable (if not provided, find it)
 option(PYTHON_EXECUTABLE_PATH "Path to Python executable" "")
-
-if(NOT PYTHON_EXECUTABLE_PATH)
-    # Find Python and its development components
-    find_package(Python REQUIRED COMPONENTS Interpreter Development)
-else()
-    # Set Python_EXECUTABLE to the specified path
+if(PYTHON_EXECUTABLE_PATH)
     set(Python_EXECUTABLE "${PYTHON_EXECUTABLE_PATH}")
 endif()
-
-# Find Python interpreter
-find_package(PythonInterp REQUIRED)
+find_package(Python REQUIRED COMPONENTS Interpreter Development)
 
 # Check Python version
-if(${PYTHON_VERSION_MAJOR} EQUAL 3 AND ${PYTHON_VERSION_MINOR} GREATER_EQUAL 10)
+if(${Python_VERSION_MAJOR} EQUAL 3 AND ${Python_VERSION_MINOR} GREATER_EQUAL 10)
     execute_process(
         COMMAND "${Python_EXECUTABLE}" -c "from sys import exec_prefix; print(exec_prefix)"
         OUTPUT_VARIABLE PYTHON_SITE_PACKAGES
@@ -101,3 +104,5 @@ else()
 endif()
 message("***** INSTALL IS AT ${Python_Install_Location}; to change, run cmake with -DCMAKE_INSTALL_PREFIX=/your/path")
 install(TARGETS openfhe LIBRARY DESTINATION ${Python_Install_Location})
+install(FILES ${CMAKE_SOURCE_DIR}/__init__.py DESTINATION ${Python_Install_Location})
+

+ 48 - 0
__init__.py

@@ -0,0 +1,48 @@
+import os
+import ctypes
+
+
+def load_shared_library(libname, paths):
+    for path in paths:
+        lib_path = os.path.join(path, libname)
+        if os.path.exists(lib_path):
+            return ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
+
+    raise FileNotFoundError(
+        f"Shared library {libname} not found in {paths}"
+    )
+
+# Search LD_LIBRARY_PATH
+ld_paths = os.environ.get("LD_LIBRARY_PATH", "").split(":")
+
+if not any(ld_paths):
+    # Path to the bundled `lib/` directory inside site-packages
+    package_dir = os.path.abspath(os.path.dirname(__file__))
+    internal_lib_dir = [os.path.join(package_dir, 'lib')]
+
+    # Shared libraries required
+    shared_libs = [
+        'libgomp.so',
+        'libOPENFHEcore.so.1',
+        'libOPENFHEbinfhe.so.1',
+        'libOPENFHEpke.so.1',
+    ]
+
+    for libname in shared_libs:
+        load_shared_library(libname, internal_lib_dir)
+
+    from .openfhe import *
+
+else:
+    # Shared libraries required
+    # skip 'libgomp.so' if LD_LIBRARY_PATH is set as we should get it from the libgomp.so location
+    shared_libs = [
+        'libOPENFHEcore.so.1',
+        'libOPENFHEbinfhe.so.1',
+        'libOPENFHEpke.so.1',
+    ]
+
+    for libname in shared_libs:
+        load_shared_library(libname, ld_paths)
+
+    # from .openfhe import *

+ 0 - 1
openfhe/__init__.py

@@ -1 +0,0 @@
-from openfhe.openfhe import *

+ 0 - 81
setup.py

@@ -1,81 +0,0 @@
-import os
-import subprocess
-import sys
-from setuptools import setup, Extension
-from setuptools.command.sdist import sdist as _sdist
-from setuptools.command.build_ext import build_ext as _build_ext
-from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
-import glob
-import shutil
-
-__version__ = '0.9.0'
-OPENFHE_PATH = 'openfhe/'
-OPENFHE_LIB = 'openfhe.so'
-
-class CMakeExtension(Extension):
-    def __init__(self, name, sourcedir=''):
-        super().__init__(name, sources=[])
-        self.sourcedir = os.path.abspath(sourcedir)
-
-class CMakeBuild(_build_ext):
-
-    def run(self):
-        for ext in self.extensions:
-            self.build_cmake(ext)
-
-    def build_cmake(self, ext):
-        if os.path.exists(OPENFHE_PATH + OPENFHE_LIB):
-            return
-        extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
-        print(extdir)
-        cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
-                      '-DPYTHON_EXECUTABLE=' + sys.executable]
-
-        cfg = 'Debug' if self.debug else 'Release'
-        build_args = ['--config', cfg]
-
-        build_temp = os.path.abspath(self.build_temp)
-        os.makedirs(build_temp, exist_ok=True)
-
-        num_cores = os.cpu_count() or 1
-        build_args += ['--parallel', str(num_cores)]
-
-        subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=build_temp)
-        subprocess.check_call(['cmake', '--build', '.', '--target', ext.name] + build_args, cwd=build_temp)
-
-        so_files = glob.glob(os.path.join(extdir, '*.so'))
-        if not so_files:
-            raise RuntimeError("Cannot find any built .so file in " + extdir)
-
-        src_file = so_files[0] 
-        dst_file = os.path.join('openfhe', OPENFHE_LIB)
-        shutil.move(src_file, dst_file)
-
-# Run build_ext before sdist
-class SDist(_sdist):
-    def run(self):
-        if os.path.exists(OPENFHE_PATH + OPENFHE_LIB):
-            os.remove(OPENFHE_PATH + OPENFHE_LIB)
-        self.run_command('build_ext')
-        super().run()
-
-setup(
-    name='openfhe',
-    version=__version__,
-    description='Python wrapper for OpenFHE C++ library.',
-    author='OpenFHE Team',
-    author_email='contact@openfhe.org',
-    url='https://github.com/openfheorg/openfhe-python',
-    license='BSD-2-Clause',
-    packages=['openfhe'],
-    package_data={'openfhe': ['*.so', '*.pyi']},
-    ext_modules=[CMakeExtension('openfhe', sourcedir='')],
-    cmdclass={
-        'build_ext': CMakeBuild,
-        'sdist': SDist
-    },
-    include_package_data=True,
-    python_requires=">=3.6",
-    install_requires=['pybind11', 'pybind11-global', 'pybind11-stubgen'],
-    tests_require = ['pytest'],
-)

+ 1 - 0
src/include/pke/cryptocontext_wrapper.h

@@ -54,6 +54,7 @@ Plaintext MultipartyDecryptFusionWrapper(CryptoContext<DCRTPoly>& self,const std
 
 const std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>> GetEvalSumKeyMapWrapper(CryptoContext<DCRTPoly>& self, const std::string &id);
 PlaintextModulus GetPlaintextModulusWrapper(CryptoContext<DCRTPoly>& self);
+uint32_t GetBatchSizeWrapper(CryptoContext<DCRTPoly>& self);
 double GetModulusWrapper(CryptoContext<DCRTPoly>& self);
 void RemoveElementWrapper(Ciphertext<DCRTPoly>& self, uint32_t index);
 double GetScalingFactorRealWrapper(CryptoContext<DCRTPoly>& self, uint32_t l);

+ 163 - 136
src/lib/bindings.cpp

@@ -166,6 +166,7 @@ void bind_crypto_context(py::module &m)
         //.def("GetCryptoParameters", &CryptoContextImpl<DCRTPoly>::GetCryptoParameters)
         .def("GetRingDimension", &CryptoContextImpl<DCRTPoly>::GetRingDimension, cc_GetRingDimension_docs)
         .def("GetPlaintextModulus", &GetPlaintextModulusWrapper, cc_GetPlaintextModulus_docs)
+        .def("GetBatchSize", &GetBatchSizeWrapper)
         .def("GetModulus", &GetModulusWrapper, cc_GetModulus_docs)
         .def("GetModulusCKKS", &GetModulusCKKSWrapper)
         .def("GetScalingFactorReal", &GetScalingFactorRealWrapper, cc_GetScalingFactorReal_docs)
@@ -868,101 +869,91 @@ void bind_crypto_context(py::module &m)
             cc_InsertEvalAutomorphismKey_docs,
             py::arg("evalKeyMap"),
             py::arg("keyTag") = "")
-        .def_static(
-            "ClearEvalAutomorphismKeys", []()
-            { CryptoContextImpl<DCRTPoly>::ClearEvalAutomorphismKeys(); },
+        .def_static("ClearEvalAutomorphismKeys", []() {
+                CryptoContextImpl<DCRTPoly>::ClearEvalAutomorphismKeys();
+            },
             cc_ClearEvalAutomorphismKeys_docs)
         // it is safer to return by value instead of by reference (GetEvalMultKeyVector returns a const reference to std::vector)
-        .def_static("GetEvalMultKeyVector",
-            [](const std::string& keyTag) {
-              return CryptoContextImpl<DCRTPoly>::GetEvalMultKeyVector(keyTag);
+        .def_static("GetEvalMultKeyVector", [](const std::string& keyTag) {
+                return CryptoContextImpl<DCRTPoly>::GetEvalMultKeyVector(keyTag);
             },
             cc_GetEvalMultKeyVector_docs,
             py::arg("keyTag") = "")
         .def_static("GetEvalAutomorphismKeyMap", &CryptoContextImpl<DCRTPoly>::GetEvalAutomorphismKeyMapPtr,
             cc_GetEvalAutomorphismKeyMap_docs,
             py::arg("keyTag") = "")
-        .def_static(
-            "SerializeEvalMultKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string keyTag = "")
-            {
-              std::ofstream outfile(filename, std::ios::out | std::ios::binary);
-              bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<SerType::SERBINARY>(outfile, sertype, keyTag);
-              outfile.close();
-              return res; },
+        .def_static("SerializeEvalMultKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string keyTag = "") {
+                std::ofstream outfile(filename, std::ios::out | std::ios::binary);
+                bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<SerType::SERBINARY>(outfile, sertype, keyTag);
+                outfile.close();
+                return res;
+            },
             cc_SerializeEvalMultKey_docs,
             py::arg("filename"), py::arg("sertype"), py::arg("keyTag") = "")
-        .def_static( // SerializeEvalMultKey - JSON
-            "SerializeEvalMultKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string keyTag = "")
-            {
-              std::ofstream outfile(filename, std::ios::out | std::ios::binary);
-              bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<SerType::SERJSON>(outfile, sertype, keyTag);
-              outfile.close();
-              return res; },
+        .def_static("SerializeEvalMultKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string keyTag = "") {
+                std::ofstream outfile(filename, std::ios::out | std::ios::binary);
+                bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<SerType::SERJSON>(outfile, sertype, keyTag);
+                outfile.close();
+                return res;
+            },
             cc_SerializeEvalMultKey_docs,
             py::arg("filename"), py::arg("sertype"), py::arg("keyTag") = "")
-        .def_static( // SerializeEvalAutomorphismKey - Binary
-            "SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string keyTag = "")
-            {
-              std::ofstream outfile(filename, std::ios::out | std::ios::binary);
-              bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<SerType::SERBINARY>(outfile, sertype, keyTag);
-              outfile.close();
-              return res; },
+        .def_static("SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string keyTag = "") {
+                std::ofstream outfile(filename, std::ios::out | std::ios::binary);
+                bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<SerType::SERBINARY>(outfile, sertype, keyTag);
+                outfile.close();
+                return res;
+            },
             cc_SerializeEvalAutomorphismKey_docs,
             py::arg("filename"), py::arg("sertype"), py::arg("keyTag") = "")
-        .def_static( // SerializeEvalAutomorphismKey - JSON
-            "SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string keyTag = "")
-            {
-              std::ofstream outfile(filename, std::ios::out | std::ios::binary);
-              bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<SerType::SERJSON>(outfile, sertype, keyTag);
-              outfile.close();
-              return res; },
+        .def_static("SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string keyTag = "") {
+                std::ofstream outfile(filename, std::ios::out | std::ios::binary);
+                bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<SerType::SERJSON>(outfile, sertype, keyTag);
+                outfile.close();
+                return res;
+            },
             cc_SerializeEvalAutomorphismKey_docs,
             py::arg("filename"), py::arg("sertype"), py::arg("keyTag") = "")
-        .def_static("DeserializeEvalMultKey", // DeserializeEvalMultKey - Binary
-        [](const std::string &filename, const SerType::SERBINARY &sertype)
-                    {
-              std::ifstream emkeys(filename, std::ios::in | std::ios::binary);
-              if (!emkeys.is_open()) {
-                std::cerr << "I cannot read serialization from " << filename << std::endl;
-              }
-              bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<SerType::SERBINARY>(emkeys, sertype);
-              return res; 
-                        },
-                        cc_DeserializeEvalMultKey_docs,
-                        py::arg("filename"), py::arg("sertype"))
-        .def_static("DeserializeEvalMultKey", // DeserializeEvalMultKey - JSON
-        [](const std::string &filename, const SerType::SERJSON &sertype)
-                    {
-              std::ifstream emkeys(filename, std::ios::in | std::ios::binary);
-              if (!emkeys.is_open()) {
-                std::cerr << "I cannot read serialization from " << filename << std::endl;
-              }
-              bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<SerType::SERJSON>(emkeys, sertype);
-              return res; },
-                        cc_DeserializeEvalMultKey_docs,
-                        py::arg("filename"), py::arg("sertype"))
-        .def_static("DeserializeEvalAutomorphismKey", // DeserializeEvalAutomorphismKey - Binary
-        [](const std::string &filename, const SerType::SERBINARY &sertype)
-                    {
-              std::ifstream erkeys(filename, std::ios::in | std::ios::binary);
-              if (!erkeys.is_open()) {
-                std::cerr << "I cannot read serialization from " << filename << std::endl;
-              }
-              bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<SerType::SERBINARY>(erkeys, sertype);
-              return res; },
-                        cc_DeserializeEvalAutomorphismKey_docs,
-                        py::arg("filename"), py::arg("sertype"))
-        .def_static("DeserializeEvalAutomorphismKey", // DeserializeEvalAutomorphismKey - JSON
-        [](const std::string &filename, const SerType::SERJSON &sertype)
-                    {
-              std::ifstream erkeys(filename, std::ios::in | std::ios::binary);
-              if (!erkeys.is_open()) {
-                std::cerr << "I cannot read serialization from " << filename << std::endl;
-              }
-              bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<SerType::SERJSON>(erkeys, sertype);
-              return res; },
-                        cc_DeserializeEvalAutomorphismKey_docs,
-                        py::arg("filename"), py::arg("sertype"));
+        .def_static("DeserializeEvalMultKey", [](const std::string &filename, const SerType::SERBINARY &sertype) {
+                std::ifstream emkeys(filename, std::ios::in | std::ios::binary);
+                if (!emkeys.is_open()) {
+                    std::cerr << "I cannot read serialization from " << filename << std::endl;
+                }
+                bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<SerType::SERBINARY>(emkeys, sertype);
+                return res; 
+            },
+            cc_DeserializeEvalMultKey_docs,
+            py::arg("filename"), py::arg("sertype"))
+        .def_static("DeserializeEvalMultKey", [](const std::string &filename, const SerType::SERJSON &sertype) {
+                std::ifstream emkeys(filename, std::ios::in | std::ios::binary);
+                if (!emkeys.is_open()) {
+                    std::cerr << "I cannot read serialization from " << filename << std::endl;
+                }
+                bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<SerType::SERJSON>(emkeys, sertype);
+                return res;
+            },
+            cc_DeserializeEvalMultKey_docs,
+            py::arg("filename"), py::arg("sertype"))
+        .def_static("DeserializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERBINARY &sertype) {
+                std::ifstream erkeys(filename, std::ios::in | std::ios::binary);
+                if (!erkeys.is_open()) {
+                    std::cerr << "I cannot read serialization from " << filename << std::endl;
+                }
+                bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<SerType::SERBINARY>(erkeys, sertype);
+                return res;
+            },
+            cc_DeserializeEvalAutomorphismKey_docs,
+            py::arg("filename"), py::arg("sertype"))
+        .def_static("DeserializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERJSON &sertype) {
+                std::ifstream erkeys(filename, std::ios::in | std::ios::binary);
+                if (!erkeys.is_open()) {
+                    std::cerr << "I cannot read serialization from " << filename << std::endl;
+                }
+                bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<SerType::SERJSON>(erkeys, sertype);
+                return res;
+            },
+            cc_DeserializeEvalAutomorphismKey_docs,
+            py::arg("filename"), py::arg("sertype"));
 
     // Generator Functions
     m.def("GenCryptoContext", &GenCryptoContext<CryptoContextBFVRNS>,
@@ -1159,6 +1150,7 @@ void bind_keys(py::module &m)
         .def("SetKeyTag", &PublicKeyImpl<DCRTPoly>::SetKeyTag);
     py::class_<PrivateKeyImpl<DCRTPoly>, std::shared_ptr<PrivateKeyImpl<DCRTPoly>>>(m, "PrivateKey")
         .def(py::init<>())
+        .def("GetCryptoContext", &PrivateKeyImpl<DCRTPoly>::GetCryptoContext)
         .def("GetKeyTag", &PrivateKeyImpl<DCRTPoly>::GetKeyTag)
         .def("SetKeyTag", &PrivateKeyImpl<DCRTPoly>::SetKeyTag);
     py::class_<KeyPair<DCRTPoly>>(m, "KeyPair")
@@ -1302,61 +1294,97 @@ void bind_encodings(py::module &m)
         .def("SetStringValue", &PlaintextImpl::SetStringValue)
         .def("SetIntVectorValue", &PlaintextImpl::SetIntVectorValue)
         .def("GetFormattedValues", &PlaintextImpl::GetFormattedValues)
-        .def("__repr__", [](const PlaintextImpl &p)
-             {
-        std::stringstream ss;
-        ss << "<Plaintext Object: " << p << ">";
-        return ss.str(); })
-        .def("__str__", [](const PlaintextImpl &p)
-             {
-        std::stringstream ss;
-        ss << p;
-        return ss.str(); });
+        .def("__repr__", [](const PlaintextImpl &p) {
+                std::stringstream ss;
+                ss << "<Plaintext Object: " << p << ">";
+                return ss.str();
+            })
+        .def("__str__", [](const PlaintextImpl &p) {
+                std::stringstream ss;
+                ss << p;
+                return ss.str();
+            });
 }
 
-void bind_ciphertext(py::module &m)
-{
-  py::class_<CiphertextImpl<DCRTPoly>,
-             std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
-      .def(py::init<>())
-      .def(
-          "__add__",
-          [](const Ciphertext<DCRTPoly> &a, const Ciphertext<DCRTPoly> &b) {
-            return a + b;
-          },
-          py::is_operator(), pybind11::keep_alive<0, 1>())
-      // .def(py::self + py::self);
-      // .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth)
-      // .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth)
-      .def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel, ctx_GetLevel_docs)
-      .def("SetLevel", &CiphertextImpl<DCRTPoly>::SetLevel, ctx_SetLevel_docs,
-           py::arg("level"))
-      .def("Clone", &CiphertextImpl<DCRTPoly>::Clone)
-      .def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs)
-      // .def("GetHopLevel", &CiphertextImpl<DCRTPoly>::GetHopLevel)
-      // .def("SetHopLevel", &CiphertextImpl<DCRTPoly>::SetHopLevel)
-      // .def("GetScalingFactor", &CiphertextImpl<DCRTPoly>::GetScalingFactor)
-      // .def("SetScalingFactor", &CiphertextImpl<DCRTPoly>::SetScalingFactor)
-      .def("GetSlots", &CiphertextImpl<DCRTPoly>::GetSlots)
-      .def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots)
-      .def("GetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::GetNoiseScaleDeg)
-      .def("SetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::SetNoiseScaleDeg)
-      .def("GetElements", [](const CiphertextImpl<DCRTPoly>& self) -> const std::vector<DCRTPoly> & {
-            return self.GetElements();
-          },
-          py::return_value_policy::reference_internal)
-      .def("GetElementsMutable", [](CiphertextImpl<DCRTPoly>& self) -> std::vector<DCRTPoly> & {
-            return self.GetElements();
-          },
-          py::return_value_policy::reference_internal)
-      .def("SetElements", [](CiphertextImpl<DCRTPoly>& self, const std::vector<DCRTPoly> &elems) {
-             self.SetElements(elems);
-           })
-      .def("SetElementsMove", [](CiphertextImpl<DCRTPoly>& self, std::vector<DCRTPoly> &&elems) {
-             self.SetElements(std::move(elems));
-           });
+void bind_ciphertext(py::module &m) {
+    py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext")
+        .def(py::init<>())
+        .def("__add__", [](const Ciphertext<DCRTPoly> &a, const Ciphertext<DCRTPoly> &b) {
+                return a + b;
+            },
+            py::is_operator(), pybind11::keep_alive<0, 1>())
+        // .def(py::self + py::self);
+        // .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth)
+        // .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth)
+        .def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel, ctx_GetLevel_docs)
+        .def("SetLevel", &CiphertextImpl<DCRTPoly>::SetLevel, ctx_SetLevel_docs,
+            py::arg("level"))
+        .def("Clone", &CiphertextImpl<DCRTPoly>::Clone)
+        .def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs)
+        // .def("GetHopLevel", &CiphertextImpl<DCRTPoly>::GetHopLevel)
+        // .def("SetHopLevel", &CiphertextImpl<DCRTPoly>::SetHopLevel)
+        // .def("GetScalingFactor", &CiphertextImpl<DCRTPoly>::GetScalingFactor)
+        // .def("SetScalingFactor", &CiphertextImpl<DCRTPoly>::SetScalingFactor)
+        .def("GetSlots", &CiphertextImpl<DCRTPoly>::GetSlots)
+        .def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots)
+        .def("GetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::GetNoiseScaleDeg)
+        .def("SetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::SetNoiseScaleDeg)
+        .def("GetCryptoContext", &CiphertextImpl<DCRTPoly>::GetCryptoContext)
+        .def("GetEncodingType", &CiphertextImpl<DCRTPoly>::GetEncodingType)
+        .def("GetElements", [](const CiphertextImpl<DCRTPoly>& self) -> const std::vector<DCRTPoly>& {
+                return self.GetElements();
+            },
+            py::return_value_policy::reference_internal)
+        .def("GetElementsMutable", [](CiphertextImpl<DCRTPoly>& self) -> std::vector<DCRTPoly>& {
+                return self.GetElements();
+            },
+            py::return_value_policy::reference_internal)
+        .def("SetElements", [](CiphertextImpl<DCRTPoly>& self, const std::vector<DCRTPoly>& elems) {
+                self.SetElements(elems);
+            })
+        .def("SetElementsMove", [](CiphertextImpl<DCRTPoly>& self, std::vector<DCRTPoly>&& elems) {
+                self.SetElements(std::move(elems));
+            });
 }
 
+// void bind_ciphertext(py::module &m) {
+//     using CiphertextImplDCRT = CiphertextImpl<DCRTPoly>;
+//     using CiphertextDCRT = Ciphertext<DCRTPoly>;  // shared_ptr<CiphertextImpl<DCRTPoly>>
+
+//     // Bind CiphertextImpl<DCRTPoly> and expose it to Python as "Ciphertext"
+//     py::class_<CiphertextImplDCRT, std::shared_ptr<CiphertextImplDCRT>>(m, "Ciphertext")
+//         .def(py::init<>())
+//         .def("__add__", [](const CiphertextDCRT &a, const CiphertextDCRT &b) {
+//                 return a + b;
+//             },
+//             py::is_operator(), pybind11::keep_alive<0, 1>())
+//         .def("GetLevel", &CiphertextImplDCRT::GetLevel, ctx_GetLevel_docs)
+//         .def("SetLevel", &CiphertextImplDCRT::SetLevel, ctx_SetLevel_docs, py::arg("level"))
+//         .def("Clone", &CiphertextImplDCRT::Clone)
+//         .def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs)
+//         .def("GetSlots", &CiphertextImplDCRT::GetSlots)
+//         .def("SetSlots", &CiphertextImplDCRT::SetSlots)
+//         .def("GetNoiseScaleDeg", &CiphertextImplDCRT::GetNoiseScaleDeg)
+//         .def("SetNoiseScaleDeg", &CiphertextImplDCRT::SetNoiseScaleDeg)
+//         .def("GetCryptoContext", &CiphertextImplDCRT::GetCryptoContext)
+//         .def("GetEncodingType", &CiphertextImplDCRT::GetEncodingType)
+//         .def("GetElements", [](const CiphertextImplDCRT& self) -> const std::vector<DCRTPoly>& {
+//                 return self.GetElements();
+//             }, py::return_value_policy::reference_internal)
+//         .def("GetElementsMutable", [](CiphertextImplDCRT& self) -> std::vector<DCRTPoly>& {
+//                 return self.GetElements();
+//             }, py::return_value_policy::reference_internal)
+//         .def("SetElements", [](CiphertextImplDCRT& self, const std::vector<DCRTPoly>& elems) {
+//                 self.SetElements(elems);
+//             })
+//         .def("SetElementsMove", [](CiphertextImplDCRT& self, std::vector<DCRTPoly>&& elems) {
+//                 self.SetElements(std::move(elems));
+//             });
+
+//     // Bind the shared_ptr alias (Ciphertext<DCRTPoly>) so it picks up the methods above
+//     py::class_<CiphertextDCRT>(m, "_CiphertextAlias");  // hidden helper; not necessary for users
+// }
+
 void bind_schemes(py::module &m){
     /*Bind schemes specific functionalities like bootstrapping functions and multiparty*/
     py::class_<FHECKKSRNS>(m, "FHECKKSRNS")
@@ -1409,14 +1437,13 @@ void bind_sch_swch_params(py::module &m)
         .def("SetRingDimension", &SchSwchParams::SetRingDimension)
         .def("SetScalingModSize", &SchSwchParams::SetScalingModSize)
         .def("SetBatchSize", &SchSwchParams::SetBatchSize)
-        .def("__str__",[](const SchSwchParams &params) {
-            std::stringstream stream;
-            stream << params;
-            return stream.str();
-        });
+        .def("__str__", [](const SchSwchParams &params) {
+                std::stringstream stream;
+                stream << params;
+                return stream.str();
+            });
 }
 
-
 PYBIND11_MODULE(openfhe, m)
 {
     m.doc() = "Open-Source Fully Homomorphic Encryption Library";

+ 7 - 4
src/lib/pke/cryptocontext_wrapper.cpp

@@ -34,10 +34,9 @@
 Ciphertext<DCRTPoly> EvalFastRotationPrecomputeWrapper(CryptoContext<DCRTPoly> &self,ConstCiphertext<DCRTPoly> ciphertext) {
     std::shared_ptr<std::vector<DCRTPoly>> precomp = self->EvalFastRotationPrecompute(ciphertext);
     std::vector<DCRTPoly> elements = *(precomp.get());
-    CiphertextImpl<DCRTPoly> cipherdigits = CiphertextImpl<DCRTPoly>(self);
-    std::shared_ptr<CiphertextImpl<DCRTPoly>> cipherdigitsPtr = std::make_shared<CiphertextImpl<DCRTPoly>>(cipherdigits);
-    cipherdigitsPtr->SetElements(elements);
-    return cipherdigitsPtr;
+    std::shared_ptr<CiphertextImpl<DCRTPoly>> cipherdigits = std::make_shared<CiphertextImpl<DCRTPoly>>(self);
+    cipherdigits->SetElements(std::move(elements));
+    return cipherdigits;
 }
 Ciphertext<DCRTPoly> EvalFastRotationWrapper(CryptoContext<DCRTPoly>& self,ConstCiphertext<DCRTPoly> ciphertext, uint32_t index, uint32_t m,ConstCiphertext<DCRTPoly> digits) {
     
@@ -78,6 +77,10 @@ PlaintextModulus GetPlaintextModulusWrapper(CryptoContext<DCRTPoly>& self){
     return self->GetCryptoParameters()->GetPlaintextModulus();
 }
 
+uint32_t GetBatchSizeWrapper(CryptoContext<DCRTPoly>& self){
+    return self->GetCryptoParameters()->GetBatchSize();
+}
+
 double GetModulusWrapper(CryptoContext<DCRTPoly>& self){
     return self->GetCryptoParameters()->GetElementParams()->GetModulus().ConvertToDouble();
 }