Browse Source

- update __len__ in plaintext to map to GetLength() to make the API more pythonic
- use more Python APIs

Muthu Annamalai 1 year ago
parent
commit
4c79d3210d
4 changed files with 13 additions and 4 deletions
  1. 2 0
      src/lib/bindings.cpp
  2. 2 0
      src/lib/binfhe_bindings.cpp
  3. 6 3
      tests/test_boolean.py
  4. 3 1
      tests/test_serial_cc.py

+ 2 - 0
src/lib/bindings.cpp

@@ -1041,6 +1041,8 @@ void bind_encodings(py::module &m)
             py::arg("sf"))
             py::arg("sf"))
         .def("GetSchemeID", &PlaintextImpl::GetSchemeID,
         .def("GetSchemeID", &PlaintextImpl::GetSchemeID,
             ptx_GetSchemeID_docs)
             ptx_GetSchemeID_docs)
+        .def("__len__", &PlaintextImpl::GetLength,
+            ptx_GetLength_docs)
         .def("GetLength", &PlaintextImpl::GetLength,
         .def("GetLength", &PlaintextImpl::GetLength,
             ptx_GetLength_docs)
             ptx_GetLength_docs)
         .def("GetSchemeID", &PlaintextImpl::GetSchemeID,
         .def("GetSchemeID", &PlaintextImpl::GetSchemeID,

+ 2 - 0
src/lib/binfhe_bindings.cpp

@@ -140,6 +140,7 @@ void bind_binfhe_keys(py::module &m) {
   py::class_<LWEPrivateKeyImpl, std::shared_ptr<LWEPrivateKeyImpl>>(
   py::class_<LWEPrivateKeyImpl, std::shared_ptr<LWEPrivateKeyImpl>>(
       m, "LWEPrivateKey")
       m, "LWEPrivateKey")
       .def(py::init<>())
       .def(py::init<>())
+      .def("__len__", &LWEPrivateKeyImpl::GetLength)    
       .def("GetLength", &LWEPrivateKeyImpl::GetLength)
       .def("GetLength", &LWEPrivateKeyImpl::GetLength)
       .def(py::self == py::self)
       .def(py::self == py::self)
       .def(py::self != py::self);
       .def(py::self != py::self);
@@ -148,6 +149,7 @@ void bind_binfhe_ciphertext(py::module &m) {
   py::class_<LWECiphertextImpl, std::shared_ptr<LWECiphertextImpl>>(
   py::class_<LWECiphertextImpl, std::shared_ptr<LWECiphertextImpl>>(
       m, "LWECiphertext")
       m, "LWECiphertext")
       .def(py::init<>())
       .def(py::init<>())
+      .def("__len__", &LWECiphertextImpl::GetLength)
       .def("GetLength", &LWECiphertextImpl::GetLength)
       .def("GetLength", &LWECiphertextImpl::GetLength)
       .def("GetModulus", &GetLWECiphertextModulusWrapper)
       .def("GetModulus", &GetLWECiphertextModulusWrapper)
       .def(py::self == py::self)
       .def(py::self == py::self)

+ 6 - 3
tests/test_boolean.py

@@ -3,9 +3,10 @@ import pytest
 
 
 
 
 ## Sample Program: Step 1: Set CryptoContext
 ## Sample Program: Step 1: Set CryptoContext
+@pytest.mark.parametrize("context",[TOY,MEDIUM,STD128])
 @pytest.mark.parametrize("a", [0, 1])
 @pytest.mark.parametrize("a", [0, 1])
 @pytest.mark.parametrize("b", [0, 1])
 @pytest.mark.parametrize("b", [0, 1])
-def test_boolean_AND(a, b):
+def test_boolean_AND(context,a, b):
     cc = BinFHEContext()
     cc = BinFHEContext()
 
 
     """
     """
@@ -14,13 +15,13 @@ def test_boolean_AND(a, b):
     MEDIUM corresponds to the level of more than 100 bits for both quantum and
     MEDIUM corresponds to the level of more than 100 bits for both quantum and
     classical computer attacks
     classical computer attacks
     """
     """
-    cc.GenerateBinFHEContext(STD128, GINX)
+    cc.GenerateBinFHEContext(context, GINX)
 
 
     ## Sample Program: Step 2: Key Generation
     ## Sample Program: Step 2: Key Generation
 
 
     # Generate the secret key
     # Generate the secret key
     sk = cc.KeyGen()
     sk = cc.KeyGen()
-
+    assert sk.GetLength() == len(sk)
     print("Generating the bootstrapping keys...\n")
     print("Generating the bootstrapping keys...\n")
 
 
     # Generate the bootstrapping keys (refresh and switching keys)
     # Generate the bootstrapping keys (refresh and switching keys)
@@ -37,6 +38,8 @@ def test_boolean_AND(a, b):
     ct1 = cc.Encrypt(sk, a)
     ct1 = cc.Encrypt(sk, a)
     ct2 = cc.Encrypt(sk, b)
     ct2 = cc.Encrypt(sk, b)
 
 
+    assert ct1.GetLength() == len(ct1)
+    
     # Sample Program: Step 4: Evaluation
     # Sample Program: Step 4: Evaluation
 
 
     # Compute (1 AND 1) = 1; Other binary gate options are OR, NAND, and NOR
     # Compute (1 AND 1) = 1; Other binary gate options are OR, NAND, and NOR

+ 3 - 1
tests/test_serial_cc.py

@@ -64,7 +64,9 @@ def test_serial_cryptocontext_str(mode):
     # First plaintext vector is encoded
     # First plaintext vector is encoded
     vectorOfInts1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
     vectorOfInts1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
     plaintext1 = cryptoContext.MakePackedPlaintext(vectorOfInts1)
     plaintext1 = cryptoContext.MakePackedPlaintext(vectorOfInts1)
-
+    assert len(plaintext1) == plaintext1.GetLength()
+    assert len(plaintext1) == 12
+    
     # Second plaintext vector is encoded
     # Second plaintext vector is encoded
     vectorOfInts2 = [3, 2, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12]
     vectorOfInts2 = [3, 2, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12]
     plaintext2 = cryptoContext.MakePackedPlaintext(vectorOfInts2)
     plaintext2 = cryptoContext.MakePackedPlaintext(vectorOfInts2)