threshold-fhe.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. from openfhe import *
  2. from math import log2
  3. def main():
  4. print("\n=================RUNNING FOR BGVrns - Additive =====================")
  5. RunBGVrnsAdditive()
  6. print("\n=================RUNNING FOR BFVrns=====================")
  7. RunBFVrns()
  8. print("\n=================RUNNING FOR CKKS=====================")
  9. RunCKKS()
  10. def RunBGVrnsAdditive():
  11. parameters = CCParamsBGVRNS()
  12. parameters.SetPlaintextModulus(65537)
  13. # NOISE_FLOODING_MULTIPARTY adds extra noise to the ciphertext before decrypting
  14. # and is most secure mode of threshold FHE for BFV and BGV.
  15. parameters.SetMultipartyMode(NOISE_FLOODING_MULTIPARTY)
  16. cc = GenCryptoContext(parameters)
  17. # Enable Features you wish to use
  18. cc.Enable(PKE)
  19. cc.Enable(KEYSWITCH)
  20. cc.Enable(LEVELEDSHE)
  21. cc.Enable(ADVANCEDSHE)
  22. cc.Enable(MULTIPARTY)
  23. ##########################################################
  24. # Set-up of parameters
  25. ##########################################################
  26. # Print out the parameters
  27. print(f"p = {cc.GetPlaintextModulus()}")
  28. print(f"n = {cc.GetCyclotomicOrder()/2}")
  29. print(f"lo2 q = {log2(cc.GetModulus())}")
  30. ############################################################
  31. ## Perform Key Generation Operation
  32. ############################################################
  33. print("Running key generation (used for source data)...")
  34. # generate the public key for first share
  35. kp1 = cc.KeyGen()
  36. # generate the public key for two shares
  37. kp2 = cc.MultipartyKeyGen(kp1.publicKey)
  38. # generate the public key for all three secret shares
  39. kp3 = cc.MultipartyKeyGen(kp2.publicKey)
  40. if not kp1.good():
  41. print("Key generation failed!")
  42. return 1
  43. if not kp2.good():
  44. print("Key generation failed!")
  45. return 1
  46. if not kp3.good():
  47. print("Key generation failed!")
  48. return 1
  49. ############################################################
  50. ## Encode source data
  51. ############################################################
  52. vectorOfInts1 = [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
  53. vectorOfInts2 = [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]
  54. vectorOfInts3 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0]
  55. plaintext1 = cc.MakePackedPlaintext(vectorOfInts1)
  56. plaintext2 = cc.MakePackedPlaintext(vectorOfInts2)
  57. plaintext3 = cc.MakePackedPlaintext(vectorOfInts3)
  58. ############################################################
  59. ## Encryption
  60. ############################################################
  61. ciphertext1 = cc.Encrypt(kp3.publicKey, plaintext1)
  62. ciphertext2 = cc.Encrypt(kp3.publicKey, plaintext2)
  63. ciphertext3 = cc.Encrypt(kp3.publicKey, plaintext3)
  64. ############################################################
  65. ## EvalAdd Operation on Re-Encrypted Data
  66. ############################################################
  67. ciphertextAdd12 = cc.EvalAdd(ciphertext1, ciphertext2)
  68. ciphertextAdd123 = cc.EvalAdd(ciphertextAdd12, ciphertext3)
  69. ############################################################
  70. ## Decryption after Accumulation Operation on Encrypted Data with Multiparty
  71. ############################################################
  72. # partial decryption by first party
  73. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextAdd123], kp1.secretKey)
  74. # partial decryption by second party
  75. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextAdd123], kp2.secretKey)
  76. # partial decryption by third party
  77. ciphertextPartial3 = cc.MultipartyDecryptMain([ciphertextAdd123], kp3.secretKey)
  78. partialCiphertextVec = []
  79. partialCiphertextVec.append(ciphertextPartial1[0])
  80. partialCiphertextVec.append(ciphertextPartial2[0])
  81. partialCiphertextVec.append(ciphertextPartial3[0])
  82. # partial decryption are combined together
  83. plaintextMultipartyNew = cc.MultipartyDecryptFusion(partialCiphertextVec)
  84. print("\n Original Plaintext: \n")
  85. print(plaintext1)
  86. print(plaintext2)
  87. print(plaintext3)
  88. plaintextMultipartyNew.SetLength(plaintext1.GetLength())
  89. print("\n Resulting Fused Plaintext adding 3 ciphertexts: \n")
  90. print(plaintextMultipartyNew)
  91. print("\n")
  92. def RunBFVrns():
  93. batchSize = 16
  94. parameters = CCParamsBFVRNS()
  95. parameters.SetPlaintextModulus(65537)
  96. parameters.SetBatchSize(batchSize)
  97. parameters.SetMultiplicativeDepth(2)
  98. ## NOISE_FLOODING_MULTIPARTY adds extra noise to the ciphertext before decrypting
  99. ## and is most secure mode of threshold FHE for BFV and BGV.
  100. parameters.SetMultipartyMode(NOISE_FLOODING_MULTIPARTY)
  101. cc = GenCryptoContext(parameters)
  102. cc.Enable(PKE)
  103. cc.Enable(KEYSWITCH)
  104. cc.Enable(LEVELEDSHE)
  105. cc.Enable(ADVANCEDSHE)
  106. cc.Enable(MULTIPARTY)
  107. ##########################################################
  108. # Set-up of parameters
  109. ##########################################################
  110. # Output the generated parameters
  111. print(f"p = {cc.GetPlaintextModulus()}")
  112. print(f"n = {cc.GetCyclotomicOrder()/2}")
  113. print(f"lo2 q = {log2(cc.GetModulus())}")
  114. ############################################################
  115. # Perform Key Generation Operation
  116. ############################################################
  117. print("Running key generation (used for source data)...")
  118. # Round 1 (party A)
  119. print("Round 1 (party A) started.")
  120. kp1 = cc.KeyGen()
  121. # Generate evalmult key part for A
  122. evalMultKey = cc.KeySwitchGen(kp1.secretKey, kp1.secretKey)
  123. # Generate evalsum key part for A
  124. cc.EvalSumKeyGen(kp1.secretKey)
  125. evalSumKeys = cc.GetEvalSumKeyMap(kp1.secretKey.GetKeyTag())
  126. print("Round 1 of key generation completed.")
  127. # Round 2 (party B)
  128. print("Round 2 (party B) started.")
  129. print("Joint public key for (s_a + s_b) is generated...")
  130. kp2 = cc.MultipartyKeyGen(kp1.publicKey)
  131. evalMultKey2 = cc.MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey)
  132. print("Joint evaluation multiplication key for (s_a + s_b) is generated...")
  133. evalMultAB = cc.MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey.GetKeyTag())
  134. print("Joint evaluation multiplication key (s_a + s_b) is transformed into s_b*(s_a + s_b)...")
  135. evalMultBAB = cc.MultiMultEvalKey(kp2.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  136. evalSumKeysB = cc.MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey.GetKeyTag())
  137. print("Joint evaluation summation key for (s_a + s_b) is generated...")
  138. evalSumKeysJoin = cc.MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey.GetKeyTag())
  139. cc.InsertEvalSumKey(evalSumKeysJoin)
  140. print("Round 2 of key generation completed.")
  141. print("Round 3 (party A) started.")
  142. print("Joint key (s_a + s_b) is transformed into s_a*(s_a + s_b)...")
  143. evalMultAAB = cc.MultiMultEvalKey(kp1.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  144. print("Computing the final evaluation multiplication key for (s_a + s_b)*(s_a + s_b)...")
  145. evalMultFinal = cc.MultiAddEvalMultKeys(evalMultAAB, evalMultBAB, evalMultAB.GetKeyTag())
  146. cc.InsertEvalMultKey([evalMultFinal])
  147. print("Round 3 of key generation completed.")
  148. ############################################################
  149. ## Encode source data
  150. ############################################################
  151. vectorOfInts1 = [1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1, 0]
  152. vectorOfInts2 = [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]
  153. vectorOfInts3 = [2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0]
  154. plaintext1 = cc.MakePackedPlaintext(vectorOfInts1)
  155. plaintext2 = cc.MakePackedPlaintext(vectorOfInts2)
  156. plaintext3 = cc.MakePackedPlaintext(vectorOfInts3)
  157. ############################################################
  158. ## Encryption
  159. ############################################################
  160. ciphertext1 = cc.Encrypt(kp2.publicKey, plaintext1)
  161. ciphertext2 = cc.Encrypt(kp2.publicKey, plaintext2)
  162. ciphertext3 = cc.Encrypt(kp2.publicKey, plaintext3)
  163. ############################################################
  164. ## Homomorphic Operations
  165. ############################################################
  166. ciphertextAdd12 = cc.EvalAdd(ciphertext1, ciphertext2)
  167. ciphertextAdd123 = cc.EvalAdd(ciphertextAdd12, ciphertext3)
  168. ciphertextMult = cc.EvalMult(ciphertext1, ciphertext3)
  169. ciphertextEvalSum = cc.EvalSum(ciphertext3, batchSize)
  170. ############################################################
  171. ## Decryption after Accumulation Operation on Encrypted Data with Multiparty
  172. ############################################################
  173. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextAdd123], kp1.secretKey)
  174. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextAdd123], kp2.secretKey)
  175. partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
  176. plaintextMultipartyNew = cc.MultipartyDecryptFusion(partialCiphertextVec)
  177. print("\n Original Plaintext: \n")
  178. print(plaintext1)
  179. print(plaintext2)
  180. print(plaintext3)
  181. plaintextMultipartyNew.SetLength(plaintext1.GetLength())
  182. print("\n Resulting Fused Plaintext: \n")
  183. print(plaintextMultipartyNew)
  184. print("\n")
  185. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextMult], kp1.secretKey)
  186. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextMult], kp2.secretKey)
  187. partialCiphertextVecMult = [ciphertextPartial1[0], ciphertextPartial2[0]]
  188. plaintextMultipartyMult = cc.MultipartyDecryptFusion(partialCiphertextVecMult)
  189. plaintextMultipartyMult.SetLength(plaintext1.GetLength())
  190. print("\n Resulting Fused Plaintext after Multiplication of plaintexts 1 and 3: \n")
  191. print(plaintextMultipartyMult)
  192. print("\n")
  193. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextEvalSum], kp1.secretKey)
  194. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextEvalSum], kp2.secretKey)
  195. partialCiphertextVecEvalSum = [ciphertextPartial1[0], ciphertextPartial2[0]]
  196. plaintextMultipartyEvalSum = cc.MultipartyDecryptFusion(partialCiphertextVecEvalSum)
  197. plaintextMultipartyEvalSum.SetLength(plaintext1.GetLength())
  198. print("\n Fused result after summation of ciphertext 3: \n")
  199. print(plaintextMultipartyEvalSum)
  200. def RunCKKS():
  201. batchSize = 16
  202. parameters = CCParamsCKKSRNS()
  203. parameters.SetMultiplicativeDepth(3)
  204. parameters.SetScalingModSize(50)
  205. parameters.SetBatchSize(batchSize)
  206. cc = GenCryptoContext(parameters)
  207. # Enable features you wish to use
  208. cc.Enable(PKE)
  209. cc.Enable(KEYSWITCH)
  210. cc.Enable(LEVELEDSHE)
  211. cc.Enable(ADVANCEDSHE)
  212. cc.Enable(MULTIPARTY)
  213. ##########################################################
  214. # Set-up of parameters
  215. ##########################################################
  216. # Output the generated parameters
  217. print(f"p = {cc.GetPlaintextModulus()}")
  218. print(f"n = {cc.GetCyclotomicOrder()/2}")
  219. print(f"lo2 q = {log2(cc.GetModulus())}")
  220. ############################################################
  221. ## Perform Key Generation Operation
  222. ############################################################
  223. print("Running key generation (used for source data)...")
  224. # Round 1 (party A)
  225. print("Round 1 (party A) started.")
  226. kp1 = cc.KeyGen()
  227. # Generate evalmult key part for A
  228. evalMultKey = cc.KeySwitchGen(kp1.secretKey, kp1.secretKey)
  229. # Generate evalsum key part for A
  230. cc.EvalSumKeyGen(kp1.secretKey)
  231. evalSumKeys = cc.GetEvalSumKeyMap(kp1.secretKey.GetKeyTag())
  232. print("Round 1 of key generation completed.")
  233. # Round 2 (party B)
  234. print("Round 2 (party B) started.")
  235. print("Joint public key for (s_a + s_b) is generated...")
  236. kp2 = cc.MultipartyKeyGen(kp1.publicKey)
  237. evalMultKey2 = cc.MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey)
  238. print("Joint evaluation multiplication key for (s_a + s_b) is generated...")
  239. evalMultAB = cc.MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey.GetKeyTag())
  240. print("Joint evaluation multiplication key (s_a + s_b) is transformed into s_b*(s_a + s_b)...")
  241. evalMultBAB = cc.MultiMultEvalKey(kp2.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  242. evalSumKeysB = cc.MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey.GetKeyTag())
  243. print("Joint evaluation summation key for (s_a + s_b) is generated...")
  244. evalSumKeysJoin = cc.MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey.GetKeyTag())
  245. cc.InsertEvalSumKey(evalSumKeysJoin)
  246. print("Round 2 of key generation completed.")
  247. print("Round 3 (party A) started.")
  248. print("Joint key (s_a + s_b) is transformed into s_a*(s_a + s_b)...")
  249. evalMultAAB = cc.MultiMultEvalKey(kp1.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  250. print("Computing the final evaluation multiplication key for (s_a + s_b)*(s_a + s_b)...")
  251. evalMultFinal = cc.MultiAddEvalMultKeys(evalMultAAB, evalMultBAB, evalMultAB.GetKeyTag())
  252. cc.InsertEvalMultKey([evalMultFinal])
  253. print("Round 3 of key generation completed.")
  254. ############################################################
  255. ## Encode source data
  256. ############################################################
  257. vectorOfInts1 = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
  258. vectorOfInts2 = [1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  259. vectorOfInts3 = [2.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0]
  260. plaintext1 = cc.MakeCKKSPackedPlaintext(vectorOfInts1)
  261. plaintext2 = cc.MakeCKKSPackedPlaintext(vectorOfInts2)
  262. plaintext3 = cc.MakeCKKSPackedPlaintext(vectorOfInts3)
  263. ############################################################
  264. ## Encryption
  265. ############################################################
  266. ciphertext1 = cc.Encrypt(kp2.publicKey, plaintext1)
  267. ciphertext2 = cc.Encrypt(kp2.publicKey, plaintext2)
  268. ciphertext3 = cc.Encrypt(kp2.publicKey, plaintext3)
  269. ############################################################
  270. ## EvalAdd Operation on Re-Encrypted Data
  271. ############################################################
  272. ciphertextAdd12 = cc.EvalAdd(ciphertext1, ciphertext2)
  273. ciphertextAdd123 = cc.EvalAdd(ciphertextAdd12, ciphertext3)
  274. ciphertextMultTemp = cc.EvalMult(ciphertext1, ciphertext3)
  275. ciphertextMult = cc.ModReduce(ciphertextMultTemp)
  276. ciphertextEvalSum = cc.EvalSum(ciphertext3, batchSize)
  277. ############################################################
  278. ## Decryption after Accumulation Operation on Encrypted Data with Multiparty
  279. ############################################################
  280. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextAdd123], kp1.secretKey)
  281. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextAdd123], kp2.secretKey)
  282. partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0]]
  283. plaintextMultipartyNew = cc.MultipartyDecryptFusion(partialCiphertextVec)
  284. print("\n Original Plaintext: \n")
  285. print(plaintext1)
  286. print(plaintext2)
  287. print(plaintext3)
  288. plaintextMultipartyNew.SetLength(plaintext1.GetLength())
  289. print("\n Resulting Fused Plaintext: \n")
  290. print(plaintextMultipartyNew)
  291. print("\n")
  292. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextMult], kp1.secretKey)
  293. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextMult], kp2.secretKey)
  294. partialCiphertextVecMult = [ciphertextPartial1[0], ciphertextPartial2[0]]
  295. plaintextMultipartyMult = cc.MultipartyDecryptFusion(partialCiphertextVecMult)
  296. plaintextMultipartyMult.SetLength(plaintext1.GetLength())
  297. print("\n Resulting Fused Plaintext after Multiplication of plaintexts 1 and 3: \n")
  298. print(plaintextMultipartyMult)
  299. print("\n")
  300. ciphertextPartial1 = cc.MultipartyDecryptLead([ciphertextEvalSum], kp1.secretKey)
  301. ciphertextPartial2 = cc.MultipartyDecryptMain([ciphertextEvalSum], kp2.secretKey)
  302. partialCiphertextVecEvalSum = [ciphertextPartial1[0], ciphertextPartial2[0]]
  303. plaintextMultipartyEvalSum = cc.MultipartyDecryptFusion(partialCiphertextVecEvalSum)
  304. plaintextMultipartyEvalSum.SetLength(plaintext1.GetLength())
  305. print("\n Fused result after the Summation of ciphertext 3: \n")
  306. print(plaintextMultipartyEvalSum)
  307. if __name__ == '__main__':
  308. main()