simplePIR.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. #include "libpir.hpp"
  2. #include "apps/server/DBDirectoryProcessor.hpp"
  3. bool run(DBHandler *db, uint64_t chosen_element, PIRParameters params){
  4. /******************************************************************************
  5. * PIR and Crypto Setup (must be done by both the client and the server)
  6. * In a real application the client and server must agree on the parameters
  7. * For example the client chooses and sends them to the server (or inversely)
  8. ******************************************************************************/
  9. HomomorphicCrypto *crypto = HomomorphicCryptoFactory::getCryptoMethod(params.crypto_params);
  10. // Absorption capacity of an LWE encryption scheme depends on the number of sums that are going
  11. // to be done in the PIR protocol, it must therefore be initialized
  12. // Warning here we suppose the biggest dimension is in d[0]
  13. // otherwise absorbtion needs to be computed accordingly
  14. crypto->setandgetAbsBitPerCiphertext(params.n[0]);
  15. /******************************************************************************
  16. * Query generation phase (client-side)
  17. ******************************************************************************/
  18. // Create the query generator object
  19. PIRQueryGenerator q_generator(params,*crypto);
  20. std::cout << "SimplePIR: Generating query ..." << std::endl;
  21. // Generate a query to get the FOURTH element in the database (indexes begin at 0)
  22. // Warning : if we had set params.alpha=2 elements would be aggregated 2 by 2 and
  23. // generatequery would only accept as input 0 (the two first elements) or 1 (the other two)
  24. q_generator.generateQuery(chosen_element);
  25. std::cout << "SimplePIR: Query generated" << std::endl;
  26. /******************************************************************************
  27. * Reply generation phase (server-side)
  28. ******************************************************************************/
  29. // Create the reply generator object
  30. PIRReplyGenerator r_generator(params,*crypto,db);
  31. r_generator.setPirParams(params);
  32. // In a real application the client would pop the queries from q with popQuery and
  33. // send them through the network and the server would receive and push them into s
  34. // using pushQuery
  35. char* query_element;
  36. while (q_generator.popQuery(&query_element))
  37. {
  38. r_generator.pushQuery(query_element);
  39. }
  40. // Import database
  41. // This could have been done on the "Database setup" phase if:
  42. // - the contents are static
  43. // - AND the imported database fits in RAM
  44. // - AND the server knows in advance the PIR and crypto parameters (e.g. chosen by him)
  45. std::cout << "SimplePIR: Importing database ..." << std::endl;
  46. // Warning aggregation is dealt with internally the bytes_per_db_element parameter here
  47. // is to be given WITHOUT multiplying it by params.alpha
  48. imported_database* imported_db = r_generator.importData(/* uint64_t offset*/ 0, /*uint64_t
  49. bytes_per_db_element */ db->getmaxFileBytesize());
  50. std::cout << "SimplePIR: Database imported" << std::endl;
  51. // Once the query is known and the database imported launch the reply generation
  52. std::cout << "SimplePIR: Generating reply ..." << std::endl;
  53. double start = omp_get_wtime();
  54. r_generator.generateReply(imported_db);
  55. double end = omp_get_wtime();
  56. std::cout << "SimplePIR: Reply generated in " << end-start << " seconds" << std::endl;
  57. /******************************************************************************
  58. * Reply extraction phase (client-side)
  59. ******************************************************************************/
  60. PIRReplyExtraction *r_extractor=new PIRReplyExtraction(params,*crypto);
  61. // In a real application the server would pop the replies from s with popReply and
  62. // send them through the network together with nbRepliesGenerated and aggregated_maxFileSize
  63. // and the client would receive the replies and push them into r using pushEncryptedReply
  64. std::cout << "SimplePIR: "<< r_generator.getnbRepliesGenerated()<< " Replies generated " << std::endl;
  65. uint64_t clientside_maxFileBytesize = db->getmaxFileBytesize();
  66. char* reply_element;
  67. while (r_generator.popReply(&reply_element))
  68. {
  69. r_extractor->pushEncryptedReply(reply_element);
  70. }
  71. std::cout << "SimplePIR: Extracting reply ..." << std::endl;
  72. r_extractor->extractReply(clientside_maxFileBytesize);
  73. std::cout << "SimplePIR: Reply extracted" << std::endl;
  74. // In a real application instead of writing to a buffer we could write to an output file
  75. char *outptr, *result, *tmp;
  76. outptr = result = (char*)calloc(r_extractor->getnbPlaintextReplies(clientside_maxFileBytesize)*r_extractor->getPlaintextReplyBytesize(), sizeof(char));
  77. while (r_extractor->popPlaintextResult(&tmp))
  78. {
  79. memcpy(outptr, tmp, r_extractor->getPlaintextReplyBytesize());
  80. outptr+=r_extractor->getPlaintextReplyBytesize();
  81. free(tmp);
  82. }
  83. // Result is in ... result
  84. /******************************************************************************
  85. * Test correctness
  86. ******************************************************************************/
  87. char *db_element = (char*)calloc(clientside_maxFileBytesize*params.alpha, sizeof(char));
  88. bool fail = false;
  89. db->readAggregatedStream(chosen_element, params.alpha, 0, clientside_maxFileBytesize, db_element);
  90. if (memcmp(result, db_element, clientside_maxFileBytesize*params.alpha))
  91. {
  92. std::cout << "SimplePIR: Test failed, the retrieved element is not correct" << std::endl;
  93. fail = true;
  94. }
  95. else
  96. {
  97. std::cout << "SimplePIR: Test succeeded !!!!!!!!!!!!!!!!!!!!!!!!" << std::endl<< std::endl;
  98. fail = false;
  99. }
  100. /******************************************************************************
  101. * Cleanup
  102. ******************************************************************************/
  103. delete imported_db;
  104. r_generator.freeQueries();
  105. return fail;
  106. }
  107. int main(int argc, char * argv[]) {
  108. uint64_t database_size, nb_files, chosen_element, maxFileBytesize;
  109. PIRParameters params;
  110. bool tests_failed = false;
  111. /******************************************************************************
  112. * Database setup (server-side)
  113. ******************************************************************************/
  114. // To Create the database generator object
  115. // it can be a DBGenerator that simulate nb_files files of size streamBytesize
  116. // database_size = 1ULL<<25; nb_files = 4; maxFileBytesize = database_size/nb_files;
  117. // DBGenerator db(nb_files, maxFileBytesize, /*bool silent*/ false);
  118. //
  119. // OR it can be a DBDirectoryProcessor that reads a real file in the ./db directory
  120. // and splits it into nb_files virtual files
  121. // nb_files = 4;
  122. // DBDirectoryProcessor db(nb_files);
  123. // database_size = db.getDBSizeinbits();maxFileBytesize = database_size/nb_files;
  124. //
  125. // OR it can be a DBDirectoryProcessor that reads the real files in the ./db directory
  126. // DBDirectoryProcessor db;
  127. // nb_files=db.getNbStream();database_size = db.getDBSizeinbits();
  128. // maxFileBytesize = database_size/nb_files;
  129. // Simple test
  130. database_size = 1ULL<<31; nb_files = 20; maxFileBytesize = database_size/nb_files;
  131. DBGenerator db(nb_files, maxFileBytesize, /*bool silent*/ false);
  132. chosen_element = 3;
  133. params.alpha = 1; params.d = 1; params.n[0] = nb_files;
  134. // The crypto parameters can be set to other values
  135. // You can get a list of all available cryptographic parameters with this function call
  136. // HomomorphicCryptoFactory::printAllCryptoParams();
  137. params.crypto_params = "LWE:80:2048:120";
  138. tests_failed |= run(&db, chosen_element, params);
  139. // Test with aggregation
  140. // WARNING we must provide the representation of the database GIVEN recursion and aggregation
  141. // as here we have 100 elements and aggregate them in a unique group we have params.n[0]=1
  142. database_size = 1ULL<<25; nb_files = 100; maxFileBytesize = database_size/nb_files;
  143. DBGenerator db2(nb_files, maxFileBytesize, /*bool silent*/ false);
  144. chosen_element = 0;
  145. params.alpha = 100; params.d = 1; params.n[0] = 1;
  146. params.crypto_params = "LWE:80:2048:120";
  147. tests_failed |= run(&db2, chosen_element, params);
  148. // Test with recursion 2
  149. database_size = 1ULL<<25; nb_files = 100; maxFileBytesize = database_size/nb_files;
  150. DBGenerator db3(nb_files, maxFileBytesize, /*bool silent*/ false);
  151. chosen_element = 3;
  152. params.alpha = 1; params.d = 2; params.n[0] = 50; params.n[1] = 2;
  153. params.crypto_params = "LWE:80:2048:120";
  154. tests_failed |= run(&db3, chosen_element, params);
  155. // Test with recursion 2 and aggregation
  156. database_size = 1ULL<<25; nb_files = 100; maxFileBytesize = database_size/nb_files;
  157. DBGenerator db4(nb_files, maxFileBytesize, /*bool silent*/ false);
  158. chosen_element = 3;
  159. params.alpha = 2; params.d = 2; params.n[0] = 25; params.n[1] = 2;
  160. params.crypto_params = "LWE:80:2048:120";
  161. tests_failed |= run(&db4, chosen_element, params);
  162. // Test with recursion 3
  163. database_size = 1ULL<<25; nb_files = 100; maxFileBytesize = database_size/nb_files;
  164. DBGenerator db5(nb_files, maxFileBytesize, /*bool silent*/ false);
  165. chosen_element = 3;
  166. params.alpha = 1; params.d = 3; params.n[0] = 5; params.n[1] = 5; params.n[2] = 4;
  167. params.crypto_params = "LWE:80:2048:120";
  168. tests_failed |= run(&db5, chosen_element, params);
  169. // Test with a DBDirectoryProcessor splitting a big real file
  170. database_size = 1ULL<<25; nb_files = 4; maxFileBytesize = database_size/nb_files;
  171. DBDirectoryProcessor db6(/*split the bit file in*/ nb_files /*files*/);
  172. chosen_element = 3;
  173. params.alpha = 1; params.d = 1; params.n[0] = nb_files;
  174. params.crypto_params = "LWE:80:2048:120";
  175. tests_failed |= run(&db6, chosen_element, params);
  176. // Test with a DBDirectoryProcessor reading real files
  177. DBDirectoryProcessor db7;
  178. database_size = db7.getDBSizeBits()/8; nb_files = db7.getNbStream();
  179. maxFileBytesize = database_size/nb_files;
  180. chosen_element = 0;
  181. params.alpha = 1; params.d = 1; params.n[0] = nb_files;
  182. params.crypto_params = "LWE:80:2048:120";
  183. tests_failed |= run(&db7, chosen_element, params);
  184. if (tests_failed)
  185. {
  186. std::cout << "WARNING : at least one tests failed" << std::endl;
  187. return 1;
  188. }
  189. else
  190. {
  191. std::cout << "All tests succeeded" << std::endl;
  192. return 0;
  193. }
  194. }