test_bgv.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import logging
  2. import random
  3. import pytest
  4. import openfhe as fhe
  5. pytestmark = pytest.mark.skipif(fhe.get_native_int() == 32, reason="Doesn't work for NATIVE_INT=32")
  6. LOGGER = logging.getLogger("test_bgv")
  7. @pytest.fixture(scope="module")
  8. def bgv_context():
  9. """
  10. This fixture creates a small CKKS context, with its paramters and keys.
  11. We make it because context creation can be slow.
  12. """
  13. parameters = fhe.CCParamsBGVRNS()
  14. parameters.SetPlaintextModulus(65537)
  15. parameters.SetMultiplicativeDepth(2)
  16. crypto_context = fhe.GenCryptoContext(parameters)
  17. crypto_context.Enable(fhe.PKESchemeFeature.PKE)
  18. crypto_context.Enable(fhe.PKESchemeFeature.KEYSWITCH)
  19. crypto_context.Enable(fhe.PKESchemeFeature.LEVELEDSHE)
  20. key_pair = crypto_context.KeyGen()
  21. # Generate the relinearization key
  22. crypto_context.EvalMultKeyGen(key_pair.secretKey)
  23. # Generate the rotation evaluation keys
  24. crypto_context.EvalRotateKeyGen(key_pair.secretKey, [1, 2, -1, -2])
  25. return parameters, crypto_context, key_pair
  26. def bgv_equal(raw, ciphertext, cc, keys):
  27. """Compare an unencrypted list of values with encrypted values"""
  28. pt = cc.Decrypt(ciphertext, keys.secretKey)
  29. pt.SetLength(len(raw))
  30. compare = pt.GetPackedValue()
  31. success = all([a == b for (a, b) in zip(raw, compare)])
  32. if not success:
  33. LOGGER.info("Mismatch between %s %s", raw, compare)
  34. return success
  35. def roll(a, n):
  36. """Circularly rotate a list, like numpy.roll but without numpy."""
  37. return [a[i % len(a)] for i in range(-n, len(a) - n)]
  38. @pytest.mark.parametrize("n,final", [
  39. (0, [0, 1, 2, 3, 4, 5, 6, 7]),
  40. (2, [6, 7, 0, 1, 2, 3, 4, 5]),
  41. (3, [5, 6, 7, 0, 1, 2, 3, 4]),
  42. (-1, [1, 2, 3, 4, 5, 6, 7, 0]),
  43. ])
  44. def test_roll(n, final):
  45. assert roll(list(range(8)), n) == final
  46. def shift(a, n):
  47. """Rotate a list with infill of 0."""
  48. return [(a[i] if 0 <= i < len(a) else 0) for i in range(-n, len(a) - n)]
  49. @pytest.mark.parametrize("n,final", [
  50. (0, [1, 2, 3, 4, 5, 6, 7, 8]),
  51. (2, [0, 0, 1, 2, 3, 4, 5, 6]),
  52. (3, [0, 0, 0, 1, 2, 3, 4, 5]),
  53. (-1, [2, 3, 4, 5, 6, 7, 8, 0]),
  54. ])
  55. def test_shift(n, final):
  56. assert shift(list(range(1, 9)), n) == final
  57. def test_simple_integers(bgv_context):
  58. parameters, crypto_context, key_pair = bgv_context
  59. rng = random.Random(342342)
  60. cnt = 12
  61. raw = [[rng.randint(1, 12) for _ in range(cnt)] for _ in range(3)]
  62. plaintext = [crypto_context.MakePackedPlaintext(r) for r in raw]
  63. ciphertext = [crypto_context.Encrypt(key_pair.publicKey, pt) for pt in plaintext]
  64. assert bgv_equal(raw[0], ciphertext[0], crypto_context, key_pair)
  65. # Homomorphic additions
  66. ciphertext_add12 = crypto_context.EvalAdd(ciphertext[0], ciphertext[1])
  67. ciphertext_add_result = crypto_context.EvalAdd(ciphertext_add12, ciphertext[2])
  68. assert bgv_equal(
  69. [a + b + c for (a, b, c) in zip(*raw)],
  70. ciphertext_add_result, crypto_context, key_pair
  71. )
  72. # Homomorphic Multiplication
  73. ciphertext_mult12 = crypto_context.EvalMult(ciphertext[0], ciphertext[1])
  74. ciphertext_mult_result = crypto_context.EvalMult(ciphertext_mult12, ciphertext[2])
  75. assert bgv_equal(
  76. [a * b * c for (a, b, c) in zip(*raw)],
  77. ciphertext_mult_result, crypto_context, key_pair
  78. )
  79. # Homomorphic Rotations. These values must be initialized with EvalRotateKeyGen.
  80. for rotation in [1, 2, -1, -2]:
  81. ciphertext_rot1 = crypto_context.EvalRotate(ciphertext[0], rotation)
  82. # This is a rotation with infill of 0, NOT a circular rotation.
  83. assert bgv_equal(shift(raw[0], -rotation), ciphertext_rot1, crypto_context, key_pair)