소스 검색

Adding PolynomialEvaluationExample

Hovsep Papoyan 1 년 전
부모
커밋
4062fb7a4a
3개의 변경된 파일217개의 추가작업 그리고 59개의 파일을 삭제
  1. 59 4
      src/direct.cc
  2. 30 54
      src/direct.hpp
  3. 128 1
      src/main.rs

+ 59 - 4
src/direct.cc

@@ -1,9 +1,64 @@
-#include "direct.hpp"
+#include "openfhe_rs_dev/src/main.rs.h"
 
 namespace openfhe_rs_dev
 {
-
-// TODO: implementations
-
+std::unique_ptr<VectorOfComplexNumbers> GenVectorOfComplexNumbers(const std::vector<ComplexPair>& vals)
+{
+    std::vector<std::complex<double>> result;
+    result.reserve(vals.size());
+    for (const ComplexPair& p : vals)
+    {
+        result.emplace_back(p.re, p.im);
+    }
+    return std::make_unique<VectorOfComplexNumbers>(std::move(result));
+}
+std::unique_ptr<Params> GetParamsByScheme(const SCHEME scheme)
+{
+    return std::make_unique<Params>(scheme);
+}
+std::unique_ptr<Params> GetParamsByVectorOfString(const std::vector<std::string>& vals)
+{
+    return std::make_unique<Params>(vals);
+}
+std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNS()
+{
+    return std::make_unique<ParamsBFVRNS>();
+}
+std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNSbyVectorOfString(const std::vector<std::string>& vals)
+{
+    return std::make_unique<ParamsBFVRNS>(vals);
+}
+std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNS()
+{
+    return std::make_unique<ParamsBGVRNS>();
+}
+std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNSbyVectorOfString(const std::vector<std::string>& vals)
+{
+    return std::make_unique<ParamsBGVRNS>(vals);
+}
+std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNS()
+{
+    return std::make_unique<ParamsCKKSRNS>();
+}
+std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNSbyVectorOfString(const std::vector<std::string>& vals)
+{
+    return std::make_unique<ParamsCKKSRNS>(vals);
+}
+std::unique_ptr<Plaintext> GenEmptyPlainText()
+{
+    return std::make_unique<Plaintext>();
+}
+std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBFVRNS(const ParamsBFVRNS& params)
+{
+    return std::make_unique<CryptoContextDCRTPoly>(params);
+}
+std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBGVRNS(const ParamsBGVRNS& params)
+{
+    return std::make_unique<CryptoContextDCRTPoly>(params);
+}
+std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsCKKSRNS(const ParamsCKKSRNS& params)
+{
+    return std::make_unique<CryptoContextDCRTPoly>(params);
+}
 } // openfhe_rs_dev
 

+ 30 - 54
src/direct.hpp

@@ -31,7 +31,6 @@
 
 namespace openfhe_rs_dev
 {
-    // Parameter related stuff
     using ParamsBFVRNS = lbcrypto::CCParams<lbcrypto::CryptoContextBFVRNS>;
     using ParamsBGVRNS = lbcrypto::CCParams<lbcrypto::CryptoContextBGVRNS>;
     using ParamsCKKSRNS = lbcrypto::CCParams<lbcrypto::CryptoContextCKKSRNS>;
@@ -49,41 +48,14 @@ namespace openfhe_rs_dev
     using MultiplicationTechnique = lbcrypto::MultiplicationTechnique;
     using COMPRESSION_LEVEL = lbcrypto::COMPRESSION_LEVEL;
     using PKESchemeFeature = lbcrypto::PKESchemeFeature;
-    std::unique_ptr<Params> GetParamsByScheme(const SCHEME scheme)
-    {
-        return std::make_unique<Params>(scheme);
-    }
-    std::unique_ptr<Params> GetParamsByVectorOfString(const std::vector<std::string>& vals)
-    {
-        return std::make_unique<Params>(vals);
-    }
-    std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNS()
-    {
-        return std::make_unique<ParamsBFVRNS>();
-    }
-    std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNSbyVectorOfString(const std::vector<std::string>& vals)
-    {
-        return std::make_unique<ParamsBFVRNS>(vals);
-    }
-    std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNS()
-    {
-        return std::make_unique<ParamsBGVRNS>();
-    }
-    std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNSbyVectorOfString(const std::vector<std::string>& vals)
-    {
-        return std::make_unique<ParamsBGVRNS>(vals);
-    }
-    std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNS()
-    {
-        return std::make_unique<ParamsCKKSRNS>();
-    }
-    std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNSbyVectorOfString(const std::vector<std::string>& vals)
-    {
-        return std::make_unique<ParamsCKKSRNS>(vals);
-    }
-
     using PublicKeyImpl = lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>;
     using PrivateKeyImpl = lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>;
+    using PlaintextImpl = lbcrypto::PlaintextImpl;
+    using CiphertextImpl = lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>;
+    using DecryptResult = lbcrypto::DecryptResult;
+    using DCRTPolyParams = lbcrypto::DCRTPoly::Params;
+    struct ComplexPair;
+    using VectorOfComplexNumbers = std::vector<std::complex<double>>;
 
     class KeyPairDCRTPoly final
     {
@@ -108,7 +80,6 @@ namespace openfhe_rs_dev
         // TODO: implement necessary member functions
     };
 
-    using PlaintextImpl = lbcrypto::PlaintextImpl;
     class Plaintext final
     {
     private:
@@ -153,12 +124,8 @@ namespace openfhe_rs_dev
         // TODO: implement necessary member functions
     };
 
-    std::unique_ptr<Plaintext> GenEmptyPlainText()
-    {
-        return std::make_unique<Plaintext>();
-    }
 
-    using CiphertextImpl = lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>;
+
     class CiphertextDCRTPoly final
     {
     private:
@@ -176,8 +143,6 @@ namespace openfhe_rs_dev
         // TODO: implement necessary member functions
     };
 
-    using DecryptResult = lbcrypto::DecryptResult;
-    using DCRTPolyParams = lbcrypto::DCRTPoly::Params;
     class CryptoContextDCRTPoly final
     {
     private:
@@ -261,19 +226,30 @@ namespace openfhe_rs_dev
         {
             return std::make_unique<Plaintext>(m_cryptoContextImplSharedPtr->MakeCKKSPackedPlaintext(value, scaleDeg, level, params, slots));
         }
+        std::unique_ptr<Plaintext> MakeCKKSPackedPlaintextByVectorOfComplexNumbers(const std::vector<std::complex<double>>& value, const size_t scaleDeg, const uint32_t level,
+                                                                                   const std::shared_ptr<DCRTPolyParams> params, const uint32_t slots) const
+                                                                                   // scaleDeg = 1, level = 0, params = nullptr, slots = 0
+        {
+            return std::make_unique<Plaintext>(m_cryptoContextImplSharedPtr->MakeCKKSPackedPlaintext(value, scaleDeg, level, params, slots));
+        }
+        std::unique_ptr<CiphertextDCRTPoly> EvalPoly(std::shared_ptr<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>> ciphertext, const std::vector<double>& coefficients) const
+        {
+            return std::make_unique<CiphertextDCRTPoly>(m_cryptoContextImplSharedPtr->EvalPoly(ciphertext, coefficients));
+        }
     };
 
-    std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBFVRNS(const ParamsBFVRNS& params)
-    {
-        return std::make_unique<CryptoContextDCRTPoly>(params);
-    }
-    std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBGVRNS(const ParamsBGVRNS& params)
-    {
-        return std::make_unique<CryptoContextDCRTPoly>(params);
-    }
-    std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsCKKSRNS(const ParamsCKKSRNS& params)
-    {
-        return std::make_unique<CryptoContextDCRTPoly>(params);
-    }
+    std::unique_ptr<VectorOfComplexNumbers> GenVectorOfComplexNumbers(const std::vector<ComplexPair>& vals);
+    std::unique_ptr<Params> GetParamsByScheme(const SCHEME scheme);
+    std::unique_ptr<Params> GetParamsByVectorOfString(const std::vector<std::string>& vals);
+    std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNS();
+    std::unique_ptr<ParamsBFVRNS> GetParamsBFVRNSbyVectorOfString(const std::vector<std::string>& vals);
+    std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNS();
+    std::unique_ptr<ParamsBGVRNS> GetParamsBGVRNSbyVectorOfString(const std::vector<std::string>& vals);
+    std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNS();
+    std::unique_ptr<ParamsCKKSRNS> GetParamsCKKSRNSbyVectorOfString(const std::vector<std::string>& vals);
+    std::unique_ptr<Plaintext> GenEmptyPlainText();
+    std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBFVRNS(const ParamsBFVRNS& params);
+    std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsBGVRNS(const ParamsBGVRNS& params);
+    std::unique_ptr<CryptoContextDCRTPoly> GenCryptoContextByParamsCKKSRNS(const ParamsCKKSRNS& params);
 } // openfhe_rs_dev
 

+ 128 - 1
src/main.rs

@@ -2,6 +2,7 @@
 #![allow(dead_code)]
 
 use cxx::{CxxVector, SharedPtr};
+use std::time::Instant;
 
 #[cxx::bridge(namespace = "openfhe_rs_dev")]
 mod ffi
@@ -438,6 +439,20 @@ mod ffi
         fn GetCipherText(self: &CiphertextDCRTPoly) -> SharedPtr<CiphertextImpl>;
     }
 
+    // ComplexPair
+    struct ComplexPair
+    {
+        re: f64,
+        im: f64,
+    }
+
+    // VectorOfComplexNumbers
+    unsafe extern "C++"
+    {
+        type VectorOfComplexNumbers;
+        fn GenVectorOfComplexNumbers(vals: &CxxVector<ComplexPair>) -> UniquePtr<VectorOfComplexNumbers>;
+    }
+
     // CryptoContextDCRTPoly
     unsafe extern "C++"
     {
@@ -460,10 +475,121 @@ mod ffi
                    plaintext: Pin<&mut Plaintext>) -> UniquePtr<DecryptResult>;
         fn GetRingDimension(self: &CryptoContextDCRTPoly) -> u32;
         fn MakeCKKSPackedPlaintext(self: &CryptoContextDCRTPoly, value: &CxxVector<f64>, scaleDeg: usize, level: u32,
-                                   params: SharedPtr<DCRTPolyParams>, slots: u32) -> UniquePtr<Plaintext>;
+                                   params: SharedPtr<DCRTPolyParams>, slots: u32) -> UniquePtr<Plaintext>; // scaleDeg = 1, level = 0, params = nullptr, slots = 0
+        fn MakeCKKSPackedPlaintextByVectorOfComplexNumbers(self: &CryptoContextDCRTPoly, value: &VectorOfComplexNumbers, scaleDeg: usize, level: u32,
+                                   params: SharedPtr<DCRTPolyParams>, slots: u32) -> UniquePtr<Plaintext>; // scaleDeg = 1, level = 0, params = nullptr, slots = 0
+        fn EvalPoly(self: &CryptoContextDCRTPoly, ciphertext: SharedPtr<CiphertextImpl>, coefficients: &CxxVector<f64>) -> UniquePtr<CiphertextDCRTPoly>;
     }
 }
 
+fn PolynomialEvaluationExample()
+{
+    println!("\n======EXAMPLE FOR EVALPOLY========\n");
+
+    let mut _cc_params_ckksrns = ffi::GetParamsCKKSRNS();
+    _cc_params_ckksrns.pin_mut().SetMultiplicativeDepth(6);
+    _cc_params_ckksrns.pin_mut().SetScalingModSize(50);
+
+    let mut _cc = ffi::GenCryptoContextByParamsCKKSRNS(&_cc_params_ckksrns);
+    _cc.Enable(ffi::PKESchemeFeature::PKE);
+    _cc.Enable(ffi::PKESchemeFeature::KEYSWITCH);
+    _cc.Enable(ffi::PKESchemeFeature::LEVELEDSHE);
+    _cc.Enable(ffi::PKESchemeFeature::ADVANCEDSHE);
+
+    let mut _vals = CxxVector::<ffi::ComplexPair>::new();
+    _vals.pin_mut().push(ffi::ComplexPair{re: 0.5, im: 0.0});
+    _vals.pin_mut().push(ffi::ComplexPair{re: 0.7, im: 0.0});
+    _vals.pin_mut().push(ffi::ComplexPair{re: 0.9, im: 0.0});
+    _vals.pin_mut().push(ffi::ComplexPair{re: 0.95, im: 0.0});
+    _vals.pin_mut().push(ffi::ComplexPair{re: 0.93, im: 0.0});
+    let _input = ffi::GenVectorOfComplexNumbers(&_vals);
+    let _encoded_length = _vals.len(); // no len() funtion implemented for _input
+
+    let mut _coefficients_1 = CxxVector::<f64>::new();
+    _coefficients_1.pin_mut().push(0.15);
+    _coefficients_1.pin_mut().push(0.75);
+    _coefficients_1.pin_mut().push(0.0);
+    _coefficients_1.pin_mut().push(1.25);
+    _coefficients_1.pin_mut().push(0.0);
+    _coefficients_1.pin_mut().push(0.0);
+    _coefficients_1.pin_mut().push(1.0);
+    _coefficients_1.pin_mut().push(0.0);
+    _coefficients_1.pin_mut().push(1.0);
+    _coefficients_1.pin_mut().push(2.0);
+    _coefficients_1.pin_mut().push(0.0);
+    _coefficients_1.pin_mut().push(1.0);
+    _coefficients_1.pin_mut().push(0.0);
+    _coefficients_1.pin_mut().push(0.0);
+    _coefficients_1.pin_mut().push(0.0);
+    _coefficients_1.pin_mut().push(0.0);
+    _coefficients_1.pin_mut().push(1.0);
+
+    let mut _coefficients_2 = CxxVector::<f64>::new();
+    _coefficients_2.pin_mut().push(1.0);
+    _coefficients_2.pin_mut().push(2.0);
+    _coefficients_2.pin_mut().push(3.0);
+    _coefficients_2.pin_mut().push(4.0);
+    _coefficients_2.pin_mut().push(5.0);
+    _coefficients_2.pin_mut().push(-1.0);
+    _coefficients_2.pin_mut().push(-2.0);
+    _coefficients_2.pin_mut().push(-3.0);
+    _coefficients_2.pin_mut().push(-4.0);
+    _coefficients_2.pin_mut().push(-5.0);
+    _coefficients_2.pin_mut().push(0.1);
+    _coefficients_2.pin_mut().push(0.2);
+    _coefficients_2.pin_mut().push(0.3);
+    _coefficients_2.pin_mut().push(0.4);
+    _coefficients_2.pin_mut().push(0.5);
+    _coefficients_2.pin_mut().push(-0.1);
+    _coefficients_2.pin_mut().push(-0.2);
+    _coefficients_2.pin_mut().push(-0.3);
+    _coefficients_2.pin_mut().push(-0.4);
+    _coefficients_2.pin_mut().push(-0.5);
+    _coefficients_2.pin_mut().push(0.1);
+    _coefficients_2.pin_mut().push(0.2);
+    _coefficients_2.pin_mut().push(0.3);
+    _coefficients_2.pin_mut().push(0.4);
+    _coefficients_2.pin_mut().push(0.5);
+    _coefficients_2.pin_mut().push(-0.1);
+    _coefficients_2.pin_mut().push(-0.2);
+    _coefficients_2.pin_mut().push(-0.3);
+    _coefficients_2.pin_mut().push(-0.4);
+    _coefficients_2.pin_mut().push(-0.5);
+
+    let mut _plain_text_1 = _cc.MakeCKKSPackedPlaintextByVectorOfComplexNumbers(&_input, 1, 0, SharedPtr::<ffi::DCRTPolyParams>::null(), 0);
+    let mut _key_pair = _cc.KeyGen();
+    print!("Generating evaluation key for homomorphic multiplication...");
+    _cc.EvalMultKeyGen(_key_pair.GetPrivateKey());
+    println!("Completed.\n");
+    let mut _cipher_text_1 = _cc.Encrypt(_key_pair.GetPublicKey(), _plain_text_1.GetPlainText());
+
+    let mut _start = Instant::now();
+    let mut _result = _cc.EvalPoly(_cipher_text_1.GetCipherText(), &_coefficients_1);
+    let _time_eval_poly_1 = _start.elapsed();
+
+    _start = Instant::now();
+    let mut _result_2 = _cc.EvalPoly(_cipher_text_1.GetCipherText(), &_coefficients_2);
+    let _time_eval_poly_2 = _start.elapsed();
+
+    let mut _plain_text_dec = ffi::GenEmptyPlainText();
+    _cc.Decrypt(_key_pair.GetPrivateKey(), _result.GetCipherText(), _plain_text_dec.pin_mut());
+    _plain_text_dec.SetLength(_encoded_length);
+    let mut _plain_text_dec_2 = ffi::GenEmptyPlainText();
+    _cc.Decrypt(_key_pair.GetPrivateKey(), _result_2.GetCipherText(), _plain_text_dec_2.pin_mut());
+    _plain_text_dec_2.SetLength(_encoded_length);
+
+    println!("\n Original Plaintext #1:");
+    println!("{}", _plain_text_1.GetString());
+    println!("\n Result of evaluating a polynomial with coefficients [{} ]", _coefficients_1.iter().fold(String::new(), |acc, &arg| acc + " " + &arg.to_string()));
+    println!("{}", _plain_text_dec.GetString());
+    println!("\n Expected result: (0.70519107, 1.38285078, 3.97211180, 5.60215665, 4.86357575)");
+    println!("\n Evaluation time: {:.0?}", _time_eval_poly_1);
+    println!("\n Result of evaluating a polynomial with coefficients [{} ]", _coefficients_2.iter().fold(String::new(), |acc, &arg| acc + " " + &arg.to_string()));
+    println!("{}\n", _plain_text_dec_2.GetString());
+    println!(" Expected result: (3.4515092326, 5.3752765397, 4.8993108833, 3.2495023573, 4.0485229982)");
+    print!("\n Evaluation time: {:.0?}\n", _time_eval_poly_2);
+}
+
 fn SimpleRealNumbersExample()
 {
     let _mult_depth: u32 = 1;
@@ -674,4 +800,5 @@ fn main()
 {
     SimpleIntegersExample();
     SimpleRealNumbersExample();
+    PolynomialEvaluationExample();
 }