PIRReplyGeneratorGMP.cpp 14 KB


  1. /* Copyright (C) 2014 Carlos Aguilar Melchor, Joris Barrier, Marc-Olivier Killijian
  2. * This file is part of XPIR.
  3. *
  4. * XPIR is free software: you can redistribute it and/or modify
  5. * it under the terms of the GNU General Public License as published by
  6. * the Free Software Foundation, either version 3 of the License, or
  7. * (at your option) any later version.
  8. *
  9. * XPIR is distributed in the hope that it will be useful,
  10. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. * GNU General Public License for more details.
  13. *
  14. * You should have received a copy of the GNU General Public License
  15. * along with XPIR. If not, see <http://www.gnu.org/licenses/>.
  16. */
  17. #include "PIRReplyGeneratorGMP.hpp"
  18. PIRReplyGeneratorGMP::PIRReplyGeneratorGMP()
  19. {}
  20. PIRReplyGeneratorGMP::PIRReplyGeneratorGMP( PIRParameters& param, DBHandler *db) :
  21. GenericPIRReplyGenerator(param,db),
  22. finished(false)
  23. {}
  24. /**
  25. * Read raw data from files and make padding if necessary
  26. **/
  27. void PIRReplyGeneratorGMP::importData()
  28. {
  29. unsigned int size = static_cast<double>(cryptoMethod->getPublicParameters().getAbsorptionBitsize()/GlobalConstant::kBitsPerByte), theoretic_nbr_elements = 1 ;
  30. char rawBits[size];
  31. uint64_t nbFiles = dbhandler->getNbStream();
  32. maxChunkSize = ceil((double)(dbhandler->getmaxFileBytesize())/(double)size);
  33. // Ugly ugly trick until this class is improved
  34. (*(PaillierPublicParameters *) &cryptoMethod->getPublicParameters()).getPubKey()->complete_key(pirParam.d+1);
  35. for (unsigned int i = 0 ; i < pirParam.d ; i++)
  36. theoretic_nbr_elements *= pirParam.n[i];
  37. datae = new mpz_t*[theoretic_nbr_elements];
  38. // For each real file
  39. for (unsigned int i = 0 ; i < nbFiles ; i++)
  40. {
  41. if (i % pirParam.alpha == 0) datae[i/pirParam.alpha] = new mpz_t[maxChunkSize*pirParam.alpha];
  42. ifstream* stream=dbhandler->openStream(i, 0);
  43. // For each chunk of size "size" of the file
  44. for (unsigned int j = 0 ; j < maxChunkSize ; j++ )
  45. {
  46. dbhandler->readStream(stream,rawBits, size);
  47. mpz_init(datae[i/pirParam.alpha][j + (i % pirParam.alpha) * maxChunkSize]);
  48. mpz_import(datae[i/pirParam.alpha][j + (i % pirParam.alpha) * maxChunkSize], size, 1, sizeof(char), 0, 0, rawBits);
  49. }
  50. // Pad if needed the last group of pirParam.alpha files
  51. if (i == nbFiles - 1)
  52. {
  53. for (unsigned int j = 0 ; j < maxChunkSize * (pirParam.alpha - 1 - (i % pirParam.alpha)); j++)
  54. {
  55. mpz_init_set_ui(datae[i/pirParam.alpha][j + ((i+1) % pirParam.alpha) * maxChunkSize], 0);
  56. }
  57. }
  58. dbhandler->closeStream(stream);
  59. }
  60. std::cout << "PIRReplyGeneratorGMP: " << pirParam.alpha*theoretic_nbr_elements - nbFiles << " non-aggregated padding files need to be added ..." << std::endl;
  61. /* make file padding if necessary **/
  62. for (uint64_t i = ceil((double)nbFiles/pirParam.alpha) ; i < theoretic_nbr_elements ; i++)
  63. {
  64. datae[i] = new mpz_t[maxChunkSize*pirParam.alpha];
  65. for (unsigned int k = 0 ; k < maxChunkSize*pirParam.alpha ; k++)
  66. {
  67. mpz_init_set_ui(datae[i][k], 0);
  68. }
  69. }
  70. // From now on all data sizes take into account aggregation
  71. maxChunkSize *= pirParam.alpha;
  72. }
  73. void PIRReplyGeneratorGMP::importFakeData(uint64_t plaintext_nbr)
  74. {
  75. unsigned int file_nbr = 1;
  76. unsigned int abs_size = cryptoMethod->getPublicParameters().getAbsorptionBitsize()/GlobalConstant::kBitsPerByte;
  77. unsigned int ciph_size = cryptoMethod->getPublicParameters().getCiphertextBitsize()/GlobalConstant::kBitsPerByte;
  78. char *raw_data, *raw_data2;
  79. for (unsigned int i = 0 ; i < pirParam.d ; i++)
  80. file_nbr *= pirParam.n[i];
  81. datae = new mpz_t*[file_nbr];
  82. raw_data = (char*) malloc(abs_size + 1);
  83. raw_data2 = (char*) malloc(ciph_size + 1);
  84. memset(raw_data, 0xaa, abs_size);
  85. memset(raw_data2, 0xaa, ciph_size);
  86. maxChunkSize = plaintext_nbr;
  87. for (unsigned int i = 0 ; i < file_nbr; i++)
  88. {
  89. datae[i] = new mpz_t[maxChunkSize];
  90. for (unsigned int j = 0 ; j < maxChunkSize ; j++)
  91. {
  92. mpz_init(datae[i][j]);
  93. mpz_import(datae[i][j], abs_size, 1, sizeof(char), 0, 0, raw_data);
  94. }
  95. }
  96. for (unsigned int i = 0 ; i < pirParam.n[0] ; i++)
  97. {
  98. mpz_import(queriesBuf[0][i], ciph_size, 1, sizeof(char), 0, 0, raw_data2);
  99. }
  100. free(raw_data);
  101. free(raw_data2);
  102. }
  103. void PIRReplyGeneratorGMP::clearFakeData(uint64_t plaintext_nbr)
  104. {
  105. unsigned int file_nbr = 1;
  106. unsigned int abs_size = cryptoMethod->getPublicParameters().getAbsorptionBitsize()/GlobalConstant::kBitsPerByte;
  107. char* raw_data;
  108. for (unsigned int i = 0 ; i < pirParam.d ; i++)
  109. file_nbr *= pirParam.n[i];
  110. maxChunkSize = plaintext_nbr;
  111. for (unsigned int i = 0 ; i < file_nbr; i++)
  112. {
  113. for (unsigned int j = 0 ; j < maxChunkSize ; j++)
  114. {
  115. mpz_clear(datae[i][j]);
  116. }
  117. }
  118. delete[] datae;
  119. }
  120. double PIRReplyGeneratorGMP::generateReplySimulation(const PIRParameters& pir_params, uint64_t plaintext_nbr)
  121. {
  122. setPirParams((PIRParameters&) pir_params);
  123. initQueriesBuffer();
  124. importFakeData(plaintext_nbr);
  125. double start = omp_get_wtime();
  126. generateReply();
  127. double stop = omp_get_wtime() - start;
  128. cleanQueryBuffer();
  129. clearFakeData(plaintext_nbr);
  130. freeResult();
  131. return stop;
  132. }
  133. imported_database_t PIRReplyGeneratorGMP::generateReplyGeneric(bool keep_imported_data = false)
  134. {
  135. imported_database_t database_wrapper;
  136. importData();
  137. boost::mutex::scoped_lock l(mutex);
  138. //mutex.lock();
  139. generateReply();
  140. if(keep_imported_data)
  141. {
  142. database_wrapper.imported_database_ptr = (void*)datae;
  143. database_wrapper.polysPerElement = maxChunkSize;
  144. }
  145. else
  146. {
  147. unsigned long theoretic_nbr_elements = 1;
  148. for (unsigned int i = 0 ; i < pirParam.d ; i++)
  149. theoretic_nbr_elements *= pirParam.n[i];
  150. for (unsigned int i = 0 ; i < theoretic_nbr_elements ; i++ )
  151. {
  152. for (unsigned int j = 0; j < maxChunkSize; j++)
  153. {
  154. mpz_clear(datae[i][j]);
  155. }
  156. delete[] datae[i];
  157. }
  158. delete[] datae;
  159. for (unsigned int i = 0 ; i < pirParam.d ; i++)
  160. {
  161. for (unsigned int j = 0 ; j < pirParam.n[i] ; j++)
  162. {
  163. mpz_clear(queriesBuf[i][j]);
  164. }
  165. delete[] queriesBuf[i];
  166. }
  167. delete[] queriesBuf;
  168. }
  169. return database_wrapper;
  170. }
  171. void PIRReplyGeneratorGMP::generateReplyGenericFromData(const imported_database_t database)
  172. {
  173. datae = (mpz_t**) database.imported_database_ptr;
  174. maxChunkSize = database.polysPerElement;
  175. boost::mutex::scoped_lock l(mutex);
  176. //mutex.lock();
  177. generateReply();
  178. }
  179. /**
  180. * Compute Lipmaa PIR Scheme with imported data.
  181. **/
  182. void PIRReplyGeneratorGMP::generateReply()
  183. {
  184. uint64_t pir_nbr;
  185. double start;
  186. mpz_t **data_z = datae;
  187. mpz_t **reply_vec;
  188. mpz_t *queries;
  189. omp_set_nested(0);
  190. start = omp_get_wtime();
  191. #ifdef PERF_TIMERS
  192. double vtstart = start;
  193. bool wasVerbose = false;
  194. #endif
  195. for (unsigned int i = 0 ; i < pirParam.d ; i++) // For each dimension
  196. {
  197. pir_nbr = 1;
  198. for (unsigned int j = i+1 ; j < pirParam.d ; j++)
  199. pir_nbr *= pirParam.n[j];
  200. if (pir_nbr!=1) std::cout << "PIRReplyGeneratorGMP: Generating " << pir_nbr << " replies in recursion level " << i+1 << std::endl;
  201. reply_vec = new mpz_t*[pir_nbr];
  202. queries = queriesBuf[i];
  203. if(pir_nbr == 1) omp_set_nested(1);
  204. //#pragma omp parallel for
  205. for (unsigned int j = 0 ; j < pir_nbr ; j++) // Do pir_nbr PIR iterations
  206. {
  207. reply_vec[j] = new mpz_t[maxChunkSize];
  208. //Do a PIR iteration between a sub-list of elements and a query, i + 1 being the dimension
  209. generateReply(queries , data_z, pirParam.n[i] * j , i, reply_vec[j]);
  210. #ifdef PERF_TIMERS
  211. // Give some feedback if it takes too long
  212. double vtstop = omp_get_wtime();
  213. if (vtstop - vtstart > 1)
  214. {
  215. vtstart = vtstop;
  216. if (pir_nbr!=1) std::cout <<"PIRReplyGeneratorGMP: Reply " << j+1 << "/" << pir_nbr << " generated\r" << std::flush;
  217. wasVerbose = true;
  218. }
  219. #endif
  220. }
  221. // Always print feedback for last reply
  222. if (pir_nbr!=1) std::cout <<"PIRReplyGeneratorGMP: Reply " << pir_nbr << "/" << pir_nbr << " generated" << std::endl;
  223. // Delete intermediate data obtained on the recursions
  224. if(i!=0)
  225. {
  226. for (unsigned int j = 0 ; j < pirParam.n[i] ; j++ ) delete[] data_z[j];
  227. delete[] data_z;
  228. }
  229. data_z = reply_vec;
  230. }
  231. printf( "PIRReplyGeneratorGMP: Global reply generation took %f seconds\n", omp_get_wtime() - start);
  232. repliesArray = (char**)calloc(maxChunkSize,sizeof(char*));
  233. repliesAmount = maxChunkSize;
  234. pushReply(data_z[0], 0, maxChunkSize);
  235. for (unsigned int i = 0 ; i < maxChunkSize ; i++)
  236. mpz_clear(data_z[0][i]);
  237. delete[] data_z[0];
  238. delete[] data_z;
  239. }
  240. /**
  241. * Compute single parallelizable PIR with almost fully homomorphic crypto system
  242. * Params :
  243. * - mpz_t* queries : queries array to compute the pir ;
  244. * - mpz_t** data : two dimensionnal array for raw data to treat ;
  245. * - int begin_data : begin data index
  246. * - int s : current dimension
  247. * - mpz_t* result : result array, no need to init
  248. **/
  249. void
  250. PIRReplyGeneratorGMP::generateReply(mpz_t *queries,
  251. mpz_t** data, int begin_data,
  252. int dimension,
  253. mpz_t* result)
  254. {
  255. unsigned int data_size = pirParam.n[dimension];
  256. mpz_t replyTmp;
  257. int init_s = (*((PaillierPublicParameters*) &(cryptoMethod->getPublicParameters()))).getPubKey()->getinit_s();
  258. #ifdef PERF_TIMERS
  259. bool wasVerbose = false;
  260. double vtstart = omp_get_wtime();
  261. #endif
  262. //#pragma omp parallel for private(replyTmp) firstprivate(s, begin_data)
  263. for (unsigned int chunk = 0 ; chunk < maxChunkSize ; chunk++)
  264. {
  265. mpz_inits(result[chunk], replyTmp, NULL);
  266. computeMul(queries[0], data[begin_data][chunk], result[chunk], dimension+init_s+1);
  267. if(dimension != 0) mpz_clear(data[begin_data][chunk]);
  268. //if ((chunk*chunkCost) << 21 == 0 || chunk == maxChunkSize - 1) std::cout <<"PIRReplyGeneratorGMP: Dealing with chunk " << chunk+1 << "/" << maxChunkSize << std::endl;
  269. for (unsigned int file = 1, k = begin_data + 1 ; file < data_size ; file++, k++)
  270. {
  271. computeMul(queries[file], data[k][chunk], replyTmp, dimension+init_s+1);
  272. if(dimension != 0) mpz_clear(data[k][chunk]);
  273. //We add the filechunks of index chunk for the files of index file
  274. // eg : file[1]->chunk[1] + file[2]->chunk[1] + ... + file[file]->chunk[chunk]
  275. computeSum(result[chunk], replyTmp, dimension+init_s+1);
  276. }
  277. mpz_clear(replyTmp);
  278. #ifdef CRYPTO_DEBUG
  279. gmp_printf("PIRReplyGeneratorGMP: Reply chunk generated %Zd\n\n",result[chunk]);
  280. #endif
  281. #ifdef PERF_TIMERS
  282. // Give some feedback if it takes too long
  283. double vtstop = omp_get_wtime();
  284. if (vtstop - vtstart > 1)
  285. {
  286. vtstart = vtstop;
  287. if(maxChunkSize != 1) std::cout <<"PIRReplyGeneratorGMP: Dealt with chunk " << chunk+1 << "/" << maxChunkSize << "\r" << std::flush;
  288. wasVerbose = true;
  289. }
  290. #endif
  291. }
  292. #ifdef PERF_TIMERS
  293. if (wasVerbose) std::cout <<" \r" << std::flush;
  294. #endif
  295. }
  296. /**
  297. * Performs homomorphic multiplication between query and n then put the result in res
  298. * Params :
  299. * - mpz_t query : a pir query
  300. * - mpz_t n : raw data
  301. * - mpz_t res : result of operation
  302. * - int s : dimension
  303. **/
  304. void PIRReplyGeneratorGMP::computeMul(mpz_t query, mpz_t n, mpz_t res, int modulus_index)
  305. {
  306. cryptoMethod->e_mul_const(res, query, n, modulus_index );
  307. #ifdef CRYPTO_DEBUG
  308. gmp_printf("PIRReplyGeneratorGMP: Raising %Zd\n To the power %Zd\nResult is %Zd\nModulus index is %d\n\n",query, n, res,modulus_index);
  309. #endif
  310. }
  311. /**
  312. * Performs homomorphic addition betwee two encrypted data and put the result in a
  313. * Params :
  314. * - mpz_t a : first data to multiply, store result also
  315. * - mpz_t b : second data to multoply
  316. * - int s : dimension
  317. **/
  318. void PIRReplyGeneratorGMP::computeSum(mpz_t a, mpz_t b, int modulus_index)
  319. {
  320. cryptoMethod->e_add(a, a, b, modulus_index);
  321. }
  322. /**
  323. * Inits queries buffer, used by PIRServer before reveiving client queries.
  324. **/
  325. void PIRReplyGeneratorGMP::initQueriesBuffer()
  326. {
  327. queriesBuf = new mpz_t*[pirParam.d];
  328. for (unsigned int i = 0 ; i < pirParam.d ; i++)
  329. {
  330. queriesBuf[i] = new mpz_t[pirParam.n[i]];
  331. for (unsigned int j = 0 ; j < pirParam.n[i] ; j++)
  332. mpz_init(queriesBuf[i][j]);
  333. }
  334. }
  335. void PIRReplyGeneratorGMP::setCryptoMethod(CryptographicSystem* crypto_method)
  336. {
  337. cryptoMethod = (PaillierAdapter*) crypto_method;
  338. }
  339. void PIRReplyGeneratorGMP::pushQuery(char* rawQuery, unsigned int size, int dim, int nbr)
  340. {
  341. mpz_import(queriesBuf[dim][nbr], size, 1, sizeof(char), 0, 0, rawQuery);
  342. #ifdef CRYPTO_DEBUG
  343. gmp_printf("Imported query element: %Zd\n", queriesBuf[dim][nbr]);
  344. #endif
  345. }
  346. void PIRReplyGeneratorGMP::pushReply(mpz_t* replies, unsigned init_index, unsigned replies_nbr)
  347. {
  348. size_t n;
  349. unsigned int size = cryptoMethod->getPublicParameters().getCiphBitsizeFromRecLvl(pirParam.d)/GlobalConstant::kBitsPerByte;
  350. char*ct, *tmp;
  351. for (int i = init_index ; i < init_index + replies_nbr; i++)
  352. {
  353. tmp = (char*)mpz_export(NULL, &n, 1, sizeof(char) , 0, 0, replies[i]);
  354. if (n < size)
  355. {
  356. ct = new char[size]();
  357. memcpy(ct+sizeof(char)*(size - n), tmp, n);
  358. repliesArray[i] = ct;
  359. free(tmp);
  360. }
  361. else
  362. {
  363. repliesArray[i] = tmp;
  364. }
  365. }
  366. }
  367. bool PIRReplyGeneratorGMP::isFinished()
  368. {
  369. return finished;
  370. }
  371. unsigned long int PIRReplyGeneratorGMP::computeReplySizeInChunks(unsigned long int maxFileBytesize)
  372. {
  373. float res = ceil(static_cast<float>(maxFileBytesize*pirParam.alpha) / static_cast<float>(cryptoMethod->getPublicParameters().getAbsorptionBitsize()/8.0));
  374. return static_cast<unsigned long int>(res);
  375. }
  376. void PIRReplyGeneratorGMP::setPirParams(PIRParameters& _pirParam)
  377. {
  378. pirParam = _pirParam;
  379. }
  380. void PIRReplyGeneratorGMP::cleanQueryBuffer()
  381. {
  382. for (unsigned int i = 0 ; i < pirParam.d; i++)
  383. {
  384. for (unsigned int j = 0 ; j < pirParam.n[i]; j++)
  385. {
  386. mpz_clear(queriesBuf[i][j]);
  387. }
  388. delete[] queriesBuf[i];
  389. }
  390. delete[] queriesBuf;
  391. }
  392. PIRReplyGeneratorGMP::~PIRReplyGeneratorGMP()
  393. {
  394. }
  395. void PIRReplyGeneratorGMP::freeResult()
  396. {
  397. for(unsigned i=0 ; i < repliesAmount; i++)
  398. {
  399. if(repliesArray[i]!=NULL) delete[] repliesArray[i];
  400. repliesArray[i] = NULL;
  401. }
  402. free(repliesArray);
  403. repliesArray=NULL;
  404. }