Sfoglia il codice sorgente

Merge pull request #22 from XPIR-team/feature/improvingsimplepir

Feature/improvingsimplepir
Kirija 8 anni fa
parent
commit
71fe08aa7b

+ 11 - 3
apps/server/DBDirectoryProcessor.cpp

@@ -56,11 +56,16 @@ DBDirectoryProcessor::DBDirectoryProcessor() : filesSplitting(false) {
 			}
 		}
 		std::cout << "DBDirectoryProcessor: " << i << " entries processed" << std::endl;
+    if (i==0) {
+      std::cout <<"DBDirectoryProcessor: No entries in the database" << std::endl;
+      error = true;
+    }
 		closedir (dir);
 	}
 	else // If there was a problem opening the directory
 	{
 		std::cout << "DBDirectoryProcessor: Error opening database directory" << std::endl;
+    error = true;
 	}
 
 	std::cout << "DBDirectoryProcessor: The size of the database is " << maxFileBytesize*file_list.size() << " bytes" << std::endl;
@@ -100,7 +105,7 @@ DBDirectoryProcessor::DBDirectoryProcessor(uint64_t nbStreams) : filesSplitting(
 		if(maxFileBytesize==0) {
 			std::cout << "DBDirectoryProcessor: ERROR cannot split a file en less than one byte elements!" << std::endl;
 			std::cout << "DBDirectoryProcessor: file " << realFileName << " is only "<< realFileSize << " long" << std::endl;
-			exit(1);
+			error = true;
 		}
 
 		closedir (dir);
@@ -111,8 +116,8 @@ DBDirectoryProcessor::DBDirectoryProcessor(uint64_t nbStreams) : filesSplitting(
 	else // If there was a problem opening the directory
 	{
 		std::cout << "DBDirectoryProcessor: Error when opening directory " <<directory<< std::endl;
-		exit(1);
-	}
+	  error = true;
+  }
 
 #ifdef DEBUG
 	std::cout << "maxFileBytesize." <<maxFileBytesize<< std::endl;
@@ -161,6 +166,9 @@ uint64_t DBDirectoryProcessor::getNbStream() {
 uint64_t DBDirectoryProcessor::getmaxFileBytesize() {
 	return maxFileBytesize;
 }
+bool DBDirectoryProcessor::getErrorStatus() {
+	return error;
+}
 
 std::ifstream* DBDirectoryProcessor::openStream(uint64_t streamNb, uint64_t requested_offset) {
 	std::string local_directory(DEFAULT_DIR_NAME);

+ 2 - 0
apps/server/DBDirectoryProcessor.hpp

@@ -40,6 +40,7 @@ private:
   std::vector<std::ifstream*> fdPool; // a pool of file descriptors
   std::vector <std::string> file_list; // the output file list
   bool filesSplitting;
+  bool error = false;
   std::string realFileName; // The name of the unique file in case of splitting
   
 public:
@@ -52,6 +53,7 @@ public:
   uint64_t getDBSizeBits();
   uint64_t getNbStream();
   uint64_t getmaxFileBytesize();
+  bool getErrorStatus();
   
   std::ifstream* openStream(uint64_t streamNb, uint64_t requested_offset);
   uint64_t readStream(std::ifstream* s,char * buf, uint64_t size);

+ 2 - 2
apps/server/PIRSession.cpp

@@ -406,10 +406,8 @@ void PIRSession::downloadWorker()
     {
       // Compute and allocate the size in bytes of a query ELEMENT of dimension j 
       msg_size = cryptoMethod->getPublicParameters().getQuerySizeFromRecLvl(j+1) / 8;
-      char buf[msg_size];
       boost::asio::socket_base::receive_buffer_size opt(65535);
       sessionSocket.set_option(opt);
-      auto boost_buffer = boost::asio::buffer(buf,msg_size); 
   //    boost_buffer = new boost::asio::buffer(buf, msg_size); 
 #ifdef DEBUG
       cout << "PIRSession: Size of the query element to be received is " << msg_size << endl;
@@ -419,6 +417,8 @@ void PIRSession::downloadWorker()
       // Iterate over all the elements of the query corresponding to the j-th dimension
       for (unsigned int i = 0; i < pirParam.n[j]; i++)
       {
+      char *buf = (char *) malloc(msg_size*sizeof(char));
+      auto boost_buffer = boost::asio::buffer(buf,msg_size); 
       if (i==0 && j == 0) cout << "PIRSession: Waiting for query elements ..." << endl;
         // Get a query element 
          //( async_read(sessionSocket, boost_buffer,boost::bind(&blo,boost::asio::placeholders::error)) );

+ 107 - 24
apps/simplepir/simplePIR.cpp

@@ -38,8 +38,9 @@ bool run(DBHandler *db, uint64_t chosen_element, PIRParameters params){
   ******************************************************************************/
 	
   // Create the reply generator object
-	PIRReplyGenerator r_generator(params,*crypto,db);
-  r_generator.setPirParams(params);
+  // We could have also defined PIRReplyGenerator *r_generator(params,*crypto,db);
+  // But we prefer a pointer to show (below) how to use multiple generators for a given db
+  PIRReplyGenerator *r_generator = new PIRReplyGenerator(params,*crypto,db);
 
   // In a real application the client would pop the queries from q with popQuery and 
   // send them through the network and the server would receive and push them into s 
@@ -47,7 +48,7 @@ bool run(DBHandler *db, uint64_t chosen_element, PIRParameters params){
   char* query_element;
   while (q_generator.popQuery(&query_element))
   {
-    r_generator.pushQuery(query_element);
+    r_generator->pushQuery(query_element);
   }
  
 	// Import database
@@ -58,18 +59,61 @@ bool run(DBHandler *db, uint64_t chosen_element, PIRParameters params){
   std::cout << "SimplePIR: Importing database ..." << std::endl;
   // Warning aggregation is dealt with internally the bytes_per_db_element parameter here
   // is to be given WITHOUT multiplying it by params.alpha
-	imported_database* imported_db = r_generator.importData(/* uint64_t offset*/ 0, /*uint64_t
+	imported_database* imported_db = r_generator->importData(/* uint64_t offset*/ 0, /*uint64_t
     bytes_per_db_element */ db->getmaxFileBytesize());
   std::cout << "SimplePIR: Database imported" << std::endl;
 
 	// Once the query is known and the database imported launch the reply generation
   std::cout << "SimplePIR: Generating reply ..." << std::endl;
 	double start = omp_get_wtime();
-	r_generator.generateReply(imported_db);
+	r_generator->generateReply(imported_db);
 	double end = omp_get_wtime();
   std::cout << "SimplePIR: Reply generated in " << end-start << " seconds" << std::endl;
+
+  /********************************************************************************
+   * Advanced example: uncomment it to test
+   * The object imported_db is separated from r_generator in purpose
+   * Here is an example on how to use the same imported_db for multiple queries
+   * DO NOT try to use the same reply generator more than once, this causes issues
+   * ******************************************************************************/
+
+#if 0
+  // Generate 3 replies from 3 queries
+  for (int i = 0 ; i < 3 ; i++){
+
+    // Pop (and drop for this simple example) the generated reply
+    char* reply_element_tmp;
+    while (r_generator->popReply(&reply_element_tmp)){
+      free(reply_element_tmp);
+    }
+
+    // If you are unable to reuse a r_generator object (e.g. if you want 
+    // to change the crypto object) you can always recreate a new generator
+    //delete r_generator;
+    //r_generator = new PIRReplyGenerator(params,*crypto,db);
 	
+    // In this example we want to use the same generator for 
+    // multiply queries. Before giving a new query to r_generator
+    // we must free the previous one. 
+    r_generator->freeQueries();
+    
+    // It is also possible to change the pir parameters with the 
+    // (unexposed) setPirParams(PIRParameters newparams) function
 	
+    // Generate a new query
+  	q_generator.generateQuery(chosen_element);
+
+    // Push it to the reply generator
+    while (q_generator.popQuery(&query_element))
+    {
+      r_generator->pushQuery(query_element);
+    }
+
+    // Generate again the reply
+	  r_generator->generateReply(imported_db);
+  }
+#endif
+
 
   /******************************************************************************
   * Reply extraction phase (client-side)
@@ -80,11 +124,11 @@ bool run(DBHandler *db, uint64_t chosen_element, PIRParameters params){
   // In a real application the server would pop the replies from s with popReply and 
   // send them through the network together with nbRepliesGenerated and aggregated_maxFileSize 
   // and the client would receive the replies and push them into r using pushEncryptedReply
-  std::cout << "SimplePIR: "<< r_generator.getnbRepliesGenerated()<< " Replies generated " << std::endl;
+  std::cout << "SimplePIR: "<< r_generator->getnbRepliesGenerated()<< " Replies generated " << std::endl;
 
   uint64_t clientside_maxFileBytesize = db->getmaxFileBytesize();
   char* reply_element;
-  while (r_generator.popReply(&reply_element))
+  while (r_generator->popReply(&reply_element))
   {
     r_extractor->pushEncryptedReply(reply_element);
   }
@@ -102,7 +146,7 @@ bool run(DBHandler *db, uint64_t chosen_element, PIRParameters params){
     outptr+=r_extractor->getPlaintextReplyBytesize();
     free(tmp);
   }
-  // Result is in ... result  
+  // Result is in ... result
   
 
   /******************************************************************************
@@ -127,8 +171,12 @@ bool run(DBHandler *db, uint64_t chosen_element, PIRParameters params){
   ******************************************************************************/
   
   delete imported_db;
-  r_generator.freeQueries();
-
+  r_generator->freeQueries();
+  delete r_generator;
+  delete r_extractor;
+  delete crypto;
+  free(result);
+  free(db_element);
   
 	return fail;
 	
@@ -163,7 +211,11 @@ int main(int argc, char * argv[]) {
   // maxFileBytesize = database_size/nb_files;
 
   // Simple test
-  database_size = 1ULL<<31; nb_files = 20; maxFileBytesize = database_size/nb_files;
+  std::cout << "======================================================================" << std::endl;
+  std::cout << "Test 1/7: database_size = 1ULL<<30; nb_files = 20;" << std::endl;
+  std::cout << "params.alpha = 1; params.d = 1; crypto_params = LWE:80:2048:120;" << std::endl; 
+  std::cout << "======================================================================" << std::endl;
+  database_size = 1ULL<<20; nb_files = 20; maxFileBytesize = database_size/nb_files;
   DBGenerator db(nb_files, maxFileBytesize, /*bool silent*/ false); 
   chosen_element = 3;
   params.alpha = 1; params.d = 1; params.n[0] = nb_files; 
@@ -173,18 +225,25 @@ int main(int argc, char * argv[]) {
   params.crypto_params = "LWE:80:2048:120"; 
   tests_failed |= run(&db, chosen_element, params);
   
- 
   // Test with aggregation
   // WARNING we must provide the representation of the database GIVEN recursion and aggregation
   // as here we have 100 elements and aggregate them in a unique group we have params.n[0]=1
+  std::cout << "======================================================================" << std::endl;
+  std::cout << "Test 2/7: database_size = 1ULL<<25; nb_files = 100;" << std::endl;
+  std::cout << "params.alpha = 100; params.d = 1; crypto_params = LWE:80:2048:120;" << std::endl; 
+  std::cout << "======================================================================" << std::endl;
   database_size = 1ULL<<25; nb_files = 100; maxFileBytesize = database_size/nb_files;
   DBGenerator db2(nb_files, maxFileBytesize, /*bool silent*/ false); 
   chosen_element = 0;
   params.alpha = 100; params.d = 1; params.n[0] = 1; 
   params.crypto_params = "LWE:80:2048:120";
   tests_failed |= run(&db2, chosen_element, params);
-  
+
   // Test with recursion 2
+  std::cout << "======================================================================" << std::endl;
+  std::cout << "Test 3/7: database_size = 1ULL<<25; nb_files = 100;" << std::endl;
+  std::cout << "params.alpha = 1; params.d = 2; crypto_params = LWE:80:2048:120;" << std::endl; 
+  std::cout << "======================================================================" << std::endl;
   database_size = 1ULL<<25; nb_files = 100; maxFileBytesize = database_size/nb_files;
   DBGenerator db3(nb_files, maxFileBytesize, /*bool silent*/ false); 
   chosen_element = 3;
@@ -193,6 +252,10 @@ int main(int argc, char * argv[]) {
   tests_failed |= run(&db3, chosen_element, params);
   
   // Test with recursion 2 and aggregation
+  std::cout << "======================================================================" << std::endl;
+  std::cout << "Test 4/7: database_size = 1ULL<<25; nb_files = 100;" << std::endl;
+  std::cout << "params.alpha = 2; params.d = 2; crypto_params = LWE:80:2048:120;" << std::endl; 
+  std::cout << "======================================================================" << std::endl;
   database_size = 1ULL<<25; nb_files = 100; maxFileBytesize = database_size/nb_files;
   DBGenerator db4(nb_files, maxFileBytesize, /*bool silent*/ false); 
   chosen_element = 3;
@@ -201,6 +264,10 @@ int main(int argc, char * argv[]) {
   tests_failed |= run(&db4, chosen_element, params);
   
   // Test with recursion 3
+  std::cout << "======================================================================" << std::endl;
+  std::cout << "Test 5/7: database_size = 1ULL<<25; nb_files = 100;" << std::endl;
+  std::cout << "params.alpha = 1; params.d = 3; crypto_params = LWE:80:2048:120;" << std::endl; 
+  std::cout << "======================================================================" << std::endl;
   database_size = 1ULL<<25; nb_files = 100; maxFileBytesize = database_size/nb_files;
   DBGenerator db5(nb_files, maxFileBytesize, /*bool silent*/ false); 
   chosen_element = 3;
@@ -209,21 +276,37 @@ int main(int argc, char * argv[]) {
   tests_failed |= run(&db5, chosen_element, params);
   
   // Test with a DBDirectoryProcessor splitting a big real file
+  std::cout << "======================================================================" << std::endl;
+  std::cout << "Test 6/7: DBDirectoryProcessor with split; database_size = 1ULL<<25; nb_files = 4;" << std::endl;
+  std::cout << "params.alpha = 1; params.d = 1; crypto_params = LWE:80:2048:120;" << std::endl; 
+  std::cout << "======================================================================" << std::endl;
   database_size = 1ULL<<25; nb_files = 4; maxFileBytesize = database_size/nb_files;
-  DBDirectoryProcessor db6(/*split the bit file in*/ nb_files /*files*/);
-  chosen_element = 3;
-  params.alpha = 1; params.d = 1; params.n[0] = nb_files; 
-  params.crypto_params = "LWE:80:2048:120";
-  tests_failed |= run(&db6, chosen_element, params);
+  DBDirectoryProcessor db6(/*split the first file in*/ nb_files /*files*/);
+  if (db6.getErrorStatus()==true){
+    std::cout << "SimplePIR : Error with db directory skipping test ..." << std::endl << std::endl;
+  } else {
+    chosen_element = 3;
+    params.alpha = 1; params.d = 1; params.n[0] = nb_files; 
+    params.crypto_params = "LWE:80:2048:120";
+    tests_failed |= run(&db6, chosen_element, params);
+  }
   
   // Test with a DBDirectoryProcessor reading real files
+  std::cout << "======================================================================" << std::endl;
+  std::cout << "Test 7/7: DBDirectoryProcessor without split;" << std::endl;
+  std::cout << "params.alpha = 1; params.d = 1; crypto_params = LWE:80:2048:120;" << std::endl; 
+  std::cout << "======================================================================" << std::endl;
   DBDirectoryProcessor db7;
-  database_size = db7.getDBSizeBits()/8; nb_files = db7.getNbStream(); 
-  maxFileBytesize = database_size/nb_files;
-  chosen_element = 0;
-  params.alpha = 1; params.d = 1; params.n[0] = nb_files; 
-  params.crypto_params = "LWE:80:2048:120";
-  tests_failed |= run(&db7, chosen_element, params);
+  if (db6.getErrorStatus()==true){
+    std::cout << "SimplePIR : Error with db directory skipping test ..." << std::endl << std::endl;
+  } else {
+    database_size = db7.getDBSizeBits()/8; nb_files = db7.getNbStream(); 
+    maxFileBytesize = database_size/nb_files;
+    chosen_element = 0;
+    params.alpha = 1; params.d = 1; params.n[0] = nb_files; 
+    params.crypto_params = "LWE:80:2048:120";
+    tests_failed |= run(&db7, chosen_element, params);
+  }
 
   if (tests_failed) 
   {

+ 1 - 1
apps/tools/check-correctness.sh

@@ -207,7 +207,7 @@ if [[ TEST_PAILLIER -eq 1 ]]; then
     done
 fi
 
-echo -e "\n\nOther tests\n#################\n"
+echo -e "\n\nTests\n#################\n"
 for DB in $ONE_MBIT $TEN_MBIT $HUNDRED_MBIT $ONE_GBIT 
 do
     for L in $ONE_KBIT $HUNDRED_KBIT $TEN_MBIT $ONE_GBIT  

+ 7 - 0
apps/tools/mkdb-correctness.sh

@@ -27,6 +27,13 @@ ONE_GBIT=1024000000
 #files: 1kbits, 100kbits, 10mbits 1gbit
 #bases: 1Mbits, 10M, 100M, 1G, 10G
 
+if [[ -e check.repo ]]
+then
+  echo "Check repo exists, not rebuilding it"
+  echo "Remove it manually if you want it rebuilt"
+  exit 
+fi
+
 mkdir check.repo
 cp -r ../client/exp check.repo/
 cd check.repo

+ 4 - 2
crypto/NFLLWE.cpp

@@ -383,7 +383,7 @@ void NFLLWE::dec(poly64 m, lwe_cipher *c)
 
 	for(unsigned short currentModulus=0;currentModulus<nbModuli;currentModulus++) {
 
-		// We firs% moduli[cm] t get the amplified noise plus message (e*A+m =b-a*S)
+		// We first get the amplified noise plus message (e*A+m =b-a*S)
 		for (unsigned int i=0 ; i < polyDegree; i++) 
     {
 			uint64_t temp=0;
@@ -449,7 +449,9 @@ void NFLLWE::dec(poly64 m, lwe_cipher *c)
 	    mpz_clear(tmprez[i]);
 	  }
     
-    free(tmprez);
+    delete[] tmprez;
+    mpz_clears(moduliProduct, tmpz, magicConstz, bitmaskz, NULL);
+
   } else { // nbModuli=1
 	
 	

+ 4 - 3
crypto/NFLlib.cpp

@@ -59,6 +59,7 @@ void  NFLlib::configureNTT()
   shoupinvomegas = (uint64_t **) malloc(nbModuli * sizeof(uint64_t *));  
   invpolyDegree = (uint64_t *) malloc(nbModuli * sizeof(uint64_t));
   liftingIntegers = new mpz_t[nbModuli];
+  moduli=new uint64_t[nbModuli]();
 
   // From now on, we have to do everything nbModuli times
   for(unsigned short currentModulus=0;currentModulus<nbModuli;currentModulus++) 
@@ -223,8 +224,6 @@ void NFLlib::setmodulus(uint64_t aggregatedModulusBitsize_)
   }
   nbModuli=aggregatedModulusBitsize_/kModulusBitsize;
   
-  moduli=new uint64_t[nbModuli]();
-
   configureNTT();
 }
 
@@ -528,7 +527,7 @@ mpz_t* NFLlib::poly2mpz(poly64 p)
   for(int cm = 0; cm < nbModuli;cm++) {
   	mpz_clear(tmpzbuffer[cm]);
   	}
-  free(tmpzbuffer);
+  delete[] tmpzbuffer;
   return resultmpz;
 }
 
@@ -584,6 +583,7 @@ void NFLlib::freeNTTMemory(){
   
     if (i == alreadyInit - 1)
     {
+      free(phis);
       free(shoupphis);
       free(invpoly_times_invphis);
       free(shoupinvpoly_times_invphis);
@@ -595,6 +595,7 @@ void NFLlib::freeNTTMemory(){
       delete[] liftingIntegers;
       free(inv_indexes);
       mpz_clear(moduliProduct);
+      delete[] moduli;
     }
   }
 

+ 17 - 10
pir/libpir.cpp

@@ -7,6 +7,7 @@
       free(((lwe_in_data *)imported_database_ptr)[i].p[0]);
       free(((lwe_in_data *)imported_database_ptr)[i].p);
     }
+    free(imported_database_ptr);
   }
 
 
@@ -77,14 +78,10 @@
 	: PIRReplyGeneratorNFL_internal (param,db)
   {
 		PIRReplyGeneratorNFL_internal::setCryptoMethod(&cryptoMethod_);
-    PIRReplyGeneratorNFL_internal::initQueriesBuffer();
     PIRReplyGeneratorNFL_internal::setPirParams(param);
-		nbRepliesToHandle=0;
-		nbRepliesGenerated=0;
-		currentReply=0;
   }
 
-
+  
   void PIRReplyGenerator::pushQuery(char* rawQuery) {
 		PIRReplyGeneratorNFL_internal::pushQuery(rawQuery);
   }
@@ -108,6 +105,13 @@
 
   void PIRReplyGenerator::generateReply(const imported_database* database)
   {
+    // Init
+		nbRepliesToHandle=0;
+		nbRepliesGenerated=0;
+		currentReply=0;
+    freeResult();
+
+    // Test memory
 		uint64_t usable_memory = getTotalSystemMemory();
 		nbRepliesGenerated=nbRepliesToHandle=computeReplySizeInChunks(database->beforeImportElementBytesize);
 		uint64_t polysize = cryptoMethod->getpolyDegree() * cryptoMethod->getnbModuli()*sizeof(uint64_t);
@@ -119,16 +123,19 @@
 		input_data = (lwe_in_data*) database->imported_database_ptr;
 		currentMaxNbPolys = database->polysPerElement;
     
-   		// The internal generator is locked by default waiting for the query to be received 
-    		// in this API we let the user deal with synchronisation so the lock is not needed
-    		PIRReplyGeneratorNFL_internal::mutex.unlock();
-
+   	// The internal generator is locked by default waiting for the query to be received 
+    // in this API we let the user deal with synchronisation so the lock is not needed
+    PIRReplyGeneratorNFL_internal::mutex.try_lock();
+    PIRReplyGeneratorNFL_internal::mutex.unlock();
+    
+    // Define the reply size
+    repliesAmount = computeReplySizeInChunks(database->beforeImportElementBytesize);
 		PIRReplyGeneratorNFL_internal::generateReply();
 
   }
 
   void PIRReplyGenerator::freeQueries(){
-    freeQuery();
+    PIRReplyGeneratorNFL_internal::freeQueries();
   }
 
 	

+ 7 - 7
pir/queryGen/PIRQueryGenerator_internal.cpp

@@ -50,6 +50,7 @@ void PIRQueryGenerator_internal::generateQuery()
   std::cout << "PIRQueryGenerator_internal: Generated a " << pirParams.n[j] << " element query" << std::endl;
   }
   double end = omp_get_wtime();
+  delete[] coord;
   
   std::cout << "PIRQueryGenerator_internal: All the queries have been generated, total time is " << end - start << " seconds" << std::endl;
 }
@@ -99,16 +100,15 @@ void PIRQueryGenerator_internal::joinThread()
 	if(queryThread.joinable()) queryThread.join();
 }
 
+void PIRQueryGenerator_internal::cleanQueryBuffer()
+{
+	while (!queryBuffer.empty())
+		free(queryBuffer.pop_front());
+}
+
 PIRQueryGenerator_internal::~PIRQueryGenerator_internal() 
 {
 	joinThread();
   cleanQueryBuffer();
-
-	delete[] coord;
 }
 
-void PIRQueryGenerator_internal::cleanQueryBuffer()
-{
-	while (!queryBuffer.empty())
-		free(queryBuffer.pop_front());
-}

+ 4 - 1
pir/replyGenerator/GenericPIRReplyGenerator.cpp

@@ -23,7 +23,10 @@ GenericPIRReplyGenerator::GenericPIRReplyGenerator():
   repliesAmount(0),
   repliesIndex(0)
 {
-   mutex.lock();
+  pirParam.d = 0;
+  pirParam.alpha = 0;
+  for (int i = 0 ; i < MAX_REC_LVL; i++) pirParam.n[i] = 0;
+  mutex.lock();
 }
 
 GenericPIRReplyGenerator::GenericPIRReplyGenerator(PIRParameters& param, DBHandler *db):

+ 4 - 4
pir/replyGenerator/GenericPIRReplyGenerator.hpp

@@ -45,16 +45,16 @@ class GenericPIRReplyGenerator
   protected:
 
 
-    PIRParameters emptyPIRParams;
+  PIRParameters emptyPIRParams;
 	PIRParameters& pirParam;
 	unsigned int maxChunkSize;
 	DBHandler* dbhandler;	
 
   public:
 	boost::mutex mutex;
-  char** repliesArray;
-  unsigned repliesAmount;
-  unsigned repliesIndex;
+  char** repliesArray = NULL;
+  unsigned repliesAmount = 0;
+  unsigned repliesIndex = 0;
 
     GenericPIRReplyGenerator();
     GenericPIRReplyGenerator(PIRParameters& param, DBHandler* db);

+ 3 - 1
pir/replyGenerator/PIROptimizer.cpp

@@ -145,7 +145,7 @@ std::string PIROptimizer::computeOptimData(const std::string& crypto_name)
     optim_data2write += crypto_param + " " + out.str() + "\n";
   }
 
-  //delete generator_ptr;
+  delete generator_ptr;
   delete crypto_ptr;
 
   return optim_data2write;
@@ -166,6 +166,7 @@ double PIROptimizer::getAbs1PlaintextTime(HomomorphicCrypto* crypto_ptr, Generic
 
   do
   {
+    generator->mutex.try_lock();
     generator->mutex.unlock();
     result = generator->generateReplySimulation(pir_params, plaintext_nbr);
     plaintext_nbr *= 2;
@@ -193,6 +194,7 @@ double PIROptimizer::getPrecompute1PlaintextTime(HomomorphicCrypto* crypto_ptr,
 
   do
   {
+    generator->mutex.try_lock();
     generator->mutex.unlock();
     result = generator->precomputationSimulation(pir_params, plaintext_nbr);
     plaintext_nbr *= 2;

+ 129 - 68
pir/replyGenerator/PIRReplyGeneratorNFL_internal.cpp

@@ -27,7 +27,13 @@
 //#define SNIFFER //Use this to activate a sniffer like behavior
 
 PIRReplyGeneratorNFL_internal::PIRReplyGeneratorNFL_internal():
-  lwe(false)
+  lwe(false),
+  currentMaxNbPolys(0),
+  queriesBuf(NULL),
+  current_query_index(0),
+  current_dim_index(0),
+  input_data(NULL),
+  cryptoMethod(NULL)
 {
 }
 
@@ -39,10 +45,13 @@ PIRReplyGeneratorNFL_internal::PIRReplyGeneratorNFL_internal():
  **/
 PIRReplyGeneratorNFL_internal::PIRReplyGeneratorNFL_internal( PIRParameters& param, DBHandler* db):
   lwe(false),
-  currentMaxNbPolys(1),
+  currentMaxNbPolys(0),
   GenericPIRReplyGenerator(param,db),
+  queriesBuf(NULL),
   current_query_index(0),
-  current_dim_index(0)
+  current_dim_index(0),
+  input_data(NULL),
+  cryptoMethod(NULL)
 {
   // cryptoMethod will be set later by setCryptoMethod
 }
@@ -75,7 +84,7 @@ void PIRReplyGeneratorNFL_internal::importDataNFL(uint64_t offset, uint64_t byte
 
 	for (unsigned int i = 0 ; i < pirParam.d ; i++) theoretical_files_nbr *= pirParam.n[i];
 
-	input_data = new lwe_in_data[theoretical_files_nbr];
+	input_data = (lwe_in_data *) malloc(sizeof(lwe_in_data)*theoretical_files_nbr);
 	char *rawBits = (char*)calloc(fileByteSize*pirParam.alpha, sizeof(char));
 
 	currentMaxNbPolys=0;
@@ -201,6 +210,7 @@ imported_database_t PIRReplyGeneratorNFL_internal::generateReplyGeneric(bool kee
   database_wrapper.imported_database_ptr = NULL;
   database_wrapper.nbElements = 0;
   database_wrapper.polysPerElement = 0;
+  database_wrapper.beforeImportElementBytesize = 0;
 
   // Don't use more than half of the computer's memory 
   usable_memory = getTotalSystemMemory()/2;
@@ -258,12 +268,15 @@ imported_database_t PIRReplyGeneratorNFL_internal::generateReplyGeneric(bool kee
     } 
 
     boost::mutex::scoped_lock l(mutex);
+    repliesAmount = computeReplySizeInChunks(dbhandler->getmaxFileBytesize());
     generateReply();
     end = omp_get_wtime();
 
     if(keep_imported_data && iteration == nbr_of_iterations - 1)  // && added for Perf test but is no harmful
     {
       database_wrapper.imported_database_ptr = (void*)input_data;
+      database_wrapper.beforeImportElementBytesize = dbhandler->getmaxFileBytesize();
+      database_wrapper.nbElements = dbhandler->getNbStream();
     } 
     else 
     {
@@ -274,7 +287,7 @@ imported_database_t PIRReplyGeneratorNFL_internal::generateReplyGeneric(bool kee
 	std::cout<<"PIRReplyGeneratorNFL_internal: Total process time " << end - start << " seconds" << std::endl;
 	std::cout<<"PIRReplyGeneratorNFL_internal: DB processing throughput " << 8*database_size/(end - start) << "bps" << std::endl;
 	std::cout<<"PIRReplyGeneratorNFL_internal: Client cleartext reception throughput  " << 8*dbhandler->getmaxFileBytesize()/(end - start) << "bps" << std::endl;
-  freeQuery();
+  freeQueries();
 
   return database_wrapper;
 }
@@ -289,16 +302,17 @@ void PIRReplyGeneratorNFL_internal::generateReplyGenericFromData(const imported_
   currentMaxNbPolys = database.polysPerElement;
 	boost::mutex::scoped_lock l(mutex);
   double start = omp_get_wtime();
+  repliesAmount = computeReplySizeInChunks(database.beforeImportElementBytesize);
   generateReply();
 #else
   uint64_t max_readable_size, database_size, nbr_of_iterations;
 
-  database_size = dbhandler->getmaxFileBytesize() * dbhandler->getNbStream();
-  max_readable_size = 1280000000UL/dbhandler->getNbStream();
+  database_size = database.beforeImportElementBytesize * database.nbElements;
+  max_readable_size = 1280000000UL/database.nbElements;
   // Ensure it is not larger than maxfilebytesize
-  max_readable_size = min(max_readable_size, dbhandler->getmaxFileBytesize());
+  max_readable_size = min(max_readable_size, database.beforeImportElementBytesize);
   // Given readable size we get how many iterations we need
-  nbr_of_iterations = ceil((double)dbhandler->getmaxFileBytesize()/max_readable_size);
+  nbr_of_iterations = ceil((double)database.beforeImportElementBytesize/max_readable_size);
 
 
   boost::mutex::scoped_lock l(mutex);
@@ -308,6 +322,7 @@ void PIRReplyGeneratorNFL_internal::generateReplyGenericFromData(const imported_
 
     input_data = (lwe_in_data*) database.imported_database_ptr;
     currentMaxNbPolys = database.polysPerElement;
+    repliesAmount = computeReplySizeInChunks(database.beforeImportElementBytesize);
     generateReply();
   }
   freeInputData();
@@ -316,7 +331,7 @@ void PIRReplyGeneratorNFL_internal::generateReplyGenericFromData(const imported_
 	std::cout<<"PIRReplyGeneratorNFL_internal: Total process time " << end - start << " seconds" << std::endl;
 	std::cout<<"PIRReplyGeneratorNFL_internal: DB processing throughput " << 8*dbhandler->getmaxFileBytesize()*dbhandler->getNbStream()/(end - start) << "bps" << std::endl;
 	std::cout<<"PIRReplyGeneratorNFL_internal: Client cleartext reception throughput  " << 8*dbhandler->getmaxFileBytesize()/(end - start) << "bps" << std::endl;
-  freeQuery();
+  freeQueries();
 }
 
 
@@ -327,12 +342,12 @@ void PIRReplyGeneratorNFL_internal::generateReplyExternal(imported_database_t* d
 {
   uint64_t max_readable_size, database_size, nbr_of_iterations;
 
-  database_size = dbhandler->getmaxFileBytesize() * dbhandler->getNbStream();
-  max_readable_size = 1280000000UL/dbhandler->getNbStream();
+  database_size = database->beforeImportElementBytesize * database->nbElements;
+  max_readable_size = 1280000000UL/database->nbElements;
   // Ensure it is not larger than maxfilebytesize
-  max_readable_size = min(max_readable_size, dbhandler->getmaxFileBytesize());
+  max_readable_size = min(max_readable_size, database->beforeImportElementBytesize);
   // Given readable size we get how many iterations we need
-  nbr_of_iterations = ceil((double)dbhandler->getmaxFileBytesize()/max_readable_size);
+  nbr_of_iterations = ceil((double)database->beforeImportElementBytesize/max_readable_size);
 
 
   boost::mutex::scoped_lock l(mutex);
@@ -342,6 +357,7 @@ void PIRReplyGeneratorNFL_internal::generateReplyExternal(imported_database_t* d
 
     input_data = (lwe_in_data*) database->imported_database_ptr;
     currentMaxNbPolys = database->polysPerElement;
+    repliesAmount = computeReplySizeInChunks(database->beforeImportElementBytesize);
     generateReply();
   }
   freeInputData();
@@ -349,7 +365,7 @@ void PIRReplyGeneratorNFL_internal::generateReplyExternal(imported_database_t* d
 	std::cout<<"PIRReplyGeneratorNFL_internal: Total process time " << end - start << " seconds" << std::endl;
 	std::cout<<"PIRReplyGeneratorNFL_internal: DB processing throughput " << 8*dbhandler->getmaxFileBytesize()*dbhandler->getNbStream()/(end - start) << "bps" << std::endl;
 	std::cout<<"PIRReplyGeneratorNFL_internal: Client cleartext reception throughput  " << 8*dbhandler->getmaxFileBytesize()/(end - start) << "bps" << std::endl;
-  freeQuery();
+  freeQueries();
 }
 
 
@@ -370,7 +386,7 @@ void PIRReplyGeneratorNFL_internal::generateReply()
   uint64_t old_poly_nbr = 1;
   
   // Allocate memory for the reply array
-  repliesAmount = computeReplySizeInChunks(dbhandler->getmaxFileBytesize());
+  if (repliesArray != NULL) freeResult();
   repliesArray = (char**)calloc(repliesAmount,sizeof(char*)); 
 
 
@@ -427,29 +443,36 @@ void PIRReplyGeneratorNFL_internal::generateReply()
     /*****************/
     /*MEMORY CLEANING*/
     /*****************/
+#ifdef DEBUG
     if ( i > 0)
     {
-#ifdef DEBUG
       cout << "PIRReplyGeneratorNFL_internal: reply_elt_nbr_OLD: " << old_reply_elt_nbr << endl;
+    }
 #endif
-     // for (unsigned int j = 0 ; j < old_reply_elt_nbr ; j++) {
-     //   free(in_data[j].p[0]);
-     //   free(in_data[j].p);
-     // }
-     // delete[] in_data;
+    if (i > 0)
+    { 
+      for (int j = 0 ; j < old_reply_elt_nbr ; j++)
+      {
+        free(in_data[j].p[0]);
+        free(in_data[j].p);
+      }
+      delete[] in_data;
     }
-  // When i i=> 2 clean old in_data.
     if (i < pirParam.d - 1) { 
       old_poly_nbr = currentMaxNbPolys;
       in_data = fromResulttoInData(inter_reply, reply_elt_nbr, i);
     }
 
     for (uint64_t j = 0 ; j < reply_elt_nbr ; j++) {
-      for (uint64_t k = 0 ; (k < old_poly_nbr) && (i < pirParam.d - 1); k++) free(inter_reply[j][k].a);
-
+      for (uint64_t k = 0 ; (k < old_poly_nbr) && (i < pirParam.d - 1); k++){
+        free(inter_reply[j][k].a);
+        inter_reply[j][k].a = NULL;
+      }
       delete[] inter_reply[j];
+      inter_reply[j] = NULL;
     }
     delete[] inter_reply; // allocated with a 'new' above. 
+    inter_reply = NULL;
   }
 
   // Compute execution time
@@ -459,22 +482,21 @@ void PIRReplyGeneratorNFL_internal::generateReply()
 
 double PIRReplyGeneratorNFL_internal::generateReplySimulation(const PIRParameters& pir_params, uint64_t plaintext_nbr)
 {
+  
   setPirParams((PIRParameters&)pir_params);
-  initQueriesBuffer();
   pushFakeQuery();
   
   importFakeData(plaintext_nbr);
 
 
-  uint64_t repliesAmount = computeReplySizeInChunks(cryptoMethod->getPublicParameters().getCiphertextBitsize() / CHAR_BIT);
-  repliesArray = (char**)calloc(repliesAmount,sizeof(char*)); 
+  repliesAmount = computeReplySizeInChunks(plaintext_nbr*cryptoMethod->getPublicParameters().getCiphertextBitsize() / CHAR_BIT);
 	repliesIndex = 0;
 
   double start = omp_get_wtime();
   generateReply();
   double result = omp_get_wtime() - start;
 
-  freeQuery();
+  freeQueries();
   freeInputData();
   freeResult();
   delete dbhandler;
@@ -487,7 +509,6 @@ double PIRReplyGeneratorNFL_internal::precomputationSimulation(const PIRParamete
 {
   NFLlib *nflptr = &(cryptoMethod->getnflInstance());
   setPirParams((PIRParameters&)pir_params);
-  initQueriesBuffer();
   pushFakeQuery();
   importFakeData(plaintext_nbr);
 
@@ -497,15 +518,14 @@ double PIRReplyGeneratorNFL_internal::precomputationSimulation(const PIRParamete
   double start = omp_get_wtime();
   for (unsigned int i = 0 ; i < files_nbr ; i++)
   {
-    {
       poly64 *tmp;
-      tmp= cryptoMethod->deserializeDataNFL((unsigned char**)(input_data[i].p), (uint64_t) plaintext_nbr, cryptoMethod->getPublicParameters().getCiphertextBitsize()/2 , input_data[i].nbPolys);
+      tmp = cryptoMethod->deserializeDataNFL((unsigned char**)(input_data[i].p), (uint64_t) plaintext_nbr, cryptoMethod->getPublicParameters().getCiphertextBitsize()/2 , input_data[i].nbPolys);
 	    free(tmp[0]);	
-    }
+      tmp = NULL;
   }
   double result = omp_get_wtime() - start;
   std::cout << "PIRReplyGeneratorNFL_internal: Deserialize took " << result << " (omp)seconds" << std::endl;
-  freeQuery();
+  freeQueries();
   freeInputData();
   freeResult();
   delete dbhandler;
@@ -633,9 +653,7 @@ lwe_in_data* PIRReplyGeneratorNFL_internal::fromResulttoInData(lwe_cipher** inte
 														currentMaxNbPolys, 
 														cryptoMethod->getPublicParameters().getCiphertextBitsize(), 
 														in_data2b[i].nbPolys);
-      //delete[] inter_reply[i]; free in generateReplyGeneric
     }
-    //delete[] inter_reply;
     free(bufferOfBuffers);
 
     currentMaxNbPolys = in_data2b_nbr_polys; 
@@ -725,15 +743,14 @@ void PIRReplyGeneratorNFL_internal::initQueriesBuffer() {
 
 void PIRReplyGeneratorNFL_internal::pushFakeQuery()
 {
-  char* query_element = cryptoMethod->encrypt(0, 1); 
+  char* query_element;
 
   for (unsigned int dim  = 0 ; dim < pirParam.d ; dim++) {
     for(unsigned int j = 0 ; j < pirParam.n[dim] ; j++) {
+      query_element = cryptoMethod->encrypt(0, 1); 
       pushQuery(query_element, cryptoMethod->getPublicParameters().getCiphertextBitsize()/8, dim, j); 
     }
   }
-
-  free(query_element);
 }
 
 
@@ -759,8 +776,11 @@ void PIRReplyGeneratorNFL_internal::pushQuery(char* rawQuery, unsigned int size,
   unsigned int nbModuli = cryptoMethod->getnbModuli();
   // Trick, we get both a and b at the same time, b needs to be set afterwards
   uint64_t *a,*b;
-  a = (poly64) calloc(size, 1);
-  memcpy(a,rawQuery,size);
+  
+  // We push the query we do not copy it 
+  //a = (poly64) calloc(size, 1);
+  //memcpy(a,rawQuery,size);
+  a = (poly64) rawQuery;
   if (lwe) b = a+nbModuli*polyDegree;
 #ifdef CRYPTO_DEBUG
 	std::cout<<"\nQuery received.a ";NFLTools::print_poly64(a,4);	
@@ -816,16 +836,13 @@ size_t PIRReplyGeneratorNFL_internal::getTotalSystemMemory()
 #endif
 }
 
-PIRReplyGeneratorNFL_internal::~PIRReplyGeneratorNFL_internal()
-{
-  freeResult();
-}
-
-
 void PIRReplyGeneratorNFL_internal::setPirParams(PIRParameters& param)
 {
+  freeQueries();
+  freeQueriesBuffer();
   pirParam = param;
   cryptoMethod->setandgetAbsBitPerCiphertext(pirParam.n[0]);
+  initQueriesBuffer();
 }
 
 
@@ -838,47 +855,79 @@ void PIRReplyGeneratorNFL_internal::setCryptoMethod(CryptographicSystem* cm)
 
 void PIRReplyGeneratorNFL_internal::freeInputData()
 {
+#ifdef DEBUG
+  std:cout << "PIRReplyGeneratorNFL_internal: freeing input_data" << std::endl;
+#endif
   uint64_t theoretical_files_nbr = 1;
 	for (unsigned int i = 0 ; i < pirParam.d ; i++) theoretical_files_nbr *= pirParam.n[i];
 
-  for (unsigned int i = 0 ; i < theoretical_files_nbr ; i++){
-#ifdef DEBUG
-  printf( "PIRReplyGeneratorNFL_internal: freeing input_data[%d]\n",i);
-#endif
-    free(input_data[i].p[0]);
-    free(input_data[i].p);
+  if (input_data != NULL){
+    for (unsigned int i = 0 ; i < theoretical_files_nbr ; i++){
+      if (input_data[i].p != NULL){
+        if (input_data[i].p[0] != NULL){
+          free(input_data[i].p[0]);
+          input_data[i].p[0] = NULL;
+        }
+        free(input_data[i].p);
+        input_data[i].p = NULL;
+      }
+    }
+    delete[] input_data;
+    input_data = NULL;
   }
-  delete[] input_data;
-
 #ifdef DEBUG
   printf( "PIRReplyGeneratorNFL_internal: input_data freed\n");
 #endif
 }
 
-void PIRReplyGeneratorNFL_internal::freeQuery()
+void PIRReplyGeneratorNFL_internal::freeQueries()
 {
   for (unsigned int i = 0; i < pirParam.d; i++)
   {
-#ifdef SHOUP
-      for (unsigned int j = 0 ; j < pirParam.n[i] ; j++) {
-		  free(queriesBuf[i][0][j].a); //only free a because a and b and contingus, see pushQuery
-		  free(queriesBuf[i][1][j].a); //only free a because a and b and contingus, see pushQuery
+    for (unsigned int j = 0 ; j < pirParam.n[i] ; j++) 
+    {
+		  if (queriesBuf != NULL && queriesBuf[i] != NULL && queriesBuf[i][0][j].a != NULL)
+      {
+        free(queriesBuf[i][0][j].a); //only free a because a and b and contingus, see pushQuery
+        queriesBuf[i][0][j].a = NULL;
+      }
+		  if (queriesBuf != NULL && queriesBuf[i] != NULL && queriesBuf[i][1][j].a != NULL)
+      {
+        free(queriesBuf[i][1][j].a); //only free a because a and b and contingus, see pushQuery
+        queriesBuf[i][1][j].a = NULL;
+      }
 	  }
-      delete[] queriesBuf[i][0]; //allocated in intQueriesBuf with new.
-      delete[] queriesBuf[i][1]; //allocated in intQueriesBuf with new.
-      delete[] queriesBuf[i];
-#else
-	  for (unsigned int j = 0 ; j < pirParam.n[i] ; j++) free(queriesBuf[i][j].a); //only free a because a and b and contingus, see pushQuery
-      delete[] queriesBuf[i]; //allocated in intQueriesBuf with new.
-#endif
   }
-  delete[] queriesBuf;//allocated in intQueriesBuf with new.
+  current_query_index = 0;
+  current_dim_index = 0;
 #ifdef DEBUG
   printf( "queriesBuf freed\n");
 #endif
   
 }
 
+void PIRReplyGeneratorNFL_internal::freeQueriesBuffer()
+{
+  if (queriesBuf != NULL){ 
+    for (unsigned int i = 0; i < pirParam.d; i++){
+      if (queriesBuf[i] != NULL){
+        if (queriesBuf[i][0] != NULL){
+          delete[] queriesBuf[i][0]; //allocated in intQueriesBuf with new.
+          queriesBuf[i][0] = NULL;
+        }
+        if (queriesBuf[i][1] != NULL){
+          delete[] queriesBuf[i][1]; //allocated in intQueriesBuf with new.
+          queriesBuf[i][1] = NULL;
+        }
+        delete[] queriesBuf[i]; //allocated in intQueriesBuf with new.
+        queriesBuf[i] = NULL;
+      }
+    }
+    delete[] queriesBuf; //allocated in intQueriesBuf with new.
+    queriesBuf = NULL;
+  }
+}
+
 void PIRReplyGeneratorNFL_internal::freeResult()
 {
   if(repliesArray!=NULL)
@@ -892,3 +941,15 @@ void PIRReplyGeneratorNFL_internal::freeResult()
     repliesArray=NULL;
   }
 }
+
+
+PIRReplyGeneratorNFL_internal::~PIRReplyGeneratorNFL_internal()
+{
+  freeQueries();
+  freeQueriesBuffer();
+  freeResult();
+  mutex.try_lock();
+  mutex.unlock();
+}
+
+

+ 3 - 2
pir/replyGenerator/PIRReplyGeneratorNFL_internal.hpp

@@ -50,9 +50,10 @@ private:
     void pushFakeQuery();
     void freeInputData();
     void freeFakeInputData();
-    void freeResult();
 protected:
-    void freeQuery();
+    void freeResult();
+    void freeQueries();
+    void freeQueriesBuffer();
     void generateReply();
     imported_database_t generateReplyGeneric(bool keep_imported_data = false);
     void generateReplyGenericFromData(const imported_database_t database);