Quellcode durchsuchen

Made static_f a function pointer

Dmitriy Suponitskiy vor 1 Jahr
Ursprung
Commit
d40a30c899

+ 1 - 1
src/include/binfhe/binfhecontext_wrapper.h

@@ -63,7 +63,7 @@ std::vector<uint64_t> GenerateLUTviaFunctionWrapper(BinFHEContext &self, py::fun
 NativeInteger StaticFunction(NativeInteger m, NativeInteger p);
 
 // Define static variables to hold the state
-extern py::function static_f;
+// extern py::function static_f;
 
 LWECiphertext EvalFuncWrapper(BinFHEContext &self, ConstLWECiphertext &ct, const std::vector<uint64_t> &LUT);
 #endif // BINFHE_CRYPTOCONTEXT_BINDINGS_H

+ 4 - 3
src/lib/binfhe/binfhecontext_wrapper.cpp

@@ -77,7 +77,7 @@ const uint64_t GetLWECiphertextModulusWrapper(LWECiphertext &self)
 }
 
 // Define static variables to hold the state
-py::function static_f;
+py::function* static_f = nullptr;
 
 // Define a static function that uses the static variables
 NativeInteger StaticFunction(NativeInteger m, NativeInteger p) {
@@ -85,7 +85,7 @@ NativeInteger StaticFunction(NativeInteger m, NativeInteger p) {
     uint64_t m_int = m.ConvertToInt<uint64_t>();
     uint64_t p_int = p.ConvertToInt<uint64_t>();
     // Call the Python function
-    py::object result_py = static_f(m_int, p_int);
+    py::object result_py = (*static_f)(m_int, p_int);
     // Convert the result to a NativeInteger
     return NativeInteger(py::cast<uint64_t>(result_py));
 }
@@ -93,8 +93,9 @@ NativeInteger StaticFunction(NativeInteger m, NativeInteger p) {
 std::vector<uint64_t> GenerateLUTviaFunctionWrapper(BinFHEContext &self, py::function f, uint64_t p)
 {
     NativeInteger p_native_int = NativeInteger(p);
-    static_f = f;
+    static_f = &f;
     std::vector<NativeInteger> result = self.GenerateLUTviaFunction(StaticFunction, p_native_int);
+    static_f = nullptr;
     std::vector<uint64_t> result_uint64_t;
     // int size_int = static_cast<int>(result.size());
     for (const auto& value : result)