Parcourir la source

204 fix py byte wrappers (#205)

* Improved the code for the *BytesWrapper functions, fixed input parameters for some functions, added Compress to binding

* Cleanup and formatting

---------

Co-authored-by: Dmitriy Suponitskiy <dsuponitskiy@dualitytech.com>
dsuponitskiy il y a 10 mois
Parent
commit
608717f5f2

+ 4 - 3
src/include/binfhe/binfhecontext_wrapper.h

@@ -25,8 +25,8 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-#ifndef BINFHE_CRYPTOCONTEXT_BINDINGS_H
-#define BINFHE_CRYPTOCONTEXT_BINDINGS_H
+#ifndef __BINFHECONTEXT_WRAPPER_H__
+#define __BINFHECONTEXT_WRAPPER_H__
 
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
@@ -66,4 +66,5 @@ NativeInteger StaticFunction(NativeInteger m, NativeInteger p);
 // extern py::function static_f;
 
 LWECiphertext EvalFuncWrapper(BinFHEContext &self, ConstLWECiphertext &ct, const std::vector<uint64_t> &LUT);
-#endif // BINFHE_CRYPTOCONTEXT_BINDINGS_H
+
+#endif // __BINFHECONTEXT_WRAPPER_H__

+ 3 - 3
src/include/binfhe_bindings.h

@@ -25,8 +25,8 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-#ifndef BINFHE_BINDINGS_H
-#define BINFHE_BINDINGS_H
+#ifndef __BINFHE_BINDINGS_H__
+#define __BINFHE_BINDINGS_H__
 
 #include <pybind11/pybind11.h>
 
@@ -34,4 +34,4 @@ void bind_binfhe_enums(pybind11::module &m);
 void bind_binfhe_context(pybind11::module &m);
 void bind_binfhe_keys(pybind11::module &m);
 void bind_binfhe_ciphertext(pybind11::module &m);
-#endif // BINFHE_BINDINGS_H
+#endif // __BINFHE_BINDINGS_H__

+ 3 - 3
src/include/docstrings/binfhecontext_docs.h

@@ -25,8 +25,8 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-#ifndef BINFHECONTEXT_DOCSTRINGS_H
-#define BINFHECONTEXT_DOCSTRINGS_H
+#ifndef __BINFHECONTEXT_DOCS_H
+#define __BINFHECONTEXT_DOCS_H
 
 // GenerateBinFHEContext
 const char* binfhe_GenerateBinFHEContext_parset_docs = R"pbdoc(
@@ -199,4 +199,4 @@ const char* binfhe_SerializedObjectName_docs = R"pbdoc(
    :return: object name
    :rtype: std::string
 )pbdoc";
-#endif // BINFHECONTEXT_DOCSTRINGS_H
+#endif // __BINFHECONTEXT_DOCS_H

+ 3 - 3
src/include/docstrings/ciphertext_docs.h

@@ -25,8 +25,8 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-#ifndef CIPHERTEXT_DOCSTRINGS_H
-#define CIPHERTEXT_DOCSTRINGS_H
+#ifndef __CIPHERTEXT_DOCS_H__
+#define __CIPHERTEXT_DOCS_H__
 
 // GetLevel
 const char* ctx_GetLevel_docs = R"pbdoc(
@@ -58,4 +58,4 @@ const char* cc_RemoveElement_docs = R"pbdoc(
     :param index: The index of the element to remove.
     :type index: int
 )pbdoc";
-#endif // CIPHERTEXT_DOCSTRINGS_H
+#endif // __CIPHERTEXT_DOCS_H__

+ 3 - 3
src/include/docstrings/cryptocontext_docs.h

@@ -25,8 +25,8 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-#ifndef CRYPTOCONTEXT_DOCSTRINGS_H
-#define CRYPTOCONTEXT_DOCSTRINGS_H
+#ifndef __CRYPTOCONTEXT_DOCS_H__
+#define __CRYPTOCONTEXT_DOCS_H__
 
 #include "pybind11/pybind11.h"
 #include "pybind11/attr.h"
@@ -1385,4 +1385,4 @@ const char* cc_DeserializeEvalMultKey_docs = R"pbdoc(
 )pbdoc";
 
 
-#endif //CRYPTOCONTEXT_DOCSTRINGS_H
+#endif // __CRYPTOCONTEXT_DOCS_H__

+ 3 - 3
src/include/docstrings/cryptoparameters_docs.h

@@ -25,8 +25,8 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-#ifndef CRYPTOPARAMS_DOCSTRINGS_H
-#define CRYPTOPARAMS_DOCSTRINGS_H
+#ifndef __CRYPTOPARAMETERS_DOCS_H__
+#define __CRYPTOPARAMETERS_DOCS_H__
 
 const char* ccparams_doc = R"doc(
     Crypto parameters for the BFV, BGV and CKKS scheme.
@@ -72,4 +72,4 @@ const char* cc_GetScalingFactorReal_docs = R"pbdoc(
 )pbdoc";
 
 
-#endif // CRYPTOPARAMS_DOCSTRINGS_H
+#endif // __CRYPTOPARAMETERS_DOCS_H__

+ 3 - 3
src/include/docstrings/plaintext_docs.h

@@ -25,8 +25,8 @@
 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-#ifndef PLAINTEXT_DOCSTRINGS_H
-#define PLAINTEXT_DOCSTRINGS_H
+#ifndef __PLAINTEXT_DOCS_H__
+#define __PLAINTEXT_DOCS_H__
 
 // GetScalingFactor
 const char* ptx_GetScalingFactor_docs = R"doc(
@@ -133,4 +133,4 @@ const char* ptx_GetRealPackedValue_docs = R"pbdoc(
 )pbdoc";
 
 
-#endif // PLAINTEXT_DOCSTRINGS_H
+#endif // __PLAINTEXT_DOCS_H__

+ 7 - 4
src/lib/bindings.cpp

@@ -759,14 +759,17 @@ void bind_crypto_context(py::module &m)
                 std::vector<std::shared_ptr<lbcrypto::CiphertextImpl<DCRTPoly> > >&
             ) const>(
             &CryptoContextImpl<DCRTPoly>::EvalLinearWSumMutable),
-             py::arg("ciphertext"),
-             py::arg("coefficients"))
+            py::arg("ciphertext"),
+            py::arg("coefficients"))
         .def("EvalLinearWSum",
             static_cast<lbcrypto::Ciphertext<DCRTPoly> (lbcrypto::CryptoContextImpl<DCRTPoly>::*)(
             std::vector<std::shared_ptr<const lbcrypto::CiphertextImpl<DCRTPoly> > >&,const std::vector<double>&) const>(
             &CryptoContextImpl<DCRTPoly>::EvalLinearWSum),
-             py::arg("ciphertext"),
-             py::arg("coefficients"))
+            py::arg("ciphertext"),
+            py::arg("coefficients"))
+       .def("Compress", &CryptoContextImpl<DCRTPoly>::Compress,
+            py::arg("ciphertext"),
+            py::arg("towersLeft"))
         .def("EvalMultMany", &CryptoContextImpl<DCRTPoly>::EvalMultMany,
             py::arg("ciphertextVec"))
         .def("EvalAddManyInPlace", &CryptoContextImpl<DCRTPoly>::EvalAddManyInPlace,

+ 57 - 49
src/lib/pke/serialization.cpp

@@ -35,13 +35,9 @@
 
 #include "openfhe.h"
 // header files needed for serialization
-#include "metadata-ser.h"
 #include "ciphertext-ser.h"
 #include "cryptocontext-ser.h"
 #include "key/key-ser.h"
-#include "scheme/bfvrns/bfvrns-ser.h"
-#include "scheme/bgvrns/bgvrns-ser.h"
-#include "scheme/ckksrns/ckksrns-ser.h"
 
 using namespace lbcrypto;
 namespace py = pybind11;
@@ -51,8 +47,7 @@ PYBIND11_MAKE_OPAQUE(std::map<uint32_t, EvalKey<DCRTPoly>>);
 
 
 template <typename ST>
-bool SerializeEvalMultKeyWrapper(const std::string &filename, const ST &sertype, std::string id)
-{
+bool SerializeEvalMultKeyWrapper(const std::string& filename, const ST& sertype, std::string id) {
     std::ofstream outfile(filename, std::ios::out | std::ios::binary);
     bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey<ST>(outfile, sertype, id);
     outfile.close();
@@ -60,8 +55,7 @@ bool SerializeEvalMultKeyWrapper(const std::string &filename, const ST &sertype,
 }
 
 template <typename ST>
-bool SerializeEvalAutomorphismKeyWrapper(const std::string& filename, const ST& sertype, std::string id)
-{
+bool SerializeEvalAutomorphismKeyWrapper(const std::string& filename, const ST& sertype, std::string id) {
     std::ofstream outfile(filename, std::ios::out | std::ios::binary);
     bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey<ST>(outfile, sertype, id);
     outfile.close();
@@ -69,15 +63,14 @@ bool SerializeEvalAutomorphismKeyWrapper(const std::string& filename, const ST&
 }
 
 template <typename ST>
-bool DeserializeEvalMultKeyWrapper(const std::string &filename, const ST &sertype)
-{
+bool DeserializeEvalMultKeyWrapper(const std::string& filename, const ST& sertype) {
     std::ifstream emkeys(filename, std::ios::in | std::ios::binary);
-    if (!emkeys.is_open())
-    {
+    if (!emkeys.is_open()) {
         std::cerr << "I cannot read serialization from " << filename << std::endl;
     }
     bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<ST>(emkeys, sertype);
-    return res; }
+    return res;
+}
 
 template <typename T, typename ST>
 std::tuple<T, bool> DeserializeFromFileWrapper(const std::string& filename, const ST& sertype) {
@@ -101,10 +94,14 @@ std::string SerializeToStringWrapper(const T& obj, const ST& sertype) {
 
 template <typename T, typename ST>
 py::bytes SerializeToBytesWrapper(const T& obj, const ST& sertype) {
-    std::ostringstream oss(std::ios::binary);
+    // let strbuf be dynamically allocated as we may be dealing with large keys
+    auto strbuf = std::make_unique<std::stringbuf>(std::ios::out | std::ios::binary);
+    std::ostream oss(strbuf.get());
+
     Serial::Serialize<T>(obj, oss, sertype);
-    std::string str = oss.str();
-    return py::bytes(str);
+
+    const std::string& str = strbuf->str();
+    return py::bytes(str.data(), str.size());
 }
 
 template <typename T, typename ST>
@@ -125,18 +122,20 @@ CryptoContext<DCRTPoly> DeserializeCCFromStringWrapper(const std::string& str, c
 
 template <typename T, typename ST>
 T DeserializeFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
-    T obj;
-    std::string str(bytes);
+    std::string str{static_cast<std::string>(bytes)};
     std::istringstream iss(str, std::ios::binary);
+
+    T obj;
     Serial::Deserialize<T>(obj, iss, sertype);
     return obj;
 }
 
 template <typename ST>
 CryptoContext<DCRTPoly> DeserializeCCFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
-    CryptoContext<DCRTPoly> obj;
-    std::string str(bytes);
+    std::string str{static_cast<std::string>(bytes)};
     std::istringstream iss(str, std::ios::binary);
+
+    CryptoContext<DCRTPoly> obj;
     Serial::Deserialize<DCRTPoly>(obj, iss, sertype);
     return obj;
 }
@@ -153,15 +152,17 @@ std::string SerializeEvalMultKeyToStringWrapper(const ST& sertype, const std::st
 
 template <typename ST>
 py::bytes SerializeEvalMultKeyToBytesWrapper(const ST& sertype, const std::string& id) {
-    std::ostringstream oss(std::ios::binary);
-    bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey(oss, sertype, id);
-    if (!res) {
+    // let strbuf be dynamically allocated as we may be dealing with large keys
+    auto strbuf = std::make_unique<std::stringbuf>(std::ios::out | std::ios::binary);
+    std::ostream oss(strbuf.get());
+
+    if (!CryptoContextImpl<DCRTPoly>::SerializeEvalMultKey(oss, sertype, id)) {
         throw std::runtime_error("Failed to serialize EvalMultKey");
     }
-    std::string str = oss.str();
-    return py::bytes(str);
-}
 
+    const std::string& str = strbuf->str();
+    return py::bytes(str.data(), str.size());
+}
 
 template <typename ST>
 std::string SerializeEvalAutomorphismKeyToStringWrapper(const ST& sertype, const std::string& id) {
@@ -173,15 +174,18 @@ std::string SerializeEvalAutomorphismKeyToStringWrapper(const ST& sertype, const
     return oss.str();
 }
 
-
 template <typename ST>
 py::bytes SerializeEvalAutomorphismKeyToBytesWrapper(const ST& sertype, const std::string& id) {
-    std::ostringstream oss(std::ios::binary);
-    bool res = CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey(oss, sertype, id);
-    if (!res) {
+    // let strbuf be dynamically allocated as we may be dealing with large keys
+    auto strbuf = std::make_unique<std::stringbuf>(std::ios::out | std::ios::binary);
+    std::ostream oss(strbuf.get());
+
+    if (!CryptoContextImpl<DCRTPoly>::SerializeEvalAutomorphismKey(oss, sertype, id)) {
         throw std::runtime_error("Failed to serialize EvalAutomorphismKey");
     }
-    return oss.str();
+
+    const std::string& str = strbuf->str();
+    return py::bytes(str.data(), str.size());
 }
 
 template <typename ST>
@@ -194,11 +198,11 @@ void DeserializeEvalMultKeyFromStringWrapper(const std::string& data, const ST&
 }
 
 template <typename ST>
-void DeserializeEvalMultKeyFromBytesWrapper(const std::string& data, const ST& sertype) {
-    std::string str(data);
+void DeserializeEvalMultKeyFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
+    std::string str{static_cast<std::string>(bytes)};
     std::istringstream iss(str, std::ios::binary);
-    bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<ST>(iss, sertype);
-    if (!res) {
+
+    if (!CryptoContextImpl<DCRTPoly>::DeserializeEvalMultKey<ST>(iss, sertype)) {
         throw std::runtime_error("Failed to deserialize EvalMultKey");
     }
 }
@@ -214,11 +218,11 @@ void DeserializeEvalAutomorphismKeyFromStringWrapper(const std::string& data, co
 }
 
 template <typename ST>
-void DeserializeEvalAutomorphismKeyFromBytesWrapper(const std::string& data, const ST& sertype) {
-    std::string str(data);
+void DeserializeEvalAutomorphismKeyFromBytesWrapper(const py::bytes& bytes, const ST& sertype) {
+    std::string str{static_cast<std::string>(bytes)};
     std::istringstream iss(str, std::ios::binary);
-    bool res = CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<ST>(iss, sertype);
-    if (!res) {
+
+    if (!CryptoContextImpl<DCRTPoly>::DeserializeEvalAutomorphismKey<ST>(iss, sertype)) {
         throw std::runtime_error("Failed to deserialize EvalAutomorphismKey");
     }
 }
@@ -272,17 +276,19 @@ void bind_serialization(pybind11::module &m) {
     m.def("DeserializeEvalKeyString", &DeserializeFromStringWrapper<EvalKey<DCRTPoly>, SerType::SERJSON>,
           py::arg("str"), py::arg("sertype"));
     m.def("Serialize", &SerializeToBytesWrapper<std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>>, SerType::SERJSON>,
-            py::arg("obj"), py::arg("sertype"));
+          py::arg("obj"), py::arg("sertype"));
     m.def("DeserializeEvalKeyMapString", &DeserializeFromBytesWrapper<std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>>, SerType::SERJSON>,
-            py::arg("str"), py::arg("sertype"));
+          py::arg("str"), py::arg("sertype"));
 
     m.def("SerializeEvalMultKeyString", &SerializeEvalMultKeyToStringWrapper<SerType::SERJSON>,
           py::arg("sertype"), py::arg("id") = "");
-    m.def("DeserializeEvalMultKeyString", &DeserializeEvalMultKeyFromStringWrapper<SerType::SERJSON>,
+    m.def("DeserializeEvalMultKeyString",
+          static_cast<void (*)(const std::string&, const SerType::SERJSON&)>(&DeserializeEvalMultKeyFromStringWrapper<SerType::SERJSON>),
           py::arg("data"), py::arg("sertype"));
     m.def("SerializeEvalAutomorphismKeyString", &SerializeEvalAutomorphismKeyToStringWrapper<SerType::SERJSON>,
           py::arg("sertype"), py::arg("id") = "");
-    m.def("DeserializeEvalAutomorphismKeyString", &DeserializeEvalAutomorphismKeyFromStringWrapper<SerType::SERJSON>,
+    m.def("DeserializeEvalAutomorphismKeyString",
+          static_cast<void (*)(const std::string&, const SerType::SERJSON&)>(&DeserializeEvalAutomorphismKeyFromStringWrapper<SerType::SERJSON>),
           py::arg("data"), py::arg("sertype"));
 
     // Binary Serialization
@@ -333,16 +339,18 @@ void bind_serialization(pybind11::module &m) {
     m.def("DeserializeEvalKeyString", &DeserializeFromBytesWrapper<EvalKey<DCRTPoly>, SerType::SERBINARY>,
           py::arg("str"), py::arg("sertype"));
     m.def("Serialize", &SerializeToBytesWrapper<std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>>, SerType::SERBINARY>,
-            py::arg("obj"), py::arg("sertype"));
+          py::arg("obj"), py::arg("sertype"));
     m.def("DeserializeEvalKeyMapString", &DeserializeFromBytesWrapper<std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>>, SerType::SERBINARY>,
-            py::arg("str"), py::arg("sertype"));
+          py::arg("str"), py::arg("sertype"));
 
     m.def("SerializeEvalMultKeyString", &SerializeEvalMultKeyToBytesWrapper<SerType::SERBINARY>,
           py::arg("sertype"), py::arg("id") = "");
-    m.def("DeserializeEvalMultKeyString", &DeserializeEvalMultKeyFromBytesWrapper<SerType::SERBINARY>,
-          py::arg("data"), py::arg("sertype"));
+    m.def("DeserializeEvalMultKeyString",
+          static_cast<void (*)(const py::bytes&, const SerType::SERBINARY&)>(&DeserializeEvalMultKeyFromBytesWrapper<SerType::SERBINARY>),
+          py::arg("bytes"), py::arg("sertype"));
     m.def("SerializeEvalAutomorphismKeyString", &SerializeEvalAutomorphismKeyToBytesWrapper<SerType::SERBINARY>,
           py::arg("sertype"), py::arg("id") = "");
-    m.def("DeserializeEvalAutomorphismKeyString", &DeserializeEvalAutomorphismKeyFromBytesWrapper<SerType::SERBINARY>,
-          py::arg("data"), py::arg("sertype"));
+    m.def("DeserializeEvalAutomorphismKeyString",
+          static_cast<void (*)(const py::bytes&, const SerType::SERBINARY&)>(&DeserializeEvalAutomorphismKeyFromBytesWrapper<SerType::SERBINARY>),
+          py::arg("bytes"), py::arg("sertype"));
 }