Browse Source

formatting tests

Sebastian Angel 2 years ago
parent
commit
ed34eeacd6
6 changed files with 596 additions and 571 deletions
  1. 29 28
      test/coefficient_conversion_test.cpp
  2. 65 66
      test/decomposition_test.cpp
  3. 91 87
      test/expand_test.cpp
  4. 141 135
      test/query_test.cpp
  5. 149 141
      test/replace_test.cpp
  6. 121 114
      test/simple_query_test.cpp

+ 29 - 28
test/coefficient_conversion_test.cpp

@@ -1,13 +1,13 @@
 #include "pir.hpp"
 #include "pir_client.hpp"
 #include "pir_server.hpp"
-#include <seal/seal.h>
+#include <bitset>
 #include <chrono>
+#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <random>
-#include <cstdint>
-#include <cstddef>
-#include <bitset>
+#include <seal/seal.h>
 
 using namespace std::chrono;
 using namespace std;
@@ -15,32 +15,33 @@ using namespace seal;
 
 int main(int argc, char *argv[]) {
 
-    const uint32_t logt = 16; 
-    const uint32_t ele_size = 3;
-    const uint32_t num_ele = 3;
-    uint8_t bytes[ele_size * num_ele] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
-
-    vector<uint64_t> coeffs;
-    
-    cout << "Coeffs: " << endl;
-    for(int i = 0; i < num_ele; i++){
-        vector<uint64_t> ele_coeffs = bytes_to_coeffs(logt, bytes + (i * ele_size), ele_size);
-        for(int j = 0; j < ele_coeffs.size(); j++){
-            cout << ele_coeffs[j] << endl;
-            cout << std::bitset<logt>(ele_coeffs[j]) << endl;
-            coeffs.push_back(ele_coeffs[j]);
-        }
-    }
-     
-    cout << "Num of Coeffs: " << coeffs.size() << endl;
+  const uint32_t logt = 16;
+  const uint32_t ele_size = 3;
+  const uint32_t num_ele = 3;
+  uint8_t bytes[ele_size * num_ele] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
 
-    uint8_t output[ele_size * num_ele];
-    coeffs_to_bytes(logt, coeffs, output, ele_size * num_ele, ele_size);
+  vector<uint64_t> coeffs;
 
-    cout << "Bytes: " << endl;
-    for(int i = 0 ; i < ele_size * num_ele; i++){
-        cout << (int)output[i] << endl;
+  cout << "Coeffs: " << endl;
+  for (int i = 0; i < num_ele; i++) {
+    vector<uint64_t> ele_coeffs =
+        bytes_to_coeffs(logt, bytes + (i * ele_size), ele_size);
+    for (int j = 0; j < ele_coeffs.size(); j++) {
+      cout << ele_coeffs[j] << endl;
+      cout << std::bitset<logt>(ele_coeffs[j]) << endl;
+      coeffs.push_back(ele_coeffs[j]);
     }
+  }
+
+  cout << "Num of Coeffs: " << coeffs.size() << endl;
+
+  uint8_t output[ele_size * num_ele];
+  coeffs_to_bytes(logt, coeffs, output, ele_size * num_ele, ele_size);
+
+  cout << "Bytes: " << endl;
+  for (int i = 0; i < ele_size * num_ele; i++) {
+    cout << (int)output[i] << endl;
+  }
 
-    return 0;
+  return 0;
 }

+ 65 - 66
test/decomposition_test.cpp

@@ -1,80 +1,79 @@
 #include "pir.hpp"
 #include "pir_client.hpp"
 #include "pir_server.hpp"
-#include <seal/seal.h>
 #include <chrono>
+#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <random>
-#include <cstdint>
-#include <cstddef>
+#include <seal/seal.h>
 
 using namespace std::chrono;
 using namespace std;
 using namespace seal;
 
-
 int main(int argc, char *argv[]) {
 
-    uint64_t number_of_items = 2048;
-    uint64_t size_per_item = 288; // in bytes
-    uint32_t N = 8192;
-
-    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
-    uint32_t logt = 20; 
-
-    EncryptionParameters enc_params(scheme_type::bfv);
-
-    // Generates all parameters
-    
-    cout << "Main: Generating SEAL parameters" << endl;
-    gen_encryption_params(N, logt, enc_params);
-    
-    cout << "Main: Verifying SEAL parameters" << endl;
-    verify_encryption_params(enc_params);
-    cout << "Main: SEAL parameters are good" << endl;
-    
-    SEALContext context(enc_params, true);
-    KeyGenerator keygen(context);
-    
-    SecretKey secret_key = keygen.secret_key();
-    Encryptor encryptor(context, secret_key);
-    Decryptor decryptor(context, secret_key);
-    Evaluator evaluator(context);
-    BatchEncoder encoder(context);
-    logt = floor(log2(enc_params.plain_modulus().value()));
-
-    uint32_t plain_modulus = enc_params.plain_modulus().value();
-
-    size_t slot_count = encoder.slot_count();
-
-    vector<uint64_t> coefficients(slot_count, 0ULL);
-    for(uint32_t i = 0; i < coefficients.size(); i++){
-        coefficients[i] = rand() % plain_modulus;
-    }
-    Plaintext pt; 
-    encoder.encode(coefficients, pt);
-    Ciphertext ct;
-    encryptor.encrypt_symmetric(pt, ct);
-    std::cout << "Encrypting" << std::endl;
-    auto context_data = context.last_context_data();
-    auto parms_id = context.last_parms_id();
-
-    evaluator.mod_switch_to_inplace(ct, parms_id);
-
-    EncryptionParameters params = context_data->parms();
-    std::cout << "Encoding" << std::endl;
-    vector<Plaintext> encoded = decompose_to_plaintexts(params, ct);
-    std::cout << "Expansion Factor: " << encoded.size() << std::endl;
-    std::cout << "Decoding" << std::endl;
-    Ciphertext decoded(context, parms_id);
-    compose_to_ciphertext(params, encoded, decoded);
-    std::cout << "Checking" <<std::endl;
-    Plaintext pt2;
-    decryptor.decrypt(decoded, pt2);
-
-    assert(pt == pt2);
-
-    std::cout << "Worked" << std::endl;
-
-    return 0;
+  uint64_t number_of_items = 2048;
+  uint64_t size_per_item = 288; // in bytes
+  uint32_t N = 8192;
+
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = 20;
+
+  EncryptionParameters enc_params(scheme_type::bfv);
+
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  SEALContext context(enc_params, true);
+  KeyGenerator keygen(context);
+
+  SecretKey secret_key = keygen.secret_key();
+  Encryptor encryptor(context, secret_key);
+  Decryptor decryptor(context, secret_key);
+  Evaluator evaluator(context);
+  BatchEncoder encoder(context);
+  logt = floor(log2(enc_params.plain_modulus().value()));
+
+  uint32_t plain_modulus = enc_params.plain_modulus().value();
+
+  size_t slot_count = encoder.slot_count();
+
+  vector<uint64_t> coefficients(slot_count, 0ULL);
+  for (uint32_t i = 0; i < coefficients.size(); i++) {
+    coefficients[i] = rand() % plain_modulus;
+  }
+  Plaintext pt;
+  encoder.encode(coefficients, pt);
+  Ciphertext ct;
+  encryptor.encrypt_symmetric(pt, ct);
+  std::cout << "Encrypting" << std::endl;
+  auto context_data = context.last_context_data();
+  auto parms_id = context.last_parms_id();
+
+  evaluator.mod_switch_to_inplace(ct, parms_id);
+
+  EncryptionParameters params = context_data->parms();
+  std::cout << "Encoding" << std::endl;
+  vector<Plaintext> encoded = decompose_to_plaintexts(params, ct);
+  std::cout << "Expansion Factor: " << encoded.size() << std::endl;
+  std::cout << "Decoding" << std::endl;
+  Ciphertext decoded(context, parms_id);
+  compose_to_ciphertext(params, encoded, decoded);
+  std::cout << "Checking" << std::endl;
+  Plaintext pt2;
+  decryptor.decrypt(decoded, pt2);
+
+  assert(pt == pt2);
+
+  std::cout << "Worked" << std::endl;
+
+  return 0;
 }

+ 91 - 87
test/expand_test.cpp

@@ -1,103 +1,107 @@
 #include "pir.hpp"
 #include "pir_client.hpp"
 #include "pir_server.hpp"
-#include <seal/seal.h>
 #include <chrono>
+#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <random>
-#include <cstdint>
-#include <cstddef>
+#include <seal/seal.h>
 
 using namespace std::chrono;
 using namespace std;
 using namespace seal;
 
-//For this test, we need the parameters to be such that the number of compressed ciphertexts needed is 1.
+// For this test, we need the parameters to be such that the number of
+// compressed ciphertexts needed is 1.
 int main(int argc, char *argv[]) {
 
-    uint64_t number_of_items = 2048;
-    uint64_t size_per_item = 288; // in bytes
-    uint32_t N = 4096;
-
-    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
-    uint32_t logt = 20; 
-    uint32_t d = 1;
-
-    EncryptionParameters enc_params(scheme_type::bfv);
-    PirParams pir_params;
-
-    // Generates all parameters
-    
-    cout << "Main: Generating SEAL parameters" << endl;
-    gen_encryption_params(N, logt, enc_params);
-    
-    cout << "Main: Verifying SEAL parameters" << endl;
-    verify_encryption_params(enc_params);
-    cout << "Main: SEAL parameters are good" << endl;
-
-    cout << "Main: Generating PIR parameters" << endl;
-    gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
-    
-    
-    
-    //gen_params(number_of_items, size_per_item, N, logt, d, enc_params, pir_params);
-    print_pir_params(pir_params);
-
-    // Initialize PIR Server
-    cout << "Main: Initializing server and client" << endl;
-    PIRServer server(enc_params, pir_params);
-
-    // Initialize PIR client....
-    PIRClient client(enc_params, pir_params);
-    GaloisKeys galois_keys = client.generate_galois_keys();
-
-    // Set galois key for client with id 0
-    cout << "Main: Setting Galois keys...";
-    server.set_galois_key(0, galois_keys);
-
-    random_device rd;
-    // Choose an index of an element in the DB
-    uint64_t ele_index = rd() % number_of_items; // element in DB at random position
-    uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
-    uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
-    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
-    cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
-
-    // Measure query generation
-    auto time_query_s = high_resolution_clock::now();
-    PirQuery query = client.generate_query(index);
-    auto time_query_e = high_resolution_clock::now();
-    auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
-    cout << "Main: query generated" << endl;
-
-    // Measure query processing (including expansion)
-    auto time_server_s = high_resolution_clock::now();
-    uint64_t n_i = pir_params.nvec[0];
-    vector<Ciphertext> expanded_query = server.expand_query(query[0][0], n_i, 0);
-    auto time_server_e = high_resolution_clock::now();
-    auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
-    cout << "Main: query expanded" << endl;
-
-    assert(expanded_query.size() == n_i);
-
-    cout << "Main: checking expansion" << endl;
-    for(size_t i = 0; i < expanded_query.size(); i++){
-        Plaintext decryption = client.decrypt(expanded_query.at(i));
-        
-        if(decryption.is_zero() && index != i){
-            continue;
-        }
-        else if(decryption.is_zero()){
-            cout << "Found zero where index should be" << endl;
-            return -1;
-        }
-        else if (std::stoi(decryption.to_string()) != 1) {
-            cout << "Query vector at index " << index << " should be 1 but is instead " << decryption.to_string() << endl;
-            return -1;
-        } else {
-            cout << "Query vector at index " << index << " is " << decryption.to_string() << endl;
-        }
+  uint64_t number_of_items = 2048;
+  uint64_t size_per_item = 288; // in bytes
+  uint32_t N = 4096;
+
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = 20;
+  uint32_t d = 1;
+
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
+
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
+
+  // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
+  // pir_params);
+  print_pir_params(pir_params);
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server and client" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  GaloisKeys galois_keys = client.generate_galois_keys();
+
+  // Set galois key for client with id 0
+  cout << "Main: Setting Galois keys...";
+  server.set_galois_key(0, galois_keys);
+
+  random_device rd;
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Measure query generation
+  auto time_query_s = high_resolution_clock::now();
+  PirQuery query = client.generate_query(index);
+  auto time_query_e = high_resolution_clock::now();
+  auto time_query_us =
+      duration_cast<microseconds>(time_query_e - time_query_s).count();
+  cout << "Main: query generated" << endl;
+
+  // Measure query processing (including expansion)
+  auto time_server_s = high_resolution_clock::now();
+  uint64_t n_i = pir_params.nvec[0];
+  vector<Ciphertext> expanded_query = server.expand_query(query[0][0], n_i, 0);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+  cout << "Main: query expanded" << endl;
+
+  assert(expanded_query.size() == n_i);
+
+  cout << "Main: checking expansion" << endl;
+  for (size_t i = 0; i < expanded_query.size(); i++) {
+    Plaintext decryption = client.decrypt(expanded_query.at(i));
+
+    if (decryption.is_zero() && index != i) {
+      continue;
+    } else if (decryption.is_zero()) {
+      cout << "Found zero where index should be" << endl;
+      return -1;
+    } else if (std::stoi(decryption.to_string()) != 1) {
+      cout << "Query vector at index " << index
+           << " should be 1 but is instead " << decryption.to_string() << endl;
+      return -1;
+    } else {
+      cout << "Query vector at index " << index << " is "
+           << decryption.to_string() << endl;
     }
+  }
 
-    return 0;
+  return 0;
 }

+ 141 - 135
test/query_test.cpp

@@ -1,161 +1,167 @@
 #include "pir.hpp"
 #include "pir_client.hpp"
 #include "pir_server.hpp"
-#include <seal/seal.h>
 #include <chrono>
+#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <random>
-#include <cstdint>
-#include <cstddef>
+#include <seal/seal.h>
 
 using namespace std::chrono;
 using namespace std;
 using namespace seal;
 
-int query_test(uint64_t num_items, uint64_t item_size, uint32_t degree, uint32_t lt, uint32_t dim);
+int query_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+               uint32_t lt, uint32_t dim);
 
 int main(int argc, char *argv[]) {
-    // Quick check
-    assert(query_test(1 << 10, 288, 4096, 20, 1) == 0);
+  // Quick check
+  assert(query_test(1 << 10, 288, 4096, 20, 1) == 0);
 
-    assert(query_test(1 << 10, 288, 4096, 20, 2) == 0);
+  assert(query_test(1 << 10, 288, 4096, 20, 2) == 0);
 
-    assert(query_test(1 << 10, 288, 4096, 20, 3) == 0);
+  assert(query_test(1 << 10, 288, 4096, 20, 3) == 0);
 
-    assert(query_test(1 << 10, 288, 8192, 20, 2) == 0);
+  assert(query_test(1 << 10, 288, 8192, 20, 2) == 0);
 
+  // Forces ciphertext expansion to be the same as the degree
+  assert(query_test(1 << 20, 288, 4096, 20, 1) == 0);
 
-    // Forces ciphertext expansion to be the same as the degree
-    assert(query_test(1 << 20, 288, 4096, 20, 1) == 0);
+  assert(query_test(1 << 20, 288, 4096, 20, 2) == 0);
+}
 
-    
-    assert(query_test(1 << 20, 288, 4096, 20, 2) == 0);
+int query_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+               uint32_t lt, uint32_t dim) {
+  uint64_t number_of_items = num_items;
+  uint64_t size_per_item = item_size; // in bytes
+  uint32_t N = degree;
 
-    
-}
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = lt;
+  uint32_t d = dim;
 
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
 
-int query_test(uint64_t num_items, uint64_t item_size, uint32_t degree, uint32_t lt, uint32_t dim){
-    uint64_t number_of_items = num_items;
-    uint64_t size_per_item = item_size; // in bytes
-    uint32_t N = degree;
-
-    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
-    uint32_t logt = lt; 
-    uint32_t d = dim;
-
-    EncryptionParameters enc_params(scheme_type::bfv);
-    PirParams pir_params;
-
-    // Generates all parameters
-    
-    cout << "Main: Generating SEAL parameters" << endl;
-    gen_encryption_params(N, logt, enc_params);
-    
-    cout << "Main: Verifying SEAL parameters" << endl;
-    verify_encryption_params(enc_params);
-    cout << "Main: SEAL parameters are good" << endl;
-
-    cout << "Main: Generating PIR parameters" << endl;
-    gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
-    
-    
-    
-    //gen_params(number_of_items, size_per_item, N, logt, d, enc_params, pir_params);
-    print_pir_params(pir_params);
-
-    cout << "Main: Initializing the database (this may take some time) ..." << endl;
-
-    // Create test database
-    auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    // Copy of the database. We use this at the end to make sure we retrieved
-    // the correct element.
-    auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    random_device rd;
-    for (uint64_t i = 0; i < number_of_items; i++) {
-        for (uint64_t j = 0; j < size_per_item; j++) {
-            uint8_t val = rd() % 256;
-            db.get()[(i * size_per_item) + j] = val;
-            db_copy.get()[(i * size_per_item) + j] = val;
-        }
-    }
+  // Generates all parameters
 
-    // Initialize PIR Server
-    cout << "Main: Initializing server and client" << endl;
-    PIRServer server(enc_params, pir_params);
-
-    // Initialize PIR client....
-    PIRClient client(enc_params, pir_params);
-    GaloisKeys galois_keys = client.generate_galois_keys();
-
-    // Set galois key for client with id 0
-    cout << "Main: Setting Galois keys...";
-    server.set_galois_key(0, galois_keys);
-
-    // Measure database setup
-    auto time_pre_s = high_resolution_clock::now();
-    server.set_database(move(db), number_of_items, size_per_item);
-    server.preprocess_database();
-    cout << "Main: database pre processed " << endl;
-    auto time_pre_e = high_resolution_clock::now();
-    auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
-
-
-    // Choose an index of an element in the DB
-    uint64_t ele_index = rd() % number_of_items; // element in DB at random position
-    uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
-    uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
-    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
-    cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
-
-    // Measure query generation
-    auto time_query_s = high_resolution_clock::now();
-    PirQuery query = client.generate_query(index);
-    auto time_query_e = high_resolution_clock::now();
-    auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
-    cout << "Main: query generated" << endl;
-
-    //To marshall query to send over the network, you can use serialize/deserialize:
-    //std::string query_ser = serialize_query(query);
-    //PirQuery query2 = deserialize_query(d, 1, query_ser, CIPHER_SIZE);
-
-    // Measure query processing (including expansion)
-    auto time_server_s = high_resolution_clock::now();
-    PirReply reply = server.generate_reply(query, 0);
-    auto time_server_e = high_resolution_clock::now();
-    auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
-
-    // Measure response extraction
-    auto time_decode_s = chrono::high_resolution_clock::now();
-    vector<uint8_t> elems = client.decode_reply(reply, offset);
-    auto time_decode_e = chrono::high_resolution_clock::now();
-    auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
-
-    assert(elems.size() == size_per_item);
-
-    bool failed = false;
-    // Check that we retrieved the correct element
-    for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
-            cout << "Main: elems " << (int)elems[i] << ", db "
-                << (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
-            cout << "Main: PIR result wrong at " << i <<  endl;
-            failed = true;
-        }
-    }
-    if(failed){
-        return -1;
-    }
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
+
+  // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
+  // pir_params);
+  print_pir_params(pir_params);
 
-    // Output results
-    cout << "Main: PIR result correct!" << endl;
-    cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
-    cout << "Main: PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
-    cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms" << endl;
-    cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
-    cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
+  cout << "Main: Initializing the database (this may take some time) ..."
+       << endl;
 
-    return 0;
+  // Create test database
+  auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  // Copy of the database. We use this at the end to make sure we retrieved
+  // the correct element.
+  auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  random_device rd;
+  for (uint64_t i = 0; i < number_of_items; i++) {
+    for (uint64_t j = 0; j < size_per_item; j++) {
+      uint8_t val = rd() % 256;
+      db.get()[(i * size_per_item) + j] = val;
+      db_copy.get()[(i * size_per_item) + j] = val;
+    }
+  }
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server and client" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  GaloisKeys galois_keys = client.generate_galois_keys();
+
+  // Set galois key for client with id 0
+  cout << "Main: Setting Galois keys...";
+  server.set_galois_key(0, galois_keys);
+
+  // Measure database setup
+  auto time_pre_s = high_resolution_clock::now();
+  server.set_database(move(db), number_of_items, size_per_item);
+  server.preprocess_database();
+  cout << "Main: database pre processed " << endl;
+  auto time_pre_e = high_resolution_clock::now();
+  auto time_pre_us =
+      duration_cast<microseconds>(time_pre_e - time_pre_s).count();
+
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Measure query generation
+  auto time_query_s = high_resolution_clock::now();
+  PirQuery query = client.generate_query(index);
+  auto time_query_e = high_resolution_clock::now();
+  auto time_query_us =
+      duration_cast<microseconds>(time_query_e - time_query_s).count();
+  cout << "Main: query generated" << endl;
+
+  // To marshall query to send over the network, you can use
+  // serialize/deserialize: std::string query_ser = serialize_query(query);
+  // PirQuery query2 = deserialize_query(d, 1, query_ser, CIPHER_SIZE);
+
+  // Measure query processing (including expansion)
+  auto time_server_s = high_resolution_clock::now();
+  PirReply reply = server.generate_reply(query, 0);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+
+  // Measure response extraction
+  auto time_decode_s = chrono::high_resolution_clock::now();
+  vector<uint8_t> elems = client.decode_reply(reply, offset);
+  auto time_decode_e = chrono::high_resolution_clock::now();
+  auto time_decode_us =
+      duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+
+  assert(elems.size() == size_per_item);
+
+  bool failed = false;
+  // Check that we retrieved the correct element
+  for (uint32_t i = 0; i < size_per_item; i++) {
+    if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
+      cout << "Main: elems " << (int)elems[i] << ", db "
+           << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
+      cout << "Main: PIR result wrong at " << i << endl;
+      failed = true;
+    }
+  }
+  if (failed) {
+    return -1;
+  }
+
+  // Output results
+  cout << "Main: PIR result correct!" << endl;
+  cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
+       << endl;
+  cout << "Main: PIRClient query generation time: " << time_query_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
+       << " ms" << endl;
+  cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
+
+  return 0;
 }

+ 149 - 141
test/replace_test.cpp

@@ -1,164 +1,172 @@
 #include "pir.hpp"
 #include "pir_client.hpp"
 #include "pir_server.hpp"
-#include <seal/seal.h>
 #include <chrono>
+#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <random>
-#include <cstdint>
-#include <cstddef>
+#include <seal/seal.h>
 
 using namespace std::chrono;
 using namespace std;
 using namespace seal;
 
-int replace_test(uint64_t num_items, uint64_t item_size, uint32_t degree, uint32_t lt, uint32_t dim);
+int replace_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+                 uint32_t lt, uint32_t dim);
 
 int main(int argc, char *argv[]) {
-    // Quick check
-    assert(replace_test(1 << 13, 1, 4096, 20, 1) == 0);
+  // Quick check
+  assert(replace_test(1 << 13, 1, 4096, 20, 1) == 0);
 
-    // Forces ciphertext expansion to be the same as the degree
-    assert(replace_test(1 << 20, 288, 4096, 20, 1) == 0);
+  // Forces ciphertext expansion to be the same as the degree
+  assert(replace_test(1 << 20, 288, 4096, 20, 1) == 0);
 
-    
-    assert(replace_test(1 << 20, 288, 4096, 20, 2) == 0);
+  assert(replace_test(1 << 20, 288, 4096, 20, 2) == 0);
 }
 
+int replace_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+                 uint32_t lt, uint32_t dim) {
+  uint64_t number_of_items = num_items;
+  uint64_t size_per_item = item_size; // in bytes
+  uint32_t N = degree;
 
-int replace_test(uint64_t num_items, uint64_t item_size, uint32_t degree, uint32_t lt, uint32_t dim){
-    uint64_t number_of_items = num_items;
-    uint64_t size_per_item = item_size; // in bytes
-    uint32_t N = degree;
-
-    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
-    uint32_t logt = lt; 
-    uint32_t d = dim;
-
-    EncryptionParameters enc_params(scheme_type::bfv);
-    PirParams pir_params;
-
-    // Generates all parameters
-    
-    cout << "Main: Generating SEAL parameters" << endl;
-    gen_encryption_params(N, logt, enc_params);
-    
-    cout << "Main: Verifying SEAL parameters" << endl;
-    verify_encryption_params(enc_params);
-    cout << "Main: SEAL parameters are good" << endl;
-
-    cout << "Main: Generating PIR parameters" << endl;
-    gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
-    
-    
-    
-    //gen_params(number_of_items, size_per_item, N, logt, d, enc_params, pir_params);
-    print_pir_params(pir_params);
-
-    cout << "Main: Initializing the database (this may take some time) ..." << endl;
-
-    // Create test database
-    auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    // Copy of the database. We use this at the end to make sure we retrieved
-    // the correct element.
-    auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    random_device rd;
-    for (uint64_t i = 0; i < number_of_items; i++) {
-        for (uint64_t j = 0; j < size_per_item; j++) {
-            uint8_t val = rd() % 256;
-            db.get()[(i * size_per_item) + j] = val;
-            db_copy.get()[(i * size_per_item) + j] = val;
-        }
-    }
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = lt;
+  uint32_t d = dim;
 
-    // Initialize PIR Server
-    cout << "Main: Initializing server and client" << endl;
-    PIRServer server(enc_params, pir_params);
-
-    // Initialize PIR client....
-    PIRClient client(enc_params, pir_params);
-    Ciphertext one_ct = client.get_one();
-    GaloisKeys galois_keys = client.generate_galois_keys();
-
-    // Set galois key for client with id 0
-    cout << "Main: Setting Galois keys...";
-    server.set_galois_key(0, galois_keys);
-    
-    // Measure database setup
-    auto time_pre_s = high_resolution_clock::now();
-    server.set_database(move(db), number_of_items, size_per_item);
-    server.preprocess_database();
-    server.set_one_ct(one_ct);
-    cout << "Main: database pre processed " << endl;
-    auto time_pre_e = high_resolution_clock::now();
-    auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
-
-
-    // Choose an index of an element in the DB
-    uint64_t ele_index = rd() % number_of_items; // element in DB at random position
-    uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
-    uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
-    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
-    cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
-
-    // Generate a new element
-    vector<uint8_t> new_element(size_per_item);
-    vector<uint8_t> new_element_copy(size_per_item);
-    for(uint64_t i = 0; i < size_per_item; i++){
-        uint8_t val = rd() % 256;
-            new_element[i] = val;
-            new_element_copy[i] = val;
-    }
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
 
-    // Get element to replace
-    auto time_server_s = high_resolution_clock::now();
-    Ciphertext reply = server.simple_query(index);
-    auto time_server_e = high_resolution_clock::now();
-    auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
-    auto time_decode_s = chrono::high_resolution_clock::now();
-    Plaintext old_pt = client.decrypt(reply);
-    auto time_decode_e = chrono::high_resolution_clock::now();
-    auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
-
-
-    // Replace element
-    Modulus t = enc_params.plain_modulus();
-    logt = floor(log2(t.value()));
-    vector<uint64_t> new_coeffs = bytes_to_coeffs(logt, new_element.data(), size_per_item);
-    Plaintext new_pt = client.replace_element(old_pt, new_coeffs, offset);
-    server.simple_set(index, new_pt);
-
-    //Get the replaced element
-    PirQuery query = client.generate_query(index);
-    PirReply server_reply = server.generate_reply(query, 0);
-    vector<uint8_t> elems = client.decode_reply(server_reply, offset);
-    //vector<uint8_t> elems = client.extract_bytes(client.decrypt(server.simple_query(index)), offset);
-    vector<uint8_t> old_elems = client.extract_bytes(old_pt, offset);    
-
-    assert(elems.size() == size_per_item);
-
-    bool failed = false;
-    // Check that we retrieved the correct element
-    for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[i] != new_element_copy[i]) {
-            cout << "Main: elems " << (int)elems[i] << ", new "
-                << (int) new_element_copy[i] << ", old "
-                << (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
-            cout << "Main: PIR result wrong at " << i <<  endl;
-            failed = true;
-        }
-    }
-    if(failed){
-        return -1;
-    }
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
+
+  // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
+  // pir_params);
+  print_pir_params(pir_params);
 
-    // Output results
-    cout << "Main: PIR result correct!" << endl;
-    cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
-    cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms" << endl;
-    cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
+  cout << "Main: Initializing the database (this may take some time) ..."
+       << endl;
 
-    return 0;
+  // Create test database
+  auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  // Copy of the database. We use this at the end to make sure we retrieved
+  // the correct element.
+  auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  random_device rd;
+  for (uint64_t i = 0; i < number_of_items; i++) {
+    for (uint64_t j = 0; j < size_per_item; j++) {
+      uint8_t val = rd() % 256;
+      db.get()[(i * size_per_item) + j] = val;
+      db_copy.get()[(i * size_per_item) + j] = val;
+    }
+  }
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server and client" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  Ciphertext one_ct = client.get_one();
+  GaloisKeys galois_keys = client.generate_galois_keys();
+
+  // Set galois key for client with id 0
+  cout << "Main: Setting Galois keys...";
+  server.set_galois_key(0, galois_keys);
+
+  // Measure database setup
+  auto time_pre_s = high_resolution_clock::now();
+  server.set_database(move(db), number_of_items, size_per_item);
+  server.preprocess_database();
+  server.set_one_ct(one_ct);
+  cout << "Main: database pre processed " << endl;
+  auto time_pre_e = high_resolution_clock::now();
+  auto time_pre_us =
+      duration_cast<microseconds>(time_pre_e - time_pre_s).count();
+
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Generate a new element
+  vector<uint8_t> new_element(size_per_item);
+  vector<uint8_t> new_element_copy(size_per_item);
+  for (uint64_t i = 0; i < size_per_item; i++) {
+    uint8_t val = rd() % 256;
+    new_element[i] = val;
+    new_element_copy[i] = val;
+  }
+
+  // Get element to replace
+  auto time_server_s = high_resolution_clock::now();
+  Ciphertext reply = server.simple_query(index);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+  auto time_decode_s = chrono::high_resolution_clock::now();
+  Plaintext old_pt = client.decrypt(reply);
+  auto time_decode_e = chrono::high_resolution_clock::now();
+  auto time_decode_us =
+      duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+
+  // Replace element
+  Modulus t = enc_params.plain_modulus();
+  logt = floor(log2(t.value()));
+  vector<uint64_t> new_coeffs =
+      bytes_to_coeffs(logt, new_element.data(), size_per_item);
+  Plaintext new_pt = client.replace_element(old_pt, new_coeffs, offset);
+  server.simple_set(index, new_pt);
+
+  // Get the replaced element
+  PirQuery query = client.generate_query(index);
+  PirReply server_reply = server.generate_reply(query, 0);
+  vector<uint8_t> elems = client.decode_reply(server_reply, offset);
+  // vector<uint8_t> elems =
+  // client.extract_bytes(client.decrypt(server.simple_query(index)), offset);
+  vector<uint8_t> old_elems = client.extract_bytes(old_pt, offset);
+
+  assert(elems.size() == size_per_item);
+
+  bool failed = false;
+  // Check that we retrieved the correct element
+  for (uint32_t i = 0; i < size_per_item; i++) {
+    if (elems[i] != new_element_copy[i]) {
+      cout << "Main: elems " << (int)elems[i] << ", new "
+           << (int)new_element_copy[i] << ", old "
+           << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
+      cout << "Main: PIR result wrong at " << i << endl;
+      failed = true;
+    }
+  }
+  if (failed) {
+    return -1;
+  }
+
+  // Output results
+  cout << "Main: PIR result correct!" << endl;
+  cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
+       << endl;
+  cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
+       << " ms" << endl;
+
+  return 0;
 }

+ 121 - 114
test/simple_query_test.cpp

@@ -1,136 +1,143 @@
 #include "pir.hpp"
 #include "pir_client.hpp"
 #include "pir_server.hpp"
-#include <seal/seal.h>
 #include <chrono>
+#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <random>
-#include <cstdint>
-#include <cstddef>
+#include <seal/seal.h>
 
 using namespace std::chrono;
 using namespace std;
 using namespace seal;
 
-int simple_query_test(uint64_t num_items, uint64_t item_size, uint32_t degree, uint32_t lt, uint32_t dim);
+int simple_query_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+                      uint32_t lt, uint32_t dim);
 
 int main(int argc, char *argv[]) {
-    // Quick check
-    assert(simple_query_test(1 << 10, 288, 4096, 20, 1) == 0);
+  // Quick check
+  assert(simple_query_test(1 << 10, 288, 4096, 20, 1) == 0);
 
-    // Forces ciphertext expansion to be the same as the degree
-    assert(simple_query_test(1 << 20, 288, 4096, 20, 1) == 0);
+  // Forces ciphertext expansion to be the same as the degree
+  assert(simple_query_test(1 << 20, 288, 4096, 20, 1) == 0);
 
-    
-    assert(simple_query_test(1 << 20, 288, 4096, 20, 2) == 0);
+  assert(simple_query_test(1 << 20, 288, 4096, 20, 2) == 0);
 }
 
+int simple_query_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+                      uint32_t lt, uint32_t dim) {
+  uint64_t number_of_items = num_items;
+  uint64_t size_per_item = item_size; // in bytes
+  uint32_t N = degree;
 
-int simple_query_test(uint64_t num_items, uint64_t item_size, uint32_t degree, uint32_t lt, uint32_t dim){
-    uint64_t number_of_items = num_items;
-    uint64_t size_per_item = item_size; // in bytes
-    uint32_t N = degree;
-
-    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
-    uint32_t logt = lt; 
-    uint32_t d = dim;
-
-    EncryptionParameters enc_params(scheme_type::bfv);
-    PirParams pir_params;
-
-    // Generates all parameters
-    
-    cout << "Main: Generating SEAL parameters" << endl;
-    gen_encryption_params(N, logt, enc_params);
-    
-    cout << "Main: Verifying SEAL parameters" << endl;
-    verify_encryption_params(enc_params);
-    cout << "Main: SEAL parameters are good" << endl;
-
-    cout << "Main: Generating PIR parameters" << endl;
-    gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
-    
-    
-    
-    //gen_params(number_of_items, size_per_item, N, logt, d, enc_params, pir_params);
-    print_pir_params(pir_params);
-
-    cout << "Main: Initializing the database (this may take some time) ..." << endl;
-
-    // Create test database
-    auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    // Copy of the database. We use this at the end to make sure we retrieved
-    // the correct element.
-    auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    random_device rd;
-    for (uint64_t i = 0; i < number_of_items; i++) {
-        for (uint64_t j = 0; j < size_per_item; j++) {
-            uint8_t val = rd() % 256;
-            db.get()[(i * size_per_item) + j] = val;
-            db_copy.get()[(i * size_per_item) + j] = val;
-        }
-    }
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = lt;
+  uint32_t d = dim;
 
-    // Initialize PIR Server
-    cout << "Main: Initializing server and client" << endl;
-    PIRServer server(enc_params, pir_params);
-
-    // Initialize PIR client....
-    PIRClient client(enc_params, pir_params);
-    Ciphertext one_ct = client.get_one();
-    
-    // Measure database setup
-    auto time_pre_s = high_resolution_clock::now();
-    server.set_database(move(db), number_of_items, size_per_item);
-    server.preprocess_database();
-    server.set_one_ct(one_ct);
-    cout << "Main: database pre processed " << endl;
-    auto time_pre_e = high_resolution_clock::now();
-    auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
-
-
-    // Choose an index of an element in the DB
-    uint64_t ele_index = rd() % number_of_items; // element in DB at random position
-    uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
-    uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
-    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
-    cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
-
-    // Measure query processing (including expansion)
-    auto time_server_s = high_resolution_clock::now();
-    Ciphertext reply = server.simple_query(index);
-    auto time_server_e = high_resolution_clock::now();
-    auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
-
-    // Measure response extraction
-    auto time_decode_s = chrono::high_resolution_clock::now();
-    vector<uint8_t> elems = client.extract_bytes(client.decrypt(reply), offset);
-    auto time_decode_e = chrono::high_resolution_clock::now();
-    auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
-
-    assert(elems.size() == size_per_item);
-
-    bool failed = false;
-    // Check that we retrieved the correct element
-    for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
-            cout << "Main: elems " << (int)elems[i] << ", db "
-                << (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
-            cout << "Main: PIR result wrong at " << i <<  endl;
-            failed = true;
-        }
-    }
-    if(failed){
-        return -1;
-    }
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
 
-    // Output results
-    cout << "Main: PIR result correct!" << endl;
-    cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
-    cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms" << endl;
-    cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
+  // Generates all parameters
 
-    return 0;
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
+
+  // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
+  // pir_params);
+  print_pir_params(pir_params);
+
+  cout << "Main: Initializing the database (this may take some time) ..."
+       << endl;
+
+  // Create test database
+  auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  // Copy of the database. We use this at the end to make sure we retrieved
+  // the correct element.
+  auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  random_device rd;
+  for (uint64_t i = 0; i < number_of_items; i++) {
+    for (uint64_t j = 0; j < size_per_item; j++) {
+      uint8_t val = rd() % 256;
+      db.get()[(i * size_per_item) + j] = val;
+      db_copy.get()[(i * size_per_item) + j] = val;
+    }
+  }
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server and client" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  Ciphertext one_ct = client.get_one();
+
+  // Measure database setup
+  auto time_pre_s = high_resolution_clock::now();
+  server.set_database(move(db), number_of_items, size_per_item);
+  server.preprocess_database();
+  server.set_one_ct(one_ct);
+  cout << "Main: database pre processed " << endl;
+  auto time_pre_e = high_resolution_clock::now();
+  auto time_pre_us =
+      duration_cast<microseconds>(time_pre_e - time_pre_s).count();
+
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Measure query processing (including expansion)
+  auto time_server_s = high_resolution_clock::now();
+  Ciphertext reply = server.simple_query(index);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+
+  // Measure response extraction
+  auto time_decode_s = chrono::high_resolution_clock::now();
+  vector<uint8_t> elems = client.extract_bytes(client.decrypt(reply), offset);
+  auto time_decode_e = chrono::high_resolution_clock::now();
+  auto time_decode_us =
+      duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+
+  assert(elems.size() == size_per_item);
+
+  bool failed = false;
+  // Check that we retrieved the correct element
+  for (uint32_t i = 0; i < size_per_item; i++) {
+    if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
+      cout << "Main: elems " << (int)elems[i] << ", db "
+           << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
+      cout << "Main: PIR result wrong at " << i << endl;
+      failed = true;
+    }
+  }
+  if (failed) {
+    return -1;
+  }
+
+  // Output results
+  cout << "Main: PIR result correct!" << endl;
+  cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
+       << endl;
+  cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
+       << " ms" << endl;
+
+  return 0;
 }