tckks-interactive-mp-bootstrapping.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from openfhe import *
  2. #
  3. # A utility class defining a party that is involved in the collective bootstrapping protocol
  4. #
  5. class Party:
  6. def __init__(self, id, sharesPair, kpShard):
  7. self.id = id
  8. self.sharesPair = sharesPair
  9. self.kpShard = kpShard
  10. def __init__(self):
  11. self.id = None
  12. self.sharesPair = None
  13. self.kpShard = None
  14. def __str__(self):
  15. return f"Party {self.id}"
  16. def main():
  17. print( "Interactive Multi-Party Bootstrapping Ciphertext (TCKKS) started ...\n")
  18. # Same test with different rescaling techniques in CKKS
  19. TCKKSCollectiveBoot(FIXEDMANUAL)
  20. TCKKSCollectiveBoot(FIXEDAUTO)
  21. TCKKSCollectiveBoot(FLEXIBLEAUTO)
  22. TCKKSCollectiveBoot(FLEXIBLEAUTOEXT)
  23. print("Interactive Multi-Party Bootstrapping Ciphertext (TCKKS) terminated gracefully!\n")
  24. # Demonstrate interactive multi-party bootstrapping for 3 parties
  25. # We follow Protocol 5 in https://eprint.iacr.org/2020/304, "Multiparty
  26. # Homomorphic Encryption from Ring-Learning-With-Errors"
  27. def TCKKSCollectiveBoot(scaleTech):
  28. if scaleTech != FIXEDMANUAL and scaleTech != FIXEDAUTO and scaleTech != FLEXIBLEAUTO and scaleTech != FLEXIBLEAUTOEXT:
  29. errMsg = "ERROR: Scaling technique is not supported!"
  30. raise Exception(errMsg)
  31. parameters = CCParamsCKKSRNS()
  32. secretKeyDist = UNIFORM_TERNARY
  33. parameters.SetSecretKeyDist(secretKeyDist)
  34. parameters.SetSecurityLevel(HEStd_128_classic)
  35. dcrtBits = 50
  36. firstMod = 60
  37. parameters.SetScalingModSize(dcrtBits)
  38. parameters.SetScalingTechnique(scaleTech)
  39. parameters.SetFirstModSize(firstMod)
  40. multiplicativeDepth = 7
  41. parameters.SetMultiplicativeDepth(multiplicativeDepth)
  42. parameters.SetKeySwitchTechnique(HYBRID)
  43. batchSize = 4
  44. parameters.SetBatchSize(batchSize)
  45. compressionLevel = COMPRESSION_LEVEL.SLACK
  46. parameters.SetInteractiveBootCompressionLevel(compressionLevel)
  47. cryptoContext = GenCryptoContext(parameters)
  48. cryptoContext.Enable(PKE)
  49. cryptoContext.Enable(KEYSWITCH)
  50. cryptoContext.Enable(LEVELEDSHE)
  51. cryptoContext.Enable(ADVANCEDSHE)
  52. cryptoContext.Enable(MULTIPARTY)
  53. ringDim = cryptoContext.GetRingDimension()
  54. maxNumSlots = ringDim / 2
  55. print(f"TCKKS scheme is using ring dimension {ringDim}")
  56. print(f"TCKKS scheme number of slots {maxNumSlots}")
  57. print(f"TCKKS scheme max number of slots {maxNumSlots}")
  58. print(f"TCKKS example with Scaling Technique {scaleTech}")
  59. numParties = 3
  60. print("\n===========================IntMPBoot protocol parameters===========================\n")
  61. print(f"number of parties: {numParties}\n")
  62. print("===============================================================\n")
  63. # List to store parties objects
  64. parties = [Party()]*numParties
  65. print("Running key generation (used for source data)...\n")
  66. for i in range(numParties):
  67. #define id of parties[i] as i
  68. parties[i].id = i
  69. print(f"Party {parties[i].id} started.")
  70. if i == 0:
  71. parties[i].kpShard = cryptoContext.KeyGen()
  72. else:
  73. parties[i].kpShard = cryptoContext.MultipartyKeyGen(parties[0].kpShard.publicKey)
  74. print(f"Party {i} key generation completed.\n")
  75. print("Joint public key for (s_0 + s_1 + ... + s_n) is generated...")
  76. # Assert everything is good
  77. for i in range(numParties):
  78. if not parties[i].kpShard.good():
  79. print(f"Key generation failed for party {i}!\n")
  80. return 1
  81. # Generate collective public key
  82. secretKeys = []
  83. for i in range(numParties):
  84. secretKeys.append(parties[i].kpShard.secretKey)
  85. kpMultiparty = cryptoContext.MultipartyKeyGen(secretKeys)
  86. # Prepare input vector
  87. msg1 = [-0.9, -0.8, 0.2, 0.4]
  88. ptxt1 = cryptoContext.MakeCKKSPackedPlaintext(msg1)
  89. # Encryption
  90. inCtxt = cryptoContext.Encrypt(kpMultiparty.publicKey, ptxt1)
  91. print("Compressing ctxt to the smallest possible number of towers!\n")
  92. inCtxt = cryptoContext.IntMPBootAdjustScale(inCtxt)
  93. print("\n============================ INTERACTIVE BOOTSTRAPPING STARTS ============================\n")
  94. #Leading party (P0) generates a Common Random Poly (a) at max coefficient modulus (QNumPrime).
  95. # a is sampled at random uniformly from R_{Q}
  96. a = cryptoContext.IntMPBootRandomElementGen(parties[0].kpShard.publicKey)
  97. print("Common Random Poly (a) has been generated with coefficient modulus Q\n")
  98. # Each party generates its own shares: maskedDecryptionShare and reEncryptionShare
  99. sharePairVec = []
  100. # Make a copy of input ciphertext and remove the first element (c0), we only
  101. # c1 for IntMPBootDecrypt
  102. c1 = inCtxt.Clone()
  103. c1.RemoveElement(0)
  104. for i in range(numParties):
  105. print(f"Party {i} started its part in Collective Bootstrapping Protocol.\n")
  106. parties[i].sharesPair = cryptoContext.IntMPBootDecrypt(parties[i].kpShard.secretKey, c1, a)
  107. sharePairVec.append(parties[i].sharesPair)
  108. # P0 finalizes the protocol by aggregating the shares and reEncrypting the results
  109. aggregatedSharesPair = cryptoContext.IntMPBootAdd(sharePairVec);
  110. # Make sure you provide the non-striped ciphertext (inCtxt) in IntMPBootEncrypt
  111. outCtxt = cryptoContext.IntMPBootEncrypt(parties[0].kpShard.publicKey, aggregatedSharesPair, a, inCtxt)
  112. # INTERACTIVE BOOTSTRAPPING ENDS
  113. print("\n============================ INTERACTIVE BOOTSTRAPPING ENDED ============================\n")
  114. # Distributed Decryption
  115. print("\n============================ INTERACTIVE DECRYPTION STARTED ============================ \n")
  116. partialCiphertextVec = []
  117. print("Party 0 started its part in the collective decryption protocol\n")
  118. partialCiphertextVec.append(cryptoContext.MultipartyDecryptLead([outCtxt], parties[0].kpShard.secretKey)[0])
  119. for i in range(1, numParties):
  120. print(f"Party {i} started its part in the collective decryption protocol\n")
  121. partialCiphertextVec.append(cryptoContext.MultipartyDecryptMain([outCtxt], parties[i].kpShard.secretKey)[0])
  122. # Checking the results
  123. print("MultipartyDecryptFusion ...\n")
  124. plaintextMultiparty = cryptoContext.MultipartyDecryptFusion(partialCiphertextVec)
  125. plaintextMultiparty.SetLength(len(msg1))
  126. # transform to python:
  127. print(f"Original plaintext \n\t {ptxt1.GetCKKSPackedValue()}\n")
  128. print(f"Result after bootstrapping \n\t {plaintextMultiparty.GetCKKSPackedValue()}\n")
  129. print("\n============================ INTERACTIVE DECRYPTION ENDED ============================\n")
  130. if __name__ == "__main__":
  131. main()