pir.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. #include "pir.hpp"
  2. using namespace std;
  3. using namespace seal;
  4. using namespace seal::util;
  5. std::vector<std::uint64_t> get_dimensions(std::uint64_t num_of_plaintexts, std::uint32_t d) {
  6. assert(d > 0);
  7. assert(num_of_plaintexts > 0);
  8. std::uint64_t root = max(static_cast<uint32_t>(2),static_cast<uint32_t>(floor(pow(num_of_plaintexts, 1.0/d))));
  9. std::vector<std::uint64_t> dimensions(d, root);
  10. for(int i = 0; i < d; i++){
  11. if(accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>()) > num_of_plaintexts){
  12. break;
  13. }
  14. dimensions[i] += 1;
  15. }
  16. std::uint32_t prod = accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>());
  17. assert(prod >= num_of_plaintexts);
  18. return dimensions;
  19. }
  20. void gen_encryption_params(std::uint32_t N, std::uint32_t logt,
  21. seal::EncryptionParameters &enc_params){
  22. enc_params.set_poly_modulus_degree(N);
  23. enc_params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
  24. enc_params.set_plain_modulus(PlainModulus::Batching(N, logt+1));
  25. }
  26. void verify_encryption_params(const seal::EncryptionParameters &enc_params){
  27. SEALContext context(enc_params, true);
  28. if(!context.parameters_set()){
  29. throw invalid_argument("SEAL parameters not valid.");
  30. }
  31. if(!context.using_keyswitching()){
  32. throw invalid_argument("SEAL parameters do not support key switching.");
  33. }
  34. if(!context.first_context_data()->qualifiers().using_batching){
  35. throw invalid_argument("SEAL parameters do not support batching.");
  36. }
  37. BatchEncoder batch_encoder(context);
  38. size_t slot_count = batch_encoder.slot_count();
  39. if(slot_count != enc_params.poly_modulus_degree()){
  40. throw invalid_argument("Slot count not equal to poly modulus degree - this will cause issues downstream.");
  41. }
  42. return;
  43. }
  44. void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
  45. const EncryptionParameters &enc_params, PirParams &pir_params,
  46. bool enable_symmetric, bool enable_batching){
  47. std::uint32_t N = enc_params.poly_modulus_degree();
  48. Modulus t = enc_params.plain_modulus();
  49. std::uint32_t logt = floor(log2(t.value()));
  50. std::uint64_t elements_per_plaintext;
  51. std::uint64_t num_of_plaintexts;
  52. if(enable_batching){
  53. elements_per_plaintext = elements_per_ptxt(logt, N, ele_size);
  54. num_of_plaintexts = plaintexts_per_db(logt, N, ele_num, ele_size);
  55. }
  56. else{
  57. elements_per_plaintext = 1;
  58. num_of_plaintexts = ele_num;
  59. }
  60. vector<uint64_t> nvec = get_dimensions(num_of_plaintexts, d);
  61. uint32_t expansion_ratio = 0;
  62. for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
  63. double logqi = log2(enc_params.coeff_modulus()[i].value());
  64. expansion_ratio += ceil(logqi / logt);
  65. }
  66. pir_params.enable_symmetric = enable_symmetric;
  67. pir_params.enable_batching = enable_batching;
  68. pir_params.ele_num = ele_num;
  69. pir_params.ele_size = ele_size;
  70. pir_params.elements_per_plaintext = elements_per_plaintext;
  71. pir_params.num_of_plaintexts = num_of_plaintexts;
  72. pir_params.d = d;
  73. pir_params.expansion_ratio = expansion_ratio << 1;
  74. pir_params.nvec = nvec;
  75. pir_params.slot_count = N;
  76. }
  77. void print_pir_params(const PirParams &pir_params){
  78. std::uint32_t prod = accumulate(pir_params.nvec.begin(), pir_params.nvec.end(), 1, multiplies<uint64_t>());
  79. cout << "PIR Parameters" << endl;
  80. cout << "number of elements: " << pir_params.ele_num << endl;
  81. cout << "element size: " << pir_params.ele_size << endl;
  82. cout << "elements per BFV plaintext: " << pir_params.elements_per_plaintext << endl;
  83. cout << "dimensions for d-dimensional hyperrectangle: " << pir_params.d << endl;
  84. cout << "number of BFV plaintexts (before padding): " << pir_params.num_of_plaintexts << endl;
  85. cout << "Number of BFV plaintexts after padding (to fill d-dimensional hyperrectangle): " << prod << endl;
  86. cout << "expansion ratio: " << pir_params.expansion_ratio << endl;
  87. cout << "slot count: " << pir_params.slot_count << endl;
  88. cout << "=============================="<< endl;
  89. }
  90. void print_seal_params(const EncryptionParameters &enc_params){
  91. std::uint32_t N = enc_params.poly_modulus_degree();
  92. Modulus t = enc_params.plain_modulus();
  93. std::uint32_t logt = floor(log2(t.value()));
  94. cout << "SEAL encryption parameters" << endl;
  95. cout << "Degree of polynomial modulus (N): " << N << endl;
  96. cout << "Size of plaintext modulus (log t):" << logt << endl;
  97. cout << "There are " << enc_params.coeff_modulus().size() << " coefficient modulus:" << endl;
  98. for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
  99. double logqi = log2(enc_params.coeff_modulus()[i].value());
  100. cout << "Size of coefficient modulus " << i << " (log q_" << i << "): " << logqi << endl;
  101. }
  102. cout << "=============================="<< endl;
  103. }
  104. uint32_t plainmod_after_expansion(uint32_t logt, uint32_t N, uint32_t d,
  105. uint64_t ele_num, uint64_t ele_size) {
  106. // Goal: find max logtp such that logtp + ceil(log(ceil(d_root(n)))) <= logt
  107. // where n = ceil(ele_num / floor(N*logtp / ele_size *8))
  108. for (uint32_t logtp = logt; logtp >= 2; logtp--) {
  109. uint64_t n = plaintexts_per_db(logtp, N, ele_num, ele_size);
  110. if (logtp == logt && n == 1) {
  111. return logtp - 1;
  112. }
  113. if ((double)logtp + ceil(log2(ceil(pow(n, 1.0/(double)d)))) <= logt) {
  114. return logtp;
  115. }
  116. }
  117. assert(0); // this should never happen
  118. return logt;
  119. }
  120. // Number of coefficients needed to represent a database element
  121. uint64_t coefficients_per_element(uint32_t logtp, uint64_t ele_size) {
  122. return ceil(8 * ele_size / (double)logtp);
  123. }
  124. // Number of database elements that can fit in a single FV plaintext
  125. uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
  126. uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
  127. uint64_t ele_per_ptxt = N / coeff_per_ele;
  128. assert(ele_per_ptxt > 0);
  129. return ele_per_ptxt;
  130. }
  131. // Number of FV plaintexts needed to represent the database
  132. uint64_t plaintexts_per_db(uint32_t logtp, uint64_t N, uint64_t ele_num, uint64_t ele_size) {
  133. uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
  134. return ceil((double)ele_num / ele_per_ptxt);
  135. }
  136. vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes, uint64_t size) {
  137. uint64_t size_out = coefficients_per_element(limit, size);
  138. vector<uint64_t> output(size_out);
  139. uint32_t room = limit;
  140. uint64_t *target = &output[0];
  141. for (uint32_t i = 0; i < size; i++) {
  142. uint8_t src = bytes[i];
  143. uint32_t rest = 8;
  144. while (rest) {
  145. if (room == 0) {
  146. target++;
  147. room = limit;
  148. }
  149. uint32_t shift = rest;
  150. if (room < rest) {
  151. shift = room;
  152. }
  153. *target = *target << shift;
  154. *target = *target | (src >> (8 - shift));
  155. src = src << shift;
  156. room -= shift;
  157. rest -= shift;
  158. }
  159. }
  160. *target = *target << room;
  161. return output;
  162. }
  163. void coeffs_to_bytes(uint32_t limit, const vector<uint64_t> &coeffs, uint8_t *output, uint32_t size_out) {
  164. uint32_t room = 8;
  165. uint32_t j = 0;
  166. uint8_t *target = output;
  167. for (uint32_t i = 0; i < coeffs.size(); i++) {
  168. uint64_t src = coeffs[i];
  169. uint32_t rest = limit;
  170. while (rest && j < size_out) {
  171. uint32_t shift = rest;
  172. if (room < rest) {
  173. shift = room;
  174. }
  175. target[j] = target[j] << shift;
  176. target[j] = target[j] | (src >> (limit - shift));
  177. src = src << shift;
  178. room -= shift;
  179. rest -= shift;
  180. if (room == 0) {
  181. j++;
  182. room = 8;
  183. }
  184. }
  185. }
  186. }
  187. void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
  188. uint32_t coeff_count = coeffs.size();
  189. plain.resize(coeff_count);
  190. util::set_uint(coeffs.data(), coeff_count, plain.data());
  191. }
  192. vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
  193. uint32_t num = Nvec.size();
  194. uint64_t product = 1;
  195. for (uint32_t i = 0; i < num; i++) {
  196. product *= Nvec[i];
  197. }
  198. uint64_t j = desiredIndex;
  199. vector<uint64_t> result;
  200. for (uint32_t i = 0; i < num; i++) {
  201. product /= Nvec[i];
  202. uint64_t ji = j / product;
  203. result.push_back(ji);
  204. j -= ji * product;
  205. }
  206. return result;
  207. }
  208. uint64_t invert_mod(uint64_t m, const seal::Modulus& mod) {
  209. if (mod.uint64_count() > 1) {
  210. cout << "Mod too big to invert";
  211. }
  212. uint64_t inverse = 0;
  213. if (!seal::util::try_invert_uint_mod(m, mod.value(), inverse)) {
  214. cout << "Could not invert value";
  215. }
  216. return inverse;
  217. }
  218. inline Ciphertext deserialize_ciphertext(string s, shared_ptr<SEALContext> context) {
  219. Ciphertext c;
  220. std::istringstream input(s);
  221. c.unsafe_load(*context, input);
  222. return c;
  223. }
  224. vector<Ciphertext> deserialize_ciphertexts(uint32_t count, string s, uint32_t len_ciphertext,
  225. shared_ptr<SEALContext> context) {
  226. vector<Ciphertext> c;
  227. for (uint32_t i = 0; i < count; i++) {
  228. c.push_back(deserialize_ciphertext(s.substr(i * len_ciphertext, len_ciphertext), context));
  229. }
  230. return c;
  231. }
  232. PirQuery deserialize_query(uint32_t d, uint32_t count, string s, uint32_t len_ciphertext,
  233. shared_ptr<SEALContext> context) {
  234. vector<vector<Ciphertext>> c;
  235. for (uint32_t i = 0; i < d; i++) {
  236. c.push_back(deserialize_ciphertexts(
  237. count,
  238. s.substr(i * count * len_ciphertext, count * len_ciphertext),
  239. len_ciphertext, context)
  240. );
  241. }
  242. return c;
  243. }
  244. inline string serialize_ciphertext(Ciphertext c) {
  245. std::ostringstream output;
  246. c.save(output);
  247. return output.str();
  248. }
  249. string serialize_ciphertexts(vector<Ciphertext> c) {
  250. string s;
  251. for (uint32_t i = 0; i < c.size(); i++) {
  252. s.append(serialize_ciphertext(c[i]));
  253. }
  254. return s;
  255. }
  256. string serialize_query(vector<vector<Ciphertext>> c) {
  257. string s;
  258. for (uint32_t i = 0; i < c.size(); i++) {
  259. for (uint32_t j = 0; j < c[i].size(); j++) {
  260. s.append(serialize_ciphertext(c[i][j]));
  261. }
  262. }
  263. return s;
  264. }
  265. string serialize_galoiskeys(GaloisKeys g) {
  266. std::ostringstream output;
  267. g.save(output);
  268. return output.str();
  269. }
  270. GaloisKeys *deserialize_galoiskeys(string s, shared_ptr<SEALContext> context) {
  271. GaloisKeys *g = new GaloisKeys();
  272. std::istringstream input(s);
  273. g->unsafe_load(*context, input);
  274. return g;
  275. }