Browse Source

Changes for the new openfhe-development release (#163)

* Changes for the new openfhe-development release

* Fixed test errors and disabled some of the tests for NATIVEINT == 32

---------

Co-authored-by: Dmitriy Suponitskiy <dsuponitskiy@dualitytech.com>
dsuponitskiy 1 year ago
parent
commit
6e45fa6eb2

+ 5 - 3
README.md

@@ -38,7 +38,7 @@ To install OpenFHE-python directly to your system, ensure the dependencies are s
 pip install "pybind11[global]" 
 mkdir build
 cd build
-cmake ..  # Alternatively, cmake .. -DOpenFHE_DIR=/path/to/installed/openfhe if you installed OpenFHE elsewhere
+cmake ..  # Alternatively, cmake .. -DCMAKE_PREFIX_PATH=/path/to/installed/openfhe if you installed OpenFHE elsewhere
 make
 make install  # You may have to run sudo make install
 ```
@@ -50,11 +50,13 @@ If you see an error saying that one of OpenFHE .so files cannot be found when ru
 add the path where the .so files reside to the `PYTHONPATH` environment variable:
 
 ```
-export PYTHONPATH=(path_to_OpenFHE_so_files):$PYTHONPATH
+export PYTHONPATH=(/path/to/installed/openfhe):$PYTHONPATH
 ```
 
 In some environments (this happens rarely), it may also be necessary to add the OpenFHE libraries path to `LD_LIBRARY_PATH`.
 
+If OpenFHE is not installed in the default location, then both `PYTHONPATH and LD_LIBRARY_PATH` must be set before running any Python example.
+
 #### Conda
 
 Alternatively you can install the library and handle the linking via Conda. Clone the repository, open a terminal in the repo folder and run the following commands:
@@ -73,7 +75,7 @@ Now, you would clone the repository, and run the following commands to install :
 ```bash
 mkdir build
 cd build
-cmake .. # Add in -DOpenFHE_DIR=/path/to/installed/openfhe if you installed OpenFHE elsewhere
+cmake .. # Add in -DCMAKE_PREFIX_PATH=/path/to/installed/openfhe if you installed OpenFHE elsewhere
 make
 make install  # You may have to run sudo make install
 ```

+ 3 - 3
examples/pke/advanced-real-numbers-128.py

@@ -52,15 +52,15 @@ def automatic_rescale_demo(scal_tech):
     c_res3 = cc.EvalMult(cc.EvalAdd(c18,c9), 0.5)  # Final result 3
 
     result1 = cc.Decrypt(c_res1,keys.secretKey)
-    result.SetLength(batch_size)
+    result1.SetLength(batch_size)
     print("x^18 + x^9 + 1 = ", result1)
     
     result2 = cc.Decrypt(c_res2,keys.secretKey)
-    result.SetLength(batch_size)
+    result2.SetLength(batch_size)
     print("x^18 + x^9 - 1 = ", result2)
 
     result3 = cc.Decrypt(c_res3,keys.secretKey)
-    result.SetLength(batch_size)
+    result3.SetLength(batch_size)
     print("0.5 * (x^18 + x^9) = ", result3)
 
 

+ 8 - 4
examples/pke/advanced-real-numbers.py

@@ -124,7 +124,8 @@ def hybrid_key_switching_demo1():
     parameters.SetMultiplicativeDepth(5)
     parameters.SetScalingModSize(50)
     parameters.SetBatchSize(batch_size)
-    parameters.SetScalingTechnique(ScalingTechnique.FLEXIBLEAUTO)
+    if get_native_int()!=128:
+        parameters.SetScalingTechnique(ScalingTechnique.FLEXIBLEAUTO)
     parameters.SetNumLargeDigits(dnum)
 
     cc = GenCryptoContext(parameters)
@@ -167,7 +168,8 @@ def hybrid_key_switching_demo2():
     parameters.SetMultiplicativeDepth(5)
     parameters.SetScalingModSize(50)
     parameters.SetBatchSize(batch_size)
-    parameters.SetScalingTechnique(ScalingTechnique.FLEXIBLEAUTO)
+    if get_native_int()!=128:
+        parameters.SetScalingTechnique(ScalingTechnique.FLEXIBLEAUTO)
     parameters.SetNumLargeDigits(dnum)
 
     cc = GenCryptoContext(parameters)
@@ -287,7 +289,8 @@ def fast_rotation_demo2():
     parameters.SetMultiplicativeDepth(1)
     parameters.SetScalingModSize(50)
     parameters.SetBatchSize(batch_size)
-    parameters.SetScalingTechnique(ScalingTechnique.FLEXIBLEAUTO)
+    if get_native_int()!=128:
+        parameters.SetScalingTechnique(ScalingTechnique.FLEXIBLEAUTO)
     parameters.SetKeySwitchTechnique(KeySwitchTechnique.BV)
     parameters.SetFirstModSize(60)
     parameters.SetDigitSize(digit_size)
@@ -361,7 +364,8 @@ def fast_rotation_demo2():
 
 
 def main():
-    automatic_rescale_demo(ScalingTechnique.FLEXIBLEAUTO)
+    if get_native_int()!=128:
+        automatic_rescale_demo(ScalingTechnique.FLEXIBLEAUTO)
     automatic_rescale_demo(ScalingTechnique.FIXEDAUTO)
     manual_rescale_demo(ScalingTechnique.FIXEDMANUAL)
     hybrid_key_switching_demo1()

+ 10 - 8
examples/pke/scheme-switching.py

@@ -1002,16 +1002,17 @@ def ArgminViaSchemeSwitchingUnit():
     slots = 32  # sparsely-packed
     batchSize = slots
     numValues = 32
-    scTech = FLEXIBLEAUTOEXT
     multDepth = 9 + 3 + 1 + int(log2(numValues))  # 1 for CKKS to FHEW, 13 for FHEW to CKKS, log2(numValues) for argmin
-    if scTech == FLEXIBLEAUTOEXT:
-        multDepth += 1
 
     parameters = CCParamsCKKSRNS()
+    if get_native_int()!=128:
+        scTech = FLEXIBLEAUTOEXT
+        multDepth += 1
+        parameters.SetScalingTechnique(scTech)
+
     parameters.SetMultiplicativeDepth(multDepth)
     parameters.SetScalingModSize(scaleModSize)
     parameters.SetFirstModSize(firstModSize)
-    parameters.SetScalingTechnique(scTech)
     parameters.SetSecurityLevel(sl)
     parameters.SetRingDim(ringDim)
     parameters.SetBatchSize(batchSize)
@@ -1119,16 +1120,17 @@ def ArgminViaSchemeSwitchingAltUnit():
     slots = 32  # sparsely-packed
     batchSize = slots
     numValues = 32
-    scTech = FLEXIBLEAUTOEXT
     multDepth = 9 + 3 + 1 + int(log2(numValues))  # 1 for CKKS to FHEW, 13 for FHEW to CKKS, log2(numValues) for argmin
-    if scTech == FLEXIBLEAUTOEXT:
-        multDepth += 1
 
     parameters = CCParamsCKKSRNS()
+    if get_native_int()!=128:
+        scTech = FLEXIBLEAUTOEXT
+        multDepth += 1
+        parameters.SetScalingTechnique(scTech)
+
     parameters.SetMultiplicativeDepth(multDepth)
     parameters.SetScalingModSize(scaleModSize)
     parameters.SetFirstModSize(firstModSize)
-    parameters.SetScalingTechnique(scTech)
     parameters.SetSecurityLevel(sl)
     parameters.SetRingDim(ringDim)
     parameters.SetBatchSize(batchSize)

+ 1 - 1
examples/pke/simple-integers-serial-bgvrns.py

@@ -121,7 +121,7 @@ def main_action():
     # of the keys. When deserializing a context, OpenFHE checks for the tag and
     # if it finds it in the CryptoContext map, it will return the stored version.
     # Hence, we need to clear the context and clear the keys.
-    cryptoContext.ClearEvalMultKeys()
+    ClearEvalMultKeys()
     cryptoContext.ClearEvalAutomorphismKeys()
     ReleaseAllContexts()
 

+ 1 - 1
examples/pke/simple-integers-serial.py

@@ -121,7 +121,7 @@ def main_action():
     # of the keys. When deserializing a context, OpenFHE checks for the tag and
     # if it finds it in the CryptoContext map, it will return the stored version.
     # Hence, we need to clear the context and clear the keys.
-    cryptoContext.ClearEvalMultKeys()
+    ClearEvalMultKeys()
     cryptoContext.ClearEvalAutomorphismKeys()
     ReleaseAllContexts()
 

+ 1 - 1
examples/pke/simple-real-numbers-serial.py

@@ -142,9 +142,9 @@ def serverSetupAndWrite(multDepth, scaleModSize, batchSize):
 
 def clientProcess():
     # clientCC = CryptoContext()
-    # clientCC.ClearEvalMultKeys()
     # clientCC.ClearEvalAutomorphismKeys()
     ReleaseAllContexts()
+    ClearEvalMultKeys()
 
     clientCC, res = DeserializeCryptoContext(datafolder + ccLocation, BINARY)
     if not res:

+ 3 - 2
examples/pke/tckks-interactive-mp-bootstrapping-Chebyschev.py

@@ -6,8 +6,9 @@ def main():
     # Same test with different rescaling techniques in CKKS
     TCKKSCollectiveBoot(FIXEDMANUAL)
     TCKKSCollectiveBoot(FIXEDAUTO)
-    TCKKSCollectiveBoot(FLEXIBLEAUTO)
-    TCKKSCollectiveBoot(FLEXIBLEAUTOEXT)
+    if get_native_int()!=128:
+        TCKKSCollectiveBoot(FLEXIBLEAUTO)
+        TCKKSCollectiveBoot(FLEXIBLEAUTOEXT)
 
     print("Interactive (3P) Bootstrapping Ciphertext [Chebyshev] (TCKKS) terminated gracefully!")
 

+ 4 - 3
examples/pke/tckks-interactive-mp-bootstrapping.py

@@ -21,8 +21,9 @@ def main():
     # Same test with different rescaling techniques in CKKS
     TCKKSCollectiveBoot(FIXEDMANUAL)
     TCKKSCollectiveBoot(FIXEDAUTO)
-    TCKKSCollectiveBoot(FLEXIBLEAUTO)
-    TCKKSCollectiveBoot(FLEXIBLEAUTOEXT)
+    if get_native_int()!=128:
+        TCKKSCollectiveBoot(FLEXIBLEAUTO)
+        TCKKSCollectiveBoot(FLEXIBLEAUTOEXT)
 
     print("Interactive Multi-Party Bootstrapping Ciphertext (TCKKS) terminated gracefully!\n")
 
@@ -170,4 +171,4 @@ def TCKKSCollectiveBoot(scaleTech):
     print("\n============================ INTERACTIVE DECRYPTION ENDED ============================\n")      
 
 if __name__ == "__main__":
-    main()
+    main()

+ 3 - 0
src/include/pke/cryptocontext_wrapper.h

@@ -65,4 +65,7 @@ const double GetScalingFactorRealWrapper(CryptoContext<DCRTPoly>& self, uint32_t
 const uint64_t GetModulusCKKSWrapper(CryptoContext<DCRTPoly>& self);
 const ScalingTechnique GetScalingTechniqueWrapper(CryptoContext<DCRTPoly>& self);
 const usint GetDigitSizeWrapper(CryptoContext<DCRTPoly>& self);
+
+void ClearEvalMultKeysWrapper();
+
 #endif // OPENFHE_CRYPTOCONTEXT_BINDINGS_H

+ 9 - 7
src/lib/bindings.cpp

@@ -715,10 +715,6 @@ void bind_crypto_context(py::module &m)
         .def("FindAutomorphismIndices", &CryptoContextImpl<DCRTPoly>::FindAutomorphismIndices,
             cc_FindAutomorphismIndices_docs,
             py::arg("idxList"))
-        .def_static(
-            "ClearEvalMultKeys", []()
-            { CryptoContextImpl<DCRTPoly>::ClearEvalMultKeys(); },
-            cc_ClearEvalMultKeys_docs)
         .def_static(
             "InsertEvalSumKey", &CryptoContextImpl<DCRTPoly>::InsertEvalSumKey,
             cc_InsertEvalSumKey_docs,
@@ -727,7 +723,8 @@ void bind_crypto_context(py::module &m)
         .def_static(
             "InsertEvalMultKey", &CryptoContextImpl<DCRTPoly>::InsertEvalMultKey,
             cc_InsertEvalMultKey_docs,
-            py::arg("evalKeyVec"))
+            py::arg("evalKeyVec"),
+            py::arg("keyTag") = "")
         .def_static(
             "ClearEvalAutomorphismKeys", []()
             { CryptoContextImpl<DCRTPoly>::ClearEvalAutomorphismKeys(); },
@@ -835,15 +832,20 @@ void bind_crypto_context(py::module &m)
         py::arg("params"));
     m.def("GenCryptoContext", &GenCryptoContext<CryptoContextCKKSRNS>,
         py::arg("params"));
-    m.def("ReleaseAllContexts", &CryptoContextFactory<DCRTPoly>::ReleaseAllContexts);
+
     m.def("GetAllContexts", &CryptoContextFactory<DCRTPoly>::GetAllContexts);
+
+    m.def("ReleaseAllContexts", &CryptoContextFactory<DCRTPoly>::ReleaseAllContexts);
+    m.def("ClearEvalMultKeys", &ClearEvalMultKeysWrapper);
 }
 
 int get_native_int(){
     #if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)
         return 128;
+    #elif NATIVEINT == 32
+        return 32;
     #else
-        return 64;    
+        return 64;
     #endif
 }
 

+ 4 - 0
src/lib/pke/cryptocontext_wrapper.cpp

@@ -150,3 +150,7 @@ const ScalingTechnique GetScalingTechniqueWrapper(CryptoContext<DCRTPoly> & self
     }
 
 }
+
+void ClearEvalMultKeysWrapper() {
+    CryptoContextImpl<DCRTPoly>::ClearEvalMultKeys();
+}

+ 1 - 0
tests/test_bgv.py

@@ -4,6 +4,7 @@ import random
 import pytest
 import openfhe as fhe
 
+pytestmark = pytest.mark.skipif(fhe.get_native_int() == 32, reason="Doesn't work for NATIVE_INT=32")
 
 LOGGER = logging.getLogger("test_bgv")
 

+ 4 - 3
tests/test_ckks.py

@@ -3,6 +3,7 @@ import random
 import pytest
 import openfhe as fhe
 
+pytestmark = pytest.mark.skipif(fhe.get_native_int() == 32, reason="Doesn't work for NATIVE_INT=32")
 
 @pytest.fixture(scope="module")
 def ckks_context():
@@ -13,14 +14,14 @@ def ckks_context():
     batch_size = 8
     parameters = fhe.CCParamsCKKSRNS()
     parameters.SetMultiplicativeDepth(5)
-    if fhe.get_native_int() > 90:
+    if fhe.get_native_int() == 128:
         parameters.SetFirstModSize(89)
         parameters.SetScalingModSize(78)
         parameters.SetBatchSize(batch_size)
         parameters.SetScalingTechnique(fhe.ScalingTechnique.FIXEDAUTO)
         parameters.SetNumLargeDigits(2)
 
-    elif fhe.get_native_int() > 60:
+    elif fhe.get_native_int() == 64:
         parameters.SetFirstModSize(60)
         parameters.SetScalingModSize(56)
         parameters.SetBatchSize(batch_size)
@@ -28,7 +29,7 @@ def ckks_context():
         parameters.SetNumLargeDigits(2)
 
     else:
-        raise ValueError("Expected a native int size greater than 60.")
+        raise ValueError("Expected a native int size 64 or 128.")
 
     cc = fhe.GenCryptoContext(parameters)
     cc.Enable(fhe.PKESchemeFeature.PKE)

+ 1 - 2
tests/test_cryptocontext.py

@@ -1,9 +1,8 @@
 import pytest
 import openfhe as fhe
 
+pytestmark = pytest.mark.skipif(fhe.get_native_int() != 128, reason="Only for NATIVE_INT=128")
 
-@pytest.mark.long
-@pytest.mark.skipif(fhe.get_native_int() < 80, reason="Only for NATIVE_INT=128")
 @pytest.mark.parametrize("scaling", [fhe.FIXEDAUTO, fhe.FIXEDMANUAL])
 def test_ckks_context(scaling):
     batch_size = 8

+ 3 - 0
tests/test_examples.py

@@ -5,6 +5,9 @@ import importlib.util
 import pytest
 import tempfile
 import shutil
+import openfhe as fhe
+
+pytestmark = pytest.mark.skipif(fhe.get_native_int() == 32, reason="Doesn't work for NATIVE_INT=32")
 
 EXAMPLES_SCRIPTS_PATH = os.path.join(Path(__file__).parent.parent, "examples", "pke")
 

+ 5 - 3
tests/test_serial_cc.py

@@ -3,6 +3,8 @@ import pytest
 
 import openfhe as fhe
 
+pytestmark = pytest.mark.skipif(fhe.get_native_int() == 32, reason="Doesn't work for NATIVE_INT=32")
+
 LOGGER = logging.getLogger("test_serial_cc")
 
 
@@ -23,7 +25,7 @@ def test_serial_cryptocontext(tmp_path):
     LOGGER.debug("The cryptocontext has been serialized.")
     assert fhe.SerializeToFile(str(tmp_path / "ciphertext1.json"), ciphertext1, fhe.JSON)
 
-    cryptoContext.ClearEvalMultKeys()
+    fhe.ClearEvalMultKeys()
     cryptoContext.ClearEvalAutomorphismKeys()
     fhe.ReleaseAllContexts()
 
@@ -117,7 +119,7 @@ def test_serial_cryptocontext_str(mode):
     automorphismKey_ser = fhe.SerializeEvalAutomorphismKeyString(mode, "")
     LOGGER.debug("The rotation evaluation keys have been serialized.")
 
-    cryptoContext.ClearEvalMultKeys()
+    fhe.ClearEvalMultKeys()
     cryptoContext.ClearEvalAutomorphismKeys()
     fhe.ReleaseAllContexts()
 
@@ -213,4 +215,4 @@ def rotate_vector(vector, rotation):
         rotated = [0] * rotation + vector[:n - rotation]
     else:
         rotated = vector
-    return rotated
+    return rotated