SerialDeserial.cc 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. #include "SerialDeserial.h"
  2. #include "openfhe/pke/cryptocontext-ser.h"
  3. #include "Ciphertext.h"
  4. #include "CryptoContext.h"
  5. #include "PrivateKey.h"
  6. #include "PublicKey.h"
  7. namespace openfhe
  8. {
  9. template <typename ST, typename Object>
  10. [[nodiscard]] bool SerialDeserial(const std::string& location,
  11. bool (* const funcPtr) (const std::string&, Object&, const ST&), Object& object)
  12. {
  13. return funcPtr(location, object, ST{});
  14. }
  15. template <typename Object>
  16. [[nodiscard]] bool Deserial(const std::string& location, Object& object,
  17. const SerialMode serialMode)
  18. {
  19. if (serialMode == SerialMode::BINARY)
  20. {
  21. return SerialDeserial<lbcrypto::SerType::SERBINARY, decltype(object.GetRef())>(location,
  22. lbcrypto::Serial::DeserializeFromFile, object.GetRef());
  23. }
  24. if (serialMode == SerialMode::JSON)
  25. {
  26. return SerialDeserial<lbcrypto::SerType::SERJSON, decltype(object.GetRef())>(location,
  27. lbcrypto::Serial::DeserializeFromFile, object.GetRef());
  28. }
  29. return false;
  30. }
  31. template <typename Object>
  32. [[nodiscard]] bool Serial(const std::string& location, Object& object, const SerialMode serialMode)
  33. {
  34. if (serialMode == SerialMode::BINARY)
  35. {
  36. return SerialDeserial<lbcrypto::SerType::SERBINARY, decltype(object.GetRef())>(location,
  37. lbcrypto::Serial::SerializeToFile, object.GetRef());
  38. }
  39. if (serialMode == SerialMode::JSON)
  40. {
  41. return SerialDeserial<lbcrypto::SerType::SERJSON, decltype(object.GetRef())>(location,
  42. lbcrypto::Serial::SerializeToFile, object.GetRef());
  43. }
  44. return false;
  45. }
  46. template <typename ST, typename Stream, typename FStream, typename... Types>
  47. [[nodiscard]] bool SerialDeserial(const std::string& location,
  48. bool (* const funcPtr) (Stream&, const ST&, Types... args), Types... args)
  49. {
  50. const auto close = [](FStream* const fs){ if (fs->is_open()) { fs->close(); } };
  51. const std::unique_ptr<FStream, decltype(close)> fs(
  52. new FStream(location, std::ios::binary), close);
  53. return fs->is_open() ? funcPtr(*fs, ST{}, args...) : false;
  54. }
  55. // Ciphertext
  56. bool DCRTPolyDeserializeCiphertextFromFile(const std::string& ciphertextLocation,
  57. CiphertextDCRTPoly& ciphertext, const SerialMode serialMode)
  58. {
  59. return Deserial(ciphertextLocation, ciphertext, serialMode);
  60. }
  61. bool DCRTPolySerializeCiphertextToFile(const std::string& ciphertextLocation,
  62. const CiphertextDCRTPoly& ciphertext, const SerialMode serialMode)
  63. {
  64. return Serial(ciphertextLocation, ciphertext, serialMode);
  65. }
  66. // CryptoContext
  67. bool DCRTPolyDeserializeCryptoContextFromFile(const std::string& ccLocation,
  68. CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
  69. {
  70. return Deserial(ccLocation, cryptoContext, serialMode);
  71. }
  72. bool DCRTPolySerializeCryptoContextToFile(const std::string& ccLocation,
  73. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
  74. {
  75. return Serial(ccLocation, cryptoContext, serialMode);
  76. }
  77. // EvalAutomorphismKey
  78. bool DCRTPolyDeserializeEvalAutomorphismKeyFromFile(const std::string& automorphismKeyLocation,
  79. const SerialMode serialMode)
  80. {
  81. if (serialMode == SerialMode::BINARY)
  82. {
  83. return SerialDeserial<lbcrypto::SerType::SERBINARY, std::istream, std::ifstream>(
  84. automorphismKeyLocation, CryptoContextImpl::DeserializeEvalAutomorphismKey);
  85. }
  86. if (serialMode == SerialMode::JSON)
  87. {
  88. return SerialDeserial<lbcrypto::SerType::SERJSON, std::istream, std::ifstream>(
  89. automorphismKeyLocation, CryptoContextImpl::DeserializeEvalAutomorphismKey);
  90. }
  91. return false;
  92. }
  93. bool DCRTPolySerializeEvalAutomorphismKeyByIdToFile(const std::string& automorphismKeyLocation,
  94. const SerialMode serialMode, const std::string& id)
  95. {
  96. if (serialMode == SerialMode::BINARY)
  97. {
  98. return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
  99. automorphismKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey, id);
  100. }
  101. if (serialMode == SerialMode::JSON)
  102. {
  103. return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
  104. automorphismKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey, id);
  105. }
  106. return false;
  107. }
  108. bool DCRTPolySerializeEvalAutomorphismKeyToFile(const std::string& automorphismKeyLocation,
  109. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
  110. {
  111. if (serialMode == SerialMode::BINARY)
  112. {
  113. return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
  114. automorphismKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey,
  115. cryptoContext.GetRef());
  116. }
  117. if (serialMode == SerialMode::JSON)
  118. {
  119. return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
  120. automorphismKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey,
  121. cryptoContext.GetRef());
  122. }
  123. return false;
  124. }
  125. // EvalMultKey
  126. bool DCRTPolyDeserializeEvalMultKeyFromFile(const std::string& multKeyLocation,
  127. const SerialMode serialMode)
  128. {
  129. if (serialMode == SerialMode::BINARY)
  130. {
  131. return SerialDeserial<lbcrypto::SerType::SERBINARY, std::istream, std::ifstream>(
  132. multKeyLocation, CryptoContextImpl::DeserializeEvalMultKey);
  133. }
  134. if (serialMode == SerialMode::JSON)
  135. {
  136. return SerialDeserial<lbcrypto::SerType::SERJSON, std::istream, std::ifstream>(
  137. multKeyLocation, CryptoContextImpl::DeserializeEvalMultKey);
  138. }
  139. return false;
  140. }
  141. bool SerializeEvalMultKeyDCRTPolyByIdToFile(const std::string& multKeyLocation,
  142. const SerialMode serialMode, const std::string& id)
  143. {
  144. if (serialMode == SerialMode::BINARY)
  145. {
  146. return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
  147. multKeyLocation, CryptoContextImpl::SerializeEvalMultKey, id);
  148. }
  149. if (serialMode == SerialMode::JSON)
  150. {
  151. return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
  152. multKeyLocation, CryptoContextImpl::SerializeEvalMultKey, id);
  153. }
  154. return false;
  155. }
  156. bool DCRTPolySerializeEvalMultKeyToFile(const std::string& multKeyLocation,
  157. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
  158. {
  159. if (serialMode == SerialMode::BINARY)
  160. {
  161. return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
  162. multKeyLocation, CryptoContextImpl::SerializeEvalMultKey, cryptoContext.GetRef());
  163. }
  164. if (serialMode == SerialMode::JSON)
  165. {
  166. return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
  167. multKeyLocation, CryptoContextImpl::SerializeEvalMultKey, cryptoContext.GetRef());
  168. }
  169. return false;
  170. }
  171. // EvalSumKey
  172. bool DCRTPolyDeserializeEvalSumKeyFromFile(const std::string& sumKeyLocation, const SerialMode serialMode)
  173. {
  174. if (serialMode == SerialMode::BINARY)
  175. {
  176. return SerialDeserial<lbcrypto::SerType::SERBINARY, std::istream, std::ifstream>(
  177. sumKeyLocation, CryptoContextImpl::DeserializeEvalAutomorphismKey);
  178. }
  179. if (serialMode == SerialMode::JSON)
  180. {
  181. return SerialDeserial<lbcrypto::SerType::SERJSON, std::istream, std::ifstream>(
  182. sumKeyLocation, CryptoContextImpl::DeserializeEvalAutomorphismKey);
  183. }
  184. return false;
  185. }
  186. bool DCRTPolySerializeEvalSumKeyByIdToFile(const std::string& sumKeyLocation,
  187. const SerialMode serialMode, const std::string& id)
  188. {
  189. if (serialMode == SerialMode::BINARY)
  190. {
  191. return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
  192. sumKeyLocation, CryptoContextImpl::SerializeEvalSumKey, id);
  193. }
  194. if (serialMode == SerialMode::JSON)
  195. {
  196. return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
  197. sumKeyLocation, CryptoContextImpl::SerializeEvalSumKey, id);
  198. }
  199. return false;
  200. }
  201. bool DCRTPolySerializeEvalSumKeyToFile(const std::string& sumKeyLocation,
  202. const CryptoContextDCRTPoly& cryptoContext, const SerialMode serialMode)
  203. {
  204. if (serialMode == SerialMode::BINARY)
  205. {
  206. return SerialDeserial<lbcrypto::SerType::SERBINARY, std::ostream, std::ofstream>(
  207. sumKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey,
  208. cryptoContext.GetRef());
  209. }
  210. if (serialMode == SerialMode::JSON)
  211. {
  212. return SerialDeserial<lbcrypto::SerType::SERJSON, std::ostream, std::ofstream>(
  213. sumKeyLocation, CryptoContextImpl::SerializeEvalAutomorphismKey,
  214. cryptoContext.GetRef());
  215. }
  216. return false;
  217. }
  218. // PublicKey
  219. bool DCRTPolyDeserializePublicKeyFromFile(const std::string& publicKeyLocation,
  220. PublicKeyDCRTPoly& publicKey, const SerialMode serialMode)
  221. {
  222. return Deserial(publicKeyLocation, publicKey, serialMode);
  223. }
  224. bool DCRTPolySerializePublicKeyToFile(const std::string& publicKeyLocation,
  225. const PublicKeyDCRTPoly& publicKey, const SerialMode serialMode)
  226. {
  227. return Serial(publicKeyLocation, publicKey, serialMode);
  228. }
  229. bool DCRTPolyDeserializePrivateKeyFromFile(const std::string& privateKeyLocation,
  230. PrivateKeyDCRTPoly& privateKey, const SerialMode serialMode)
  231. {
  232. return Deserial(privateKeyLocation, privateKey, serialMode);
  233. }
  234. bool DCRTPolySerializePrivateKeyToFile(const std::string& privateKeyLocation,
  235. const PrivateKeyDCRTPoly& privateKey, const SerialMode serialMode)
  236. {
  237. return Serial(privateKeyLocation, privateKey, serialMode);
  238. }
  239. } // openfhe