tckks-interactive-mp-bootstrapping-Chebyschev.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from openfhe import *
  2. def main():
  3. print("Interactive (3P) Bootstrapping Ciphertext [Chebyshev] (TCKKS) started ...")
  4. # Same test with different rescaling techniques in CKKS
  5. TCKKSCollectiveBoot(FIXEDMANUAL)
  6. TCKKSCollectiveBoot(FIXEDAUTO)
  7. if get_native_int()!=128:
  8. TCKKSCollectiveBoot(FLEXIBLEAUTO)
  9. TCKKSCollectiveBoot(FLEXIBLEAUTOEXT)
  10. print("Interactive (3P) Bootstrapping Ciphertext [Chebyshev] (TCKKS) terminated gracefully!")
  11. def checkApproximateEquality(a, b, vectorSize, epsilon):
  12. allTrue = [1] * vectorSize
  13. tmp = [abs(a[i] - b[i]) <= epsilon for i in range(vectorSize)]
  14. if tmp != allTrue:
  15. print("IntMPBoot - Ctxt Chebyshev Failed:")
  16. print(f"- is diff <= eps?: {tmp}")
  17. else:
  18. print("SUCCESSFUL Bootstrapping!")
  19. def TCKKSCollectiveBoot(scaleTech):
  20. if scaleTech not in [FIXEDMANUAL, FIXEDAUTO, FLEXIBLEAUTO, FLEXIBLEAUTOEXT]:
  21. errMsg = "ERROR: Scaling technique is not supported!"
  22. raise Exception(errMsg)
  23. parameters = CCParamsCKKSRNS()
  24. secretKeyDist = UNIFORM_TERNARY
  25. parameters.SetSecretKeyDist(secretKeyDist)
  26. parameters.SetSecurityLevel(HEStd_128_classic)
  27. dcrtBits = 50
  28. firstMod = 60
  29. parameters.SetScalingModSize(dcrtBits)
  30. parameters.SetScalingTechnique(scaleTech)
  31. parameters.SetFirstModSize(firstMod)
  32. multiplicativeDepth = 10 # Adjust according to your requirements
  33. parameters.SetMultiplicativeDepth(multiplicativeDepth)
  34. parameters.SetKeySwitchTechnique(HYBRID)
  35. batchSize = 16 # Adjust batch size if needed
  36. parameters.SetBatchSize(batchSize)
  37. compressionLevel = COMPRESSION_LEVEL.COMPACT # or COMPRESSION_LEVEL.SLACK
  38. parameters.SetInteractiveBootCompressionLevel(compressionLevel)
  39. cryptoContext = GenCryptoContext(parameters)
  40. cryptoContext.Enable(PKE)
  41. cryptoContext.Enable(KEYSWITCH)
  42. cryptoContext.Enable(LEVELEDSHE)
  43. cryptoContext.Enable(ADVANCEDSHE)
  44. cryptoContext.Enable(MULTIPARTY)
  45. ringDim = cryptoContext.GetRingDimension()
  46. maxNumSlots = ringDim // 2
  47. print(f"TCKKS scheme is using ring dimension {ringDim}")
  48. print(f"TCKKS scheme number of slots {batchSize}")
  49. print(f"TCKKS scheme max number of slots {maxNumSlots}")
  50. print(f"TCKKS example with Scaling Technique {scaleTech}")
  51. numParties = 3
  52. print("\n===========================IntMPBoot protocol parameters===========================\n")
  53. print(f"number of parties: {numParties}\n")
  54. print("===============================================================\n")
  55. # Round 1 (party A)
  56. kp1 = cryptoContext.KeyGen()
  57. # Generate evalmult key part for A
  58. evalMultKey = cryptoContext.KeySwitchGen(kp1.secretKey, kp1.secretKey)
  59. # Generate evalsum key part for A
  60. cryptoContext.EvalSumKeyGen(kp1.secretKey)
  61. evalSumKeys = cryptoContext.GetEvalSumKeyMap(kp1.secretKey.GetKeyTag())
  62. # Round 2 (party B)
  63. kp2 = cryptoContext.MultipartyKeyGen(kp1.publicKey)
  64. evalMultKey2 = cryptoContext.MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey)
  65. evalMultAB = cryptoContext.MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey.GetKeyTag())
  66. evalMultBAB = cryptoContext.MultiMultEvalKey(kp2.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  67. evalSumKeysB = cryptoContext.MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey.GetKeyTag())
  68. evalSumKeysJoin = cryptoContext.MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey.GetKeyTag())
  69. cryptoContext.InsertEvalSumKey(evalSumKeysJoin)
  70. evalMultAAB = cryptoContext.MultiMultEvalKey(kp1.secretKey, evalMultAB, kp2.publicKey.GetKeyTag())
  71. evalMultFinal = cryptoContext.MultiAddEvalMultKeys(evalMultAAB, evalMultBAB, evalMultAB.GetKeyTag())
  72. cryptoContext.InsertEvalMultKey([evalMultFinal])
  73. # Round 3 (party C) - Lead Party (who encrypts and finalizes the bootstrapping protocol)
  74. kp3 = cryptoContext.MultipartyKeyGen(kp2.publicKey)
  75. evalMultKey3 = cryptoContext.MultiKeySwitchGen(kp3.secretKey, kp3.secretKey, evalMultKey)
  76. evalMultABC = cryptoContext.MultiAddEvalKeys(evalMultAB, evalMultKey3, kp3.publicKey.GetKeyTag())
  77. evalMultBABC = cryptoContext.MultiMultEvalKey(kp2.secretKey, evalMultABC, kp3.publicKey.GetKeyTag())
  78. evalMultAABC = cryptoContext.MultiMultEvalKey(kp1.secretKey, evalMultABC, kp3.publicKey.GetKeyTag())
  79. evalMultCABC = cryptoContext.MultiMultEvalKey(kp3.secretKey, evalMultABC, kp3.publicKey.GetKeyTag())
  80. evalMultABABC = cryptoContext.MultiAddEvalMultKeys(evalMultBABC, evalMultAABC, evalMultBABC.GetKeyTag())
  81. evalMultFinal2 = cryptoContext.MultiAddEvalMultKeys(evalMultABABC, evalMultCABC, evalMultCABC.GetKeyTag())
  82. cryptoContext.InsertEvalMultKey([evalMultFinal2])
  83. if not kp1.good():
  84. print("Key generation failed!")
  85. exit(1)
  86. if not kp2.good():
  87. print("Key generation failed!")
  88. exit(1)
  89. if not kp3.good():
  90. print("Key generation failed!")
  91. exit(1)
  92. # END of Key Generation
  93. input = [-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0]
  94. # Chebyshev coefficients
  95. coefficients = [1.0, 0.558971, 0.0, -0.0943712, 0.0, 0.0215023, 0.0, -0.00505348, 0.0, 0.00119324,
  96. 0.0, -0.000281928, 0.0, 0.0000664347, 0.0, -0.0000148709]
  97. # Input range
  98. a = -4
  99. b = 4
  100. pt1 = cryptoContext.MakeCKKSPackedPlaintext(input)
  101. encodedLength = len(input)
  102. ct1 = cryptoContext.Encrypt(kp3.publicKey, pt1)
  103. ct1 = cryptoContext.EvalChebyshevSeries(ct1, coefficients, a, b)
  104. # INTERACTIVE BOOTSTRAPPING STARTS
  105. ct1 = cryptoContext.IntMPBootAdjustScale(ct1)
  106. # Leading party (party B) generates a Common Random Poly (crp) at max coefficient modulus (QNumPrime).
  107. # a is sampled at random uniformly from R_{Q}
  108. crp = cryptoContext.IntMPBootRandomElementGen(kp3.publicKey)
  109. # Each party generates its own shares: maskedDecryptionShare and reEncryptionShare
  110. # (h_{0,i}, h_{1,i}) = (masked decryption share, re-encryption share)
  111. # extract c1 - element-wise
  112. c1 = ct1.Clone()
  113. c1.RemoveElement(0)
  114. sharesPair0 = cryptoContext.IntMPBootDecrypt(kp1.secretKey, c1, crp)
  115. sharesPair1 = cryptoContext.IntMPBootDecrypt(kp2.secretKey, c1, crp)
  116. sharesPair2 = cryptoContext.IntMPBootDecrypt(kp3.secretKey, c1, crp)
  117. sharesPairVec = [sharesPair0, sharesPair1, sharesPair2]
  118. # Party B finalizes the protocol by aggregating the shares and reEncrypting the results
  119. aggregatedSharesPair = cryptoContext.IntMPBootAdd(sharesPairVec)
  120. ciphertextOutput = cryptoContext.IntMPBootEncrypt(kp3.publicKey, aggregatedSharesPair, crp, ct1)
  121. # INTERACTIVE BOOTSTRAPPING ENDS
  122. # distributed decryption
  123. ciphertextPartial1 = cryptoContext.MultipartyDecryptMain([ciphertextOutput], kp1.secretKey)
  124. ciphertextPartial2 = cryptoContext.MultipartyDecryptMain([ciphertextOutput], kp2.secretKey)
  125. ciphertextPartial3 = cryptoContext.MultipartyDecryptLead([ciphertextOutput], kp3.secretKey)
  126. partialCiphertextVec = [ciphertextPartial1[0], ciphertextPartial2[0], ciphertextPartial3[0]]
  127. plaintextMultiparty = cryptoContext.MultipartyDecryptFusion(partialCiphertextVec)
  128. plaintextMultiparty.SetLength(encodedLength)
  129. # Ground truth result
  130. result = [0.0179885, 0.0474289, 0.119205, 0.268936, 0.5, 0.731064, 0.880795, 0.952571, 0.982011]
  131. plaintextResult = cryptoContext.MakeCKKSPackedPlaintext(result)
  132. print("Ground Truth:")
  133. print("\t", plaintextResult.GetCKKSPackedValue())
  134. print("Computed Result:")
  135. print("\t", plaintextMultiparty.GetCKKSPackedValue())
  136. checkApproximateEquality(plaintextResult.GetCKKSPackedValue(), plaintextMultiparty.GetCKKSPackedValue(), encodedLength, 0.0001)
  137. print("\n============================ INTERACTIVE DECRYPTION ENDED ============================")
  138. print(f"\nTCKKSCollectiveBoot FHE example with rescaling technique: {scaleTech} Completed!")
  139. if __name__ == "__main__":
  140. main()