test_bgv.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import logging
  2. import random
  3. import pytest
  4. import openfhe as fhe
  5. LOGGER = logging.getLogger("test_bgv")
  6. @pytest.fixture(scope="module")
  7. def bgv_context():
  8. """
  9. This fixture creates a small CKKS context, with its paramters and keys.
  10. We make it because context creation can be slow.
  11. """
  12. parameters = fhe.CCParamsBGVRNS()
  13. parameters.SetPlaintextModulus(65537)
  14. parameters.SetMultiplicativeDepth(2)
  15. crypto_context = fhe.GenCryptoContext(parameters)
  16. crypto_context.Enable(fhe.PKESchemeFeature.PKE)
  17. crypto_context.Enable(fhe.PKESchemeFeature.KEYSWITCH)
  18. crypto_context.Enable(fhe.PKESchemeFeature.LEVELEDSHE)
  19. key_pair = crypto_context.KeyGen()
  20. # Generate the relinearization key
  21. crypto_context.EvalMultKeyGen(key_pair.secretKey)
  22. # Generate the rotation evaluation keys
  23. crypto_context.EvalRotateKeyGen(key_pair.secretKey, [1, 2, -1, -2])
  24. return parameters, crypto_context, key_pair
  25. def bgv_equal(raw, ciphertext, cc, keys):
  26. """Compare an unencrypted list of values with encrypted values"""
  27. pt = cc.Decrypt(ciphertext, keys.secretKey)
  28. pt.SetLength(len(raw))
  29. compare = pt.GetPackedValue()
  30. success = all([a == b for (a, b) in zip(raw, compare)])
  31. if not success:
  32. LOGGER.info("Mismatch between %s %s", raw, compare)
  33. return success
  34. def roll(a, n):
  35. """Circularly rotate a list, like numpy.roll but without numpy."""
  36. return [a[i % len(a)] for i in range(-n, len(a) - n)]
  37. @pytest.mark.parametrize("n,final", [
  38. (0, [0, 1, 2, 3, 4, 5, 6, 7]),
  39. (2, [6, 7, 0, 1, 2, 3, 4, 5]),
  40. (3, [5, 6, 7, 0, 1, 2, 3, 4]),
  41. (-1, [1, 2, 3, 4, 5, 6, 7, 0]),
  42. ])
  43. def test_roll(n, final):
  44. assert roll(list(range(8)), n) == final
  45. def shift(a, n):
  46. """Rotate a list with infill of 0."""
  47. return [(a[i] if 0 <= i < len(a) else 0) for i in range(-n, len(a) - n)]
  48. @pytest.mark.parametrize("n,final", [
  49. (0, [1, 2, 3, 4, 5, 6, 7, 8]),
  50. (2, [0, 0, 1, 2, 3, 4, 5, 6]),
  51. (3, [0, 0, 0, 1, 2, 3, 4, 5]),
  52. (-1, [2, 3, 4, 5, 6, 7, 8, 0]),
  53. ])
  54. def test_shift(n, final):
  55. assert shift(list(range(1, 9)), n) == final
  56. def test_simple_integers(bgv_context):
  57. parameters, crypto_context, key_pair = bgv_context
  58. rng = random.Random(342342)
  59. cnt = 12
  60. raw = [[rng.randint(1, 12) for _ in range(cnt)] for _ in range(3)]
  61. plaintext = [crypto_context.MakePackedPlaintext(r) for r in raw]
  62. ciphertext = [crypto_context.Encrypt(key_pair.publicKey, pt) for pt in plaintext]
  63. assert bgv_equal(raw[0], ciphertext[0], crypto_context, key_pair)
  64. # Homomorphic additions
  65. ciphertext_add12 = crypto_context.EvalAdd(ciphertext[0], ciphertext[1])
  66. ciphertext_add_result = crypto_context.EvalAdd(ciphertext_add12, ciphertext[2])
  67. assert bgv_equal(
  68. [a + b + c for (a, b, c) in zip(*raw)],
  69. ciphertext_add_result, crypto_context, key_pair
  70. )
  71. # Homomorphic Multiplication
  72. ciphertext_mult12 = crypto_context.EvalMult(ciphertext[0], ciphertext[1])
  73. ciphertext_mult_result = crypto_context.EvalMult(ciphertext_mult12, ciphertext[2])
  74. assert bgv_equal(
  75. [a * b * c for (a, b, c) in zip(*raw)],
  76. ciphertext_mult_result, crypto_context, key_pair
  77. )
  78. # Homomorphic Rotations. These values must be initialized with EvalRotateKeyGen.
  79. for rotation in [1, 2, -1, -2]:
  80. ciphertext_rot1 = crypto_context.EvalRotate(ciphertext[0], rotation)
  81. # This is a rotation with infill of 0, NOT a circular rotation.
  82. assert bgv_equal(shift(raw[0], -rotation), ciphertext_rot1, crypto_context, key_pair)