tckks-interactive-mp-bootstrapping.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from openfhe import *
  2. #
  3. # A utility class defining a party that is involved in the collective bootstrapping protocol
  4. #
  5. class Party:
  6. def __init__(self, id, sharesPair, kpShard):
  7. self.id = id
  8. self.sharesPair = sharesPair
  9. self.kpShard = kpShard
  10. def __init__(self):
  11. self.id = None
  12. self.sharesPair = None
  13. self.kpShard = None
  14. def __str__(self):
  15. return f"Party {self.id}"
  16. def main():
  17. print( "Interactive Multi-Party Bootstrapping Ciphertext (TCKKS) started ...\n")
  18. # Same test with different rescaling techniques in CKKS
  19. TCKKSCollectiveBoot(FIXEDMANUAL)
  20. TCKKSCollectiveBoot(FIXEDAUTO)
  21. if get_native_int()!=128:
  22. TCKKSCollectiveBoot(FLEXIBLEAUTO)
  23. TCKKSCollectiveBoot(FLEXIBLEAUTOEXT)
  24. print("Interactive Multi-Party Bootstrapping Ciphertext (TCKKS) terminated gracefully!\n")
  25. # Demonstrate interactive multi-party bootstrapping for 3 parties
  26. # We follow Protocol 5 in https://eprint.iacr.org/2020/304, "Multiparty
  27. # Homomorphic Encryption from Ring-Learning-With-Errors"
  28. def TCKKSCollectiveBoot(scaleTech):
  29. if scaleTech != FIXEDMANUAL and scaleTech != FIXEDAUTO and scaleTech != FLEXIBLEAUTO and scaleTech != FLEXIBLEAUTOEXT:
  30. errMsg = "ERROR: Scaling technique is not supported!"
  31. raise Exception(errMsg)
  32. parameters = CCParamsCKKSRNS()
  33. secretKeyDist = UNIFORM_TERNARY
  34. parameters.SetSecretKeyDist(secretKeyDist)
  35. parameters.SetSecurityLevel(HEStd_128_classic)
  36. dcrtBits = 50
  37. firstMod = 60
  38. parameters.SetScalingModSize(dcrtBits)
  39. parameters.SetScalingTechnique(scaleTech)
  40. parameters.SetFirstModSize(firstMod)
  41. multiplicativeDepth = 7
  42. parameters.SetMultiplicativeDepth(multiplicativeDepth)
  43. parameters.SetKeySwitchTechnique(HYBRID)
  44. batchSize = 4
  45. parameters.SetBatchSize(batchSize)
  46. compressionLevel = COMPRESSION_LEVEL.SLACK
  47. parameters.SetInteractiveBootCompressionLevel(compressionLevel)
  48. cryptoContext = GenCryptoContext(parameters)
  49. cryptoContext.Enable(PKE)
  50. cryptoContext.Enable(KEYSWITCH)
  51. cryptoContext.Enable(LEVELEDSHE)
  52. cryptoContext.Enable(ADVANCEDSHE)
  53. cryptoContext.Enable(MULTIPARTY)
  54. ringDim = cryptoContext.GetRingDimension()
  55. maxNumSlots = ringDim / 2
  56. print(f"TCKKS scheme is using ring dimension {ringDim}")
  57. print(f"TCKKS scheme number of slots {maxNumSlots}")
  58. print(f"TCKKS scheme max number of slots {maxNumSlots}")
  59. print(f"TCKKS example with Scaling Technique {scaleTech}")
  60. numParties = 3
  61. print("\n===========================IntMPBoot protocol parameters===========================\n")
  62. print(f"number of parties: {numParties}\n")
  63. print("===============================================================\n")
  64. # List to store parties objects
  65. parties = [Party()]*numParties
  66. print("Running key generation (used for source data)...\n")
  67. for i in range(numParties):
  68. #define id of parties[i] as i
  69. parties[i].id = i
  70. print(f"Party {parties[i].id} started.")
  71. if i == 0:
  72. parties[i].kpShard = cryptoContext.KeyGen()
  73. else:
  74. parties[i].kpShard = cryptoContext.MultipartyKeyGen(parties[0].kpShard.publicKey)
  75. print(f"Party {i} key generation completed.\n")
  76. print("Joint public key for (s_0 + s_1 + ... + s_n) is generated...")
  77. # Assert everything is good
  78. for i in range(numParties):
  79. if not parties[i].kpShard.good():
  80. print(f"Key generation failed for party {i}!\n")
  81. return 1
  82. # Generate collective public key
  83. secretKeys = []
  84. for i in range(numParties):
  85. secretKeys.append(parties[i].kpShard.secretKey)
  86. kpMultiparty = cryptoContext.MultipartyKeyGen(secretKeys)
  87. # Prepare input vector
  88. msg1 = [-0.9, -0.8, 0.2, 0.4]
  89. ptxt1 = cryptoContext.MakeCKKSPackedPlaintext(msg1)
  90. # Encryption
  91. inCtxt = cryptoContext.Encrypt(kpMultiparty.publicKey, ptxt1)
  92. print("Compressing ctxt to the smallest possible number of towers!\n")
  93. inCtxt = cryptoContext.IntMPBootAdjustScale(inCtxt)
  94. print("\n============================ INTERACTIVE BOOTSTRAPPING STARTS ============================\n")
  95. #Leading party (P0) generates a Common Random Poly (a) at max coefficient modulus (QNumPrime).
  96. # a is sampled at random uniformly from R_{Q}
  97. a = cryptoContext.IntMPBootRandomElementGen(parties[0].kpShard.publicKey)
  98. print("Common Random Poly (a) has been generated with coefficient modulus Q\n")
  99. # Each party generates its own shares: maskedDecryptionShare and reEncryptionShare
  100. sharePairVec = []
  101. # Make a copy of input ciphertext and remove the first element (c0), we only
  102. # c1 for IntMPBootDecrypt
  103. c1 = inCtxt.Clone()
  104. c1.RemoveElement(0)
  105. for i in range(numParties):
  106. print(f"Party {i} started its part in Collective Bootstrapping Protocol.\n")
  107. parties[i].sharesPair = cryptoContext.IntMPBootDecrypt(parties[i].kpShard.secretKey, c1, a)
  108. sharePairVec.append(parties[i].sharesPair)
  109. # P0 finalizes the protocol by aggregating the shares and reEncrypting the results
  110. aggregatedSharesPair = cryptoContext.IntMPBootAdd(sharePairVec);
  111. # Make sure you provide the non-striped ciphertext (inCtxt) in IntMPBootEncrypt
  112. outCtxt = cryptoContext.IntMPBootEncrypt(parties[0].kpShard.publicKey, aggregatedSharesPair, a, inCtxt)
  113. # INTERACTIVE BOOTSTRAPPING ENDS
  114. print("\n============================ INTERACTIVE BOOTSTRAPPING ENDED ============================\n")
  115. # Distributed Decryption
  116. print("\n============================ INTERACTIVE DECRYPTION STARTED ============================ \n")
  117. partialCiphertextVec = []
  118. print("Party 0 started its part in the collective decryption protocol\n")
  119. partialCiphertextVec.append(cryptoContext.MultipartyDecryptLead([outCtxt], parties[0].kpShard.secretKey)[0])
  120. for i in range(1, numParties):
  121. print(f"Party {i} started its part in the collective decryption protocol\n")
  122. partialCiphertextVec.append(cryptoContext.MultipartyDecryptMain([outCtxt], parties[i].kpShard.secretKey)[0])
  123. # Checking the results
  124. print("MultipartyDecryptFusion ...\n")
  125. plaintextMultiparty = cryptoContext.MultipartyDecryptFusion(partialCiphertextVec)
  126. plaintextMultiparty.SetLength(len(msg1))
  127. # transform to python:
  128. print(f"Original plaintext \n\t {ptxt1.GetCKKSPackedValue()}\n")
  129. print(f"Result after bootstrapping \n\t {plaintextMultiparty.GetCKKSPackedValue()}\n")
  130. print("\n============================ INTERACTIVE DECRYPTION ENDED ============================\n")
  131. if __name__ == "__main__":
  132. main()