/* Copyright (C) 2014 Carlos Aguilar Melchor, Joris Barrier, Marc-Olivier Killijian
* This file is part of XPIR.
*
* XPIR is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* XPIR is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with XPIR. If not, see .
*/
#include "PIRSession.hpp"
#include "pir/replyGenerator/PIROptimizer.hpp"
#define NDSS_UPLOAD_SPEED 100000000UL
PIRSession::pointer PIRSession::create(boost::asio::io_service& ios)
{
return PIRSession::pointer(new PIRSession(ios));
}
PIRSession::PIRSession(boost::asio::io_service& ios) :
sessionSocket(ios),
handmadeExceptionRaised(false),
finished(false),
cryptoMethod(NULL),
generator(NULL),
no_pipeline_mode(false)
{
}
void PIRSession::setDBHandler(DBHandler *db)
{
dbhandler = db;
}
/**
* Do the PIR protocol
**/
bool PIRSession::start(session_option_t session_option)
{
uint64_t nbFiles=dbhandler->getNbStream();
maxFileBytesize = dbhandler->getmaxFileBytesize();
short exchange_method = (session_option.driven_mode) ? CLIENT_DRIVEN : SERVER_DRIVEN;
bool rcv_paramsandkey = true;
// Deal with Optimizer dry-run queries if requested
if(!rcvIsClient())
{
std::cout << "PIRSession: Incoming dry-run optimizer request, dealing with it ..." << std::endl;
PIROptimizer optimizer(dbhandler);
optimizer.prepareOptimData();
optimizer.controlAndCommand(sessionSocket);
std::cout << "PIRSession: Session finished" << std::endl << std::endl;
finished = true;
return 1; // Did deal with a dry-run query
}
sendCatalog();
sendPIRParamsExchangeMethod(exchange_method);
if(session_option.driven_mode) {
PIROptimizer optimizer(dbhandler);
optimizer.prepareOptimData();
optimizer.controlAndCommand(sessionSocket);
rcvCryptoParams(rcv_paramsandkey);
rcvPirParams();
}
else //not driven
{
std::vector fields;
boost::algorithm::split(fields, pirParam.crypto_params, boost::algorithm::is_any_of(GlobalConstant::kDelim));
cryptoMethod = HomomorphicCryptoFactory_internal::getCrypto(fields.at(0));
cryptoMethod->setNewParameters(pirParam.crypto_params);
generator = PIRReplyGeneratorFactory::getPIRReplyGenerator(fields.at(0), pirParam,dbhandler);
generator->setCryptoMethod(cryptoMethod);
sendCryptoParams();
rcvCryptoParams(!rcv_paramsandkey);
sendPirParams();
}
// If one of the functions above generates an error, handmadeExceptionRaised is set
if (!handmadeExceptionRaised)
{
// This is just a download thread. Reply generation is unlocked (by a mutex)
// when this thread finishes.
startProcessQuery();
// Import the database
// Start reply generation when mutex unlocked
// Start a thread which uploads the reply as it is generated
startProcessResult(session_option);
}
// Wait for child threads
if (upThread.joinable()) upThread.join();
if (downThread.joinable()) downThread.join();
std::cout << "PIRSession: Session finished" << std::endl << std::endl;
// Note that we have finished so that the PIRServer can do garbage collection
finished = true;
if (generator != NULL) delete generator;
return 0; // Did deal with a client query
}
/**
* Getter for other classes (such as PIRServer)
**/
tcp::socket& PIRSession::getSessionSocket()
{
return sessionSocket;
}
/**
* Send the catalog to the client as a string.
* Format : file_list_size \n filename1 \n filesize1 \n filename2 \n ... filesizeN \n
* if SEND_CATALOG is defined and catalog size is less than 1000, the full catalog is sent
* otherwise only the size of the catalog is sent
**/
void PIRSession::sendCatalog()
{
string buf;
#ifdef SEND_CATALOG
if(dbhandler->getNbStream()>1000) {
buf=dbhandler->getCatalog(false);
} else {
buf=dbhandler->getCatalog(true);
}
#else
buf = dbhandler->getCatalog(false);
#endif
// Send the buffer
const boost::uint64_t size = buf.size();
try
{
if (write(sessionSocket, boost::asio::buffer(&size, sizeof(size))) <= 0)
exitWithErrorMessage(__FUNCTION__,"Error sending catalog size");
if (write(sessionSocket, boost::asio::buffer(buf.c_str(), size)) < size)
exitWithErrorMessage(__FUNCTION__,"Error sending catalog");
} catch (std::exception const& ex) {
exitWithErrorMessage(__FUNCTION__,"Error sending catalog: " + string(ex.what()));
}
writeWarningMessage(__FUNCTION__ , "done.");
}
void PIRSession::sendCryptoParams()
{
std::cout << "PIRSession: Mandatory crypto params sent to the client are " << pirParam.crypto_params << std::endl;
try
{
int crypto_params_size = pirParam.crypto_params.size();
write(sessionSocket, boost::asio::buffer(&crypto_params_size, sizeof(crypto_params_size)));
write(sessionSocket, boost::asio::buffer(pirParam.crypto_params));
}
catch(std::exception const& ex)
{
exitWithErrorMessage(__FUNCTION__, string(ex.what()));
}
}
bool PIRSession::rcvIsClient()
{
try{
int is_client;
while ( read(sessionSocket, boost::asio::buffer(&(is_client), sizeof(int))) == 0)
boost::this_thread::yield();
//exitWithErrorMessage(__FUNCTION__, "No client or optim choice recieved");
return is_client == 1;
}catch (std::exception const& ex)
{
exitWithErrorMessage(__FUNCTION__, string(ex.what()));
}
}
/**
* Receive client's cryptographic parameters as:
* - An int describing a string size
* - A string of the given size containing the public parameters
* - A second int describing a byte size
* - A buffer of bytes of the given size with key material
**/
void PIRSession::rcvCryptoParams(bool paramsandkey)
{
unsigned int size = 0;
try{
std::vector fields;
if(paramsandkey == true)
{
// First get the int and allocate space for the string (plus the null caracter)
read(sessionSocket, boost::asio::buffer(&size, sizeof(int)));
char params_buf[size + 1];
// Get the string in the buffer and add a null character
read(sessionSocket, boost::asio::buffer(params_buf, size));
params_buf[size] = '\0';
cout << "PIRSession: Received crypto parameters " << params_buf << " processing them ..." << endl;
#ifdef DEBUG
cout << "PIRSession: Parameter string size is " << size << endl;
#endif
// Extract cryptosystem's name
string crypto_system_desc(params_buf);
boost::algorithm::split(fields, crypto_system_desc, boost::algorithm::is_any_of(":"));
// Create cryptosystem using a factory and the extracted name
cryptoMethod = HomomorphicCryptoFactory_internal::getCrypto(fields[0]);
// Set cryptosystem with received parameters and key material
cryptoMethod->setNewParameters(params_buf);
// Create the PIR reply generator object using a factory to have
// the correct object given the cryptosystem used (for optimization
// reply generation is cryptosystem dependent)
generator = PIRReplyGeneratorFactory::getPIRReplyGenerator(fields[0], pirParam, dbhandler);
if (generator == NULL)
{
std::cout << "PIRSession: CRITICAL no reply generator found, exiting session" << std::endl;
pthread_exit(this);
}
else
{
generator->setCryptoMethod(cryptoMethod);
}
}
// Use again size to describe the key size
if (read(sessionSocket, boost::asio::buffer(&size ,sizeof(size))) <= 0)
exitWithErrorMessage(__FUNCTION__,"No key received, abort.");
#ifdef DEBUG
cout << "PIRSession: Size of received key material is " << size << endl;
#endif
// Get the key material only if there is some
if(size > 0)
{
// This time we don't use a string so no need for an extra character
char buf[size];
if (read(sessionSocket, boost::asio::buffer(buf, size)) < size)
exitWithErrorMessage(__FUNCTION__,"No parameters received, abort.");
cryptoMethod->getPublicParameters().setModulus(buf);
}
}catch(std::exception const& ex)
{
exitWithErrorMessage(__FUNCTION__, string(ex.what()));
}
std::cout << "PIRSession: Finished processing crypto parameters" << std::endl;
}
void PIRSession::sendPIRParamsExchangeMethod(short exchange_method)
{
if (exchange_method == CLIENT_DRIVEN)
{
cout << "PIRSession: Notifying the client this is a client-driven session" << endl;
}
else
{
cout << "PIRSession: Notifying the client this is a server-driven session" << endl;
}
try
{
write(sessionSocket, boost::asio::buffer(&exchange_method, sizeof(exchange_method)));
}catch(std::exception const& ex)
{
exitWithErrorMessage(__FUNCTION__, string(ex.what()));
}
}
/**
* Receive client's PIR parameters.
**/
void PIRSession::rcvPirParams()
{
try{
// First we get an int with the recursion level
if ( read(sessionSocket, boost::asio::buffer(&(pirParam.d), sizeof(int))) <= 0)
exitWithErrorMessage(__FUNCTION__, "No pir param recieved");
// Then we get an int with the aggregation
if ( read(sessionSocket, boost::asio::buffer(&(pirParam.alpha), sizeof(int))) <= 0)
exitWithErrorMessage(__FUNCTION__, "No pir param recieved");
// Finally for each level we get an int withe the corresponding dimension size
for (unsigned int i = 0 ; i < pirParam.d ; i++)
{
if ( read(sessionSocket, boost::asio::buffer( &(pirParam.n[i]), sizeof(int))) <= 0)
exitWithErrorMessage(__FUNCTION__, "No pir param recieved");
}
// The last dimension + 1 is set to 1 (used in some functions to compute the number of
// elements after a PIR recursion level)
pirParam.n[pirParam.d] = 1;
// Pass the PIR parameters to the reply generator
generator->setPirParams(pirParam);
cout << "PIRSession: PIR params recieved from the client are d=" << pirParam.d << ", alpha=" << pirParam.alpha << ", data_layout=";
for (unsigned int i = 0; i < pirParam.d; i++)
{
if (pirParam.d == 1) cout << "1x";
cout << pirParam.n[i];
if (i != pirParam.d-1) cout << "x";
}
cout << endl;
}catch (std::exception const& ex)
{
exitWithErrorMessage(__FUNCTION__, string(ex.what()));
}
}
void PIRSession::sendPirParams()
{
cout << "PIRSession: Mandatory PIR params sent to the client are d=" << pirParam.d << ", alpha=" << pirParam.alpha << ", data_layout=";
for (unsigned int i = 0; i < pirParam.d; i++)
{
if (pirParam.d == 1) cout << "1x";
cout << pirParam.n[i];
if (i != pirParam.d-1) cout << "x";
}
cout << endl;
try{
// First we send an int with the recursion level
if ( write(sessionSocket, boost::asio::buffer(&(pirParam.d), sizeof(int))) <= 0)
exitWithErrorMessage(__FUNCTION__, "No pir param sended");
// Then we send an int with the aggregation
if ( write(sessionSocket, boost::asio::buffer(&(pirParam.alpha), sizeof(int))) <= 0)
exitWithErrorMessage(__FUNCTION__, "No pir param sended");
// Finally for each level we send an int withe the corresponding dimension size
for (unsigned int i = 0 ; i < pirParam.d ; i++)
{
if ( write(sessionSocket, boost::asio::buffer( &(pirParam.n[i]), sizeof(int))) <= 0)
exitWithErrorMessage(__FUNCTION__, "No pir param sended");
}
// The last dimension + 1 is set to 1 (used in some functions to compute the number of
// elements after a PIR recursion level)
pirParam.n[pirParam.d] = 1;
// Pass the PIR parameters to the reply generator
generator->setPirParams(pirParam);
}catch (std::exception const& ex)
{
exitWithErrorMessage(__FUNCTION__, string(ex.what()));
}
}
/**
* Start downloadworker in downThread.
**/
void PIRSession::startProcessQuery ()
{
if(no_pipeline_mode) {
std::cout << "No pipeline in query processing." << std::endl;
downloadWorker();
} else {
downThread = boost::thread(&PIRSession::downloadWorker, this);
}
}
/**
* Recieve queries n messages with n = nbr of files.
**/
void PIRSession::downloadWorker()
{
double start = omp_get_wtime();
unsigned int msg_size = 0;
// Allocate an array with d dimensions with pointers to arrays of n[i] lwe_query elements
generator->initQueriesBuffer();
#ifdef PERF_TIMERS
double vtstart = omp_get_wtime();
bool wasVerbose = false;
unsigned int previous_elts = 0;
unsigned int total_elts = 0;
for (unsigned int k = 0 ; k < pirParam.d ; k++) total_elts += pirParam.n[k];
#endif
try{
for (unsigned int j = 0 ; j < pirParam.d ; j++)
{
// Compute and allocate the size in bytes of a query ELEMENT of dimension j
msg_size = cryptoMethod->getPublicParameters().getQuerySizeFromRecLvl(j+1) / 8;
boost::asio::socket_base::receive_buffer_size opt(65535);
sessionSocket.set_option(opt);
// boost_buffer = new boost::asio::buffer(buf, msg_size);
#ifdef DEBUG
cout << "PIRSession: Size of the query element to be received is " << msg_size << endl;
cout << "PIRSession: Number of query elements to be received is " << pirParam.n[j] << endl;
#endif
// Iterate over all the elements of the query corresponding to the j-th dimension
for (unsigned int i = 0; i < pirParam.n[j]; i++)
{
char *buf = (char *) malloc(msg_size*sizeof(char));
auto boost_buffer = boost::asio::buffer(buf,msg_size);
if (i==0 && j == 0) cout << "PIRSession: Waiting for query elements ..." << endl;
// Get a query element
//( async_read(sessionSocket, boost_buffer,boost::bind(&blo,boost::asio::placeholders::error)) );
if (read(sessionSocket, boost_buffer) < msg_size )
writeWarningMessage(__FUNCTION__, "Query element not entirely recieved");
// std::cout <<"PIRSession: " << total_elts << " query elements received in " << omp_get_wtime() - start << std::endl;
// Allocate the memory for the element, copy it, and point to it with the query buffer
if (i==0 && j == 0) cout << "PIRSession: Starting query element reception" << endl;
#ifdef PERF_TIMERS
// Give some feedback if it takes too long
double vtstop = omp_get_wtime();
if (vtstop - vtstart > 1)
{
vtstart = vtstop;
previous_elts = 0;
for (unsigned int k = 0 ; k < j ; k++) previous_elts += pirParam.n[k];
std::cout <<"PIRSession: Query element " << i+1+previous_elts << "/" << total_elts << " received\r" << std::flush;
}
#endif
generator->pushQuery(buf, msg_size, j, i);
}
}
}catch (std::exception const& ex)
{
exitWithErrorMessage(__FUNCTION__, string(ex.what()));
return;
}
#ifdef PERF_TIMERS
std::cout <<"PIRSession: Query element " << total_elts << "/" << total_elts << " received" << std::endl;
std::cout <<"PIRSession: " << total_elts << " query elements received in " << omp_get_wtime() - start << std::endl;
#endif
// All the query elements received, unlock reply generation
generator->mutex.unlock();
// Output we are done
writeWarningMessage(__FUNCTION__, "done.");
}
/**
* Start uploadWorker in a thread.
**/
void PIRSession::startProcessResult(session_option_t session_option)
{
if (no_pipeline_mode) {
std::cout << "No pipeline in query Generation." << std::endl;
}
else {
upThread = boost::thread(&PIRSession::uploadWorker, this);
}
// Import and generate reply once unlocked by the query downloader thread
// If we got a preimported database generate reply directly from it
if (session_option.got_preimported_database == true)
{
std::cout << "PIRSession: Already got an imported database available, using it" << std::endl;
generator->generateReplyGenericFromData(session_option.data);
}
else if (session_option.keep_database == true) {
savedDatabase = generator->generateReplyGeneric(true);
}
else if (session_option.keep_database == false)
{
generator->generateReplyGeneric(false);
}
if (no_pipeline_mode) {
uploadWorker();
}
}
void sleepForBytes(unsigned int bytes) {
#ifdef NDSS_UPLOAD_SPEED
uint64_t seconds=(bytes*8)/NDSS_UPLOAD_SPEED;
uint64_t nanoseconds=((((double)bytes*8.)/(double)NDSS_UPLOAD_SPEED)-(double)seconds)*1000000000UL;
struct timespec req={0},rem={0};
req.tv_sec=seconds;
req.tv_nsec=nanoseconds;
nanosleep(&req,&rem);
#endif
}
/**
* Send PIR's result, asynchronously.
**/
void PIRSession::uploadWorker()
{
// Ciphertext byte size
unsigned int byteSize = cryptoMethod->getPublicParameters().getCiphBitsizeFromRecLvl(pirParam.d)/GlobalConstant::kBitsPerByte;
uint64_t totalbytesent=0;
// Number of ciphertexts in the reply
unsigned long reply_nbr = generator->computeReplySizeInChunks(maxFileBytesize);
#ifdef DEBUG
cout << "PIRSession: Number of ciphertexts to send is " << reply_nbr << endl;
cout << "PIRSession: maxFileBytesize is " << maxFileBytesize << endl;
cout << "PIRSession: Ciphertext bytesize is " << byteSize << endl;
#endif
try
{
// Pointer for the ciphertexts to be sent
char *ptr;
// For each ciphertext in the reply
for (unsigned i = 0 ; i < reply_nbr ; i++)
{
while (generator->repliesArray == NULL || generator->repliesArray[i] == NULL)
{
boost::this_thread::sleep(boost::posix_time::milliseconds(10));
}
ptr = generator->repliesArray[i];
// Send it
//int byteSent=sessionSocket.send(boost::asio::buffer(ptr, byteSize));
if (write(sessionSocket,boost::asio::buffer(ptr, byteSize)) <= 0)
exitWithErrorMessage(__FUNCTION__,"Error sending request" );
totalbytesent+=byteSize;
#ifdef NDSS_UPLOAD_SPEED
sleepForBytes(byteSize);
#endif
// Free its memory
free(ptr);
generator->repliesArray[i]=NULL;
}
// When everythind is send close the socket
sessionSocket.close();
}
// If there was a problem sending the reply
catch (std::exception const& ex)
{
#ifdef DEBUG
std::cerr << "Number of chunks sent: " << i << "/" << reply_nbr << std::endl;
#endif
exitWithErrorMessage(__FUNCTION__, string(ex.what()));
return;
}
// Tell when we have finished
writeWarningMessage(__FUNCTION__ , "done.");;
}
// Functions for displaying logs
void PIRSession::writeErrorMessage(string funcName, string message)
{
cerr << BOLD << funcName << " : " << RESET_COLOR << RED << message << RESET_COLOR <getPublicParameters().getSerializedParams(false);
return pirParam;
}
void PIRSession::setPIRParams(PIRParameters pir_parameters)
{
pirParam = pir_parameters;
}
imported_database_t PIRSession::getSavedDatabase() {
return savedDatabase;
}
void PIRSession::no_pipeline(bool b)
{
no_pipeline_mode = b;
}