Browse Source

Methods KeyGen, EvalMultKeyGen and EvalMultKeysGen added. GenCryptoContext function added.

nkaskov 1 year ago
parent
commit
aa2e1a2b10
2 changed files with 108 additions and 34 deletions
  1. 34 17
      cpp/include/bindings.hpp
  2. 74 17
      cpp/src/bindings.cpp

+ 34 - 17
cpp/include/bindings.hpp

@@ -30,6 +30,13 @@
 
 #include <cstdint>
 
+// forward declarations
+class FFIPublicKeyImpl;
+class FFIPrivateKeyImpl;
+class FFIKeyPair;
+class FFIParams;
+class FFICryptoContextImpl;
+
 typedef std::uint32_t usint;
 typedef std::uint64_t FFIPlaintextModulus;
 
@@ -160,6 +167,7 @@ public:
     const char* GetKeyTag() const;
 
     friend class FFIKeyPair;
+    friend class FFICryptoContextImpl;
 };
 
 // PrivateKeyImpl FFI
@@ -177,6 +185,7 @@ public:
     const char* GetKeyTag() const;
 
     friend class FFIKeyPair;
+    friend class FFICryptoContextImpl;
 };
 
 // KeyPair FFI
@@ -186,6 +195,10 @@ protected:
     void* keypair_ptr;
 public:
     FFIKeyPair();
+    
+    explicit FFIKeyPair(void* new_keypair_ptr){
+        keypair_ptr = new_keypair_ptr;
+    }
 
     explicit FFIKeyPair(const FFIPublicKeyImpl& publicKey, const FFIPrivateKeyImpl& privateKey);
     
@@ -206,9 +219,9 @@ public:
 
     FFIParams(FFISCHEME scheme);
 
-    virtual FFIPlaintextModulus GetPlaintextModulus() const;
+    FFIPlaintextModulus GetPlaintextModulus() const;
 
-    virtual FFISCHEME GetScheme() const;
+    FFISCHEME GetScheme() const;
 
     usint GetDigitSize() const;
 
@@ -325,19 +338,21 @@ public:
     void SetInteractiveBootCompressionLevel(FFICOMPRESSION_LEVEL interactiveBootCompressionLevel);
 
 //     std::stream str();  
-};
 
-class CryptoContextBFVRNSCCParams : public FFIParams {
-public:
-    // CryptoContextBFVRNSCCParams();
-    CryptoContextBFVRNSCCParams():FFIParams(BFVRNS_SCHEME){};
+    friend FFICryptoContextImpl GenCryptoContext(FFIParams params);
 };
 
-class CryptoContextBGVRNSCCParams : public FFIParams {
-public:
-    // CryptoContextBGVRNSCCParams();
-    CryptoContextBGVRNSCCParams():FFIParams(BGVRNS_SCHEME){};
-};
+// class CryptoContextBFVRNSCCParams : public FFIParams {
+// public:
+//     // CryptoContextBFVRNSCCParams();
+//     CryptoContextBFVRNSCCParams():FFIParams(BFVRNS_SCHEME){};
+// };
+
+// class CryptoContextBGVRNSCCParams : public FFIParams {
+// public:
+//     // CryptoContextBGVRNSCCParams();
+//     CryptoContextBGVRNSCCParams():FFIParams(BGVRNS_SCHEME){};
+// };
 
 // CryptoContext FFI
 
@@ -369,14 +384,14 @@ public:
 
     void Enable(FFIPKESchemeFeature feature);
 
-//     KeyPair<Element> KeyGen();
+    FFIKeyPair KeyGen();
 
-//     void EvalMultKeyGen(const FFIPrivateKey key);
+    void EvalMultKeyGen(const FFIPrivateKeyImpl key);
 
-//     void EvalMultKeysGen(const FFIPrivateKey key);
+    void EvalMultKeysGen(const FFIPrivateKeyImpl key);
 
-//     void EvalRotateKeyGen(const FFIPrivateKey privateKey, const std::vector<int32_t>& indexList,
-//                           const FFIPublicKey publicKey = nullptr);
+//     void EvalRotateKeyGen(const FFIPrivateKeyImpl privateKey, const std::vector<int32_t>& indexList,
+//      const FFIPublicKey publicKey = nullptr);
 
 //     FFIPlaintext MakeStringPlaintext(const std::string& str) const;
 
@@ -461,7 +476,9 @@ public:
 //     Ciphertext<Element> EvalSubMutable(Plaintext plaintext, Ciphertext<Element>& ciphertext) const;
 
 //     void EvalSubMutableInPlace(Ciphertext<Element>& ciphertext1, Ciphertext<Element>& ciphertext2) const;
+    friend FFICryptoContextImpl GenCryptoContext(FFIParams params);
 };
 
+FFICryptoContextImpl GenCryptoContext(FFIParams params);
 
 #endif // OPENFHE_BINDINGS_H

+ 74 - 17
cpp/src/bindings.cpp

@@ -476,16 +476,44 @@ void FFICryptoContextImpl::Enable(FFIPKESchemeFeature feature) {
     cc->Enable(PKESchemeFeature(feature));
 }
 
+FFIKeyPair FFICryptoContextImpl::KeyGen(){
+    // TODO make it const method
+    // std::shared_ptr<const CryptoContextImpl<DCRTPoly>> cc =
+    std::shared_ptr<CryptoContextImpl<DCRTPoly>> cc =
+        reinterpret_cast<CryptoContextImplHolder*>(cc_ptr)->ptr;
+    cc->KeyGen();
+    void* keypair_ptr = reinterpret_cast<void*>(
+        new KeyPairHolder{std::make_shared<KeyPair<DCRTPoly>>(cc->KeyGen())});
+    
+    return FFIKeyPair(keypair_ptr);
+}
+
+void FFICryptoContextImpl::EvalMultKeyGen(const FFIPrivateKeyImpl key){
+    std::shared_ptr<CryptoContextImpl<DCRTPoly>> cc =
+        reinterpret_cast<CryptoContextImplHolder*>(cc_ptr)->ptr;
+    const std::shared_ptr<PrivateKeyImpl<DCRTPoly>> cc_key = 
+        reinterpret_cast<PrivkeyHolder*>(key.privkey_ptr)->ptr;
+    cc->EvalMultKeyGen(cc_key);
+}
+
+void FFICryptoContextImpl::EvalMultKeysGen(const FFIPrivateKeyImpl key){
+    std::shared_ptr<CryptoContextImpl<DCRTPoly>> cc =
+        reinterpret_cast<CryptoContextImplHolder*>(cc_ptr)->ptr;
+    const std::shared_ptr<PrivateKeyImpl<DCRTPoly>> cc_key = 
+        reinterpret_cast<PrivkeyHolder*>(key.privkey_ptr)->ptr;
+    cc->EvalMultKeysGen(cc_key);
+}
+
+//void FFICryptoContextImpl::EvalRotateKeyGen(const FFIPrivateKey privateKey, const std::vector<int32_t>& indexList,
+//      const FFIPublicKey publicKey = nullptr){
+//    std::shared_ptr<CryptoContextImpl<DCRTPoly>> cc =
+//        reinterpret_cast<CryptoContextImplHolder*>(cc_ptr)->ptr;
+//    cc->EvalRotateKeyGen();
+//}
+
 // void bind_crypto_context(py::module &m)
 // {
 //     py::class_<CryptoContextImpl<DCRTPoly>, std::shared_ptr<CryptoContextImpl<DCRTPoly>>>(m, "CryptoContext")
-//         .def("KeyGen", &CryptoContextImpl<DCRTPoly>::KeyGen, cc_KeyGen_docs)
-//         .def("EvalMultKeyGen", &CryptoContextImpl<DCRTPoly>::EvalMultKeyGen,
-//              cc_EvalMultKeyGen_docs,
-//              py::arg("privateKey"))
-//         .def("EvalMultKeysGen", &CryptoContextImpl<DCRTPoly>::EvalMultKeysGen,
-//              cc_EvalMultKeysGen_docs,
-//              py::arg("privateKey"))
 //         .def("EvalRotateKeyGen", &CryptoContextImpl<DCRTPoly>::EvalRotateKeyGen,
 //              cc_EvalRotateKeyGen_docs,
 //              py::arg("privateKey"),
@@ -1170,16 +1198,45 @@ void FFICryptoContextImpl::Enable(FFIPKESchemeFeature feature) {
 //                         cc_DeserializeEvalAutomorphismKey_docs,
 //                         py::arg("filename"), py::arg("sertype"));
 
-//     // Generator Functions
-//     m.def("GenCryptoContext", &GenCryptoContext<CryptoContextBFVRNS>,
-//         py::arg("params"));
-//     m.def("GenCryptoContext", &GenCryptoContext<CryptoContextBGVRNS>,
-//         py::arg("params"));
-//     m.def("GenCryptoContext", &GenCryptoContext<CryptoContextCKKSRNS>,
-//         py::arg("params"));
-//     m.def("ReleaseAllContexts", &CryptoContextFactory<DCRTPoly>::ReleaseAllContexts);
-//     m.def("GetAllContexts", &CryptoContextFactory<DCRTPoly>::GetAllContexts);
-// }
+FFICryptoContextImpl GenCryptoContext(FFIParams params){
+    FFISCHEME scheme = params.GetScheme();
+    std::shared_ptr<Params> cc_params =
+        reinterpret_cast<ParamsHolder*>(params.params_ptr)->ptr;
+
+    FFICryptoContextImpl cc;
+
+    switch(scheme){
+        case FFISCHEME::BFVRNS_SCHEME:{
+            CCParams<CryptoContextBFVRNS> bfv_cc_params_instance = CCParams<CryptoContextBFVRNS>(*cc_params);
+
+            cc.cc_ptr = reinterpret_cast<void*>(
+                new CryptoContextImplHolder{GenCryptoContext<CryptoContextBFVRNS>(
+                    bfv_cc_params_instance)});
+            break;
+        }
+        case FFISCHEME::BGVRNS_SCHEME:{
+            CCParams<CryptoContextBGVRNS> bgv_cc_params_instance = CCParams<CryptoContextBGVRNS>(*cc_params);
+
+            cc.cc_ptr = reinterpret_cast<void*>(
+                new CryptoContextImplHolder{GenCryptoContext<CryptoContextBGVRNS>(
+                    bgv_cc_params_instance)});
+            break;
+        }
+        case FFISCHEME::CKKSRNS_SCHEME:{
+            CCParams<CryptoContextCKKSRNS> ckks_cc_params_instance = CCParams<CryptoContextCKKSRNS>(*cc_params);
+
+            cc.cc_ptr = reinterpret_cast<void*>(
+                new CryptoContextImplHolder{GenCryptoContext<CryptoContextCKKSRNS>(
+                    ckks_cc_params_instance)});
+            break;
+        }
+        case FFISCHEME::INVALID_SCHEME:
+        default:
+            throw std::invalid_argument("Invalid scheme");
+    }
+
+    return cc;
+}
 
 // int get_native_int(){
 //     #if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)