test_ckks.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import random
  2. import pytest
  3. import openfhe as fhe
  4. pytestmark = pytest.mark.skipif(fhe.get_native_int() == 32, reason="Doesn't work for NATIVE_INT=32")
  5. @pytest.fixture(scope="module")
  6. def ckks_context():
  7. """
  8. This fixture creates a small CKKS context, with its paramters and keys.
  9. We make it because context creation can be slow.
  10. """
  11. batch_size = 8
  12. parameters = fhe.CCParamsCKKSRNS()
  13. parameters.SetMultiplicativeDepth(5)
  14. if fhe.get_native_int() == 128:
  15. parameters.SetFirstModSize(89)
  16. parameters.SetScalingModSize(78)
  17. parameters.SetBatchSize(batch_size)
  18. parameters.SetScalingTechnique(fhe.ScalingTechnique.FIXEDAUTO)
  19. parameters.SetNumLargeDigits(2)
  20. elif fhe.get_native_int() == 64:
  21. parameters.SetFirstModSize(60)
  22. parameters.SetScalingModSize(56)
  23. parameters.SetBatchSize(batch_size)
  24. parameters.SetScalingTechnique(fhe.ScalingTechnique.FLEXIBLEAUTO)
  25. parameters.SetNumLargeDigits(2)
  26. else:
  27. raise ValueError("Expected a native int size 64 or 128.")
  28. cc = fhe.GenCryptoContext(parameters)
  29. cc.Enable(fhe.PKESchemeFeature.PKE)
  30. cc.Enable(fhe.PKESchemeFeature.KEYSWITCH)
  31. cc.Enable(fhe.PKESchemeFeature.LEVELEDSHE)
  32. keys = cc.KeyGen()
  33. cc.EvalRotateKeyGen(keys.secretKey, [1, -2])
  34. return parameters, cc, keys
  35. def test_add_two_numbers(ckks_context):
  36. params, cc, keys = ckks_context
  37. batch_size = params.GetBatchSize()
  38. rng = random.Random(42429842)
  39. raw = [[rng.uniform(-1, 1) for _ in range(batch_size)] for _ in range(2)]
  40. ptxt = [cc.MakeCKKSPackedPlaintext(x) for x in raw]
  41. ctxt = [cc.Encrypt(keys.publicKey, y) for y in ptxt]
  42. ct_added = cc.EvalAdd(ctxt[0], ctxt[1])
  43. pt_added = cc.Decrypt(ct_added, keys.secretKey)
  44. pt_added.SetLength(batch_size)
  45. final_added = pt_added.GetCKKSPackedValue()
  46. raw_added = [a + b for (a, b) in zip(*raw)]
  47. total = sum(abs(a - b) for (a, b) in zip(raw_added, final_added))
  48. assert total < 1e-3