interactive-bootstrapping.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from openfhe import *
  2. def main():
  3. # the scaling technigue can be changed to FIXEDMANUAL, FIXEDAUTO, or FLEXIBLEAUTOEXT
  4. ThresholdFHE(FLEXIBLEAUTO)
  5. Chebyshev(FLEXIBLEAUTO)
  6. def ThresholdFHE(scaleTech):
  7. # if scaleTech not in [FIXEDMANUAL, FIXEDAUTO, FLEXIBLEAUTOEXT]:
  8. # errMsg = "ERROR: Scaling technique is not supported!"
  9. # raise Exception(errMsg)
  10. print(f"Threshold FHE example with Scaling Technique {scaleTech}")
  11. parameters = CCParamsCKKSRNS()
  12. # 1 extra level needs to be added for FIXED* modes (2 extra levels for FLEXIBLE* modes) to the multiplicative depth
  13. # to support 2-party interactive bootstrapping
  14. depth = 7
  15. parameters.SetMultiplicativeDepth(depth)
  16. parameters.SetScalingModSize(50)
  17. parameters.SetBatchSize(16)
  18. parameters.SetScalingTechnique(scaleTech)
  19. cc = GenCryptoContext(parameters)
  20. cc.Enable(PKE)
  21. cc.Enable(LEVELEDSHE)
  22. cc.Enable(ADVANCEDSHE)
  23. cc.Enable(MULTIPARTY)
  24. #############################################################
  25. # Perform Key Generation Operation
  26. #############################################################
  27. print("Running key generation (used for source data)...")
  28. print("Round 1 (party A) started.")
  29. kp1 = cc.KeyGen()
  30. evalMultKey = cc.KeySwitchGen(kp1.secretKey, kp1.secretKey)
  31. print("Round 1 of key generation completed.")
  32. #############################################################
  33. print("Round 2 (party B) started.")
  34. print("Joint public key for (s_a + s_b) is generated...")
  35. kp2 = cc.MultipartyKeyGen(kp1.publicKey)
  36. input = [-0.9, -0.8, -0.6, -0.4, -0.2, 0., 0.2, 0.4, 0.6, 0.8, 0.9]
  37. # This plaintext only has 3 RNS limbs, the minimum needed to perform 2-party interactive bootstrapping for FLEXIBLEAUTO
  38. plaintext1 = cc.MakeCKKSPackedPlaintext(input, 1, depth - 2)
  39. ciphertext1 = cc.Encrypt(kp2.publicKey, plaintext1)
  40. # INTERACTIVE BOOTSTRAPPING STARTS
  41. # under the hood it reduces to two towers
  42. ciphertext1 = cc.IntBootAdjustScale(ciphertext1)
  43. print("IntBootAdjustScale Succeeded")
  44. # masked decryption on the server: c0 = b + a*s0
  45. ciphertextOutput1 = cc.IntBootDecrypt(kp1.secretKey, ciphertext1)
  46. print("IntBootDecrypt on Server Succeeded")
  47. ciphertext2 = ciphertext1.Clone()
  48. ciphertext2.SetElements([ciphertext2.GetElements()[1]])
  49. # masked decryption on the client: c1 = a*s1
  50. ciphertextOutput2 = cc.IntBootDecrypt(kp2.secretKey, ciphertext2)
  51. print("IntBootDecrypt on Client Succeeded")
  52. # Encryption of masked decryption c1 = a*s1
  53. ciphertextOutput2 = cc.IntBootEncrypt(kp2.publicKey, ciphertextOutput2)
  54. print("IntBootEncrypt on Client Succeeded")
  55. # Compute Enc(c1) + c0
  56. ciphertextOutput = cc.IntBootAdd(ciphertextOutput2, ciphertextOutput1)
  57. print("IntBootAdd on Server Succeeded")
  58. # INTERACTIVE BOOTSTRAPPING ENDS
  59. # distributed decryption
  60. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextOutput], kp1.secretKey)
  61. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
  62. partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
  63. plaintextMultiparty = cc.MultipartyDecryptFusion(partialCiphertextVec)
  64. plaintextMultiparty.SetLength(len(input))
  65. print(f"Original plaintext \n\t {plaintext1.GetCKKSPackedValue()}")
  66. print(f"Result after bootstrapping \n\t {plaintextMultiparty.GetCKKSPackedValue()}")
  67. def Chebyshev(scaleTech):
  68. # if scaleTech not in [FIXEDMANUAL, FIXEDAUTO, FLEXIBLEAUTOEXT]:
  69. # errMsg = "ERROR: Scaling technique is not supported!"
  70. # raise Exception(errMsg)
  71. print(f"Threshold FHE example with Scaling Technique {scaleTech}")
  72. parameters = CCParamsCKKSRNS()
  73. # 1 extra level needs to be added for FIXED* modes (2 extra levels for FLEXIBLE* modes) to the multiplicative depth
  74. # to support 2-party interactive bootstrapping
  75. parameters.SetMultiplicativeDepth(8)
  76. parameters.SetScalingModSize(50)
  77. parameters.SetBatchSize(16)
  78. parameters.SetScalingTechnique(scaleTech)
  79. cc = GenCryptoContext(parameters)
  80. # enable features that you wish to use
  81. cc.Enable(PKE)
  82. cc.Enable(LEVELEDSHE)
  83. cc.Enable(ADVANCEDSHE)
  84. cc.Enable(MULTIPARTY)
  85. ############################################################
  86. # Perform Key Generation Operation
  87. ############################################################
  88. print("Running key generation (used for source data)...")
  89. print("Round 1 (party A) started.")
  90. kp1 = cc.KeyGen()
  91. evalMultKey = cc.KeySwitchGen(kp1.secretKey, kp1.secretKey)
  92. cc.EvalSumKeyGen(kp1.secretKey)
  93. evalSumKeys = cc.GetEvalSumKeyMap(kp1.secretKey.GetKeyTag())
  94. print("Round 1 of key generation completed.")
  95. ############################################################
  96. print("Round 2 (party B) started.")
  97. print("Joint public key for (s_a + s_b) is generated...")
  98. kp2 = cc.MultipartyKeyGen(kp1.publicKey)
  99. evalMultKey2 = cc.MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey)
  100. print("Joint evaluation multiplication key for (s_a + s_b) is generated...")
  101. evalMultAB = cc.MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey.GetKeyTag())
  102. print("Joint evaluation multiplication key (s_a + s_b) is transformed into s_b*(s_a + s_b)...")
  103. evalMultBAB = cc.MultiMultEvalKey(kp2.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  104. evalSumKeysB = cc.MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey.GetKeyTag())
  105. print("Joint evaluation summation key for (s_a + s_b) is generated...")
  106. evalSumKeysJoin = cc.MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey.GetKeyTag())
  107. cc.InsertEvalSumKey(evalSumKeysJoin)
  108. print("Round 2 of key generation completed.")
  109. print("Round 3 (party A) started.")
  110. print("Joint key (s_a + s_b) is transformed into s_a*(s_a + s_b)...")
  111. evalMultAAB = cc.MultiMultEvalKey(kp1.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  112. print("Computing the final evaluation multiplication key for (s_a + s_b)*(s_a + s_b)...")
  113. evalMultFinal = cc.MultiAddEvalMultKeys(evalMultAAB, evalMultBAB, evalMultAB.GetKeyTag())
  114. cc.InsertEvalMultKey([evalMultFinal])
  115. print("Round 3 of key generation completed.")
  116. input = [-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0]
  117. coefficients = [1.0, 0.558971, 0.0, -0.0943712, 0.0, 0.0215023, 0.0, -0.00505348, 0.0, 0.00119324,
  118. 0.0, -0.000281928, 0.0, 0.0000664347, 0.0, -0.0000148709]
  119. a = -4
  120. b = 4
  121. plaintext1 = cc.MakeCKKSPackedPlaintext(input)
  122. ciphertext1 = cc.Encrypt(kp2.publicKey, plaintext1)
  123. # The Chebyshev series interpolation requires 6 levels
  124. ciphertext1 = cc.EvalChebyshevSeries(ciphertext1, coefficients, a, b)
  125. print("Ran Chebyshev interpolation")
  126. # INTERACTIVE BOOTSTRAPPING STARTS
  127. ciphertext1 = cc.IntBootAdjustScale(ciphertext1)
  128. print("IntBootAdjustScale Succeeded")
  129. # masked decryption on the server: c0 = b + a*s0
  130. ciphertextOutput1 = cc.IntBootDecrypt(kp1.secretKey, ciphertext1)
  131. print("IntBootDecrypt on Server Succeeded")
  132. ciphertext2 = ciphertext1.Clone()
  133. ciphertext2.SetElements([ciphertext2.GetElements()[1]])
  134. # masked decryption on the client: c1 = a*s1
  135. ciphertextOutput2 = cc.IntBootDecrypt(kp2.secretKey, ciphertext2)
  136. print("IntBootDecrypt on Client Succeeded")
  137. # Encryption of masked decryption c1 = a*s1
  138. ciphertextOutput2 = cc.IntBootEncrypt(kp2.publicKey, ciphertextOutput2)
  139. print("IntBootEncrypt on Client Succeeded")
  140. # Compute Enc(c1) + c0
  141. ciphertextOutput = cc.IntBootAdd(ciphertextOutput2, ciphertextOutput1)
  142. print("IntBootAdd on Server Succeeded")
  143. # INTERACTIVE BOOTSTRAPPING ENDS
  144. # distributed decryption
  145. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextOutput], kp1.secretKey)
  146. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
  147. partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
  148. plaintextMultiparty = cc.MultipartyDecryptFusion(partialCiphertextVec)
  149. plaintextMultiparty.SetLength(len(input))
  150. print(f"\n Original Plaintext #1: \n {plaintext1}")
  151. print(f"\n Results of evaluating the polynomial with coefficients {coefficients} \n")
  152. print(f"\n Ciphertext result: {plaintextMultiparty}")
  153. print("\n Plaintext result: ( 0.0179885, 0.0474289, 0.119205, 0.268936, 0.5, 0.731064, 0.880795, 0.952571, 0.982011 ) \n")
  154. print("\n Exact result: ( 0.0179862, 0.0474259, 0.119203, 0.268941, 0.5, 0.731059, 0.880797, 0.952574, 0.982014 ) \n")
  155. print("\n Another round of Chebyshev interpolation after interactive bootstrapping: \n")
  156. ciphertextOutput = cc.EvalChebyshevSeries(ciphertextOutput, coefficients, a, b)
  157. print("Ran Chebyshev interpolation")
  158. # distributed decryption
  159. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextOutput], kp1.secretKey)
  160. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
  161. partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
  162. plaintextMultiparty = cc.MultipartyDecryptFusion(partialCiphertextVec)
  163. plaintextMultiparty.SetLength(len(input))
  164. print(f"\n Ciphertext result: {plaintextMultiparty}")
  165. print("\n Plaintext result: ( 0.504497, 0.511855, 0.529766, 0.566832, 0.622459, 0.675039, 0.706987, 0.721632, 0.727508 )")
  166. if __name__ == "__main__":
  167. main()