瀏覽代碼

adding basic tests to dev work

Drew Dolgert 2 年之前
父節點
當前提交
820ba008f8
共有 6 個文件被更改,包括 409 次插入0 次删除
  1. 89 0
      tests/README.md
  2. 105 0
      tests/conftest.py
  3. 103 0
      tests/test_bgv.py
  4. 56 0
      tests/test_ckks.py
  5. 17 0
      tests/test_cryptocontext.py
  6. 39 0
      tests/test_serial_cc.py

+ 89 - 0
tests/README.md

@@ -0,0 +1,89 @@
+# Working with Tests
+
+These tests use Pytest (https://docs.pytest.org/).
+
+## Running and Using Tests
+
+These tests assume that openfhe-python is installed in the current python environment, which you can check by running
+```bash
+python -c "__import__('openfhe')"
+```
+and that the `pytest` package is installed, either through pip or by installing `python3-pytest` in the operating system package manager.
+
+### Specific to the OpenFHE unit tests
+
+Some tests are marked with `@pytest.mark.long` if they are not meant to run
+on Github Actions. Run these locally with:
+
+```bash
+pytest --run-long
+pytest --run-all
+```
+
+### General Pytest usage
+
+Test a particular file:
+
+```bash
+pytest test_particular_file.py
+```
+
+Test all functions matching a name. For instance, this would pick up
+`test_add_two_numbers`:
+
+```bash
+pytest -k add
+```
+
+As a reminder, pytest can be helpful for debugging.
+
+```bash
+pytest --log-cli-level=debug
+```
+
+If a test is failing, pytest can drop into the debugger when an exception
+happens.
+
+```bash
+pytest --pdb
+```
+
+## Guidelines for Writing Tests
+
+**Mark long-running tests with long** -- These tests run with default settings
+on Github Actions, which can be underpowered, so there is a way to mark tests
+that can be run by hand or on other automation servers.
+
+```python
+@pytest.mark.long
+def test_ckks_large_context():
+    assert true
+```
+
+The goal is for the Github Actions tests to reassure a committer that they have
+not broken the Python wrapper.
+
+**Import OpenFHE as fhe** -- Unit tests tend to use more imports than most
+code, for instance JSON, which conflicts with an OpenFHE name, so quality
+imports in the tests.
+
+```python
+import openfhe as fhe
+
+def test_something():
+    parameters = fhe.CCParamsCKKSRNS()
+```
+
+**Use logging instead of print statements** -- Pytest has nice support for
+making logging statements visible, in the case that you are using tests
+for debugging.
+
+```python
+import logging
+
+LOGGER = logging.getLogger("test_file_name")
+
+def test_something():
+    arg = 3
+    LOGGER.debug("My message has an argument %s", arg)
+```

+ 105 - 0
tests/conftest.py

@@ -0,0 +1,105 @@
+"""
+This is a specially-named file that pytest finds in order to
+configure testing. Most of the logic comes from
+https://docs.pytest.org/en/7.1.x/example/simple.html#control-skipping-of-tests-according-to-command-line-option
+"""
+import pytest
+
+
+class CustomMarker:
+    """
+    Custom Markers are used to annotate tests.
+
+    Tests marked with a custom marker will be skipped by default. Pass either
+
+        --run-NAME_OF_MARKER or --run-NAME-OF-MARKER
+
+    to override this behavior.
+
+    --run-all may also be used to run all marked tests.
+    """
+    def __init__(self, name, desc, dest=None):
+        self.name = name
+        self.desc = desc
+        self.dest = name if dest is None else dest
+
+    def option_flags(self):
+        """
+        Return option flags for this marker.
+
+        >>> marker = CustomMarker('foo_bar', 'my desc')
+        >>> marker.option_flags()
+        ['--run-foo_bar', '--run-foo-bar']
+        >>> marker2 = CustomMarker('foo', 'my desc')
+        >>> marker2.option_flags()
+        ['--run-foo']
+        """
+        # NOTE: pytest is not testing the above doctest
+        # instead, run this file directly (see doctest.testmod at bottom)
+        result = ['--run-{}'.format(self.name)]
+        as_hyphen = '--run-{}'.format(self.name.replace('_', '-'))
+        if as_hyphen != result[0]:
+            result.append(as_hyphen)
+
+        return result
+
+
+CUSTOM_MARKERS = (
+    CustomMarker('long',
+                 'this test runs too long for Github Actions'),
+    CustomMarker('uses_card',
+                 'must have acceleration card installed to run test'),
+)
+
+
+def pytest_addoption(parser):
+    """
+    pytest hook - adds options to argument parser.
+    """
+
+    parser.addoption('--run-all',
+                     dest='run_all',
+                     action='store_true',
+                     help='Run all tests normally skipped by default')
+
+    for marker in CUSTOM_MARKERS:
+        parser.addoption(*marker.option_flags(),
+                         dest=marker.dest,
+                         action='store_true',
+                         help='Run tests marked with {}'.format(marker.name))
+
+
+def pytest_configure(config):
+    # Adds explicit marker definitions
+    # with these, pytest will error if `--strict` is applied and unregistered
+    # markers are present.
+    for marker in CUSTOM_MARKERS:
+        config.addinivalue_line("markers",
+                                "{}: {}".format(marker.name, marker.desc))
+
+
+def pytest_collection_modifyitems(config, items):
+    """
+    pytest hook which runs after tests have been collected.
+    """
+    skip_marked_tests(config, items)
+
+
+def skip_marked_tests(config, items):
+    """
+    Dynamically applies pytest.mark.skip to tests with custom markers.
+
+    Tests with explicit --run-FOO flags are not skipped.
+
+    This keeps `pytest` from footshooting with tests that should only be run
+    under particular conditions.
+    """
+    run_all = config.getoption('--run-all', default=False)
+    run_mark = {marker.name: config.getoption(marker.dest)
+                for marker in CUSTOM_MARKERS}
+
+    for item in items:
+        for marker_name, run_marker in run_mark.items():
+            if marker_name in item.keywords and not (run_all or run_marker):
+                item.add_marker(pytest.mark.skip)
+                break

+ 103 - 0
tests/test_bgv.py

@@ -0,0 +1,103 @@
+import logging
+import random
+
+import pytest
+import openfhe as fhe
+
+
+LOGGER = logging.getLogger("test_bgv")
+
+
+@pytest.fixture(scope="module")
+def bgv_context():
+    """
+    This fixture creates a small CKKS context, with its paramters and keys.
+    We make it because context creation can be slow.
+    """
+    parameters = fhe.CCParamsBGVRNS()
+    parameters.SetPlaintextModulus(65537)
+    parameters.SetMultiplicativeDepth(2)
+
+    crypto_context = fhe.GenCryptoContext(parameters)
+    crypto_context.Enable(fhe.PKESchemeFeature.PKE)
+    crypto_context.Enable(fhe.PKESchemeFeature.KEYSWITCH)
+    crypto_context.Enable(fhe.PKESchemeFeature.LEVELEDSHE)
+    key_pair = crypto_context.KeyGen()
+    # Generate the relinearization key
+    crypto_context.EvalMultKeyGen(key_pair.secretKey)
+    # Generate the rotation evaluation keys
+    crypto_context.EvalRotateKeyGen(key_pair.secretKey, [1, 2, -1, -2])
+    return parameters, crypto_context, key_pair
+
+
+def bgv_equal(raw, ciphertext, cc, keys):
+    """Compare an unencrypted list of values with encrypted values"""
+    pt = cc.Decrypt(ciphertext, keys.secretKey)
+    pt.SetLength(len(raw))
+    compare = pt.GetPackedValue()
+    success = all([a == b for (a, b) in zip(raw, compare)])
+    if not success:
+        LOGGER.info("Mismatch between %s %s", raw, compare)
+    return success
+
+
+def roll(a, n):
+    """Circularly rotate a list, like numpy.roll but without numpy."""
+    return [a[i % len(a)] for i in range(-n, len(a) - n)]
+
+
+@pytest.mark.parametrize("n,final", [
+    (0, [0, 1, 2, 3, 4, 5, 6, 7]),
+    (2, [6, 7, 0, 1, 2, 3, 4, 5]),
+    (3, [5, 6, 7, 0, 1, 2, 3, 4]),
+    (-1, [1, 2, 3, 4, 5, 6, 7, 0]),
+    ])
+def test_roll(n, final):
+    assert roll(list(range(8)), n) == final
+
+
+def shift(a, n):
+    """Rotate a list with infill of 0."""
+    return [(a[i] if 0 <= i < len(a) else 0) for i in range(-n, len(a) - n)]
+
+
+@pytest.mark.parametrize("n,final", [
+    (0, [1, 2, 3, 4, 5, 6, 7, 8]),
+    (2, [0, 0, 1, 2, 3, 4, 5, 6]),
+    (3, [0, 0, 0, 1, 2, 3, 4, 5]),
+    (-1, [2, 3, 4, 5, 6, 7, 8, 0]),
+    ])
+def test_shift(n, final):
+    assert shift(list(range(1, 9)), n) == final
+
+
+def test_simple_integers(bgv_context):
+    parameters, crypto_context, key_pair = bgv_context
+    rng = random.Random(342342)
+    cnt = 12
+    raw = [[rng.randint(1, 12) for _ in range(cnt)] for _ in range(3)]
+    plaintext = [crypto_context.MakePackedPlaintext(r) for r in raw]
+    ciphertext = [crypto_context.Encrypt(key_pair.publicKey, pt) for pt in plaintext]
+    assert bgv_equal(raw[0], ciphertext[0], crypto_context, key_pair)
+
+    # Homomorphic additions
+    ciphertext_add12 = crypto_context.EvalAdd(ciphertext[0], ciphertext[1])
+    ciphertext_add_result = crypto_context.EvalAdd(ciphertext_add12, ciphertext[2])
+    assert bgv_equal(
+        [a + b + c for (a, b, c) in zip(*raw)],
+        ciphertext_add_result, crypto_context, key_pair
+        )
+
+    # Homomorphic Multiplication
+    ciphertext_mult12 = crypto_context.EvalMult(ciphertext[0], ciphertext[1])
+    ciphertext_mult_result = crypto_context.EvalMult(ciphertext_mult12, ciphertext[2])
+    assert bgv_equal(
+        [a * b * c for (a, b, c) in zip(*raw)],
+        ciphertext_mult_result, crypto_context, key_pair
+        )
+
+    # Homomorphic Rotations. These values must be initialized with EvalRotateKeyGen.
+    for rotation in [1, 2, -1, -2]:
+        ciphertext_rot1 = crypto_context.EvalRotate(ciphertext[0], rotation)
+        # This is a rotation with infill of 0, NOT a circular rotation.
+        assert bgv_equal(shift(raw[0], -rotation), ciphertext_rot1, crypto_context, key_pair)

+ 56 - 0
tests/test_ckks.py

@@ -0,0 +1,56 @@
+import random
+
+import pytest
+import openfhe as fhe
+
+
+@pytest.fixture(scope="module")
+def ckks_context():
+    """
+    This fixture creates a small CKKS context, with its paramters and keys.
+    We make it because context creation can be slow.
+    """
+    batch_size = 8
+    parameters = fhe.CCParamsCKKSRNS()
+    parameters.SetMultiplicativeDepth(5)
+    if fhe.get_native_int() > 90:
+        parameters.SetFirstModSize(89)
+        parameters.SetScalingModSize(78)
+        parameters.SetBatchSize(batch_size)
+        parameters.SetScalingTechnique(fhe.ScalingTechnique.FIXEDAUTO)
+        parameters.SetNumLargeDigits(2)
+
+    elif fhe.get_native_int() > 60:
+        parameters.SetFirstModSize(60)
+        parameters.SetScalingModSize(56)
+        parameters.SetBatchSize(batch_size)
+        parameters.SetScalingTechnique(fhe.ScalingTechnique.FLEXIBLEAUTO)
+        parameters.SetNumLargeDigits(2)
+
+    else:
+        raise ValueError("Expected a native int size greater than 60.")
+
+    cc = fhe.GenCryptoContext(parameters)
+    cc.Enable(fhe.PKESchemeFeature.PKE)
+    cc.Enable(fhe.PKESchemeFeature.KEYSWITCH)
+    cc.Enable(fhe.PKESchemeFeature.LEVELEDSHE)
+    keys = cc.KeyGen()
+    cc.EvalRotateKeyGen(keys.secretKey, [1, -2])
+    return parameters, cc, keys
+
+
+def test_add_two_numbers(ckks_context):
+    params, cc, keys = ckks_context
+    batch_size = params.GetBatchSize()
+    rng = random.Random(42429842)
+    raw = [[rng.uniform(-1, 1) for _ in range(batch_size)] for _ in range(2)]
+    ptxt = [cc.MakeCKKSPackedPlaintext(x) for x in raw]
+    ctxt = [cc.Encrypt(keys.publicKey, y) for y in ptxt]
+
+    ct_added = cc.EvalAdd(ctxt[0], ctxt[1])
+    pt_added = cc.Decrypt(ct_added, keys.secretKey)
+    pt_added.SetLength(batch_size)
+    final_added = pt_added.GetCKKSPackedValue()
+    raw_added = [a + b for (a, b) in zip(*raw)]
+    total = sum(abs(a - b) for (a, b) in zip(raw_added, final_added))
+    assert total < 1e-3

+ 17 - 0
tests/test_cryptocontext.py

@@ -0,0 +1,17 @@
+import pytest
+import openfhe as fhe
+
+
+@pytest.mark.long
+@pytest.mark.skipif(fhe.get_native_int() < 80, reason="Only for NATIVE_INT=128")
+@pytest.mark.parametrize("scaling", [fhe.FIXEDAUTO, fhe.FIXEDMANUAL])
+def test_ckks_context(scaling):
+    batch_size = 8
+    parameters = fhe.CCParamsCKKSRNS()
+    parameters.SetMultiplicativeDepth(5)
+    parameters.SetScalingModSize(78)
+    parameters.SetBatchSize(batch_size)
+    parameters.SetScalingTechnique(scaling)
+    parameters.SetNumLargeDigits(2)
+    cc = fhe.GenCryptoContext(parameters)
+    assert isinstance(cc, fhe.CryptoContext)

+ 39 - 0
tests/test_serial_cc.py

@@ -0,0 +1,39 @@
+import logging
+
+import openfhe as fhe
+
+LOGGER = logging.getLogger("test_serial_cc")
+
+
+def test_serial_cryptocontext(tmp_path):
+    parameters = fhe.CCParamsBFVRNS()
+    parameters.SetPlaintextModulus(65537)
+    parameters.SetMultiplicativeDepth(2)
+
+    cryptoContext = fhe.GenCryptoContext(parameters)
+    cryptoContext.Enable(fhe.PKESchemeFeature.PKE)
+
+    keypair = cryptoContext.KeyGen()
+    vectorOfInts1 = list(range(12))
+    plaintext1 = cryptoContext.MakePackedPlaintext(vectorOfInts1)
+    ciphertext1 = cryptoContext.Encrypt(keypair.publicKey, plaintext1)
+
+    assert fhe.SerializeToFile(str(tmp_path / "cryptocontext.json"), cryptoContext, fhe.JSON)
+    LOGGER.debug("The cryptocontext has been serialized.")
+    assert fhe.SerializeToFile(str(tmp_path / "ciphertext1.json"), ciphertext1, fhe.JSON)
+
+    cryptoContext.ClearEvalMultKeys()
+    cryptoContext.ClearEvalAutomorphismKeys()
+    fhe.ReleaseAllContexts()
+
+    cc, success = fhe.DeserializeCryptoContext(str(tmp_path / "cryptocontext.json"), fhe.JSON)
+    assert success
+    assert isinstance(cc, fhe.CryptoContext)
+    assert fhe.SerializeToFile(str(tmp_path / "cryptocontext2.json"), cc, fhe.JSON)
+    LOGGER.debug("The cryptocontext has been serialized.")
+
+    ct1, success = fhe.DeserializeCiphertext(str(tmp_path / "ciphertext1.json"), fhe.JSON)
+    assert success
+    assert isinstance(ct1, fhe.Ciphertext)
+    LOGGER.debug("Cryptocontext deserializes to %s %s", success, ct1)
+    assert fhe.SerializeToFile(str(tmp_path / "ciphertext12.json"), ct1, fhe.JSON)