PIRSession.cpp 20 KB


  1. /* Copyright (C) 2014 Carlos Aguilar Melchor, Joris Barrier, Marc-Olivier Killijian
  2. * This file is part of XPIR.
  3. *
  4. * XPIR is free software: you can redistribute it and/or modify
  5. * it under the terms of the GNU General Public License as published by
  6. * the Free Software Foundation, either version 3 of the License, or
  7. * (at your option) any later version.
  8. *
  9. * XPIR is distributed in the hope that it will be useful,
  10. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. * GNU General Public License for more details.
  13. *
  14. * You should have received a copy of the GNU General Public License
  15. * along with XPIR. If not, see <http://www.gnu.org/licenses/>.
  16. */
  17. #include "PIRSession.hpp"
  18. #include "pir/replyGenerator/PIROptimizer.hpp"
  19. #define NDSS_UPLOAD_SPEED 100000000UL
  20. PIRSession::pointer PIRSession::create(boost::asio::io_service& ios)
  21. {
  22. return PIRSession::pointer(new PIRSession(ios));
  23. }
  24. PIRSession::PIRSession(boost::asio::io_service& ios) :
  25. sessionSocket(ios),
  26. handmadeExceptionRaised(false),
  27. finished(false),
  28. cryptoMethod(NULL),
  29. generator(NULL),
  30. no_pipeline_mode(false)
  31. {
  32. }
  33. void PIRSession::setDBHandler(DBHandler *db)
  34. {
  35. dbhandler = db;
  36. }
  37. /**
  38. * Do the PIR protocol
  39. **/
  40. bool PIRSession::start(session_option_t session_option)
  41. {
  42. uint64_t nbFiles=dbhandler->getNbStream();
  43. maxFileBytesize = dbhandler->getmaxFileBytesize();
  44. short exchange_method = (session_option.driven_mode) ? CLIENT_DRIVEN : SERVER_DRIVEN;
  45. bool rcv_paramsandkey = true;
  46. // Deal with Optimizer dry-run queries if requested
  47. if(!rcvIsClient())
  48. {
  49. std::cout << "PIRSession: Incoming dry-run optimizer request, dealing with it ..." << std::endl;
  50. PIROptimizer optimizer(dbhandler);
  51. optimizer.prepareOptimData();
  52. optimizer.controlAndCommand(sessionSocket);
  53. std::cout << "PIRSession: Session finished" << std::endl << std::endl;
  54. finished = true;
  55. return 1; // Did deal with a dry-run query
  56. }
  57. sendCatalog();
  58. sendPIRParamsExchangeMethod(exchange_method);
  59. if(session_option.driven_mode) {
  60. PIROptimizer optimizer(dbhandler);
  61. optimizer.prepareOptimData();
  62. optimizer.controlAndCommand(sessionSocket);
  63. rcvCryptoParams(rcv_paramsandkey);
  64. rcvPirParams();
  65. }
  66. else //not driven
  67. {
  68. std::vector<std::string> fields;
  69. boost::algorithm::split(fields, pirParam.crypto_params, boost::algorithm::is_any_of(GlobalConstant::kDelim));
  70. cryptoMethod = HomomorphicCryptoFactory_internal::getCrypto(fields.at(0));
  71. cryptoMethod->setNewParameters(pirParam.crypto_params);
  72. generator = PIRReplyGeneratorFactory::getPIRReplyGenerator(fields.at(0), pirParam,dbhandler);
  73. generator->setCryptoMethod(cryptoMethod);
  74. sendCryptoParams();
  75. rcvCryptoParams(!rcv_paramsandkey);
  76. sendPirParams();
  77. }
  78. // If one of the functions above generates an error, handmadeExceptionRaised is set
  79. if (!handmadeExceptionRaised)
  80. {
  81. // This is just a download thread. Reply generation is unlocked (by a mutex)
  82. // when this thread finishes.
  83. startProcessQuery();
  84. // Import the database
  85. // Start reply generation when mutex unlocked
  86. // Start a thread which uploads the reply as it is generated
  87. startProcessResult(session_option);
  88. }
  89. // Wait for child threads
  90. if (upThread.joinable()) upThread.join();
  91. if (downThread.joinable()) downThread.join();
  92. std::cout << "PIRSession: Session finished" << std::endl << std::endl;
  93. // Note that we have finished so that the PIRServer can do garbage collection
  94. finished = true;
  95. if (generator != NULL) delete generator;
  96. return 0; // Did deal with a client query
  97. }
  98. /**
  99. * Getter for other classes (such as PIRServer)
  100. **/
  101. tcp::socket& PIRSession::getSessionSocket()
  102. {
  103. return sessionSocket;
  104. }
  105. /**
  106. * Send the catalog to the client as a string.
  107. * Format : file_list_size \n filename1 \n filesize1 \n filename2 \n ... filesizeN \n
  108. * if SEND_CATALOG is defined and catalog size is less than 1000, the full catalog is sent
  109. * otherwise only the size of the catalog is sent
  110. **/
  111. void PIRSession::sendCatalog()
  112. {
  113. string buf;
  114. #ifdef SEND_CATALOG
  115. if(dbhandler->getNbStream()>1000) {
  116. buf=dbhandler->getCatalog(false);
  117. } else {
  118. buf=dbhandler->getCatalog(true);
  119. }
  120. #else
  121. buf = dbhandler->getCatalog(false);
  122. #endif
  123. // Send the buffer
  124. const boost::uint64_t size = buf.size();
  125. try
  126. {
  127. if (write(sessionSocket, boost::asio::buffer(&size, sizeof(size))) <= 0)
  128. exitWithErrorMessage(__FUNCTION__,"Error sending catalog size");
  129. if (write(sessionSocket, boost::asio::buffer(buf.c_str(), size)) < size)
  130. exitWithErrorMessage(__FUNCTION__,"Error sending catalog");
  131. } catch (std::exception const& ex) {
  132. exitWithErrorMessage(__FUNCTION__,"Error sending catalog: " + string(ex.what()));
  133. }
  134. writeWarningMessage(__FUNCTION__ , "done.");
  135. }
  136. void PIRSession::sendCryptoParams()
  137. {
  138. std::cout << "PIRSession: Mandatory crypto params sent to the client are " << pirParam.crypto_params << std::endl;
  139. try
  140. {
  141. int crypto_params_size = pirParam.crypto_params.size();
  142. write(sessionSocket, boost::asio::buffer(&crypto_params_size, sizeof(crypto_params_size)));
  143. write(sessionSocket, boost::asio::buffer(pirParam.crypto_params));
  144. }
  145. catch(std::exception const& ex)
  146. {
  147. exitWithErrorMessage(__FUNCTION__, string(ex.what()));
  148. }
  149. }
  150. bool PIRSession::rcvIsClient()
  151. {
  152. try{
  153. int is_client;
  154. while ( read(sessionSocket, boost::asio::buffer(&(is_client), sizeof(int))) == 0)
  155. boost::this_thread::yield();
  156. //exitWithErrorMessage(__FUNCTION__, "No client or optim choice recieved");
  157. return is_client == 1;
  158. }catch (std::exception const& ex)
  159. {
  160. exitWithErrorMessage(__FUNCTION__, string(ex.what()));
  161. }
  162. }
  163. /**
  164. * Receive client's cryptographic parameters as:
  165. * - An int describing a string size
  166. * - A string of the given size containing the public parameters
  167. * - A second int describing a byte size
  168. * - A buffer of bytes of the given size with key material
  169. **/
  170. void PIRSession::rcvCryptoParams(bool paramsandkey)
  171. {
  172. unsigned int size = 0;
  173. try{
  174. std::vector<std::string> fields;
  175. if(paramsandkey == true)
  176. {
  177. // First get the int and allocate space for the string (plus the null caracter)
  178. read(sessionSocket, boost::asio::buffer(&size, sizeof(int)));
  179. char params_buf[size + 1];
  180. // Get the string in the buffer and add a null character
  181. read(sessionSocket, boost::asio::buffer(params_buf, size));
  182. params_buf[size] = '\0';
  183. cout << "PIRSession: Received crypto parameters " << params_buf << " processing them ..." << endl;
  184. #ifdef DEBUG
  185. cout << "PIRSession: Parameter string size is " << size << endl;
  186. #endif
  187. // Extract cryptosystem's name
  188. string crypto_system_desc(params_buf);
  189. boost::algorithm::split(fields, crypto_system_desc, boost::algorithm::is_any_of(":"));
  190. // Create cryptosystem using a factory and the extracted name
  191. cryptoMethod = HomomorphicCryptoFactory_internal::getCrypto(fields[0]);
  192. // Set cryptosystem with received parameters and key material
  193. cryptoMethod->setNewParameters(params_buf);
  194. // Create the PIR reply generator object using a factory to have
  195. // the correct object given the cryptosystem used (for optimization
  196. // reply generation is cryptosystem dependent)
  197. generator = PIRReplyGeneratorFactory::getPIRReplyGenerator(fields[0], pirParam, dbhandler);
  198. if (generator == NULL)
  199. {
  200. std::cout << "PIRSession: CRITICAL no reply generator found, exiting session" << std::endl;
  201. pthread_exit(this);
  202. }
  203. else
  204. {
  205. generator->setCryptoMethod(cryptoMethod);
  206. }
  207. }
  208. // Use again size to describe the key size
  209. if (read(sessionSocket, boost::asio::buffer(&size ,sizeof(size))) <= 0)
  210. exitWithErrorMessage(__FUNCTION__,"No key received, abort.");
  211. #ifdef DEBUG
  212. cout << "PIRSession: Size of received key material is " << size << endl;
  213. #endif
  214. // Get the key material only if there is some
  215. if(size > 0)
  216. {
  217. // This time we don't use a string so no need for an extra character
  218. char buf[size];
  219. if (read(sessionSocket, boost::asio::buffer(buf, size)) < size)
  220. exitWithErrorMessage(__FUNCTION__,"No parameters received, abort.");
  221. cryptoMethod->getPublicParameters().setModulus(buf);
  222. }
  223. }catch(std::exception const& ex)
  224. {
  225. exitWithErrorMessage(__FUNCTION__, string(ex.what()));
  226. }
  227. std::cout << "PIRSession: Finished processing crypto parameters" << std::endl;
  228. }
  229. void PIRSession::sendPIRParamsExchangeMethod(short exchange_method)
  230. {
  231. if (exchange_method == CLIENT_DRIVEN)
  232. {
  233. cout << "PIRSession: Notifying the client this is a client-driven session" << endl;
  234. }
  235. else
  236. {
  237. cout << "PIRSession: Notifying the client this is a server-driven session" << endl;
  238. }
  239. try
  240. {
  241. write(sessionSocket, boost::asio::buffer(&exchange_method, sizeof(exchange_method)));
  242. }catch(std::exception const& ex)
  243. {
  244. exitWithErrorMessage(__FUNCTION__, string(ex.what()));
  245. }
  246. }
  247. /**
  248. * Receive client's PIR parameters.
  249. **/
  250. void PIRSession::rcvPirParams()
  251. {
  252. try{
  253. // First we get an int with the recursion level
  254. if ( read(sessionSocket, boost::asio::buffer(&(pirParam.d), sizeof(int))) <= 0)
  255. exitWithErrorMessage(__FUNCTION__, "No pir param recieved");
  256. // Then we get an int with the aggregation
  257. if ( read(sessionSocket, boost::asio::buffer(&(pirParam.alpha), sizeof(int))) <= 0)
  258. exitWithErrorMessage(__FUNCTION__, "No pir param recieved");
  259. // Finally for each level we get an int withe the corresponding dimension size
  260. for (unsigned int i = 0 ; i < pirParam.d ; i++)
  261. {
  262. if ( read(sessionSocket, boost::asio::buffer( &(pirParam.n[i]), sizeof(int))) <= 0)
  263. exitWithErrorMessage(__FUNCTION__, "No pir param recieved");
  264. }
  265. // The last dimension + 1 is set to 1 (used in some functions to compute the number of
  266. // elements after a PIR recursion level)
  267. pirParam.n[pirParam.d] = 1;
  268. // Pass the PIR parameters to the reply generator
  269. generator->setPirParams(pirParam);
  270. cout << "PIRSession: PIR params recieved from the client are d=" << pirParam.d << ", alpha=" << pirParam.alpha << ", data_layout=";
  271. for (unsigned int i = 0; i < pirParam.d; i++)
  272. {
  273. if (pirParam.d == 1) cout << "1x";
  274. cout << pirParam.n[i];
  275. if (i != pirParam.d-1) cout << "x";
  276. }
  277. cout << endl;
  278. }catch (std::exception const& ex)
  279. {
  280. exitWithErrorMessage(__FUNCTION__, string(ex.what()));
  281. }
  282. }
  283. void PIRSession::sendPirParams()
  284. {
  285. cout << "PIRSession: Mandatory PIR params sent to the client are d=" << pirParam.d << ", alpha=" << pirParam.alpha << ", data_layout=";
  286. for (unsigned int i = 0; i < pirParam.d; i++)
  287. {
  288. if (pirParam.d == 1) cout << "1x";
  289. cout << pirParam.n[i];
  290. if (i != pirParam.d-1) cout << "x";
  291. }
  292. cout << endl;
  293. try{
  294. // First we send an int with the recursion level
  295. if ( write(sessionSocket, boost::asio::buffer(&(pirParam.d), sizeof(int))) <= 0)
  296. exitWithErrorMessage(__FUNCTION__, "No pir param sended");
  297. // Then we send an int with the aggregation
  298. if ( write(sessionSocket, boost::asio::buffer(&(pirParam.alpha), sizeof(int))) <= 0)
  299. exitWithErrorMessage(__FUNCTION__, "No pir param sended");
  300. // Finally for each level we send an int withe the corresponding dimension size
  301. for (unsigned int i = 0 ; i < pirParam.d ; i++)
  302. {
  303. if ( write(sessionSocket, boost::asio::buffer( &(pirParam.n[i]), sizeof(int))) <= 0)
  304. exitWithErrorMessage(__FUNCTION__, "No pir param sended");
  305. }
  306. // The last dimension + 1 is set to 1 (used in some functions to compute the number of
  307. // elements after a PIR recursion level)
  308. pirParam.n[pirParam.d] = 1;
  309. // Pass the PIR parameters to the reply generator
  310. generator->setPirParams(pirParam);
  311. }catch (std::exception const& ex)
  312. {
  313. exitWithErrorMessage(__FUNCTION__, string(ex.what()));
  314. }
  315. }
  316. /**
  317. * Start downloadworker in downThread.
  318. **/
  319. void PIRSession::startProcessQuery ()
  320. {
  321. if(no_pipeline_mode) {
  322. std::cout << "No pipeline in query processing." << std::endl;
  323. downloadWorker();
  324. } else {
  325. downThread = boost::thread(&PIRSession::downloadWorker, this);
  326. }
  327. }
  328. /**
  329. * Recieve queries n messages with n = nbr of files.
  330. **/
  331. void PIRSession::downloadWorker()
  332. {
  333. double start = omp_get_wtime();
  334. unsigned int msg_size = 0;
  335. // Allocate an array with d dimensions with pointers to arrays of n[i] lwe_query elements
  336. generator->initQueriesBuffer();
  337. #ifdef PERF_TIMERS
  338. double vtstart = omp_get_wtime();
  339. bool wasVerbose = false;
  340. unsigned int previous_elts = 0;
  341. unsigned int total_elts = 0;
  342. for (unsigned int k = 0 ; k < pirParam.d ; k++) total_elts += pirParam.n[k];
  343. #endif
  344. try{
  345. for (unsigned int j = 0 ; j < pirParam.d ; j++)
  346. {
  347. // Compute and allocate the size in bytes of a query ELEMENT of dimension j
  348. msg_size = cryptoMethod->getPublicParameters().getQuerySizeFromRecLvl(j+1) / 8;
  349. boost::asio::socket_base::receive_buffer_size opt(65535);
  350. sessionSocket.set_option(opt);
  351. // boost_buffer = new boost::asio::buffer(buf, msg_size);
  352. #ifdef DEBUG
  353. cout << "PIRSession: Size of the query element to be received is " << msg_size << endl;
  354. cout << "PIRSession: Number of query elements to be received is " << pirParam.n[j] << endl;
  355. #endif
  356. // Iterate over all the elements of the query corresponding to the j-th dimension
  357. for (unsigned int i = 0; i < pirParam.n[j]; i++)
  358. {
  359. char *buf = (char *) malloc(msg_size*sizeof(char));
  360. auto boost_buffer = boost::asio::buffer(buf,msg_size);
  361. if (i==0 && j == 0) cout << "PIRSession: Waiting for query elements ..." << endl;
  362. // Get a query element
  363. //( async_read(sessionSocket, boost_buffer,boost::bind(&blo,boost::asio::placeholders::error)) );
  364. if (read(sessionSocket, boost_buffer) < msg_size )
  365. writeWarningMessage(__FUNCTION__, "Query element not entirely recieved");
  366. // std::cout <<"PIRSession: " << total_elts << " query elements received in " << omp_get_wtime() - start << std::endl;
  367. // Allocate the memory for the element, copy it, and point to it with the query buffer
  368. if (i==0 && j == 0) cout << "PIRSession: Starting query element reception" << endl;
  369. #ifdef PERF_TIMERS
  370. // Give some feedback if it takes too long
  371. double vtstop = omp_get_wtime();
  372. if (vtstop - vtstart > 1)
  373. {
  374. vtstart = vtstop;
  375. previous_elts = 0;
  376. for (unsigned int k = 0 ; k < j ; k++) previous_elts += pirParam.n[k];
  377. std::cout <<"PIRSession: Query element " << i+1+previous_elts << "/" << total_elts << " received\r" << std::flush;
  378. wasVerbose = true;
  379. }
  380. #endif
  381. generator->pushQuery(buf, msg_size, j, i);
  382. }
  383. }
  384. }catch (std::exception const& ex)
  385. {
  386. exitWithErrorMessage(__FUNCTION__, string(ex.what()));
  387. return;
  388. }
  389. #ifdef PERF_TIMERS
  390. std::cout <<"PIRSession: Query element " << total_elts << "/" << total_elts << " received" << std::endl;
  391. std::cout <<"PIRSession: " << total_elts << " query elements received in " << omp_get_wtime() - start << std::endl;
  392. #endif
  393. // All the query elements received, unlock reply generation
  394. generator->mutex.unlock();
  395. // Output we are done
  396. writeWarningMessage(__FUNCTION__, "done.");
  397. }
  398. /**
  399. * Start uploadWorker in a thread.
  400. **/
  401. void PIRSession::startProcessResult(session_option_t session_option)
  402. {
  403. if (no_pipeline_mode) {
  404. std::cout << "No pipeline in query Generation." << std::endl;
  405. }
  406. else {
  407. upThread = boost::thread(&PIRSession::uploadWorker, this);
  408. }
  409. // Import and generate reply once unlocked by the query downloader thread
  410. // If we got a preimported database generate reply directly from it
  411. if (session_option.got_preimported_database == true)
  412. {
  413. std::cout << "PIRSession: Already got an imported database available, using it" << std::endl;
  414. generator->generateReplyGenericFromData(session_option.data);
  415. }
  416. else if (session_option.keep_database == true) {
  417. savedDatabase = generator->generateReplyGeneric(true);
  418. }
  419. else if (session_option.keep_database == false)
  420. {
  421. generator->generateReplyGeneric(false);
  422. }
  423. if (no_pipeline_mode) {
  424. uploadWorker();
  425. }
  426. }
  427. void sleepForBytes(unsigned int bytes) {
  428. #ifdef NDSS_UPLOAD_SPEED
  429. uint64_t seconds=(bytes*8)/NDSS_UPLOAD_SPEED;
  430. uint64_t nanoseconds=((((double)bytes*8.)/(double)NDSS_UPLOAD_SPEED)-(double)seconds)*1000000000UL;
  431. struct timespec req={0},rem={0};
  432. req.tv_sec=seconds;
  433. req.tv_nsec=nanoseconds;
  434. nanosleep(&req,&rem);
  435. #endif
  436. }
  437. /**
  438. * Send PIR's result, asynchronously.
  439. **/
  440. void PIRSession::uploadWorker()
  441. {
  442. // Ciphertext byte size
  443. unsigned int byteSize = cryptoMethod->getPublicParameters().getCiphBitsizeFromRecLvl(pirParam.d)/GlobalConstant::kBitsPerByte;
  444. uint64_t totalbytesent=0;
  445. // Number of ciphertexts in the reply
  446. unsigned long reply_nbr = generator->computeReplySizeInChunks(maxFileBytesize), i = 0;
  447. #ifdef DEBUG
  448. cout << "PIRSession: Number of ciphertexts to send is " << reply_nbr << endl;
  449. cout << "PIRSession: maxFileBytesize is " << maxFileBytesize << endl;
  450. cout << "PIRSession: Ciphertext bytesize is " << byteSize << endl;
  451. #endif
  452. try
  453. {
  454. // Pointer for the ciphertexts to be sent
  455. char *ptr;
  456. // For each ciphertext in the reply
  457. for (unsigned i = 0 ; i < reply_nbr ; i++)
  458. {
  459. while (generator->repliesArray == NULL || generator->repliesArray[i] == NULL)
  460. {
  461. boost::this_thread::sleep(boost::posix_time::milliseconds(10));
  462. }
  463. ptr = generator->repliesArray[i];
  464. // Send it
  465. //int byteSent=sessionSocket.send(boost::asio::buffer(ptr, byteSize));
  466. if (write(sessionSocket,boost::asio::buffer(ptr, byteSize)) <= 0)
  467. exitWithErrorMessage(__FUNCTION__,"Error sending request" );
  468. totalbytesent+=byteSize;
  469. #ifdef NDSS_UPLOAD_SPEED
  470. sleepForBytes(byteSize);
  471. #endif
  472. // Free its memory
  473. free(ptr);
  474. generator->repliesArray[i]=NULL;
  475. }
  476. // When everythind is send close the socket
  477. sessionSocket.close();
  478. }
  479. // If there was a problem sending the reply
  480. catch (std::exception const& ex)
  481. {
  482. #ifdef DEBUG
  483. std::cerr << "Number of chunks sent: " << i << "/" << reply_nbr << std::endl;
  484. #endif
  485. exitWithErrorMessage(__FUNCTION__, string(ex.what()));
  486. return;
  487. }
  488. // Tell when we have finished
  489. writeWarningMessage(__FUNCTION__ , "done.");;
  490. }
  491. // Functions for displaying logs
  492. void PIRSession::writeErrorMessage(string funcName, string message)
  493. {
  494. cerr << BOLD << funcName << " : " << RESET_COLOR << RED << message << RESET_COLOR <<endl;
  495. }
  496. void PIRSession::writeWarningMessage(string funcName, string message)
  497. {
  498. cerr << BOLD << funcName << " : " << RESET_COLOR << ORANGE << message << RESET_COLOR <<endl;
  499. }
  500. // For critical erros
  501. void PIRSession::exitWithErrorMessage(string funcName, string message)
  502. {
  503. writeErrorMessage(funcName, message);
  504. // This is used in the main function (start()) to skip costly operations if an error occurred
  505. handmadeExceptionRaised = true;
  506. }
  507. // Used by the PIRServer for garbage collection
  508. bool PIRSession::isFinished()
  509. {
  510. return finished;
  511. }
  512. // Destructor
  513. PIRSession::~PIRSession()
  514. {
  515. if (cryptoMethod != NULL) delete cryptoMethod;
  516. }
  517. PIRParameters PIRSession::getPIRParams()
  518. {
  519. pirParam.crypto_params = cryptoMethod->getPublicParameters().getSerializedParams(false);
  520. return pirParam;
  521. }
  522. void PIRSession::setPIRParams(PIRParameters pir_parameters)
  523. {
  524. pirParam = pir_parameters;
  525. }
  526. imported_database_t PIRSession::getSavedDatabase() {
  527. return savedDatabase;
  528. }
  529. void PIRSession::no_pipeline(bool b)
  530. {
  531. no_pipeline_mode = b;
  532. }