simple-real-numbers-serial.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. from openfhe import *
  2. # NOTE:
  3. # If running locally, you may want to replace the "hardcoded" datafolder with
  4. # the datafolder location below which gets the current working directory
  5. # Save-Load locations for keys
  6. datafolder = 'demoData'
  7. ccLocation = '/cryptocontext.txt'
  8. pubKeyLocation = '/key_pub.txt' # Pub key
  9. multKeyLocation = '/key_mult.txt' # relinearization key
  10. rotKeyLocation = '/key_rot.txt' # automorphism / rotation key
  11. # Save-load locations for RAW ciphertexts
  12. cipherOneLocation = '/ciphertext1.txt'
  13. cipherTwoLocation = '/ciphertext2.txt'
  14. # Save-load locations for evaluated ciphertexts
  15. cipherMultLocation = '/ciphertextMult.txt'
  16. cipherAddLocation = '/ciphertextAdd.txt'
  17. cipherRotLocation = '/ciphertextRot.txt'
  18. cipherRotNegLocation = '/ciphertextRotNegLocation.txt'
  19. clientVectorLocation = '/clientVectorFromClient.txt'
  20. # Demarcate - Visual separator between the sections of code
  21. def demarcate(msg):
  22. print("**************************************************\n")
  23. print(msg)
  24. print("**************************************************\n")
  25. """
  26. serverSetupAndWrite(multDepth, scaleModSize, batchSize)
  27. simulates a server at startup where we generate a cryptocontext and keys.
  28. then, we generate some data (akin to loading raw data on an enclave)
  29. before encrypting the data
  30. :param multDepth: multiplication depth
  31. :param scaleModSize: number of bits to use in the scale factor (not the
  32. scale factor itself)
  33. :param batchSize: batch size to use
  34. :return Tuple<cryptoContext, keyPair>
  35. """
  36. def serverSetupAndWrite(multDepth, scaleModSize, batchSize):
  37. parameters = CCParamsCKKSRNS()
  38. parameters.SetMultiplicativeDepth(multDepth)
  39. parameters.SetScalingModSize(scaleModSize)
  40. parameters.SetBatchSize(batchSize)
  41. serverCC = GenCryptoContext(parameters)
  42. serverCC.Enable(PKE)
  43. serverCC.Enable(KEYSWITCH)
  44. serverCC.Enable(LEVELEDSHE)
  45. print("Cryptocontext generated")
  46. serverKP = serverCC.KeyGen()
  47. print("Keypair generated")
  48. serverCC.EvalMultKeyGen(serverKP.secretKey)
  49. print("Eval Mult Keys/ Relinearization keys have been generated")
  50. serverCC.EvalRotateKeyGen(serverKP.secretKey, [1, 2, -1, -2])
  51. print("Rotation keys generated")
  52. vec1 = [1.0, 2.0, 3.0, 4.0]
  53. vec2 = [12.5, 13.5, 14.5, 15.5]
  54. vec3 = [10.5, 11.5, 12.5, 13.5]
  55. print("\nDisplaying first data vector: ")
  56. print(vec1)
  57. print("\n")
  58. serverP1 = serverCC.MakeCKKSPackedPlaintext(vec1)
  59. serverP2 = serverCC.MakeCKKSPackedPlaintext(vec2)
  60. serverP3 = serverCC.MakeCKKSPackedPlaintext(vec3)
  61. print("Plaintext version of first vector: "+ str(serverP1))
  62. print("Plaintexts have been generated from complex-double vectors")
  63. serverC1 = serverCC.Encrypt(serverKP.publicKey, serverP1)
  64. serverC2 = serverCC.Encrypt(serverKP.publicKey, serverP2)
  65. serverC3 = serverCC.Encrypt(serverKP.publicKey, serverP3)
  66. print("Ciphertexts have been generated from plaintexts")
  67. ###
  68. # Part 2:
  69. # We serialize the following:
  70. # Cryptocontext
  71. # Public key
  72. # relinearization (eval mult keys)
  73. # rotation keys
  74. # Some of the ciphertext
  75. #
  76. # We serialize all of them to files
  77. ###
  78. demarcate("Part 2: Data Serialization (server)")
  79. if not SerializeToFile(datafolder + ccLocation, serverCC, BINARY):
  80. raise Exception("Exception writing cryptocontext to cryptocontext.txt")
  81. print("Cryptocontext serialized")
  82. if not SerializeToFile(datafolder + pubKeyLocation, serverKP.publicKey, BINARY):
  83. raise Exception("Exception writing public key to pubkey.txt")
  84. print("Public key has been serialized")
  85. if not serverCC.SerializeEvalMultKey(datafolder + multKeyLocation, BINARY):
  86. raise Exception("Error writing eval mult keys")
  87. print("EvalMult/ relinearization keys have been serialized")
  88. if not serverCC.SerializeEvalAutomorphismKey(datafolder + rotKeyLocation, BINARY):
  89. raise Exception("Error writing rotation keys")
  90. print("Rotation keys have been serialized")
  91. if not SerializeToFile(datafolder + cipherOneLocation, serverC1, BINARY):
  92. raise Exception("Error writing ciphertext 1")
  93. if not SerializeToFile(datafolder + cipherTwoLocation, serverC2, BINARY):
  94. raise Exception("Error writing ciphertext 2")
  95. return (serverCC, serverKP, len(vec1))
  96. ###
  97. # clientProcess
  98. # - deserialize data from a file which simulates receiving data from a server
  99. # after making a request
  100. # - we then process the data by doing operations (multiplication, addition,
  101. # rotation, etc)
  102. # - !! We also create an object and encrypt it in this function before sending
  103. # it off to the server to be decrypted
  104. ###
  105. def clientProcess():
  106. # clientCC = CryptoContext()
  107. # clientCC.ClearEvalMultKeys()
  108. # clientCC.ClearEvalAutomorphismKeys()
  109. ReleaseAllContexts()
  110. clientCC, res = DeserializeCryptoContext(datafolder + ccLocation, BINARY)
  111. if not res:
  112. raise Exception(f"I cannot deserialize the cryptocontext from {datafolder+ccLocation}")
  113. print("Client CC deserialized")
  114. #clientKP = KeyPair()
  115. # We do NOT have a secret key. The client
  116. # should not have access to this
  117. clientPuclicKey, res = DeserializePublicKey(datafolder + pubKeyLocation, BINARY)
  118. if not res:
  119. raise Exception(f"I cannot deserialize the public key from {datafolder+pubKeyLocation}")
  120. print("Client KP deserialized\n")
  121. if not clientCC.DeserializeEvalMultKey(datafolder + multKeyLocation, BINARY):
  122. raise Exception(f"Cannot deserialize eval mult keys from {datafolder+multKeyLocation}")
  123. print("Deserialized eval mult keys\n")
  124. if not clientCC.DeserializeEvalAutomorphismKey(datafolder + rotKeyLocation, BINARY):
  125. raise Exception(f"Cannot deserialize eval automorphism keys from {datafolder+rotKeyLocation}")
  126. clientC1, res = DeserializeCiphertext(datafolder + cipherOneLocation, BINARY)
  127. if not res:
  128. raise Exception(f"Cannot deserialize the ciphertext from {datafolder+cipherOneLocation}")
  129. print("Deserialized ciphertext 1\n")
  130. clientC2, res = DeserializeCiphertext(datafolder + cipherTwoLocation, BINARY)
  131. if not res:
  132. raise Exception(f"Cannot deserialize the ciphertext from {datafolder+cipherTwoLocation}")
  133. print("Deserialized ciphertext 2\n")
  134. clientCiphertextMult = clientCC.EvalMult(clientC1, clientC2)
  135. clientCiphertextAdd = clientCC.EvalAdd(clientC1, clientC2)
  136. clientCiphertextRot = clientCC.EvalRotate(clientC1, 1)
  137. clientCiphertextRotNeg = clientCC.EvalRotate(clientC1, -1)
  138. # Now, we want to simulate a client who is encrypting data for the server to
  139. # decrypt. E.g weights of a machine learning algorithm
  140. demarcate("Part 3.5: Client Serialization of data that has been operated on")
  141. clientVector1 = [1.0, 2.0, 3.0, 4.0]
  142. clientPlaintext1 = clientCC.MakeCKKSPackedPlaintext(clientVector1)
  143. clientInitializedEncryption = clientCC.Encrypt(clientPuclicKey, clientPlaintext1)
  144. SerializeToFile(datafolder + cipherMultLocation, clientCiphertextMult, BINARY)
  145. SerializeToFile(datafolder + cipherAddLocation, clientCiphertextAdd, BINARY)
  146. SerializeToFile(datafolder + cipherRotLocation, clientCiphertextRot, BINARY)
  147. SerializeToFile(datafolder + cipherRotNegLocation, clientCiphertextRotNeg, BINARY)
  148. SerializeToFile(datafolder + clientVectorLocation, clientInitializedEncryption, BINARY)
  149. print("Serialized all ciphertexts from client\n")
  150. ###
  151. # serverVerification
  152. # - deserialize data from the client.
  153. # - Verify that the results are as we expect
  154. # @param cc cryptocontext that was previously generated
  155. # @param kp keypair that was previously generated
  156. # @param vectorSize vector size of the vectors supplied
  157. # @return
  158. # 5-tuple of the plaintexts of various operations
  159. ##
  160. def serverVerification(cc,kp,vectorSize):
  161. serverCiphertextFromClient_Mult, res = DeserializeCiphertext(datafolder + cipherMultLocation, BINARY)
  162. serverCiphertextFromClient_Add, res = DeserializeCiphertext(datafolder + cipherAddLocation, BINARY)
  163. serverCiphertextFromClient_Rot, res = DeserializeCiphertext(datafolder + cipherRotLocation, BINARY)
  164. serverCiphertextFromClient_RotNeg, res = DeserializeCiphertext(datafolder + cipherRotNegLocation, BINARY)
  165. serverCiphertextFromClient_Vec, res = DeserializeCiphertext(datafolder + clientVectorLocation, BINARY)
  166. print("Deserialized all data from client on server\n")
  167. print("Part 5: Correctness Verification")
  168. serverPlaintextFromClient_Mult = cc.Decrypt(kp.secretKey, serverCiphertextFromClient_Mult)
  169. serverPlaintextFromClient_Add = cc.Decrypt(kp.secretKey, serverCiphertextFromClient_Add)
  170. serverPlaintextFromClient_Rot = cc.Decrypt(kp.secretKey, serverCiphertextFromClient_Rot)
  171. serverPlaintextFromClient_RotNeg = cc.Decrypt(kp.secretKey, serverCiphertextFromClient_RotNeg)
  172. serverPlaintextFromClient_Vec = cc.Decrypt(kp.secretKey, serverCiphertextFromClient_Vec)
  173. serverPlaintextFromClient_Mult.SetLength(vectorSize)
  174. serverPlaintextFromClient_Add.SetLength(vectorSize)
  175. serverPlaintextFromClient_Vec.SetLength(vectorSize)
  176. serverPlaintextFromClient_Rot.SetLength(vectorSize + 1)
  177. serverPlaintextFromClient_RotNeg.SetLength(vectorSize + 1)
  178. return (serverPlaintextFromClient_Mult,
  179. serverPlaintextFromClient_Add,
  180. serverPlaintextFromClient_Vec,
  181. serverPlaintextFromClient_Rot,
  182. serverPlaintextFromClient_RotNeg)
  183. def main():
  184. print(f"This program requires the subdirectory `{datafolder}' to exist, otherwise you will get\n an error writing serializations.")
  185. # Set main params
  186. multDepth = 5
  187. scaleModSize = 40
  188. batchSize = 32
  189. cryptoContextIdx = 0
  190. keyPairIdx = 1
  191. vectorSizeIdx = 2
  192. cipherMultResIdx = 0
  193. cipherAddResIdx = 1
  194. cipherVecResIdx = 2
  195. cipherRotResIdx = 3
  196. cipherRotNegResIdx = 4
  197. demarcate("Part 1: Cryptocontext generation, key generation, data encryption \n(server)")
  198. tupleCryptoContext_KeyPair = serverSetupAndWrite(multDepth, scaleModSize, batchSize)
  199. cc = tupleCryptoContext_KeyPair[cryptoContextIdx]
  200. kp = tupleCryptoContext_KeyPair[keyPairIdx]
  201. vectorSize = tupleCryptoContext_KeyPair[vectorSizeIdx]
  202. demarcate("Part 3: Client deserialize all data")
  203. clientProcess()
  204. demarcate("Part 4: Server deserialization of data from client. ")
  205. tupleRes = serverVerification(cc, kp, vectorSize)
  206. multRes = tupleRes[cipherMultResIdx]
  207. addRes = tupleRes[cipherAddResIdx]
  208. vecRes = tupleRes[cipherVecResIdx]
  209. rotRes = tupleRes[cipherRotResIdx]
  210. rotNegRes = tupleRes[cipherRotNegResIdx]
  211. # vec1: [1,2,3,4]
  212. # vec2: [12.5, 13.5, 14.5, 15.5]
  213. print(multRes) # EXPECT: 12.5, 27.0, 43.5, 62
  214. print(addRes) # EXPECT: 13.5, 15.5, 17.5, 19.5
  215. print(vecRes) # EXPECT: [1,2,3,4]
  216. print("Displaying 5 elements of a 4-element vector to illustrate rotation")
  217. print(rotRes) # EXPECT: [2, 3, 4, noise, noise]
  218. print(rotNegRes) # EXPECT: [noise, 1, 2, 3, 4]
  219. if __name__ == "__main__":
  220. main()