simple-real-numbers-serial.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. from openfhe import *
  2. import os
  3. from pathlib import Path
  4. import tempfile
  5. # NOTE:
  6. # If running locally, you may want to replace the "hardcoded" datafolder with
  7. # the datafolder location below which gets the current working directory
  8. # Save-Load locations for keys
  9. datafolder = "demoData"
  10. ccLocation = "/cryptocontext.txt"
  11. pubKeyLocation = "/key_pub.txt" # Pub key
  12. multKeyLocation = "/key_mult.txt" # relinearization key
  13. rotKeyLocation = "/key_rot.txt" # automorphism / rotation key
  14. # Save-load locations for RAW ciphertexts
  15. cipherOneLocation = "/ciphertext1.txt"
  16. cipherTwoLocation = "/ciphertext2.txt"
  17. # Save-load locations for evaluated ciphertexts
  18. cipherMultLocation = "/ciphertextMult.txt"
  19. cipherAddLocation = "/ciphertextAdd.txt"
  20. cipherRotLocation = "/ciphertextRot.txt"
  21. cipherRotNegLocation = "/ciphertextRotNegLocation.txt"
  22. clientVectorLocation = "/clientVectorFromClient.txt"
  23. # Demarcate - Visual separator between the sections of code
  24. def demarcate(msg):
  25. print("**************************************************\n")
  26. print(msg)
  27. print("**************************************************\n")
  28. """
  29. serverSetupAndWrite(multDepth, scaleModSize, batchSize)
  30. simulates a server at startup where we generate a cryptocontext and keys.
  31. then, we generate some data (akin to loading raw data on an enclave)
  32. before encrypting the data
  33. :param multDepth: multiplication depth
  34. :param scaleModSize: number of bits to use in the scale factor (not the
  35. scale factor itself)
  36. :param batchSize: batch size to use
  37. :return Tuple<cryptoContext, keyPair>
  38. """
  39. def serverSetupAndWrite(multDepth, scaleModSize, batchSize):
  40. parameters = CCParamsCKKSRNS()
  41. parameters.SetMultiplicativeDepth(multDepth)
  42. parameters.SetScalingModSize(scaleModSize)
  43. parameters.SetBatchSize(batchSize)
  44. serverCC = GenCryptoContext(parameters)
  45. serverCC.Enable(PKE)
  46. serverCC.Enable(KEYSWITCH)
  47. serverCC.Enable(LEVELEDSHE)
  48. print("Cryptocontext generated")
  49. serverKP = serverCC.KeyGen()
  50. print("Keypair generated")
  51. serverCC.EvalMultKeyGen(serverKP.secretKey)
  52. print("Eval Mult Keys/ Relinearization keys have been generated")
  53. serverCC.EvalRotateKeyGen(serverKP.secretKey, [1, 2, -1, -2])
  54. print("Rotation keys generated")
  55. vec1 = [1.0, 2.0, 3.0, 4.0]
  56. vec2 = [12.5, 13.5, 14.5, 15.5]
  57. vec3 = [10.5, 11.5, 12.5, 13.5]
  58. print("\nDisplaying first data vector: ")
  59. print(vec1)
  60. print("\n")
  61. serverP1 = serverCC.MakeCKKSPackedPlaintext(vec1)
  62. serverP2 = serverCC.MakeCKKSPackedPlaintext(vec2)
  63. serverP3 = serverCC.MakeCKKSPackedPlaintext(vec3)
  64. print("Plaintext version of first vector: " + str(serverP1))
  65. print("Plaintexts have been generated from complex-double vectors")
  66. serverC1 = serverCC.Encrypt(serverKP.publicKey, serverP1)
  67. serverC2 = serverCC.Encrypt(serverKP.publicKey, serverP2)
  68. serverC3 = serverCC.Encrypt(serverKP.publicKey, serverP3)
  69. print("Ciphertexts have been generated from plaintexts")
  70. ###
  71. # Part 2:
  72. # We serialize the following:
  73. # Cryptocontext
  74. # Public key
  75. # relinearization (eval mult keys)
  76. # rotation keys
  77. # Some of the ciphertext
  78. #
  79. # We serialize all of them to files
  80. ###
  81. demarcate("Part 2: Data Serialization (server)")
  82. if not SerializeToFile(datafolder + ccLocation, serverCC, BINARY):
  83. raise Exception("Exception writing cryptocontext to cryptocontext.txt")
  84. print("Cryptocontext serialized")
  85. if not SerializeToFile(datafolder + pubKeyLocation, serverKP.publicKey, BINARY):
  86. raise Exception("Exception writing public key to pubkey.txt")
  87. print("Public key has been serialized")
  88. if not serverCC.SerializeEvalMultKey(datafolder + multKeyLocation, BINARY):
  89. raise Exception("Error writing eval mult keys")
  90. print("EvalMult/ relinearization keys have been serialized")
  91. if not serverCC.SerializeEvalAutomorphismKey(datafolder + rotKeyLocation, BINARY):
  92. raise Exception("Error writing rotation keys")
  93. print("Rotation keys have been serialized")
  94. if not SerializeToFile(datafolder + cipherOneLocation, serverC1, BINARY):
  95. raise Exception("Error writing ciphertext 1")
  96. if not SerializeToFile(datafolder + cipherTwoLocation, serverC2, BINARY):
  97. raise Exception("Error writing ciphertext 2")
  98. return (serverCC, serverKP, len(vec1))
  99. ###
  100. # clientProcess
  101. # - deserialize data from a file which simulates receiving data from a server
  102. # after making a request
  103. # - we then process the data by doing operations (multiplication, addition,
  104. # rotation, etc)
  105. # - !! We also create an object and encrypt it in this function before sending
  106. # it off to the server to be decrypted
  107. ###
  108. def clientProcess():
  109. # clientCC = CryptoContext()
  110. # clientCC.ClearEvalAutomorphismKeys()
  111. ReleaseAllContexts()
  112. ClearEvalMultKeys()
  113. clientCC, res = DeserializeCryptoContext(datafolder + ccLocation, BINARY)
  114. if not res:
  115. raise Exception(
  116. f"I cannot deserialize the cryptocontext from {datafolder+ccLocation}"
  117. )
  118. print("Client CC deserialized")
  119. # clientKP = KeyPair()
  120. # We do NOT have a secret key. The client
  121. # should not have access to this
  122. clientPuclicKey, res = DeserializePublicKey(datafolder + pubKeyLocation, BINARY)
  123. if not res:
  124. raise Exception(
  125. f"I cannot deserialize the public key from {datafolder+pubKeyLocation}"
  126. )
  127. print("Client KP deserialized\n")
  128. if not clientCC.DeserializeEvalMultKey(datafolder + multKeyLocation, BINARY):
  129. raise Exception(
  130. f"Cannot deserialize eval mult keys from {datafolder+multKeyLocation}"
  131. )
  132. print("Deserialized eval mult keys\n")
  133. if not clientCC.DeserializeEvalAutomorphismKey(datafolder + rotKeyLocation, BINARY):
  134. raise Exception(
  135. f"Cannot deserialize eval automorphism keys from {datafolder+rotKeyLocation}"
  136. )
  137. clientC1, res = DeserializeCiphertext(datafolder + cipherOneLocation, BINARY)
  138. if not res:
  139. raise Exception(
  140. f"Cannot deserialize the ciphertext from {datafolder+cipherOneLocation}"
  141. )
  142. print("Deserialized ciphertext 1\n")
  143. clientC2, res = DeserializeCiphertext(datafolder + cipherTwoLocation, BINARY)
  144. if not res:
  145. raise Exception(
  146. f"Cannot deserialize the ciphertext from {datafolder+cipherTwoLocation}"
  147. )
  148. print("Deserialized ciphertext 2\n")
  149. clientCiphertextMult = clientCC.EvalMult(clientC1, clientC2)
  150. clientCiphertextAdd = clientCC.EvalAdd(clientC1, clientC2)
  151. clientCiphertextRot = clientCC.EvalRotate(clientC1, 1)
  152. clientCiphertextRotNeg = clientCC.EvalRotate(clientC1, -1)
  153. # Now, we want to simulate a client who is encrypting data for the server to
  154. # decrypt. E.g weights of a machine learning algorithm
  155. demarcate("Part 3.5: Client Serialization of data that has been operated on")
  156. clientVector1 = [1.0, 2.0, 3.0, 4.0]
  157. clientPlaintext1 = clientCC.MakeCKKSPackedPlaintext(clientVector1)
  158. clientInitializedEncryption = clientCC.Encrypt(clientPuclicKey, clientPlaintext1)
  159. SerializeToFile(datafolder + cipherMultLocation, clientCiphertextMult, BINARY)
  160. SerializeToFile(datafolder + cipherAddLocation, clientCiphertextAdd, BINARY)
  161. SerializeToFile(datafolder + cipherRotLocation, clientCiphertextRot, BINARY)
  162. SerializeToFile(datafolder + cipherRotNegLocation, clientCiphertextRotNeg, BINARY)
  163. SerializeToFile(
  164. datafolder + clientVectorLocation, clientInitializedEncryption, BINARY
  165. )
  166. print("Serialized all ciphertexts from client\n")
  167. ###
  168. # serverVerification
  169. # - deserialize data from the client.
  170. # - Verify that the results are as we expect
  171. # @param cc cryptocontext that was previously generated
  172. # @param kp keypair that was previously generated
  173. # @param vectorSize vector size of the vectors supplied
  174. # @return
  175. # 5-tuple of the plaintexts of various operations
  176. ##
  177. def serverVerification(cc, kp, vectorSize):
  178. serverCiphertextFromClient_Mult, res = DeserializeCiphertext(
  179. datafolder + cipherMultLocation, BINARY
  180. )
  181. serverCiphertextFromClient_Add, res = DeserializeCiphertext(
  182. datafolder + cipherAddLocation, BINARY
  183. )
  184. serverCiphertextFromClient_Rot, res = DeserializeCiphertext(
  185. datafolder + cipherRotLocation, BINARY
  186. )
  187. serverCiphertextFromClient_RotNeg, res = DeserializeCiphertext(
  188. datafolder + cipherRotNegLocation, BINARY
  189. )
  190. serverCiphertextFromClient_Vec, res = DeserializeCiphertext(
  191. datafolder + clientVectorLocation, BINARY
  192. )
  193. print("Deserialized all data from client on server\n")
  194. print("Part 5: Correctness Verification")
  195. serverPlaintextFromClient_Mult = cc.Decrypt(
  196. kp.secretKey, serverCiphertextFromClient_Mult
  197. )
  198. serverPlaintextFromClient_Add = cc.Decrypt(
  199. kp.secretKey, serverCiphertextFromClient_Add
  200. )
  201. serverPlaintextFromClient_Rot = cc.Decrypt(
  202. kp.secretKey, serverCiphertextFromClient_Rot
  203. )
  204. serverPlaintextFromClient_RotNeg = cc.Decrypt(
  205. kp.secretKey, serverCiphertextFromClient_RotNeg
  206. )
  207. serverPlaintextFromClient_Vec = cc.Decrypt(
  208. kp.secretKey, serverCiphertextFromClient_Vec
  209. )
  210. serverPlaintextFromClient_Mult.SetLength(vectorSize)
  211. serverPlaintextFromClient_Add.SetLength(vectorSize)
  212. serverPlaintextFromClient_Vec.SetLength(vectorSize)
  213. serverPlaintextFromClient_Rot.SetLength(vectorSize + 1)
  214. serverPlaintextFromClient_RotNeg.SetLength(vectorSize + 1)
  215. return (
  216. serverPlaintextFromClient_Mult,
  217. serverPlaintextFromClient_Add,
  218. serverPlaintextFromClient_Vec,
  219. serverPlaintextFromClient_Rot,
  220. serverPlaintextFromClient_RotNeg,
  221. )
  222. def main():
  223. global datafolder
  224. with tempfile.TemporaryDirectory() as td:
  225. datafolder = td + "/" + datafolder
  226. os.makedirs(datafolder)
  227. main_action()
  228. def main_action():
  229. print(
  230. f"This program requires the subdirectory `{datafolder}' to exist, otherwise you will get\n an error writing serializations."
  231. )
  232. # Set main params
  233. multDepth = 5
  234. scaleModSize = 40
  235. batchSize = 32
  236. cryptoContextIdx = 0
  237. keyPairIdx = 1
  238. vectorSizeIdx = 2
  239. cipherMultResIdx = 0
  240. cipherAddResIdx = 1
  241. cipherVecResIdx = 2
  242. cipherRotResIdx = 3
  243. cipherRotNegResIdx = 4
  244. demarcate(
  245. "Part 1: Cryptocontext generation, key generation, data encryption \n(server)"
  246. )
  247. tupleCryptoContext_KeyPair = serverSetupAndWrite(multDepth, scaleModSize, batchSize)
  248. cc = tupleCryptoContext_KeyPair[cryptoContextIdx]
  249. kp = tupleCryptoContext_KeyPair[keyPairIdx]
  250. vectorSize = tupleCryptoContext_KeyPair[vectorSizeIdx]
  251. demarcate("Part 3: Client deserialize all data")
  252. clientProcess()
  253. demarcate("Part 4: Server deserialization of data from client. ")
  254. tupleRes = serverVerification(cc, kp, vectorSize)
  255. multRes = tupleRes[cipherMultResIdx]
  256. addRes = tupleRes[cipherAddResIdx]
  257. vecRes = tupleRes[cipherVecResIdx]
  258. rotRes = tupleRes[cipherRotResIdx]
  259. rotNegRes = tupleRes[cipherRotNegResIdx]
  260. # vec1: [1,2,3,4]
  261. # vec2: [12.5, 13.5, 14.5, 15.5]
  262. print(multRes) # EXPECT: 12.5, 27.0, 43.5, 62
  263. print(addRes) # EXPECT: 13.5, 15.5, 17.5, 19.5
  264. print(vecRes) # EXPECT: [1,2,3,4]
  265. print("Displaying 5 elements of a 4-element vector to illustrate rotation")
  266. print(rotRes) # EXPECT: [2, 3, 4, noise, noise]
  267. print(rotNegRes) # EXPECT: [noise, 1, 2, 3, 4]
  268. if __name__ == "__main__":
  269. main()