threshold-fhe-5p.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from openfhe import *
  2. from math import log2
  3. def main():
  4. print("\n=================RUNNING FOR BFVrns======================\n")
  5. RunBFVrns()
  6. def RunBFVrns():
  7. plaintextModulus = 65537
  8. sigma = 3.2
  9. securityLevel = SecurityLevel.HEStd_128_classic
  10. batchSize = 16
  11. multDepth = 4
  12. digitSize = 30
  13. dcrtBits = 60
  14. parameters = CCParamsBFVRNS()
  15. parameters.SetPlaintextModulus(plaintextModulus)
  16. parameters.SetSecurityLevel(securityLevel)
  17. parameters.SetStandardDeviation(sigma)
  18. parameters.SetSecretKeyDist(UNIFORM_TERNARY)
  19. parameters.SetMultiplicativeDepth(multDepth)
  20. parameters.SetBatchSize(batchSize)
  21. parameters.SetDigitSize(digitSize)
  22. parameters.SetScalingModSize(dcrtBits)
  23. parameters.SetThresholdNumOfParties(5)
  24. parameters.SetMultiplicationTechnique(HPSPOVERQLEVELED)
  25. cc = GenCryptoContext(parameters)
  26. # Enable features you wish to use
  27. cc.Enable(PKE)
  28. cc.Enable(KEYSWITCH)
  29. cc.Enable(LEVELEDSHE)
  30. cc.Enable(ADVANCEDSHE)
  31. cc.Enable(MULTIPARTY)
  32. ##########################################################
  33. # Set-up of parameters
  34. ##########################################################
  35. # Output the generated parameters
  36. print(f"p = {cc.GetPlaintextModulus()}")
  37. print(f"n = {cc.GetCyclotomicOrder() / 2}")
  38. print(f"log2 q = {log2(cc.GetModulus())}")
  39. ############################################################
  40. ## Perform Key Generation Operation
  41. ############################################################
  42. print("Running key generation (used for source data)...")
  43. # Round 1 (party A)
  44. print("Round 1 (party A) started.")
  45. kp1 = cc.KeyGen()
  46. kp2 = cc.MultipartyKeyGen(kp1.publicKey)
  47. kp3 = cc.MultipartyKeyGen(kp2.publicKey)
  48. kp4 = cc.MultipartyKeyGen(kp3.publicKey)
  49. kp5 = cc.MultipartyKeyGen(kp4.publicKey)
  50. # Generate evalmult key part for A
  51. evalMultKey = cc.KeySwitchGen(kp1.secretKey, kp1.secretKey)
  52. evalMultKey2 = cc.MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey)
  53. evalMultKey3 = cc.MultiKeySwitchGen(kp3.secretKey, kp3.secretKey, evalMultKey)
  54. evalMultKey4 = cc.MultiKeySwitchGen(kp4.secretKey, kp4.secretKey, evalMultKey)
  55. evalMultKey5 = cc.MultiKeySwitchGen(kp5.secretKey, kp5.secretKey, evalMultKey)
  56. evalMultAB = cc.MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey.GetKeyTag())
  57. evalMultABC = cc.MultiAddEvalKeys(evalMultAB, evalMultKey3, kp3.publicKey.GetKeyTag())
  58. evalMultABCD = cc.MultiAddEvalKeys(evalMultABC, evalMultKey4, kp4.publicKey.GetKeyTag())
  59. evalMultABCDE = cc.MultiAddEvalKeys(evalMultABCD, evalMultKey5, kp5.publicKey.GetKeyTag())
  60. evalMultEABCDE = cc.MultiMultEvalKey(kp5.secretKey, evalMultABCDE, kp5.publicKey.GetKeyTag())
  61. evalMultDABCDE = cc.MultiMultEvalKey(kp4.secretKey, evalMultABCDE, kp5.publicKey.GetKeyTag())
  62. evalMultCABCDE = cc.MultiMultEvalKey(kp3.secretKey, evalMultABCDE, kp5.publicKey.GetKeyTag())
  63. evalMultBABCDE = cc.MultiMultEvalKey(kp2.secretKey, evalMultABCDE, kp5.publicKey.GetKeyTag())
  64. evalMultAABCDE = cc.MultiMultEvalKey(kp1.secretKey, evalMultABCDE, kp5.publicKey.GetKeyTag())
  65. evalMultDEABCDE = cc.MultiAddEvalMultKeys(evalMultEABCDE, evalMultDABCDE, evalMultEABCDE.GetKeyTag())
  66. evalMultCDEABCDE = cc.MultiAddEvalMultKeys(evalMultCABCDE, evalMultDEABCDE, evalMultCABCDE.GetKeyTag())
  67. evalMultBCDEABCDE = cc.MultiAddEvalMultKeys(evalMultBABCDE, evalMultCDEABCDE, evalMultBABCDE.GetKeyTag())
  68. evalMultFinal = cc.MultiAddEvalMultKeys(evalMultAABCDE, evalMultBCDEABCDE, kp5.publicKey.GetKeyTag())
  69. cc.InsertEvalMultKey([evalMultFinal])
  70. print("Round 1 of key generation completed.")
  71. ############################################################
  72. ## EvalSum Key Generation
  73. ############################################################
  74. print("Running evalsum key generation (used for source data)...")
  75. # Generate evalsum key part for A
  76. cc.EvalSumKeyGen(kp1.secretKey)
  77. evalSumKeys = cc.GetEvalSumKeyMap(kp1.secretKey.GetKeyTag())
  78. evalSumKeysB = cc.MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey.GetKeyTag())
  79. evalSumKeysC = cc.MultiEvalSumKeyGen(kp3.secretKey, evalSumKeys, kp3.publicKey.GetKeyTag())
  80. evalSumKeysD = cc.MultiEvalSumKeyGen(kp4.secretKey, evalSumKeys, kp4.publicKey.GetKeyTag())
  81. evalSumKeysE = cc.MultiEvalSumKeyGen(kp5.secretKey, evalSumKeys, kp5.publicKey.GetKeyTag())
  82. evalSumKeysAB = cc.MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey.GetKeyTag())
  83. evalSumKeysABC = cc.MultiAddEvalSumKeys(evalSumKeysC, evalSumKeysAB, kp3.publicKey.GetKeyTag())
  84. evalSumKeysABCD = cc.MultiAddEvalSumKeys(evalSumKeysABC, evalSumKeysD, kp4.publicKey.GetKeyTag())
  85. evalSumKeysJoin = cc.MultiAddEvalSumKeys(evalSumKeysE, evalSumKeysABCD, kp5.publicKey.GetKeyTag())
  86. cc.InsertEvalSumKey(evalSumKeysJoin)
  87. print("Evalsum key generation completed.")
  88. ############################################################
  89. ## Encode source data
  90. ############################################################
  91. vectorOfInts1 = [1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1, 0]
  92. vectorOfInts2 = [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]
  93. vectorOfInts3 = [2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0]
  94. plaintext1 = cc.MakePackedPlaintext(vectorOfInts1)
  95. plaintext2 = cc.MakePackedPlaintext(vectorOfInts2)
  96. plaintext3 = cc.MakePackedPlaintext(vectorOfInts3)
  97. ############################################################
  98. ## Encryption
  99. ############################################################
  100. ciphertext1 = cc.Encrypt(kp5.publicKey, plaintext1)
  101. ciphertext2 = cc.Encrypt(kp5.publicKey, plaintext2)
  102. ciphertext3 = cc.Encrypt(kp5.publicKey, plaintext3)
  103. ############################################################
  104. ## Homomorphic Operations
  105. ############################################################
  106. ciphertextAdd12 = cc.EvalAdd(ciphertext1, ciphertext2)
  107. ciphertextAdd123 = cc.EvalAdd(ciphertextAdd12, ciphertext3)
  108. ciphertextMult1 = cc.EvalMult(ciphertext1, ciphertext1)
  109. ciphertextMult2 = cc.EvalMult(ciphertextMult1, ciphertext1)
  110. ciphertextMult3 = cc.EvalMult(ciphertextMult2, ciphertext1)
  111. ciphertextMult = cc.EvalMult(ciphertextMult3, ciphertext1)
  112. ciphertextEvalSum = cc.EvalSum(ciphertext3, batchSize)
  113. ############################################################
  114. ## Decryption after Accumulation Operation on Encrypted Data with Multiparty
  115. ############################################################
  116. # Distributed decryption
  117. # partial decryption by party A
  118. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextAdd123], kp1.secretKey)
  119. # partial decryption by party B
  120. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextAdd123], kp2.secretKey)
  121. # partial decryption by party C
  122. ciphertextPartial3 = cc.MultipartyDecryptMain([ciphertextAdd123], kp3.secretKey)
  123. # partial decryption by party D
  124. ciphertextPartial4 = cc.MultipartyDecryptMain([ciphertextAdd123], kp4.secretKey)
  125. # partial decryption by party E
  126. ciphertextPartial5 = cc.MultipartyDecryptMain([ciphertextAdd123], kp5.secretKey)
  127. partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0], ciphertextPartial3[0],
  128. ciphertextPartial4[0], ciphertextPartial5[0]]
  129. plaintextMultipartyNew = cc.MultipartyDecryptFusion(partialCiphertextVec)
  130. print("\n Original Plaintext: \n")
  131. print(plaintext1)
  132. print(plaintext2)
  133. print(plaintext3)
  134. plaintextMultipartyNew.SetLength(plaintext1.GetLength())
  135. print("\n Resulting Fused Plaintext: \n")
  136. print(plaintextMultipartyNew)
  137. print("\n")
  138. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextMult], kp1.secretKey)
  139. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextMult], kp2.secretKey)
  140. ciphertextPartial3 = cc.MultipartyDecryptMain([ciphertextMult], kp3.secretKey)
  141. ciphertextPartial4 = cc.MultipartyDecryptMain([ciphertextMult], kp4.secretKey)
  142. ciphertextPartial5 = cc.MultipartyDecryptMain([ciphertextMult], kp5.secretKey)
  143. partialCiphertextVecMult = [ciphertextPartial1[0], ciphertextPartial2[0], ciphertextPartial3[0],
  144. ciphertextPartial4[0], ciphertextPartial5[0]]
  145. plaintextMultipartyMult = cc.MultipartyDecryptFusion(partialCiphertextVecMult)
  146. plaintextMultipartyMult.SetLength(plaintext1.GetLength())
  147. print("\n Resulting Fused Plaintext after Multiplication of plaintexts 1 and 3: \n")
  148. print(plaintextMultipartyMult)
  149. print("\n")
  150. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextEvalSum], kp1.secretKey)
  151. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextEvalSum], kp2.secretKey)
  152. ciphertextPartial3 = cc.MultipartyDecryptMain([ciphertextEvalSum], kp3.secretKey)
  153. ciphertextPartial4 = cc.MultipartyDecryptMain([ciphertextEvalSum], kp4.secretKey)
  154. ciphertextPartial5 = cc.MultipartyDecryptMain([ciphertextEvalSum], kp5.secretKey)
  155. partialCiphertextVecEvalSum = [ciphertextPartial1[0], ciphertextPartial2[0], ciphertextPartial3[0],
  156. ciphertextPartial4[0], ciphertextPartial5[0]]
  157. plaintextMultipartyEvalSum = cc.MultipartyDecryptFusion(partialCiphertextVecEvalSum)
  158. plaintextMultipartyEvalSum.SetLength(plaintext1.GetLength())
  159. print("\n Fused result after the Summation of ciphertext 3: \n")
  160. print(plaintextMultipartyEvalSum)
  161. if __name__ == "__main__":
  162. main()