tckks-interactive-mp-bootstrapping-Chebyschev.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from openfhe import *
  2. def main():
  3. print("Interactive (3P) Bootstrapping Ciphertext [Chebyshev] (TCKKS) started ...")
  4. # Same test with different rescaling techniques in CKKS
  5. TCKKSCollectiveBoot(FIXEDMANUAL)
  6. TCKKSCollectiveBoot(FIXEDAUTO)
  7. TCKKSCollectiveBoot(FLEXIBLEAUTO)
  8. TCKKSCollectiveBoot(FLEXIBLEAUTOEXT)
  9. print("Interactive (3P) Bootstrapping Ciphertext [Chebyshev] (TCKKS) terminated gracefully!")
  10. def checkApproximateEquality(a, b, vectorSize, epsilon):
  11. allTrue = [1] * vectorSize
  12. tmp = [abs(a[i] - b[i]) <= epsilon for i in range(vectorSize)]
  13. if tmp != allTrue:
  14. print("IntMPBoot - Ctxt Chebyshev Failed:")
  15. print(f"- is diff <= eps?: {tmp}")
  16. else:
  17. print("SUCCESSFUL Bootstrapping!")
  18. def TCKKSCollectiveBoot(scaleTech):
  19. if scaleTech not in [FIXEDMANUAL, FIXEDAUTO, FLEXIBLEAUTO, FLEXIBLEAUTOEXT]:
  20. errMsg = "ERROR: Scaling technique is not supported!"
  21. raise Exception(errMsg)
  22. parameters = CCParamsCKKSRNS()
  23. secretKeyDist = UNIFORM_TERNARY
  24. parameters.SetSecretKeyDist(secretKeyDist)
  25. parameters.SetSecurityLevel(HEStd_128_classic)
  26. dcrtBits = 50
  27. firstMod = 60
  28. parameters.SetScalingModSize(dcrtBits)
  29. parameters.SetScalingTechnique(scaleTech)
  30. parameters.SetFirstModSize(firstMod)
  31. multiplicativeDepth = 10 # Adjust according to your requirements
  32. parameters.SetMultiplicativeDepth(multiplicativeDepth)
  33. parameters.SetKeySwitchTechnique(HYBRID)
  34. batchSize = 16 # Adjust batch size if needed
  35. parameters.SetBatchSize(batchSize)
  36. compressionLevel = COMPRESSION_LEVEL.COMPACT # or COMPRESSION_LEVEL.SLACK
  37. parameters.SetInteractiveBootCompressionLevel(compressionLevel)
  38. cryptoContext = GenCryptoContext(parameters)
  39. cryptoContext.Enable(PKE)
  40. cryptoContext.Enable(KEYSWITCH)
  41. cryptoContext.Enable(LEVELEDSHE)
  42. cryptoContext.Enable(ADVANCEDSHE)
  43. cryptoContext.Enable(MULTIPARTY)
  44. ringDim = cryptoContext.GetRingDimension()
  45. maxNumSlots = ringDim // 2
  46. print(f"TCKKS scheme is using ring dimension {ringDim}")
  47. print(f"TCKKS scheme number of slots {batchSize}")
  48. print(f"TCKKS scheme max number of slots {maxNumSlots}")
  49. print(f"TCKKS example with Scaling Technique {scaleTech}")
  50. numParties = 3
  51. print("\n===========================IntMPBoot protocol parameters===========================\n")
  52. print(f"number of parties: {numParties}\n")
  53. print("===============================================================\n")
  54. # Round 1 (party A)
  55. kp1 = cryptoContext.KeyGen()
  56. # Generate evalmult key part for A
  57. evalMultKey = cryptoContext.KeySwitchGen(kp1.secretKey, kp1.secretKey)
  58. # Generate evalsum key part for A
  59. cryptoContext.EvalSumKeyGen(kp1.secretKey)
  60. evalSumKeys = cryptoContext.GetEvalSumKeyMap(kp1.secretKey.GetKeyTag())
  61. # Round 2 (party B)
  62. kp2 = cryptoContext.MultipartyKeyGen(kp1.publicKey)
  63. evalMultKey2 = cryptoContext.MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey)
  64. evalMultAB = cryptoContext.MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey.GetKeyTag())
  65. evalMultBAB = cryptoContext.MultiMultEvalKey(kp2.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  66. evalSumKeysB = cryptoContext.MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey.GetKeyTag())
  67. evalSumKeysJoin = cryptoContext.MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey.GetKeyTag())
  68. cryptoContext.InsertEvalSumKey(evalSumKeysJoin)
  69. evalMultAAB = cryptoContext.MultiMultEvalKey(kp1.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  70. evalMultFinal = cryptoContext.MultiAddEvalMultKeys(evalMultAAB, evalMultBAB, evalMultAB.GetKeyTag())
  71. cryptoContext.InsertEvalMultKey([evalMultFinal])
  72. # Round 3 (party C) - Lead Party (who encrypts and finalizes the bootstrapping protocol)
  73. kp3 = cryptoContext.MultipartyKeyGen(kp2.publicKey)
  74. evalMultKey3 = cryptoContext.MultiKeySwitchGen(kp3.secretKey, kp3.secretKey, evalMultKey)
  75. evalMultABC = cryptoContext.MultiAddEvalKeys(evalMultAB, evalMultKey3, kp3.publicKey.GetKeyTag())
  76. evalMultBABC = cryptoContext.MultiMultEvalKey(kp2.secretKey, evalMultABC, kp3.publicKey.GetKeyTag())
  77. evalMultAABC = cryptoContext.MultiMultEvalKey(kp1.secretKey, evalMultABC, kp3.publicKey.GetKeyTag())
  78. evalMultCABC = cryptoContext.MultiMultEvalKey(kp3.secretKey, evalMultABC, kp3.publicKey.GetKeyTag())
  79. evalMultABABC = cryptoContext.MultiAddEvalMultKeys(evalMultBABC, evalMultAABC, evalMultBABC.GetKeyTag())
  80. evalMultFinal2 = cryptoContext.MultiAddEvalMultKeys(evalMultABABC, evalMultCABC, evalMultCABC.GetKeyTag())
  81. cryptoContext.InsertEvalMultKey([evalMultFinal2])
  82. if not kp1.good():
  83. print("Key generation failed!")
  84. exit(1)
  85. if not kp2.good():
  86. print("Key generation failed!")
  87. exit(1)
  88. if not kp3.good():
  89. print("Key generation failed!")
  90. exit(1)
  91. # END of Key Generation
  92. input = [-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0]
  93. # Chebyshev coefficients
  94. coefficients = [1.0, 0.558971, 0.0, -0.0943712, 0.0, 0.0215023, 0.0, -0.00505348, 0.0, 0.00119324,
  95. 0.0, -0.000281928, 0.0, 0.0000664347, 0.0, -0.0000148709]
  96. # Input range
  97. a = -4
  98. b = 4
  99. pt1 = cryptoContext.MakeCKKSPackedPlaintext(input)
  100. encodedLength = len(input)
  101. ct1 = cryptoContext.Encrypt(kp3.publicKey, pt1)
  102. ct1 = cryptoContext.EvalChebyshevSeries(ct1, coefficients, a, b)
  103. # INTERACTIVE BOOTSTRAPPING STARTS
  104. ct1 = cryptoContext.IntMPBootAdjustScale(ct1)
  105. # Leading party (party B) generates a Common Random Poly (crp) at max coefficient modulus (QNumPrime).
  106. # a is sampled at random uniformly from R_{Q}
  107. crp = cryptoContext.IntMPBootRandomElementGen(kp3.publicKey)
  108. # Each party generates its own shares: maskedDecryptionShare and reEncryptionShare
  109. # (h_{0,i}, h_{1,i}) = (masked decryption share, re-encryption share)
  110. # extract c1 - element-wise
  111. c1 = ct1.Clone()
  112. c1.RemoveElement(0)
  113. sharesPair0 = cryptoContext.IntMPBootDecrypt(kp1.secretKey, c1, crp)
  114. sharesPair1 = cryptoContext.IntMPBootDecrypt(kp2.secretKey, c1, crp)
  115. sharesPair2 = cryptoContext.IntMPBootDecrypt(kp3.secretKey, c1, crp)
  116. sharesPairVec = [sharesPair0, sharesPair1, sharesPair2]
  117. # Party B finalizes the protocol by aggregating the shares and reEncrypting the results
  118. aggregatedSharesPair = cryptoContext.IntMPBootAdd(sharesPairVec)
  119. ciphertextOutput = cryptoContext.IntMPBootEncrypt(kp3.publicKey, aggregatedSharesPair, crp, ct1)
  120. # INTERACTIVE BOOTSTRAPPING ENDS
  121. # distributed decryption
  122. ciphertextPartial1 = cryptoContext.MultipartyDecryptMain([ciphertextOutput], kp1.secretKey)
  123. ciphertextPartial2 = cryptoContext.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
  124. ciphertextPartial3 = cryptoContext.MultipartyDecryptLead([ciphertextOutput], kp3.secretKey)
  125. partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0], ciphertextPartial3[0]]
  126. plaintextMultiparty = cryptoContext.MultipartyDecryptFusion(partialCiphertextVec)
  127. plaintextMultiparty.SetLength(encodedLength)
  128. # Ground truth result
  129. result = [0.0179885, 0.0474289, 0.119205, 0.268936, 0.5, 0.731064, 0.880795, 0.952571, 0.982011]
  130. plaintextResult = cryptoContext.MakeCKKSPackedPlaintext(result)
  131. print("Ground Truth:")
  132. print("\t", plaintextResult.GetCKKSPackedValue())
  133. print("Computed Result:")
  134. print("\t", plaintextMultiparty.GetCKKSPackedValue())
  135. checkApproximateEquality(plaintextResult.GetCKKSPackedValue(), plaintextMultiparty.GetCKKSPackedValue(), encodedLength, 0.0001)
  136. print("\n============================ INTERACTIVE DECRYPTION ENDED ============================")
  137. print(f"\nTCKKSCollectiveBoot FHE example with rescaling technique: {scaleTech} Completed!")
  138. if __name__ == "__main__":
  139. main()