#include "pir.hpp" #include "pir_client.hpp" #include "pir_server.hpp" #include #include #include #include #include #include using namespace std::chrono; using namespace std; using namespace seal; //TODO: Accept the following as CMD parameters // 1) number_of_items // 2) size_per_item // 3) d? // 4) num_requests uint8_t no_of_expected_parameters = 4; uint64_t number_of_items; uint64_t size_per_item; uint32_t d; uint64_t num_requests; int getCMDArgs(int argc, char * argv[]) { if(argc < no_of_expected_parameters){ printf("Command line parameters error, expected :\n"); printf("./sealpir \n"); exit(0); } std::string str = argv[1]; number_of_items = std::stoi(str); str = argv[2]; size_per_item = std::stoi(str); str = argv[3]; d = std::stoi(str); str = argv[4]; num_requests = std::stoi(str); return 1; } int main(int argc, char *argv[]) { getCMDArgs(argc, argv); //uint64_t number_of_items = 1 << 11; //uint64_t number_of_items = 262144; //uint64_t number_of_items = 1 << 12; //uint64_t size_per_item = 288; // in bytes // uint64_t size_per_item = 1 << 10; // 1 KB. // uint64_t size_per_item = 10 << 10; // 10 KB. uint32_t N = 2048; // Recommended values: (logt, d) = (12, 2) or (8, 1). uint32_t logt = 12; //uint32_t logt = 19; //uint32_t d = 2; EncryptionParameters params(scheme_type::BFV); PirParams pir_params; // Generates all parameters cout << "Generating all parameters" << endl; gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params); cout << "Initializing the database (this may take some time) ..." << endl; // Create test database auto db(make_unique(number_of_items * size_per_item)); // For testing purposes only auto check_db(make_unique(number_of_items * size_per_item)); random_device rd; for (uint64_t i = 0; i < number_of_items; i++) { for (uint64_t j = 0; j < size_per_item; j++) { auto val = rd() % 256; db.get()[(i * size_per_item) + j] = val; check_db.get()[(i * size_per_item) + j] = val; } } // Initialize PIR Server cout << "Initializing server and client" << endl; PIRServer server(params, pir_params); // Initialize PIR client.... PIRClient client(params, pir_params); GaloisKeys galois_keys = client.generate_galois_keys(); // Set galois key cout << "Main: Setting Galois keys..."; server.set_galois_key(0, galois_keys); // The following can be used to update parameters rather than creating new instances // (here it doesn't do anything). // cout << "Updating database size to: " << number_of_items << " elements" << endl; // update_params(number_of_items, size_per_item, d, params, expanded_params, pir_params); cout << "done" << endl; // Measure database setup auto time_pre_s = high_resolution_clock::now(); server.set_database(move(db), number_of_items, size_per_item); server.preprocess_database(); cout << "database pre processed " << endl; auto time_pre_e = high_resolution_clock::now(); auto time_pre_us = duration_cast(time_pre_e - time_pre_s).count(); //TODO: Loop query generation, processing and reply extraction num_requests times for (uint64_t l = 0; l< num_requests; l++) { // Choose an index of an element in the DB uint64_t ele_index = rd() % number_of_items; // element in DB at random position //uint64_t ele_index = 35; cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl; uint64_t index = client.get_fv_index(ele_index, size_per_item); // index of FV plaintext uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext // Measure query generation cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; auto time_query_s = high_resolution_clock::now(); PirQuery query = client.generate_query(index); // Note: PirQuery is std::vector, so size of a PirQuery is: // sum of sizes of each of these Ciphertexts uint64_t query_size = 0; for (vector>::iterator iter = query.begin(); iter != query.end(); iter++) { uint64_t vector_size = (*iter).size(); //cout<<"d = "<< d_ctr <<", vector_size = " << (*iter).size() << endl; //cout<<"Ciphertext N = "<< ((*iter).front()).poly_modulus_degree(); //cout<<", k = "<< ((*iter).front()).coeff_mod_count(); //cout<< "query size of each element in "<::iterator iter=reply.begin(); reply_vector_size = reply.size(); reply_size=reply_vector_size * ((*iter).uint64_count()*8); cout<<"reply_vector_size = "<