test_ckks.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import random
  2. import pytest
  3. import openfhe as fhe
  4. @pytest.fixture(scope="module")
  5. def ckks_context():
  6. """
  7. This fixture creates a small CKKS context, with its paramters and keys.
  8. We make it because context creation can be slow.
  9. """
  10. batch_size = 8
  11. parameters = fhe.CCParamsCKKSRNS()
  12. parameters.SetMultiplicativeDepth(5)
  13. if fhe.get_native_int() > 90:
  14. parameters.SetFirstModSize(89)
  15. parameters.SetScalingModSize(78)
  16. parameters.SetBatchSize(batch_size)
  17. parameters.SetScalingTechnique(fhe.ScalingTechnique.FIXEDAUTO)
  18. parameters.SetNumLargeDigits(2)
  19. elif fhe.get_native_int() > 60:
  20. parameters.SetFirstModSize(60)
  21. parameters.SetScalingModSize(56)
  22. parameters.SetBatchSize(batch_size)
  23. parameters.SetScalingTechnique(fhe.ScalingTechnique.FLEXIBLEAUTO)
  24. parameters.SetNumLargeDigits(2)
  25. else:
  26. raise ValueError("Expected a native int size greater than 60.")
  27. cc = fhe.GenCryptoContext(parameters)
  28. cc.Enable(fhe.PKESchemeFeature.PKE)
  29. cc.Enable(fhe.PKESchemeFeature.KEYSWITCH)
  30. cc.Enable(fhe.PKESchemeFeature.LEVELEDSHE)
  31. keys = cc.KeyGen()
  32. cc.EvalRotateKeyGen(keys.secretKey, [1, -2])
  33. return parameters, cc, keys
  34. def test_add_two_numbers(ckks_context):
  35. params, cc, keys = ckks_context
  36. batch_size = params.GetBatchSize()
  37. rng = random.Random(42429842)
  38. raw = [[rng.uniform(-1, 1) for _ in range(batch_size)] for _ in range(2)]
  39. ptxt = [cc.MakeCKKSPackedPlaintext(x) for x in raw]
  40. ctxt = [cc.Encrypt(keys.publicKey, y) for y in ptxt]
  41. ct_added = cc.EvalAdd(ctxt[0], ctxt[1])
  42. pt_added = cc.Decrypt(ct_added, keys.secretKey)
  43. pt_added.SetLength(batch_size)
  44. final_added = pt_added.GetCKKSPackedValue()
  45. raw_added = [a + b for (a, b) in zip(*raw)]
  46. total = sum(abs(a - b) for (a, b) in zip(raw_added, final_added))
  47. assert total < 1e-3