pre-buffer.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import time
  2. import random
  3. from math import log2
  4. from openfhe import *
  5. def main():
  6. passed = run_demo_pre()
  7. if not passed: # there could be an error
  8. return 1
  9. return 0 # successful return
  10. def run_demo_pre():
  11. # Generate parameters.
  12. print("setting up BFV RNS crypto system")
  13. start_time = time.time()
  14. plaintextModulus = 65537 # can encode shorts
  15. parameters = CCParamsBFVRNS()
  16. parameters.SetPlaintextModulus(plaintextModulus)
  17. parameters.SetScalingModSize(60)
  18. cc = GenCryptoContext(parameters)
  19. print(f"\nParam generation time: {time.time() - start_time} ms")
  20. # Turn on features
  21. cc.Enable(PKE)
  22. cc.Enable(KEYSWITCH)
  23. cc.Enable(LEVELEDSHE)
  24. cc.Enable(PRE)
  25. print(f"p = {cc.GetPlaintextModulus()}")
  26. print(f"n = {cc.GetCyclotomicOrder()/2}")
  27. print(f"log2 q = {log2(cc.GetModulus())}")
  28. print(f"r = {cc.GetDigitSize()}")
  29. ringsize = cc.GetRingDimension()
  30. print(f"Alice can encrypt {ringsize * 2} bytes of data")
  31. # Perform Key Generation Operation
  32. print("\nRunning Alice key generation (used for source data)...")
  33. start_time = time.time()
  34. keyPair1 = cc.KeyGen()
  35. print(f"Key generation time: {time.time() - start_time} ms")
  36. if not keyPair1.good():
  37. print("Alice Key generation failed!")
  38. return False
  39. # Encode source data
  40. nshort = ringsize
  41. vShorts = [random.randint(0, 65536) for _ in range(nshort)]
  42. pt = cc.MakePackedPlaintext(vShorts)
  43. # Encryption
  44. start_time = time.time()
  45. ct1 = cc.Encrypt(keyPair1.publicKey, pt)
  46. print(f"Encryption time: {time.time() - start_time} ms")
  47. # Decryption of Ciphertext
  48. start_time = time.time()
  49. ptDec1 = cc.Decrypt(keyPair1.secretKey, ct1)
  50. print(f"Decryption time: {time.time() - start_time} ms")
  51. ptDec1.SetLength(pt.GetLength())
  52. # Perform Key Generation Operation
  53. print("Bob Running key generation ...")
  54. start_time = time.time()
  55. keyPair2 = cc.KeyGen()
  56. print(f"Key generation time: {time.time() - start_time} ms")
  57. if not keyPair2.good():
  58. print("Bob Key generation failed!")
  59. return False
  60. # Perform the proxy re-encryption key generation operation.
  61. # This generates the keys which are used to perform the key switching.
  62. print("\nGenerating proxy re-encryption key...")
  63. start_time = time.time()
  64. reencryptionKey12 = cc.ReKeyGen(keyPair1.secretKey, keyPair2.publicKey)
  65. print(f"Key generation time: {time.time() - start_time} ms")
  66. # Re-Encryption
  67. start_time = time.time()
  68. ct2 = cc.ReEncrypt(ct1, reencryptionKey12)
  69. print(f"Re-Encryption time: {time.time() - start_time} ms")
  70. # Decryption of Ciphertext
  71. start_time = time.time()
  72. ptDec2 = cc.Decrypt(keyPair2.secretKey, ct2)
  73. print(f"Decryption time: {time.time() - start_time} ms")
  74. ptDec2.SetLength(pt.GetLength())
  75. unpacked0 = pt.GetPackedValue()
  76. unpacked1 = ptDec1.GetPackedValue()
  77. unpacked2 = ptDec2.GetPackedValue()
  78. good = True
  79. # note that OpenFHE assumes that plaintext is in the range of -p/2..p/2
  80. # to recover 0...q simply add q if the unpacked value is negative
  81. for j in range(pt.GetLength()):
  82. if unpacked1[j] < 0:
  83. unpacked1[j] += plaintextModulus
  84. if unpacked2[j] < 0:
  85. unpacked2[j] += plaintextModulus
  86. # compare all the results for correctness
  87. for j in range(pt.GetLength()):
  88. if (unpacked0[j] != unpacked1[j]) or (unpacked0[j] != unpacked2[j]):
  89. print(f"{j}, {unpacked0[j]}, {unpacked1[j]}, {unpacked2[j]}")
  90. good = False
  91. if good:
  92. print("PRE passes")
  93. else:
  94. print("PRE fails")
  95. print("Execution Completed.")
  96. return good
  97. if __name__ == "__main__":
  98. main()