Browse Source

EvalRotate CryptoContext method added. Plaintext class FFI inited.

nkaskov 1 year ago
parent
commit
df3796c7ab
2 changed files with 120 additions and 38 deletions
  1. 37 2
      cpp/include/bindings.hpp
  2. 83 36
      cpp/src/bindings.cpp

+ 37 - 2
cpp/include/bindings.hpp

@@ -209,6 +209,39 @@ public:
     FFIPrivateKeyImpl GetPrivateKey() const;
 };
 
+// Plaintext FFI
+
+class FFIPlaintext {
+protected:
+    void* plaintext_ptr;
+public:
+    explicit FFIPlaintext(void* new_plaintext_ptr){
+        plaintext_ptr = new_plaintext_ptr;
+    }
+
+    double GetScalingFactor() const;
+
+    void SetScalingFactor(double sf);
+
+    FFISCHEME GetSchemeID() const;
+
+    std::size_t GetLength() const;
+
+    void SetLength(std::size_t newSize);
+
+    bool IsEncoded() const;
+
+    double GetLogPrecision() const;
+
+    void Encode();
+
+    void Decode();
+
+    std::int64_t LowBound() const;
+
+    std::int64_t HighBound() const;
+};
+
 // Ciphertext FFI
 
 class FFICiphertext {
@@ -232,6 +265,8 @@ public:
     std::size_t GetSlots() const;
 
     void SetSlots(std::size_t slots);
+
+    friend class FFICryptoContextImpl;
 };
 
 // Params FFI
@@ -432,7 +467,7 @@ public:
 //     Plaintext MakeCKKSPackedPlaintext(const std::vector<double>& value, std::size_t scaleDeg = 1, uint32_t level = 0,
 //                                       const std::shared_ptr<ParmType> params = nullptr, usint slots = 0) const;
 
-//     Ciphertext<Element> EvalRotate(ConstCiphertext<Element> ciphertext, int32_t index) const;
+    FFICiphertext EvalRotate(const FFICiphertext ciphertext, std::int32_t index) const;
 
 // // const?
 //     Ciphertext<DCRTPoly> EvalFastRotationPrecompute(ConstCiphertext<DCRTPoly> ciphertext);
@@ -446,7 +481,7 @@ public:
 
 //     Ciphertext<Element> EvalAtIndex(ConstCiphertext<Element> ciphertext, int32_t index) const;
 
-//     Ciphertext<Element> Encrypt(const PublicKey<Element> publicKey, Plaintext plaintext) const;
+    // FFICiphertext Encrypt(const FFIPublicKeyImpl publicKey, FFIPlaintext plaintext) const;
 
 //     DecryptResult Decrypt(ConstCiphertext<Element> ciphertext, const PrivateKey<Element> privateKey,
 //                           Plaintext* plaintext);

+ 83 - 36
cpp/src/bindings.cpp

@@ -36,9 +36,17 @@ namespace {
            std::shared_ptr<KeyPair<DCRTPoly>> ptr;
     };
 
+    struct PlaintextHolder{
+           std::shared_ptr<PlaintextImpl> ptr;
+    };
+
     struct CiphertextHolder{
            std::shared_ptr<CiphertextImpl<DCRTPoly>> ptr;
     };
+
+    struct ConstCiphertextHolder{
+           std::shared_ptr<const CiphertextImpl<DCRTPoly>> ptr;
+    };
 }
 
 FFIParams::FFIParams(){
@@ -507,12 +515,15 @@ void FFICryptoContextImpl::EvalMultKeysGen(const FFIPrivateKeyImpl key){
     cc->EvalMultKeysGen(cc_key);
 }
 
-//void FFICryptoContextImpl::EvalRotateKeyGen(const FFIPrivateKey privateKey, const std::vector<int32_t>& indexList,
-//      const FFIPublicKey publicKey = nullptr){
-//    std::shared_ptr<CryptoContextImpl<DCRTPoly>> cc =
-//        reinterpret_cast<CryptoContextImplHolder*>(cc_ptr)->ptr;
-//    cc->EvalRotateKeyGen();
-//}
+FFICiphertext FFICryptoContextImpl::EvalRotate(const FFICiphertext ciphertext, std::int32_t index) const{
+    std::shared_ptr<const CryptoContextImpl<DCRTPoly>> cc =
+        reinterpret_cast<const CryptoContextImplHolder*>(cc_ptr)->ptr;
+    const std::shared_ptr<const CiphertextImpl<DCRTPoly>> cc_ciphertext = 
+        reinterpret_cast<const ConstCiphertextHolder*>(ciphertext.ciphertext_ptr)->ptr;
+    void* ciphertext_ptr = reinterpret_cast<void*>(
+        new CiphertextHolder{cc->EvalRotate(cc_ciphertext, index)});
+    return FFICiphertext(ciphertext_ptr);
+}
 
 // void bind_crypto_context(py::module &m)
 // {
@@ -548,10 +559,6 @@ void FFICryptoContextImpl::EvalMultKeysGen(const FFIPrivateKeyImpl key){
 //              py::arg("level") = static_cast<uint32_t>(0),
 //              py::arg("params") = py::none(),
 //              py::arg("slots") = 0)
-//         .def("EvalRotate", &CryptoContextImpl<DCRTPoly>::EvalRotate,
-//             cc_EvalRotate_docs,
-//             py::arg("ciphertext"),
-//             py::arg("index"))
 //         .def("EvalFastRotationPrecompute", &EvalFastRotationPrecomputeWrapper,
 //             cc_EvalFastRotationPreCompute_docs,
 //             py::arg("ciphertext"))
@@ -1366,35 +1373,75 @@ FFIPrivateKeyImpl FFIKeyPair::GetPrivateKey() const{
 //         .def(py::init<>());
 // }
 
+double FFIPlaintext::GetScalingFactor() const{
+    std::shared_ptr<const PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    return plaintext->GetScalingFactor();
+}
+
+void FFIPlaintext::SetScalingFactor(double sf){
+    std::shared_ptr<PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    plaintext->SetScalingFactor(sf);
+}
+
+FFISCHEME FFIPlaintext::GetSchemeID() const{
+    std::shared_ptr<const PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    return FFISCHEME(plaintext->GetSchemeID());
+}
+
+std::size_t FFIPlaintext::GetLength() const{
+    std::shared_ptr<const PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    return plaintext->GetLength();
+}
+
+void FFIPlaintext::SetLength(std::size_t newSize){
+    std::shared_ptr<PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    plaintext->SetLength(newSize);
+}
+
+bool FFIPlaintext::IsEncoded() const{
+    std::shared_ptr<const PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    return plaintext->IsEncoded();
+}
+
+double FFIPlaintext::GetLogPrecision() const{
+    std::shared_ptr<const PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    return plaintext->GetLogPrecision();
+}
+
+void FFIPlaintext::Encode(){
+    std::shared_ptr<PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    plaintext->Encode();
+}
+
+void FFIPlaintext::Decode(){
+    std::shared_ptr<PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    plaintext->Decode();
+}
+
+std::int64_t FFIPlaintext::LowBound() const{
+    std::shared_ptr<const PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    return plaintext->LowBound();
+}
+
+std::int64_t FFIPlaintext::HighBound() const{
+    std::shared_ptr<const PlaintextImpl> plaintext =
+        reinterpret_cast<PlaintextHolder*>(plaintext_ptr)->ptr;
+    return plaintext->HighBound();
+}
+
 // void bind_encodings(py::module &m)
 // {
 //     py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>>(m, "Plaintext")
-//         .def("GetScalingFactor", &PlaintextImpl::GetScalingFactor,
-//             ptx_GetScalingFactor_docs)
-//         .def("SetScalingFactor", &PlaintextImpl::SetScalingFactor,
-//             ptx_SetScalingFactor_docs,
-//             py::arg("sf"))
-//         .def("GetSchemeID", &PlaintextImpl::GetSchemeID,
-//             ptx_GetSchemeID_docs)
-//         .def("GetLength", &PlaintextImpl::GetLength,
-//             ptx_GetLength_docs)
-//         .def("GetSchemeID", &PlaintextImpl::GetSchemeID,
-//             ptx_GetSchemeID_docs)
-//         .def("SetLength", &PlaintextImpl::SetLength,
-//             ptx_SetLength_docs,
-//             py::arg("newSize"))
-//         .def("IsEncoded", &PlaintextImpl::IsEncoded,
-//             ptx_IsEncoded_docs)
-//         .def("GetLogPrecision", &PlaintextImpl::GetLogPrecision,
-//             ptx_GetLogPrecision_docs)
-//         .def("Encode", &PlaintextImpl::Encode,
-//             ptx_Encode_docs)
-//         .def("Decode", &PlaintextImpl::Decode,
-//             ptx_Decode_docs)
-//         .def("LowBound", &PlaintextImpl::LowBound,
-//             ptx_LowBound_docs)
-//         .def("HighBound", &PlaintextImpl::HighBound,
-//             ptx_HighBound_docs)
 //         .def("SetFormat", &PlaintextImpl::SetFormat,
 //             ptx_SetFormat_docs,
 //             py::arg("fmt"))