Browse Source

formatting src

Sebastian Angel 1 year ago
parent
commit
f83b53c0f2
7 changed files with 1095 additions and 1017 deletions
  1. 168 151
      src/main.cpp
  2. 234 215
      src/pir.cpp
  3. 49 46
      src/pir.hpp
  4. 229 215
      src/pir_client.cpp
  5. 37 35
      src/pir_client.hpp
  6. 332 312
      src/pir_server.cpp
  7. 46 43
      src/pir_server.hpp

+ 168 - 151
src/main.cpp

@@ -1,12 +1,12 @@
 #include "pir.hpp"
 #include "pir_client.hpp"
 #include "pir_server.hpp"
-#include <seal/seal.h>
 #include <chrono>
+#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <random>
-#include <cstdint>
-#include <cstddef>
+#include <seal/seal.h>
 
 using namespace std::chrono;
 using namespace std;
@@ -14,155 +14,172 @@ using namespace seal;
 
 int main(int argc, char *argv[]) {
 
-    uint64_t number_of_items = 1 << 16;
-    uint64_t size_per_item = 1024; // in bytes
-    uint32_t N = 4096;
-
-    // Recommended values: (logt, d) = (20, 2).
-    uint32_t logt = 20;
-    uint32_t d = 2;
-    bool use_symmetric = true; // use symmetric encryption instead of public key (recommended for smaller query)
-    bool use_batching = true; // pack as many elements as possible into a BFV plaintext (recommended)
-    bool use_recursive_mod_switching = true;
-
-    EncryptionParameters enc_params(scheme_type::bfv);
-    PirParams pir_params;
-
-    // Generates all parameters
-    
-    cout << "Main: Generating SEAL parameters" << endl;
-    gen_encryption_params(N, logt, enc_params);
-    
-    cout << "Main: Verifying SEAL parameters" << endl;
-    verify_encryption_params(enc_params);
-    cout << "Main: SEAL parameters are good" << endl;
-
-    cout << "Main: Generating PIR parameters" << endl;
-    gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params, use_symmetric, use_batching, use_recursive_mod_switching);
-    
-    
-    print_seal_params(enc_params); 
-    print_pir_params(pir_params);
-
-
-    // Initialize PIR client....
-    PIRClient client(enc_params, pir_params);
-    cout << "Main: Generating galois keys for client" << endl;
-
-    GaloisKeys galois_keys = client.generate_galois_keys();
-
-    // Initialize PIR Server
-    cout << "Main: Initializing server" << endl;
-    PIRServer server(enc_params, pir_params);
-
-    // Server maps the galois key to client 0. We only have 1 client,
-    // which is why we associate it with 0. If there are multiple PIR
-    // clients, you should have each client generate a galois key, 
-    // and assign each client an index or id, then call the procedure below.
-    server.set_galois_key(0, galois_keys);
-
-    cout << "Main: Creating the database with random data (this may take some time) ..." << endl;
-
-    // Create test database
-    auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    // Copy of the database. We use this at the end to make sure we retrieved
-    // the correct element.
-    auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    random_device rd;
-    for (uint64_t i = 0; i < number_of_items; i++) {
-        for (uint64_t j = 0; j < size_per_item; j++) {
-            uint8_t val = rd() % 256;
-            db.get()[(i * size_per_item) + j] = val;
-            db_copy.get()[(i * size_per_item) + j] = val;
-        }
-    }
+  uint64_t number_of_items = 1 << 16;
+  uint64_t size_per_item = 1024; // in bytes
+  uint32_t N = 4096;
+
+  // Recommended values: (logt, d) = (20, 2).
+  uint32_t logt = 20;
+  uint32_t d = 2;
+  bool use_symmetric = true; // use symmetric encryption instead of public key
+                             // (recommended for smaller query)
+  bool use_batching = true;  // pack as many elements as possible into a BFV
+                             // plaintext (recommended)
+  bool use_recursive_mod_switching = true;
+
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
+
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
 
-    // Measure database setup
-    auto time_pre_s = high_resolution_clock::now(); 
-    server.set_database(move(db), number_of_items, size_per_item);
-    server.preprocess_database();
-    auto time_pre_e = high_resolution_clock::now();
-    auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
-    cout << "Main: database pre processed " << endl;
-
-    // Choose an index of an element in the DB
-    uint64_t ele_index = rd() % number_of_items; // element in DB at random position
-    uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
-    uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
-    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
-    cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
-
-    // Measure query generation
-    auto time_query_s = high_resolution_clock::now();
-    PirQuery query = client.generate_query(index);
-    auto time_query_e = high_resolution_clock::now();
-    auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
-    cout << "Main: query generated" << endl;
-
-    // Measure serialized query generation (useful for sending over the network)
-    stringstream client_stream;
-    stringstream server_stream;
-    auto time_s_query_s = high_resolution_clock::now();
-    int query_size = client.generate_serialized_query(index, client_stream);
-    auto time_s_query_e = high_resolution_clock::now();
-    auto time_s_query_us = duration_cast<microseconds>(time_s_query_e - time_s_query_s).count();
-    cout << "Main: serialized query generated" << endl;
-
-    // Measure query deserialization (useful for receiving over the network)
-    auto time_deserial_s = high_resolution_clock::now();
-    PirQuery query2 = server.deserialize_query(client_stream);
-    auto time_deserial_e = high_resolution_clock::now();
-    auto time_deserial_us = duration_cast<microseconds>(time_deserial_e - time_deserial_s).count();
-    cout << "Main: query deserialized" << endl;
-
-    // Measure query processing (including expansion)
-    auto time_server_s = high_resolution_clock::now();
-    // Answer PIR query from client 0. If there are multiple clients, 
-    // enter the id of the client (to use the associated galois key).
-    PirReply reply = server.generate_reply(query2, 0); 
-    auto time_server_e = high_resolution_clock::now();
-    auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
-    cout << "Main: reply generated" << endl;
-
-    // Serialize reply (useful for sending over the network)
-    int reply_size = server.serialize_reply(reply, server_stream);
-
-    // Measure response extraction
-    auto time_decode_s = chrono::high_resolution_clock::now();
-    vector<uint8_t> elems = client.decode_reply(reply, offset);
-    auto time_decode_e = chrono::high_resolution_clock::now();
-    auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
-    cout << "Main: reply decoded" << endl;
-
-    assert(elems.size() == size_per_item);
-
-    bool failed = false;
-    // Check that we retrieved the correct element
-    for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
-            cout << "Main: elems " << (int)elems[i] << ", db "
-                << (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
-            cout << "Main: PIR result wrong at " << i <<  endl;
-            failed = true;
-        }
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params,
+                 use_symmetric, use_batching, use_recursive_mod_switching);
+
+  print_seal_params(enc_params);
+  print_pir_params(pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  cout << "Main: Generating galois keys for client" << endl;
+
+  GaloisKeys galois_keys = client.generate_galois_keys();
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Server maps the galois key to client 0. We only have 1 client,
+  // which is why we associate it with 0. If there are multiple PIR
+  // clients, you should have each client generate a galois key,
+  // and assign each client an index or id, then call the procedure below.
+  server.set_galois_key(0, galois_keys);
+
+  cout << "Main: Creating the database with random data (this may take some "
+          "time) ..."
+       << endl;
+
+  // Create test database
+  auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  // Copy of the database. We use this at the end to make sure we retrieved
+  // the correct element.
+  auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  random_device rd;
+  for (uint64_t i = 0; i < number_of_items; i++) {
+    for (uint64_t j = 0; j < size_per_item; j++) {
+      uint8_t val = rd() % 256;
+      db.get()[(i * size_per_item) + j] = val;
+      db_copy.get()[(i * size_per_item) + j] = val;
     }
-    if(failed){
-        return -1;
+  }
+
+  // Measure database setup
+  auto time_pre_s = high_resolution_clock::now();
+  server.set_database(move(db), number_of_items, size_per_item);
+  server.preprocess_database();
+  auto time_pre_e = high_resolution_clock::now();
+  auto time_pre_us =
+      duration_cast<microseconds>(time_pre_e - time_pre_s).count();
+  cout << "Main: database pre processed " << endl;
+
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Measure query generation
+  auto time_query_s = high_resolution_clock::now();
+  PirQuery query = client.generate_query(index);
+  auto time_query_e = high_resolution_clock::now();
+  auto time_query_us =
+      duration_cast<microseconds>(time_query_e - time_query_s).count();
+  cout << "Main: query generated" << endl;
+
+  // Measure serialized query generation (useful for sending over the network)
+  stringstream client_stream;
+  stringstream server_stream;
+  auto time_s_query_s = high_resolution_clock::now();
+  int query_size = client.generate_serialized_query(index, client_stream);
+  auto time_s_query_e = high_resolution_clock::now();
+  auto time_s_query_us =
+      duration_cast<microseconds>(time_s_query_e - time_s_query_s).count();
+  cout << "Main: serialized query generated" << endl;
+
+  // Measure query deserialization (useful for receiving over the network)
+  auto time_deserial_s = high_resolution_clock::now();
+  PirQuery query2 = server.deserialize_query(client_stream);
+  auto time_deserial_e = high_resolution_clock::now();
+  auto time_deserial_us =
+      duration_cast<microseconds>(time_deserial_e - time_deserial_s).count();
+  cout << "Main: query deserialized" << endl;
+
+  // Measure query processing (including expansion)
+  auto time_server_s = high_resolution_clock::now();
+  // Answer PIR query from client 0. If there are multiple clients,
+  // enter the id of the client (to use the associated galois key).
+  PirReply reply = server.generate_reply(query2, 0);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+  cout << "Main: reply generated" << endl;
+
+  // Serialize reply (useful for sending over the network)
+  int reply_size = server.serialize_reply(reply, server_stream);
+
+  // Measure response extraction
+  auto time_decode_s = chrono::high_resolution_clock::now();
+  vector<uint8_t> elems = client.decode_reply(reply, offset);
+  auto time_decode_e = chrono::high_resolution_clock::now();
+  auto time_decode_us =
+      duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+  cout << "Main: reply decoded" << endl;
+
+  assert(elems.size() == size_per_item);
+
+  bool failed = false;
+  // Check that we retrieved the correct element
+  for (uint32_t i = 0; i < size_per_item; i++) {
+    if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
+      cout << "Main: elems " << (int)elems[i] << ", db "
+           << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
+      cout << "Main: PIR result wrong at " << i << endl;
+      failed = true;
     }
-
-    // Output results
-    cout << "Main: PIR result correct!" << endl;
-    cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
-    cout << "Main: PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
-    cout << "Main: PIRClient serialized query generation time: " << time_s_query_us / 1000 << " ms" << endl;
-    cout << "Main: PIRServer query deserialization time: " << time_deserial_us << " us" << endl;
-    cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms" << endl;
-    cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
-    cout << "Main: Query size: " << query_size << " bytes" << endl;
-    cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
-    cout << "Main: Reply size: " << reply_size << " bytes" << endl;
-
-    return 0;
+  }
+  if (failed) {
+    return -1;
+  }
+
+  // Output results
+  cout << "Main: PIR result correct!" << endl;
+  cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
+       << endl;
+  cout << "Main: PIRClient query generation time: " << time_query_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient serialized query generation time: "
+       << time_s_query_us / 1000 << " ms" << endl;
+  cout << "Main: PIRServer query deserialization time: " << time_deserial_us
+       << " us" << endl;
+  cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
+       << " ms" << endl;
+  cout << "Main: Query size: " << query_size << " bytes" << endl;
+  cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
+  cout << "Main: Reply size: " << reply_size << " bytes" << endl;
+
+  return 0;
 }

+ 234 - 215
src/pir.cpp

@@ -4,247 +4,264 @@ using namespace std;
 using namespace seal;
 using namespace seal::util;
 
-std::vector<std::uint64_t> get_dimensions(std::uint64_t num_of_plaintexts, std::uint32_t d) {
+std::vector<std::uint64_t> get_dimensions(std::uint64_t num_of_plaintexts,
+                                          std::uint32_t d) {
 
-    assert(d > 0);
-    assert(num_of_plaintexts > 0);
+  assert(d > 0);
+  assert(num_of_plaintexts > 0);
 
-    std::uint64_t root = max(static_cast<uint32_t>(2),static_cast<uint32_t>(floor(pow(num_of_plaintexts, 1.0/d))));
+  std::uint64_t root =
+      max(static_cast<uint32_t>(2),
+          static_cast<uint32_t>(floor(pow(num_of_plaintexts, 1.0 / d))));
 
-    std::vector<std::uint64_t> dimensions(d, root);
+  std::vector<std::uint64_t> dimensions(d, root);
 
-    for(int i = 0; i < d; i++){
-        if(accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>()) > num_of_plaintexts){
-            break;
-        } 
-        dimensions[i] += 1;
+  for (int i = 0; i < d; i++) {
+    if (accumulate(dimensions.begin(), dimensions.end(), 1,
+                   multiplies<uint64_t>()) > num_of_plaintexts) {
+      break;
     }
+    dimensions[i] += 1;
+  }
 
-    std::uint32_t prod = accumulate(dimensions.begin(), dimensions.end(), 1, multiplies<uint64_t>());
-    assert(prod >= num_of_plaintexts);
-    return dimensions;
+  std::uint32_t prod = accumulate(dimensions.begin(), dimensions.end(), 1,
+                                  multiplies<uint64_t>());
+  assert(prod >= num_of_plaintexts);
+  return dimensions;
 }
 
 void gen_encryption_params(std::uint32_t N, std::uint32_t logt,
-                           seal::EncryptionParameters &enc_params){
-    
-    enc_params.set_poly_modulus_degree(N);
-    enc_params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
-    enc_params.set_plain_modulus(PlainModulus::Batching(N, logt+1)); 
-    // the +1 above ensures we get logt bits for each plaintext coefficient. Otherwise
-    // the coefficient modulus t will be logt bits, but only floor(t) = logt-1 (whp) 
-    // will be usable (since we need to ensure that all data in the coefficient is < t).
+                           seal::EncryptionParameters &enc_params) {
+
+  enc_params.set_poly_modulus_degree(N);
+  enc_params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
+  enc_params.set_plain_modulus(PlainModulus::Batching(N, logt + 1));
+  // the +1 above ensures we get logt bits for each plaintext coefficient.
+  // Otherwise the coefficient modulus t will be logt bits, but only floor(t) =
+  // logt-1 (whp) will be usable (since we need to ensure that all data in the
+  // coefficient is < t).
 }
 
-void verify_encryption_params(const seal::EncryptionParameters &enc_params){
-    SEALContext context(enc_params, true);
-    if(!context.parameters_set()){
-        throw invalid_argument("SEAL parameters not valid.");
-    }
-    if(!context.using_keyswitching()){
-        throw invalid_argument("SEAL parameters do not support key switching.");
-    }
-    if(!context.first_context_data()->qualifiers().using_batching){
-        throw invalid_argument("SEAL parameters do not support batching.");
-    }
+void verify_encryption_params(const seal::EncryptionParameters &enc_params) {
+  SEALContext context(enc_params, true);
+  if (!context.parameters_set()) {
+    throw invalid_argument("SEAL parameters not valid.");
+  }
+  if (!context.using_keyswitching()) {
+    throw invalid_argument("SEAL parameters do not support key switching.");
+  }
+  if (!context.first_context_data()->qualifiers().using_batching) {
+    throw invalid_argument("SEAL parameters do not support batching.");
+  }
 
-    BatchEncoder batch_encoder(context);
-    size_t slot_count = batch_encoder.slot_count();
-    if(slot_count != enc_params.poly_modulus_degree()){
-        throw invalid_argument("Slot count not equal to poly modulus degree - this will cause issues downstream.");
-    }
+  BatchEncoder batch_encoder(context);
+  size_t slot_count = batch_encoder.slot_count();
+  if (slot_count != enc_params.poly_modulus_degree()) {
+    throw invalid_argument("Slot count not equal to poly modulus degree - this "
+                           "will cause issues downstream.");
+  }
 
-    return;
+  return;
 }
 
 void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
-                    const EncryptionParameters &enc_params, PirParams &pir_params,
-                    bool enable_symmetric, bool enable_batching, bool enable_mswitching){
-    std::uint32_t N = enc_params.poly_modulus_degree();
-    Modulus t = enc_params.plain_modulus();
-    std::uint32_t logt = floor(log2(t.value())); // # of usable bits
-    std::uint64_t elements_per_plaintext;
-    std::uint64_t num_of_plaintexts;
-
-    if(enable_batching){
-        elements_per_plaintext = elements_per_ptxt(logt, N, ele_size);
-        num_of_plaintexts = plaintexts_per_db(logt, N, ele_num, ele_size);
-    }
-    else{
-        elements_per_plaintext = 1;
-        num_of_plaintexts = ele_num;
-    }
+                    const EncryptionParameters &enc_params,
+                    PirParams &pir_params, bool enable_symmetric,
+                    bool enable_batching, bool enable_mswitching) {
+  std::uint32_t N = enc_params.poly_modulus_degree();
+  Modulus t = enc_params.plain_modulus();
+  std::uint32_t logt = floor(log2(t.value())); // # of usable bits
+  std::uint64_t elements_per_plaintext;
+  std::uint64_t num_of_plaintexts;
+
+  if (enable_batching) {
+    elements_per_plaintext = elements_per_ptxt(logt, N, ele_size);
+    num_of_plaintexts = plaintexts_per_db(logt, N, ele_num, ele_size);
+  } else {
+    elements_per_plaintext = 1;
+    num_of_plaintexts = ele_num;
+  }
 
-    vector<uint64_t> nvec = get_dimensions(num_of_plaintexts, d);
+  vector<uint64_t> nvec = get_dimensions(num_of_plaintexts, d);
 
-    uint32_t expansion_ratio = 0;
-    for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
-        double logqi = log2(enc_params.coeff_modulus()[i].value());
-        expansion_ratio += ceil(logqi / logt);
-    }
+  uint32_t expansion_ratio = 0;
+  for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
+    double logqi = log2(enc_params.coeff_modulus()[i].value());
+    expansion_ratio += ceil(logqi / logt);
+  }
 
-    pir_params.enable_symmetric = enable_symmetric;
-    pir_params.enable_batching = enable_batching;
-    pir_params.enable_mswitching = enable_mswitching;
-    pir_params.ele_num = ele_num;
-    pir_params.ele_size = ele_size;
-    pir_params.elements_per_plaintext = elements_per_plaintext;
-    pir_params.num_of_plaintexts = num_of_plaintexts;
-    pir_params.d = d;                 
-    pir_params.expansion_ratio = expansion_ratio << 1;           
-    pir_params.nvec = nvec;
-    pir_params.slot_count = N;
+  pir_params.enable_symmetric = enable_symmetric;
+  pir_params.enable_batching = enable_batching;
+  pir_params.enable_mswitching = enable_mswitching;
+  pir_params.ele_num = ele_num;
+  pir_params.ele_size = ele_size;
+  pir_params.elements_per_plaintext = elements_per_plaintext;
+  pir_params.num_of_plaintexts = num_of_plaintexts;
+  pir_params.d = d;
+  pir_params.expansion_ratio = expansion_ratio << 1;
+  pir_params.nvec = nvec;
+  pir_params.slot_count = N;
 }
 
-
-void print_pir_params(const PirParams &pir_params){
-    std::uint32_t prod = accumulate(pir_params.nvec.begin(), pir_params.nvec.end(), 1, multiplies<uint64_t>());
-
-    cout << "PIR Parameters" << endl;
-    cout << "number of elements: " << pir_params.ele_num << endl;
-    cout << "element size: " << pir_params.ele_size << endl;
-    cout << "elements per BFV plaintext: " << pir_params.elements_per_plaintext << endl;
-    cout << "dimensions for d-dimensional hyperrectangle: " << pir_params.d << endl;
-    cout << "number of BFV plaintexts (before padding): " << pir_params.num_of_plaintexts << endl;
-    cout << "Number of BFV plaintexts after padding (to fill d-dimensional hyperrectangle): " << prod << endl;
-    cout << "expansion ratio: " << pir_params.expansion_ratio << endl;
-    cout << "Using symmetric encryption: " << pir_params.enable_symmetric << endl;
-    cout << "Using recursive mod switching: " << pir_params.enable_mswitching << endl;
-    cout << "slot count: " << pir_params.slot_count << endl;
-    cout << "=============================="<< endl;
+void print_pir_params(const PirParams &pir_params) {
+  std::uint32_t prod =
+      accumulate(pir_params.nvec.begin(), pir_params.nvec.end(), 1,
+                 multiplies<uint64_t>());
+
+  cout << "PIR Parameters" << endl;
+  cout << "number of elements: " << pir_params.ele_num << endl;
+  cout << "element size: " << pir_params.ele_size << endl;
+  cout << "elements per BFV plaintext: " << pir_params.elements_per_plaintext
+       << endl;
+  cout << "dimensions for d-dimensional hyperrectangle: " << pir_params.d
+       << endl;
+  cout << "number of BFV plaintexts (before padding): "
+       << pir_params.num_of_plaintexts << endl;
+  cout << "Number of BFV plaintexts after padding (to fill d-dimensional "
+          "hyperrectangle): "
+       << prod << endl;
+  cout << "expansion ratio: " << pir_params.expansion_ratio << endl;
+  cout << "Using symmetric encryption: " << pir_params.enable_symmetric << endl;
+  cout << "Using recursive mod switching: " << pir_params.enable_mswitching
+       << endl;
+  cout << "slot count: " << pir_params.slot_count << endl;
+  cout << "==============================" << endl;
 }
 
-
-void print_seal_params(const EncryptionParameters &enc_params){
-    std::uint32_t N = enc_params.poly_modulus_degree();
-    Modulus t = enc_params.plain_modulus();
-    std::uint32_t logt = floor(log2(t.value()));
-
-    cout << "SEAL encryption parameters"  << endl;
-    cout << "Degree of polynomial modulus (N): " << N << endl;
-    cout << "Size of plaintext modulus (log t):" << logt << endl;
-    cout << "There are " << enc_params.coeff_modulus().size() << " coefficient modulus:" << endl;
-
-    for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
-        double logqi = log2(enc_params.coeff_modulus()[i].value());
-        cout << "Size of coefficient modulus " << i << " (log q_" << i << "): " << logqi << endl;
-    }
-    cout << "=============================="<< endl;
+void print_seal_params(const EncryptionParameters &enc_params) {
+  std::uint32_t N = enc_params.poly_modulus_degree();
+  Modulus t = enc_params.plain_modulus();
+  std::uint32_t logt = floor(log2(t.value()));
+
+  cout << "SEAL encryption parameters" << endl;
+  cout << "Degree of polynomial modulus (N): " << N << endl;
+  cout << "Size of plaintext modulus (log t):" << logt << endl;
+  cout << "There are " << enc_params.coeff_modulus().size()
+       << " coefficient modulus:" << endl;
+
+  for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
+    double logqi = log2(enc_params.coeff_modulus()[i].value());
+    cout << "Size of coefficient modulus " << i << " (log q_" << i
+         << "): " << logqi << endl;
+  }
+  cout << "==============================" << endl;
 }
 
-
 // Number of coefficients needed to represent a database element
 uint64_t coefficients_per_element(uint32_t logt, uint64_t ele_size) {
-    return ceil(8 * ele_size / (double)logt);
+  return ceil(8 * ele_size / (double)logt);
 }
 
 // Number of database elements that can fit in a single FV plaintext
 uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
-    uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
-    uint64_t ele_per_ptxt = N / coeff_per_ele;
-    assert(ele_per_ptxt > 0);
-    return ele_per_ptxt;
+  uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
+  uint64_t ele_per_ptxt = N / coeff_per_ele;
+  assert(ele_per_ptxt > 0);
+  return ele_per_ptxt;
 }
 
 // Number of FV plaintexts needed to represent the database
-uint64_t plaintexts_per_db(uint32_t logt, uint64_t N, uint64_t ele_num, uint64_t ele_size) {
-    uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
-    return ceil((double)ele_num / ele_per_ptxt);
+uint64_t plaintexts_per_db(uint32_t logt, uint64_t N, uint64_t ele_num,
+                           uint64_t ele_size) {
+  uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
+  return ceil((double)ele_num / ele_per_ptxt);
 }
 
-vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes, uint64_t size) {
-
-    uint64_t size_out = coefficients_per_element(limit, size);
-    vector<uint64_t> output(size_out);
-
-    uint32_t room = limit;
-    uint64_t *target = &output[0];
-
-    for (uint32_t i = 0; i < size; i++) {
-        uint8_t src = bytes[i];
-        uint32_t rest = 8;
-        while (rest) {
-            if (room == 0) {
-                target++;
-                room = limit;
-            }
-            uint32_t shift = rest;
-            if (room < rest) {
-                shift = room;
-            }
-            *target = *target << shift;
-            *target = *target | (src >> (8 - shift));
-            src = src << shift;
-            room -= shift;
-            rest -= shift;
-        }
+vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes,
+                                 uint64_t size) {
+
+  uint64_t size_out = coefficients_per_element(limit, size);
+  vector<uint64_t> output(size_out);
+
+  uint32_t room = limit;
+  uint64_t *target = &output[0];
+
+  for (uint32_t i = 0; i < size; i++) {
+    uint8_t src = bytes[i];
+    uint32_t rest = 8;
+    while (rest) {
+      if (room == 0) {
+        target++;
+        room = limit;
+      }
+      uint32_t shift = rest;
+      if (room < rest) {
+        shift = room;
+      }
+      *target = *target << shift;
+      *target = *target | (src >> (8 - shift));
+      src = src << shift;
+      room -= shift;
+      rest -= shift;
     }
+  }
 
-    *target = *target << room;
-    return output;
+  *target = *target << room;
+  return output;
 }
 
-void coeffs_to_bytes(uint32_t limit, const vector<uint64_t> &coeffs, uint8_t *output, uint32_t size_out, uint32_t ele_size){
-    uint32_t room = 8;
-    uint32_t j = 0;
-    uint8_t *target = output;
-    uint32_t bits_left = ele_size * 8;
-    for (uint32_t i = 0; i < coeffs.size(); i++) {
-        if(bits_left == 0){
-            bits_left = ele_size * 8;
-        }
-        uint64_t src = coeffs[i];
-        uint32_t rest = min(limit, bits_left);
-        while (rest && j < size_out) {
-            uint32_t shift = rest;
-            if (room < rest) {
-                shift = room;
-            }
-            
-            target[j] = target[j] << shift;
-            target[j] = target[j] | (src >> (limit - shift));
-            src = src << shift;
-            room -= shift;
-            rest -= shift;
-            bits_left -= shift;
-            if (room == 0) {
-                j++;
-                room = 8;
-            }
-        }
+void coeffs_to_bytes(uint32_t limit, const vector<uint64_t> &coeffs,
+                     uint8_t *output, uint32_t size_out, uint32_t ele_size) {
+  uint32_t room = 8;
+  uint32_t j = 0;
+  uint8_t *target = output;
+  uint32_t bits_left = ele_size * 8;
+  for (uint32_t i = 0; i < coeffs.size(); i++) {
+    if (bits_left == 0) {
+      bits_left = ele_size * 8;
+    }
+    uint64_t src = coeffs[i];
+    uint32_t rest = min(limit, bits_left);
+    while (rest && j < size_out) {
+      uint32_t shift = rest;
+      if (room < rest) {
+        shift = room;
+      }
+
+      target[j] = target[j] << shift;
+      target[j] = target[j] | (src >> (limit - shift));
+      src = src << shift;
+      room -= shift;
+      rest -= shift;
+      bits_left -= shift;
+      if (room == 0) {
+        j++;
+        room = 8;
+      }
     }
+  }
 }
 
 void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
-    uint32_t coeff_count = coeffs.size();
-    plain.resize(coeff_count);
-    util::set_uint(coeffs.data(), coeff_count, plain.data());
+  uint32_t coeff_count = coeffs.size();
+  plain.resize(coeff_count);
+  util::set_uint(coeffs.data(), coeff_count, plain.data());
 }
 
 vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
-    uint32_t num = Nvec.size();
-    uint64_t product = 1;
+  uint32_t num = Nvec.size();
+  uint64_t product = 1;
 
-    for (uint32_t i = 0; i < num; i++) {
-        product *= Nvec[i];
-    }
+  for (uint32_t i = 0; i < num; i++) {
+    product *= Nvec[i];
+  }
 
-    uint64_t j = desiredIndex;
-    vector<uint64_t> result;
+  uint64_t j = desiredIndex;
+  vector<uint64_t> result;
 
-    for (uint32_t i = 0; i < num; i++) {
+  for (uint32_t i = 0; i < num; i++) {
 
-        product /= Nvec[i];
-        uint64_t ji = j / product;
+    product /= Nvec[i];
+    uint64_t ji = j / product;
 
-        result.push_back(ji);
-        j -= ji * product;
-    }
+    result.push_back(ji);
+    j -= ji * product;
+  }
 
-    return result;
+  return result;
 }
 
-uint64_t invert_mod(uint64_t m, const seal::Modulus& mod) {
+uint64_t invert_mod(uint64_t m, const seal::Modulus &mod) {
   if (mod.uint64_count() > 1) {
     cout << "Mod too big to invert";
   }
@@ -255,7 +272,6 @@ uint64_t invert_mod(uint64_t m, const seal::Modulus& mod) {
   return inverse;
 }
 
-
 uint32_t compute_expansion_ratio(EncryptionParameters params) {
   uint32_t expansion_ratio = 0;
   uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
@@ -266,7 +282,8 @@ uint32_t compute_expansion_ratio(EncryptionParameters params) {
   return expansion_ratio;
 }
 
-vector<Plaintext> decompose_to_plaintexts(EncryptionParameters params, const Ciphertext& ct) {
+vector<Plaintext> decompose_to_plaintexts(EncryptionParameters params,
+                                          const Ciphertext &ct) {
   const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
   const auto coeff_count = params.poly_modulus_degree();
   const auto coeff_mod_count = params.coeff_modulus().size();
@@ -298,8 +315,9 @@ vector<Plaintext> decompose_to_plaintexts(EncryptionParameters params, const Cip
   return result;
 }
 
-void compose_to_ciphertext(EncryptionParameters params, vector<Plaintext>::const_iterator pt_iter,
-  const size_t ct_poly_count, Ciphertext& ct) {
+void compose_to_ciphertext(EncryptionParameters params,
+                           vector<Plaintext>::const_iterator pt_iter,
+                           const size_t ct_poly_count, Ciphertext &ct) {
   const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
   const auto coeff_count = params.poly_modulus_degree();
   const auto coeff_mod_count = params.coeff_modulus().size();
@@ -330,38 +348,39 @@ void compose_to_ciphertext(EncryptionParameters params, vector<Plaintext>::const
   }
 }
 
-void compose_to_ciphertext(EncryptionParameters params, const vector<Plaintext>& pts, Ciphertext& ct) {
-  return compose_to_ciphertext(params, pts.begin(), pts.size() / compute_expansion_ratio(params), ct);
+void compose_to_ciphertext(EncryptionParameters params,
+                           const vector<Plaintext> &pts, Ciphertext &ct) {
+  return compose_to_ciphertext(
+      params, pts.begin(), pts.size() / compute_expansion_ratio(params), ct);
 }
 
-
-
-PirQuery deserialize_query(uint32_t d, uint32_t count, string s, uint32_t len_ciphertext,
-shared_ptr<SEALContext> context) {
-    vector<vector<Ciphertext>> q;
-    std::istringstream input(s);
-
-    for (uint32_t i = 0; i < d; i++) {
-        vector<Ciphertext> cs;
-        for (uint32_t i = 0; i < count; i++) {
-          Ciphertext c;
-          c.load(*context, input);
-          cs.push_back(c);
-        }
-        q.push_back(cs);
+PirQuery deserialize_query(uint32_t d, uint32_t count, string s,
+                           uint32_t len_ciphertext,
+                           shared_ptr<SEALContext> context) {
+  vector<vector<Ciphertext>> q;
+  std::istringstream input(s);
+
+  for (uint32_t i = 0; i < d; i++) {
+    vector<Ciphertext> cs;
+    for (uint32_t i = 0; i < count; i++) {
+      Ciphertext c;
+      c.load(*context, input);
+      cs.push_back(c);
     }
-    return q;
+    q.push_back(cs);
+  }
+  return q;
 }
 
 string serialize_galoiskeys(Serializable<GaloisKeys> g) {
-    std::ostringstream output;
-    g.save(output);
-    return output.str();
+  std::ostringstream output;
+  g.save(output);
+  return output.str();
 }
 
 GaloisKeys *deserialize_galoiskeys(string s, shared_ptr<SEALContext> context) {
-    GaloisKeys *g = new GaloisKeys();
-    std::istringstream input(s);
-    g->load(*context, input);
-    return g;
+  GaloisKeys *g = new GaloisKeys();
+  std::istringstream input(s);
+  g->load(*context, input);
+  return g;
 }

+ 49 - 46
src/pir.hpp

@@ -12,38 +12,30 @@ typedef std::vector<std::vector<seal::Ciphertext>> PirQuery;
 typedef std::vector<seal::Ciphertext> PirReply;
 
 struct PirParams {
-    bool enable_symmetric;
-    bool enable_batching;
-    bool enable_mswitching;
-    std::uint64_t ele_num;
-    std::uint64_t ele_size;
-    std::uint64_t elements_per_plaintext;
-    std::uint64_t num_of_plaintexts;         // number of plaintexts in database
-    std::uint32_t d;                         // number of dimensions for the database
-    std::uint32_t expansion_ratio;           // ratio of ciphertext to plaintext
-    std::vector<std::uint64_t> nvec;         // size of each of the d dimensions
-    std::uint32_t slot_count;
+  bool enable_symmetric;
+  bool enable_batching;
+  bool enable_mswitching;
+  std::uint64_t ele_num;
+  std::uint64_t ele_size;
+  std::uint64_t elements_per_plaintext;
+  std::uint64_t num_of_plaintexts; // number of plaintexts in database
+  std::uint32_t d;                 // number of dimensions for the database
+  std::uint32_t expansion_ratio;   // ratio of ciphertext to plaintext
+  std::vector<std::uint64_t> nvec; // size of each of the d dimensions
+  std::uint32_t slot_count;
 };
 
-void gen_encryption_params(std::uint32_t N,        // degree of polynomial
-                           std::uint32_t logt,     // bits of plaintext coefficient
+void gen_encryption_params(std::uint32_t N,    // degree of polynomial
+                           std::uint32_t logt, // bits of plaintext coefficient
                            seal::EncryptionParameters &enc_params);
 
-void gen_pir_params(uint64_t ele_num,
-                    uint64_t ele_size,
-                    uint32_t d,
+void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
                     const seal::EncryptionParameters &enc_params,
-                    PirParams &pir_params,
-                    bool enable_symmetric = false,
-                    bool enable_batching = true,
-                    bool enable_mswitching = true);
-
-void gen_params(uint64_t ele_num,
-                uint64_t ele_size,
-                uint32_t N,
-                uint32_t logt,
-                uint32_t d,
-                seal::EncryptionParameters &params,
+                    PirParams &pir_params, bool enable_symmetric = false,
+                    bool enable_batching = true, bool enable_mswitching = true);
+
+void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
+                uint32_t d, seal::EncryptionParameters &params,
                 PirParams &pir_params);
 
 void verify_encryption_params(const seal::EncryptionParameters &enc_params);
@@ -52,26 +44,32 @@ void print_pir_params(const PirParams &pir_params);
 void print_seal_params(const seal::EncryptionParameters &enc_params);
 
 // returns the number of plaintexts that the database can hold
-std::uint64_t plaintexts_per_db(std::uint32_t logt, std::uint64_t N, std::uint64_t ele_num,
-                                std::uint64_t ele_size);
+std::uint64_t plaintexts_per_db(std::uint32_t logt, std::uint64_t N,
+                                std::uint64_t ele_num, std::uint64_t ele_size);
 
 // returns the number of elements that a single FV plaintext can hold
-std::uint64_t elements_per_ptxt(std::uint32_t logt, std::uint64_t N, std::uint64_t ele_size);
+std::uint64_t elements_per_ptxt(std::uint32_t logt, std::uint64_t N,
+                                std::uint64_t ele_size);
 
 // returns the number of coefficients needed to store one element
-std::uint64_t coefficients_per_element(std::uint32_t logt, std::uint64_t ele_size);
+std::uint64_t coefficients_per_element(std::uint32_t logt,
+                                       std::uint64_t ele_size);
 
 // Converts an array of bytes to a vector of coefficients, each of which is less
 // than the plaintext modulus
-std::vector<std::uint64_t> bytes_to_coeffs(std::uint32_t limit, const std::uint8_t *bytes,
+std::vector<std::uint64_t> bytes_to_coeffs(std::uint32_t limit,
+                                           const std::uint8_t *bytes,
                                            std::uint64_t size);
 
 // Converts an array of coefficients into an array of bytes
-void coeffs_to_bytes(std::uint32_t limit, const std::vector<std::uint64_t> &coeffs, std::uint8_t *output, 
-                    std::uint32_t size_out, std::uint32_t ele_size);
+void coeffs_to_bytes(std::uint32_t limit,
+                     const std::vector<std::uint64_t> &coeffs,
+                     std::uint8_t *output, std::uint32_t size_out,
+                     std::uint32_t ele_size);
 
 // Takes a vector of coefficients and returns the corresponding FV plaintext
-void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs, seal::Plaintext &plain);
+void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs,
+                         seal::Plaintext &plain);
 
 // Since the database has d dimensions, and an item is a particular cell
 // in the d-dimensional hypercube, this function computes the corresponding
@@ -79,19 +77,24 @@ void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs, seal::Plainte
 std::vector<std::uint64_t> compute_indices(std::uint64_t desiredIndex,
                                            std::vector<std::uint64_t> nvec);
 
-uint64_t invert_mod(uint64_t m, const seal::Modulus& mod);
+uint64_t invert_mod(uint64_t m, const seal::Modulus &mod);
 
 uint32_t compute_expansion_ratio(seal::EncryptionParameters params);
-std::vector<seal::Plaintext> decompose_to_plaintexts(seal::EncryptionParameters params,
-    const seal::Ciphertext& ct);
-
-//We need the returned ciphertext to be initialized by Context so the caller will pass it in
-void compose_to_ciphertext(seal::EncryptionParameters params, 
-    const std::vector<seal::Plaintext>& pts, seal::Ciphertext& ct);
-void compose_to_ciphertext(seal::EncryptionParameters params, 
-    std::vector<seal::Plaintext>::const_iterator pt_iter, seal::Ciphertext& ct);
-
+std::vector<seal::Plaintext>
+decompose_to_plaintexts(seal::EncryptionParameters params,
+                        const seal::Ciphertext &ct);
+
+// We need the returned ciphertext to be initialized by Context so the caller
+// will pass it in
+void compose_to_ciphertext(seal::EncryptionParameters params,
+                           const std::vector<seal::Plaintext> &pts,
+                           seal::Ciphertext &ct);
+void compose_to_ciphertext(seal::EncryptionParameters params,
+                           std::vector<seal::Plaintext>::const_iterator pt_iter,
+                           seal::Ciphertext &ct);
 
 // Serialize and deserialize galois keys to send them over the network
 std::string serialize_galoiskeys(seal::Serializable<seal::GaloisKeys> g);
-seal::GaloisKeys *deserialize_galoiskeys(std::string s, std::shared_ptr<seal::SEALContext> context);
+seal::GaloisKeys *
+deserialize_galoiskeys(std::string s,
+                       std::shared_ptr<seal::SEALContext> context);

+ 229 - 215
src/pir_client.cpp

@@ -5,271 +5,285 @@ using namespace seal;
 using namespace seal::util;
 
 PIRClient::PIRClient(const EncryptionParameters &enc_params,
-                     const PirParams &pir_params) :
-    enc_params_(enc_params),
-    pir_params_(pir_params){
+                     const PirParams &pir_params)
+    : enc_params_(enc_params), pir_params_(pir_params) {
 
-    context_ = make_shared<SEALContext>(enc_params, true);
+  context_ = make_shared<SEALContext>(enc_params, true);
 
-    keygen_ = make_unique<KeyGenerator>(*context_);
-    
-    PublicKey public_key;
-    keygen_->create_public_key(public_key);
-    SecretKey secret_key = keygen_->secret_key();
+  keygen_ = make_unique<KeyGenerator>(*context_);
 
-    if(pir_params_.enable_symmetric){
-        encryptor_ = make_unique<Encryptor>(*context_, secret_key);
-    }
-    else{
-        encryptor_ = make_unique<Encryptor>(*context_, public_key);
-    }
-    
-    decryptor_ = make_unique<Decryptor>(*context_, secret_key);
-    evaluator_ = make_unique<Evaluator>(*context_);
-    encoder_ = make_unique<BatchEncoder>(*context_);
+  PublicKey public_key;
+  keygen_->create_public_key(public_key);
+  SecretKey secret_key = keygen_->secret_key();
+
+  if (pir_params_.enable_symmetric) {
+    encryptor_ = make_unique<Encryptor>(*context_, secret_key);
+  } else {
+    encryptor_ = make_unique<Encryptor>(*context_, public_key);
+  }
+
+  decryptor_ = make_unique<Decryptor>(*context_, secret_key);
+  evaluator_ = make_unique<Evaluator>(*context_);
+  encoder_ = make_unique<BatchEncoder>(*context_);
 }
 
-int PIRClient::generate_serialized_query(uint64_t desiredIndex, std::stringstream &stream) {
-
-    int N = enc_params_.poly_modulus_degree(); 
-    int output_size = 0;
-    indices_ = compute_indices(desiredIndex, pir_params_.nvec);
-    Plaintext pt(enc_params_.poly_modulus_degree());
-
-    for (uint32_t i = 0; i < indices_.size(); i++) {
-        uint32_t num_ptxts = ceil( (pir_params_.nvec[i] + 0.0) / N);
-        // initialize result. 
-        cout << "Client: index " << i + 1  <<  "/ " <<  indices_.size() << " = " << indices_[i] << endl; 
-        cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
-        
-        for (uint32_t j =0; j < num_ptxts; j++){
-            pt.set_zero();
-            if (indices_[i] >= N*j && indices_[i] <= N*(j+1)){
-                uint64_t real_index = indices_[i] - N*j; 
-                uint64_t n_i = pir_params_.nvec[i];
-                uint64_t total = N; 
-                if (j == num_ptxts - 1){
-                    total = n_i % N; 
-                }
-                uint64_t log_total = ceil(log2(total));
-
-                cout << "Client: Inverting " << pow(2, log_total) << endl;
-                pt[real_index] = invert_mod(pow(2, log_total), enc_params_.plain_modulus());
-            }
-
-            if(pir_params_.enable_symmetric){
-                output_size += encryptor_->encrypt_symmetric(pt).save(stream);
-            }
-            else{
-                output_size += encryptor_->encrypt(pt).save(stream);
-            }
-        }   
+int PIRClient::generate_serialized_query(uint64_t desiredIndex,
+                                         std::stringstream &stream) {
+
+  int N = enc_params_.poly_modulus_degree();
+  int output_size = 0;
+  indices_ = compute_indices(desiredIndex, pir_params_.nvec);
+  Plaintext pt(enc_params_.poly_modulus_degree());
+
+  for (uint32_t i = 0; i < indices_.size(); i++) {
+    uint32_t num_ptxts = ceil((pir_params_.nvec[i] + 0.0) / N);
+    // initialize result.
+    cout << "Client: index " << i + 1 << "/ " << indices_.size() << " = "
+         << indices_[i] << endl;
+    cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
+
+    for (uint32_t j = 0; j < num_ptxts; j++) {
+      pt.set_zero();
+      if (indices_[i] >= N * j && indices_[i] <= N * (j + 1)) {
+        uint64_t real_index = indices_[i] - N * j;
+        uint64_t n_i = pir_params_.nvec[i];
+        uint64_t total = N;
+        if (j == num_ptxts - 1) {
+          total = n_i % N;
+        }
+        uint64_t log_total = ceil(log2(total));
+
+        cout << "Client: Inverting " << pow(2, log_total) << endl;
+        pt[real_index] =
+            invert_mod(pow(2, log_total), enc_params_.plain_modulus());
+      }
+
+      if (pir_params_.enable_symmetric) {
+        output_size += encryptor_->encrypt_symmetric(pt).save(stream);
+      } else {
+        output_size += encryptor_->encrypt(pt).save(stream);
+      }
     }
+  }
 
-    return output_size;
+  return output_size;
 }
 
-
 PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
 
-    indices_ = compute_indices(desiredIndex, pir_params_.nvec);
-
-    PirQuery result(pir_params_.d);
-    int N = enc_params_.poly_modulus_degree(); 
-
-    Plaintext pt(enc_params_.poly_modulus_degree());
-    for (uint32_t i = 0; i < indices_.size(); i++) {
-        uint32_t num_ptxts = ceil( (pir_params_.nvec[i] + 0.0) / N);
-        // initialize result. 
-        cout << "Client: index " << i + 1  <<  "/ " <<  indices_.size() << " = " << indices_[i] << endl; 
-        cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
-        
-        for (uint32_t j =0; j < num_ptxts; j++){
-            pt.set_zero();
-            if (indices_[i] >= N*j && indices_[i] <= N*(j+1)){
-                uint64_t real_index = indices_[i] - N*j; 
-                uint64_t n_i = pir_params_.nvec[i];
-                uint64_t total = N; 
-                if (j == num_ptxts - 1){
-                    total = n_i % N; 
-                }
-                uint64_t log_total = ceil(log2(total));
-
-                cout << "Client: Inverting " << pow(2, log_total) << endl;
-                pt[real_index] = invert_mod(pow(2, log_total), enc_params_.plain_modulus());
-            }
-            Ciphertext dest;
-            if(pir_params_.enable_symmetric){
-                encryptor_->encrypt_symmetric(pt, dest);
-            }
-            else{
-                encryptor_->encrypt(pt, dest);
-            }
-            result[i].push_back(dest);
-        }   
+  indices_ = compute_indices(desiredIndex, pir_params_.nvec);
+
+  PirQuery result(pir_params_.d);
+  int N = enc_params_.poly_modulus_degree();
+
+  Plaintext pt(enc_params_.poly_modulus_degree());
+  for (uint32_t i = 0; i < indices_.size(); i++) {
+    uint32_t num_ptxts = ceil((pir_params_.nvec[i] + 0.0) / N);
+    // initialize result.
+    cout << "Client: index " << i + 1 << "/ " << indices_.size() << " = "
+         << indices_[i] << endl;
+    cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
+
+    for (uint32_t j = 0; j < num_ptxts; j++) {
+      pt.set_zero();
+      if (indices_[i] >= N * j && indices_[i] <= N * (j + 1)) {
+        uint64_t real_index = indices_[i] - N * j;
+        uint64_t n_i = pir_params_.nvec[i];
+        uint64_t total = N;
+        if (j == num_ptxts - 1) {
+          total = n_i % N;
+        }
+        uint64_t log_total = ceil(log2(total));
+
+        cout << "Client: Inverting " << pow(2, log_total) << endl;
+        pt[real_index] =
+            invert_mod(pow(2, log_total), enc_params_.plain_modulus());
+      }
+      Ciphertext dest;
+      if (pir_params_.enable_symmetric) {
+        encryptor_->encrypt_symmetric(pt, dest);
+      } else {
+        encryptor_->encrypt(pt, dest);
+      }
+      result[i].push_back(dest);
     }
+  }
 
-    return result;
+  return result;
 }
 
 uint64_t PIRClient::get_fv_index(uint64_t element_index) {
-    return static_cast<uint64_t>(element_index / pir_params_.elements_per_plaintext);
+  return static_cast<uint64_t>(element_index /
+                               pir_params_.elements_per_plaintext);
 }
 
 uint64_t PIRClient::get_fv_offset(uint64_t element_index) {
-    return element_index % pir_params_.elements_per_plaintext;
+  return element_index % pir_params_.elements_per_plaintext;
 }
 
-Plaintext PIRClient::decrypt(Ciphertext ct){
-    Plaintext pt;
-    decryptor_->decrypt(ct, pt);
-    return pt;
+Plaintext PIRClient::decrypt(Ciphertext ct) {
+  Plaintext pt;
+  decryptor_->decrypt(ct, pt);
+  return pt;
 }
 
-vector<uint8_t> PIRClient::decode_reply(PirReply &reply, uint64_t offset){
-    Plaintext result = decode_reply(reply);
-    return extract_bytes(result, offset);   
+vector<uint8_t> PIRClient::decode_reply(PirReply &reply, uint64_t offset) {
+  Plaintext result = decode_reply(reply);
+  return extract_bytes(result, offset);
 }
 
-vector<uint64_t> PIRClient::extract_coeffs(Plaintext pt){
-    vector<uint64_t> coeffs;
-    encoder_->decode(pt, coeffs);
-    return coeffs;
+vector<uint64_t> PIRClient::extract_coeffs(Plaintext pt) {
+  vector<uint64_t> coeffs;
+  encoder_->decode(pt, coeffs);
+  return coeffs;
 }
 
-std::vector<uint64_t> PIRClient::extract_coeffs(seal::Plaintext pt, uint64_t offset){
-    vector<uint64_t> coeffs;
-    encoder_->decode(pt, coeffs);
+std::vector<uint64_t> PIRClient::extract_coeffs(seal::Plaintext pt,
+                                                uint64_t offset) {
+  vector<uint64_t> coeffs;
+  encoder_->decode(pt, coeffs);
+
+  uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
 
-    uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
-    
-    uint64_t coeffs_per_element = coefficients_per_element(logt, pir_params_.ele_size);
+  uint64_t coeffs_per_element =
+      coefficients_per_element(logt, pir_params_.ele_size);
 
-    return std::vector<uint64_t>(coeffs.begin() + offset * coeffs_per_element, coeffs.begin() + (offset + 1) * coeffs_per_element);
+  return std::vector<uint64_t>(coeffs.begin() + offset * coeffs_per_element,
+                               coeffs.begin() +
+                                   (offset + 1) * coeffs_per_element);
 }
 
-std::vector<uint8_t> PIRClient::extract_bytes(seal::Plaintext pt, uint64_t offset){
-    uint32_t N = enc_params_.poly_modulus_degree(); 
-    uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
-    uint32_t bytes_per_ptxt = pir_params_.elements_per_plaintext * pir_params_.ele_size;
-
-    // Convert from FV plaintext (polynomial) to database element at the client
-    vector<uint8_t> elems(bytes_per_ptxt);
-    vector<uint64_t> coeffs;
-    encoder_->decode(pt, coeffs);
-    coeffs_to_bytes(logt, coeffs, elems.data(), bytes_per_ptxt, pir_params_.ele_size);
-    return std::vector<uint8_t>(elems.begin() + offset * pir_params_.ele_size, elems.begin() + (offset + 1) * pir_params_.ele_size);
+std::vector<uint8_t> PIRClient::extract_bytes(seal::Plaintext pt,
+                                              uint64_t offset) {
+  uint32_t N = enc_params_.poly_modulus_degree();
+  uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+  uint32_t bytes_per_ptxt =
+      pir_params_.elements_per_plaintext * pir_params_.ele_size;
+
+  // Convert from FV plaintext (polynomial) to database element at the client
+  vector<uint8_t> elems(bytes_per_ptxt);
+  vector<uint64_t> coeffs;
+  encoder_->decode(pt, coeffs);
+  coeffs_to_bytes(logt, coeffs, elems.data(), bytes_per_ptxt,
+                  pir_params_.ele_size);
+  return std::vector<uint8_t>(elems.begin() + offset * pir_params_.ele_size,
+                              elems.begin() +
+                                  (offset + 1) * pir_params_.ele_size);
 }
 
 Plaintext PIRClient::decode_reply(PirReply &reply) {
-    EncryptionParameters parms;
-    parms_id_type parms_id;
-    if(pir_params_.enable_mswitching){
-        parms = context_->last_context_data()->parms();
-        parms_id = context_->last_parms_id();
-    }
-    else{
-        parms = context_->first_context_data()->parms();
-        parms_id = context_->first_parms_id();
-    }
-    uint32_t exp_ratio = compute_expansion_ratio(parms);
-    uint32_t recursion_level = pir_params_.d;
-
-    vector<Ciphertext> temp = reply;
-    uint32_t ciphertext_size = temp[0].size();
-
-    uint64_t t = enc_params_.plain_modulus().value();
-
-    for (uint32_t i = 0; i < recursion_level; i++) {
-        cout << "Client: " << i + 1 << "/ " << recursion_level << "-th decryption layer started." << endl; 
-        vector<Ciphertext> newtemp;
-        vector<Plaintext> tempplain;
-
-        for (uint32_t j = 0; j < temp.size(); j++) {
-            Plaintext ptxt;
-            decryptor_->decrypt(temp[j], ptxt);
+  EncryptionParameters parms;
+  parms_id_type parms_id;
+  if (pir_params_.enable_mswitching) {
+    parms = context_->last_context_data()->parms();
+    parms_id = context_->last_parms_id();
+  } else {
+    parms = context_->first_context_data()->parms();
+    parms_id = context_->first_parms_id();
+  }
+  uint32_t exp_ratio = compute_expansion_ratio(parms);
+  uint32_t recursion_level = pir_params_.d;
+
+  vector<Ciphertext> temp = reply;
+  uint32_t ciphertext_size = temp[0].size();
+
+  uint64_t t = enc_params_.plain_modulus().value();
+
+  for (uint32_t i = 0; i < recursion_level; i++) {
+    cout << "Client: " << i + 1 << "/ " << recursion_level
+         << "-th decryption layer started." << endl;
+    vector<Ciphertext> newtemp;
+    vector<Plaintext> tempplain;
+
+    for (uint32_t j = 0; j < temp.size(); j++) {
+      Plaintext ptxt;
+      decryptor_->decrypt(temp[j], ptxt);
 #ifdef DEBUG
-            cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl; 
+      cout << "Client: reply noise budget = "
+           << decryptor_->invariant_noise_budget(temp[j]) << endl;
 #endif
-            
-            //cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
-            tempplain.push_back(ptxt);
+
+      // cout << "decoded (and scaled) plaintext = " << ptxt.to_string() <<
+      // endl;
+      tempplain.push_back(ptxt);
 
 #ifdef DEBUG
-            cout << "recursion level : " << i << " noise budget :  ";
-            cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
+      cout << "recursion level : " << i << " noise budget :  ";
+      cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
 #endif
 
-            if ((j + 1) % (exp_ratio * ciphertext_size) == 0 && j > 0) {
-                // Combine into one ciphertext.
-                Ciphertext combined(*context_, parms_id); 
-                compose_to_ciphertext(parms, tempplain, combined);
-                newtemp.push_back(combined);
-                tempplain.clear();
-                // cout << "Client: const term of ciphertext = " << combined[0] << endl; 
-            }
-        }
-        cout << "Client: done." << endl; 
-        cout << endl; 
-        if (i == recursion_level - 1) {
-            assert(temp.size() == 1);
-            return tempplain[0];
-        } else {
-            tempplain.clear();
-            temp = newtemp;
-        }
+      if ((j + 1) % (exp_ratio * ciphertext_size) == 0 && j > 0) {
+        // Combine into one ciphertext.
+        Ciphertext combined(*context_, parms_id);
+        compose_to_ciphertext(parms, tempplain, combined);
+        newtemp.push_back(combined);
+        tempplain.clear();
+        // cout << "Client: const term of ciphertext = " << combined[0] << endl;
+      }
+    }
+    cout << "Client: done." << endl;
+    cout << endl;
+    if (i == recursion_level - 1) {
+      assert(temp.size() == 1);
+      return tempplain[0];
+    } else {
+      tempplain.clear();
+      temp = newtemp;
     }
+  }
 
-    // This should never be called
-    assert(0);
-    Plaintext fail;
-    return fail;
+  // This should never be called
+  assert(0);
+  Plaintext fail;
+  return fail;
 }
 
 GaloisKeys PIRClient::generate_galois_keys() {
-    // Generate the Galois keys needed for coeff_select.
-    vector<uint32_t> galois_elts;
-    int N = enc_params_.poly_modulus_degree();
-    int logN = get_power_of_two(N);
-
-    //cout << "printing galois elements...";
-    for (int i = 0; i < logN; i++) {
-        galois_elts.push_back((N + exponentiate_uint(2, i)) / exponentiate_uint(2, i));
-//#ifdef DEBUG
-        // cout << galois_elts.back() << ", ";
-//#endif
-    }
-    GaloisKeys gal_keys;
-    keygen_->create_galois_keys(galois_elts, gal_keys);
-    return gal_keys;
+  // Generate the Galois keys needed for coeff_select.
+  vector<uint32_t> galois_elts;
+  int N = enc_params_.poly_modulus_degree();
+  int logN = get_power_of_two(N);
+
+  // cout << "printing galois elements...";
+  for (int i = 0; i < logN; i++) {
+    galois_elts.push_back((N + exponentiate_uint(2, i)) /
+                          exponentiate_uint(2, i));
+    //#ifdef DEBUG
+    // cout << galois_elts.back() << ", ";
+    //#endif
+  }
+  GaloisKeys gal_keys;
+  keygen_->create_galois_keys(galois_elts, gal_keys);
+  return gal_keys;
 }
 
-Plaintext PIRClient::replace_element(Plaintext pt, vector<uint64_t> new_element, uint64_t offset){
-    vector<uint64_t> coeffs = extract_coeffs(pt);
-    
-    uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
-    uint64_t coeffs_per_element = coefficients_per_element(logt, pir_params_.ele_size);
+Plaintext PIRClient::replace_element(Plaintext pt, vector<uint64_t> new_element,
+                                     uint64_t offset) {
+  vector<uint64_t> coeffs = extract_coeffs(pt);
 
-    assert(new_element.size() == coeffs_per_element);
+  uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+  uint64_t coeffs_per_element =
+      coefficients_per_element(logt, pir_params_.ele_size);
 
-    for(uint64_t i = 0; i < coeffs_per_element; i++){
-        coeffs[i + offset * coeffs_per_element] = new_element[i];
-    }
-    
-    Plaintext new_pt;
+  assert(new_element.size() == coeffs_per_element);
+
+  for (uint64_t i = 0; i < coeffs_per_element; i++) {
+    coeffs[i + offset * coeffs_per_element] = new_element[i];
+  }
 
-    encoder_->encode(coeffs, new_pt);
-    return new_pt;
+  Plaintext new_pt;
+
+  encoder_->encode(coeffs, new_pt);
+  return new_pt;
 }
 
-Ciphertext PIRClient::get_one(){
-    Plaintext pt("1");
-    Ciphertext ct;
-    if(pir_params_.enable_symmetric){
-        encryptor_->encrypt_symmetric(pt, ct);
-    }
-    else{
-        encryptor_->encrypt(pt, ct);
-    }
-    return ct;
+Ciphertext PIRClient::get_one() {
+  Plaintext pt("1");
+  Ciphertext ct;
+  if (pir_params_.enable_symmetric) {
+    encryptor_->encrypt_symmetric(pt, ct);
+  } else {
+    encryptor_->encrypt(pt, ct);
+  }
+  return ct;
 }

+ 37 - 35
src/pir_client.hpp

@@ -4,53 +4,55 @@
 #include <memory>
 #include <vector>
 
-using namespace std; 
+using namespace std;
 
 class PIRClient {
-  public:
-    PIRClient(const seal::EncryptionParameters &encparms,
-               const PirParams &pirparams);
+public:
+  PIRClient(const seal::EncryptionParameters &encparms,
+            const PirParams &pirparams);
 
-    PirQuery generate_query(std::uint64_t desiredIndex);
-    // Serializes the query into the provided stream and returns number of bytes written
-    int generate_serialized_query(std::uint64_t desiredIndex, std::stringstream &stream);
-    seal::Plaintext decode_reply(PirReply &reply);
-    
-    std::vector<uint64_t> extract_coeffs(seal::Plaintext pt);
-    std::vector<uint64_t> extract_coeffs(seal::Plaintext pt, std::uint64_t offset);
-    std::vector<uint8_t> extract_bytes(seal::Plaintext pt, std::uint64_t offset);
+  PirQuery generate_query(std::uint64_t desiredIndex);
+  // Serializes the query into the provided stream and returns number of bytes
+  // written
+  int generate_serialized_query(std::uint64_t desiredIndex,
+                                std::stringstream &stream);
+  seal::Plaintext decode_reply(PirReply &reply);
 
-    std::vector<uint8_t> decode_reply(PirReply &reply, uint64_t offset);
+  std::vector<uint64_t> extract_coeffs(seal::Plaintext pt);
+  std::vector<uint64_t> extract_coeffs(seal::Plaintext pt,
+                                       std::uint64_t offset);
+  std::vector<uint8_t> extract_bytes(seal::Plaintext pt, std::uint64_t offset);
 
+  std::vector<uint8_t> decode_reply(PirReply &reply, uint64_t offset);
 
-    seal::Plaintext decrypt(seal::Ciphertext ct);
+  seal::Plaintext decrypt(seal::Ciphertext ct);
 
-    seal::GaloisKeys generate_galois_keys();
+  seal::GaloisKeys generate_galois_keys();
 
-    // Index and offset of an element in an FV plaintext
-    uint64_t get_fv_index(uint64_t element_index);
-    uint64_t get_fv_offset(uint64_t element_index);
+  // Index and offset of an element in an FV plaintext
+  uint64_t get_fv_index(uint64_t element_index);
+  uint64_t get_fv_offset(uint64_t element_index);
 
-    // Only used for simple_query
-    seal::Ciphertext get_one();
+  // Only used for simple_query
+  seal::Ciphertext get_one();
 
-    seal::Plaintext replace_element(seal::Plaintext pt, std::vector<std::uint64_t> new_element, std::uint64_t offset);
-   
+  seal::Plaintext replace_element(seal::Plaintext pt,
+                                  std::vector<std::uint64_t> new_element,
+                                  std::uint64_t offset);
 
-  private:
-    seal::EncryptionParameters enc_params_;
-    PirParams pir_params_;
+private:
+  seal::EncryptionParameters enc_params_;
+  PirParams pir_params_;
 
-    std::unique_ptr<seal::Encryptor> encryptor_;
-    std::unique_ptr<seal::Decryptor> decryptor_;
-    std::unique_ptr<seal::Evaluator> evaluator_;
-    std::unique_ptr<seal::KeyGenerator> keygen_;
-    std::unique_ptr<seal::BatchEncoder> encoder_;
-    std::shared_ptr<seal::SEALContext> context_;
+  std::unique_ptr<seal::Encryptor> encryptor_;
+  std::unique_ptr<seal::Decryptor> decryptor_;
+  std::unique_ptr<seal::Evaluator> evaluator_;
+  std::unique_ptr<seal::KeyGenerator> keygen_;
+  std::unique_ptr<seal::BatchEncoder> encoder_;
+  std::shared_ptr<seal::SEALContext> context_;
 
-    vector<uint64_t> indices_; // the indices for retrieval. 
-    vector<uint64_t> inverse_scales_; 
+  vector<uint64_t> indices_; // the indices for retrieval.
+  vector<uint64_t> inverse_scales_;
 
-
-    friend class PIRServer;
+  friend class PIRServer;
 };

+ 332 - 312
src/pir_server.cpp

@@ -5,418 +5,438 @@ using namespace std;
 using namespace seal;
 using namespace seal::util;
 
-PIRServer::PIRServer(const EncryptionParameters &enc_params, const PirParams &pir_params) :
-    enc_params_(enc_params), 
-    pir_params_(pir_params),
-    is_db_preprocessed_(false)
-{
-    context_ = make_shared<SEALContext>(enc_params, true);
-    evaluator_ = make_unique<Evaluator>(*context_);
-    encoder_ = make_unique<BatchEncoder>(*context_);
+PIRServer::PIRServer(const EncryptionParameters &enc_params,
+                     const PirParams &pir_params)
+    : enc_params_(enc_params), pir_params_(pir_params),
+      is_db_preprocessed_(false) {
+  context_ = make_shared<SEALContext>(enc_params, true);
+  evaluator_ = make_unique<Evaluator>(*context_);
+  encoder_ = make_unique<BatchEncoder>(*context_);
 }
 
 void PIRServer::preprocess_database() {
-    if (!is_db_preprocessed_) {
+  if (!is_db_preprocessed_) {
 
-        for (uint32_t i = 0; i < db_->size(); i++) {
-            evaluator_->transform_to_ntt_inplace(
-                db_->operator[](i), context_->first_parms_id());
-        }
-
-        is_db_preprocessed_ = true;
+    for (uint32_t i = 0; i < db_->size(); i++) {
+      evaluator_->transform_to_ntt_inplace(db_->operator[](i),
+                                           context_->first_parms_id());
     }
+
+    is_db_preprocessed_ = true;
+  }
 }
 
 // Server takes over ownership of db and will free it when it exits
 void PIRServer::set_database(unique_ptr<vector<Plaintext>> &&db) {
-    if (!db) {
-        throw invalid_argument("db cannot be null");
-    }
+  if (!db) {
+    throw invalid_argument("db cannot be null");
+  }
 
-    db_ = move(db);
-    is_db_preprocessed_ = false;
+  db_ = move(db);
+  is_db_preprocessed_ = false;
 }
 
-void PIRServer::set_database(const std::unique_ptr<const uint8_t[]> &bytes, 
-    uint64_t ele_num, uint64_t ele_size) {
+void PIRServer::set_database(const std::unique_ptr<const uint8_t[]> &bytes,
+                             uint64_t ele_num, uint64_t ele_size) {
 
-    uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
-    uint32_t N = enc_params_.poly_modulus_degree();
+  uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+  uint32_t N = enc_params_.poly_modulus_degree();
 
-    // number of FV plaintexts needed to represent all elements
-    uint64_t num_of_plaintexts = pir_params_.num_of_plaintexts;
+  // number of FV plaintexts needed to represent all elements
+  uint64_t num_of_plaintexts = pir_params_.num_of_plaintexts;
 
-    // number of FV plaintexts needed to create the d-dimensional matrix
-    uint64_t prod = 1;
-    for (uint32_t i = 0; i < pir_params_.nvec.size(); i++) {
-        prod *= pir_params_.nvec[i];
-    }
-    uint64_t matrix_plaintexts = prod;
+  // number of FV plaintexts needed to create the d-dimensional matrix
+  uint64_t prod = 1;
+  for (uint32_t i = 0; i < pir_params_.nvec.size(); i++) {
+    prod *= pir_params_.nvec[i];
+  }
+  uint64_t matrix_plaintexts = prod;
 
-    assert(num_of_plaintexts <= matrix_plaintexts);
+  assert(num_of_plaintexts <= matrix_plaintexts);
 
-    auto result = make_unique<vector<Plaintext>>();
-    result->reserve(matrix_plaintexts);
+  auto result = make_unique<vector<Plaintext>>();
+  result->reserve(matrix_plaintexts);
 
-    uint64_t ele_per_ptxt = pir_params_.elements_per_plaintext;
-    uint64_t bytes_per_ptxt = ele_per_ptxt * ele_size;
+  uint64_t ele_per_ptxt = pir_params_.elements_per_plaintext;
+  uint64_t bytes_per_ptxt = ele_per_ptxt * ele_size;
 
-    uint64_t db_size = ele_num * ele_size;
+  uint64_t db_size = ele_num * ele_size;
 
-    uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, ele_size);
-    assert(coeff_per_ptxt <= N);
+  uint64_t coeff_per_ptxt =
+      ele_per_ptxt * coefficients_per_element(logt, ele_size);
+  assert(coeff_per_ptxt <= N);
 
-    cout << "Elements per plaintext: " << ele_per_ptxt << endl;
-    cout << "Coeff per ptxt: " << coeff_per_ptxt << endl;
-    cout << "Bytes per plaintext: " << bytes_per_ptxt << endl;
+  cout << "Elements per plaintext: " << ele_per_ptxt << endl;
+  cout << "Coeff per ptxt: " << coeff_per_ptxt << endl;
+  cout << "Bytes per plaintext: " << bytes_per_ptxt << endl;
 
-    uint32_t offset = 0;
+  uint32_t offset = 0;
 
-    for (uint64_t i = 0; i < num_of_plaintexts; i++) {
+  for (uint64_t i = 0; i < num_of_plaintexts; i++) {
 
-        uint64_t process_bytes = 0;
+    uint64_t process_bytes = 0;
 
-        if (db_size <= offset) {
-            break;
-        } else if (db_size < offset + bytes_per_ptxt) {
-            process_bytes = db_size - offset;
-        } else {
-            process_bytes = bytes_per_ptxt;
-        }
-        assert(process_bytes % ele_size == 0);
-        uint64_t ele_in_chunk = process_bytes / ele_size;
-
-        // Get the coefficients of the elements that will be packed in plaintext i
-        vector<uint64_t> coefficients(coeff_per_ptxt);
-        for(uint64_t ele = 0; ele < ele_in_chunk; ele++){
-            vector<uint64_t> element_coeffs = bytes_to_coeffs(logt, bytes.get() + offset + (ele_size*ele), ele_size);
-            std::copy(element_coeffs.begin(), element_coeffs.end(), coefficients.begin() + (coefficients_per_element(logt, ele_size) * ele));
-        }
-         
-        offset += process_bytes;
+    if (db_size <= offset) {
+      break;
+    } else if (db_size < offset + bytes_per_ptxt) {
+      process_bytes = db_size - offset;
+    } else {
+      process_bytes = bytes_per_ptxt;
+    }
+    assert(process_bytes % ele_size == 0);
+    uint64_t ele_in_chunk = process_bytes / ele_size;
+
+    // Get the coefficients of the elements that will be packed in plaintext i
+    vector<uint64_t> coefficients(coeff_per_ptxt);
+    for (uint64_t ele = 0; ele < ele_in_chunk; ele++) {
+      vector<uint64_t> element_coeffs = bytes_to_coeffs(
+          logt, bytes.get() + offset + (ele_size * ele), ele_size);
+      std::copy(element_coeffs.begin(), element_coeffs.end(),
+                coefficients.begin() +
+                    (coefficients_per_element(logt, ele_size) * ele));
+    }
 
-        uint64_t used = coefficients.size();
+    offset += process_bytes;
 
-        assert(used <= coeff_per_ptxt);
+    uint64_t used = coefficients.size();
 
-        // Pad the rest with 1s
-        for (uint64_t j = 0; j < (pir_params_.slot_count - used); j++) {
-            coefficients.push_back(1);
-        }
+    assert(used <= coeff_per_ptxt);
 
-        Plaintext plain;
-        encoder_->encode(coefficients, plain);
-        // cout << i << "-th encoded plaintext = " << plain.to_string() << endl; 
-        result->push_back(move(plain));
+    // Pad the rest with 1s
+    for (uint64_t j = 0; j < (pir_params_.slot_count - used); j++) {
+      coefficients.push_back(1);
     }
 
-    // Add padding to make database a matrix
-    uint64_t current_plaintexts = result->size();
-    assert(current_plaintexts <= num_of_plaintexts);
+    Plaintext plain;
+    encoder_->encode(coefficients, plain);
+    // cout << i << "-th encoded plaintext = " << plain.to_string() << endl;
+    result->push_back(move(plain));
+  }
+
+  // Add padding to make database a matrix
+  uint64_t current_plaintexts = result->size();
+  assert(current_plaintexts <= num_of_plaintexts);
 
 #ifdef DEBUG
-    cout << "adding: " << matrix_plaintexts - current_plaintexts
-         << " FV plaintexts of padding (equivalent to: "
-         << (matrix_plaintexts - current_plaintexts) * elements_per_ptxt(logtp, N, ele_size)
-         << " elements)" << endl;
+  cout << "adding: " << matrix_plaintexts - current_plaintexts
+       << " FV plaintexts of padding (equivalent to: "
+       << (matrix_plaintexts - current_plaintexts) *
+              elements_per_ptxt(logtp, N, ele_size)
+       << " elements)" << endl;
 #endif
 
-    vector<uint64_t> padding(N, 1);
+  vector<uint64_t> padding(N, 1);
 
-    for (uint64_t i = 0; i < (matrix_plaintexts - current_plaintexts); i++) {
-        Plaintext plain;
-        vector_to_plaintext(padding, plain);
-        result->push_back(plain);
-    }
+  for (uint64_t i = 0; i < (matrix_plaintexts - current_plaintexts); i++) {
+    Plaintext plain;
+    vector_to_plaintext(padding, plain);
+    result->push_back(plain);
+  }
 
-    set_database(move(result));
+  set_database(move(result));
 }
 
 void PIRServer::set_galois_key(uint32_t client_id, seal::GaloisKeys galkey) {
-    galoisKeys_[client_id] = galkey;
+  galoisKeys_[client_id] = galkey;
 }
 
 PirQuery PIRServer::deserialize_query(stringstream &stream) {
-    PirQuery q;
-
-    for (uint32_t i = 0; i < pir_params_.d; i++) {
-        // number of ciphertexts needed to encode the index for dimension i
-        // keeping into account that each ciphertext can encode up to poly_modulus_degree indexes
-        // In most cases this is usually 1.
-        uint32_t ctx_per_dimension = ceil((pir_params_.nvec[i] + 0.0) / enc_params_.poly_modulus_degree());
-
-        vector<Ciphertext> cs;
-        for (uint32_t j = 0; j < ctx_per_dimension; j++) {
-          Ciphertext c;
-          c.load(*context_, stream);
-          cs.push_back(c);
-        }
-
-        q.push_back(cs);
+  PirQuery q;
+
+  for (uint32_t i = 0; i < pir_params_.d; i++) {
+    // number of ciphertexts needed to encode the index for dimension i
+    // keeping into account that each ciphertext can encode up to
+    // poly_modulus_degree indexes In most cases this is usually 1.
+    uint32_t ctx_per_dimension =
+        ceil((pir_params_.nvec[i] + 0.0) / enc_params_.poly_modulus_degree());
+
+    vector<Ciphertext> cs;
+    for (uint32_t j = 0; j < ctx_per_dimension; j++) {
+      Ciphertext c;
+      c.load(*context_, stream);
+      cs.push_back(c);
     }
 
-    return q;
+    q.push_back(cs);
+  }
+
+  return q;
 }
 
 int PIRServer::serialize_reply(PirReply &reply, stringstream &stream) {
-    int output_size = 0;
-    for (int i = 0; i < reply.size(); i++){
-        evaluator_->mod_switch_to_inplace(reply[i], context_->last_parms_id());
-        output_size += reply[i].save(stream);
-    }
-    return output_size;
+  int output_size = 0;
+  for (int i = 0; i < reply.size(); i++) {
+    evaluator_->mod_switch_to_inplace(reply[i], context_->last_parms_id());
+    output_size += reply[i].save(stream);
+  }
+  return output_size;
 }
 
 PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id) {
 
-    vector<uint64_t> nvec = pir_params_.nvec;
-    uint64_t product = 1;
+  vector<uint64_t> nvec = pir_params_.nvec;
+  uint64_t product = 1;
 
-    for (uint32_t i = 0; i < nvec.size(); i++) {
-        product *= nvec[i];
-    }
+  for (uint32_t i = 0; i < nvec.size(); i++) {
+    product *= nvec[i];
+  }
 
-    auto coeff_count = enc_params_.poly_modulus_degree();
+  auto coeff_count = enc_params_.poly_modulus_degree();
 
-    vector<Plaintext> *cur = db_.get();
-    vector<Plaintext> intermediate_plain; // decompose....
+  vector<Plaintext> *cur = db_.get();
+  vector<Plaintext> intermediate_plain; // decompose....
 
-    auto pool = MemoryManager::GetPool();
+  auto pool = MemoryManager::GetPool();
 
+  int N = enc_params_.poly_modulus_degree();
 
-    int N = enc_params_.poly_modulus_degree();
+  int logt = floor(log2(enc_params_.plain_modulus().value()));
 
-    int logt = floor(log2(enc_params_.plain_modulus().value()));
+  for (uint32_t i = 0; i < nvec.size(); i++) {
+    cout << "Server: " << i + 1 << "-th recursion level started " << endl;
 
-    for (uint32_t i = 0; i < nvec.size(); i++) {
-        cout << "Server: " << i + 1 << "-th recursion level started " << endl; 
+    vector<Ciphertext> expanded_query;
 
+    uint64_t n_i = nvec[i];
+    cout << "Server: n_i = " << n_i << endl;
+    cout << "Server: expanding " << query[i].size() << " query ctxts" << endl;
+    for (uint32_t j = 0; j < query[i].size(); j++) {
+      uint64_t total = N;
+      if (j == query[i].size() - 1) {
+        total = n_i % N;
+      }
+      cout << "-- expanding one query ctxt into " << total << " ctxts " << endl;
+      vector<Ciphertext> expanded_query_part =
+          expand_query(query[i][j], total, client_id);
+      expanded_query.insert(
+          expanded_query.end(),
+          std::make_move_iterator(expanded_query_part.begin()),
+          std::make_move_iterator(expanded_query_part.end()));
+      expanded_query_part.clear();
+    }
+    cout << "Server: expansion done " << endl;
+    if (expanded_query.size() != n_i) {
+      cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i
+           << endl;
+    }
 
-        vector<Ciphertext> expanded_query;
-
-        uint64_t n_i = nvec[i];
-        cout << "Server: n_i = " << n_i << endl; 
-        cout << "Server: expanding " << query[i].size() << " query ctxts" << endl;
-        for (uint32_t j = 0; j < query[i].size(); j++){
-            uint64_t total = N; 
-            if (j == query[i].size() - 1){
-                total = n_i % N;
-            }
-            cout << "-- expanding one query ctxt into " << total  << " ctxts "<< endl;
-            vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id);
-            expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()), 
-                    std::make_move_iterator(expanded_query_part.end()));
-            expanded_query_part.clear(); 
-        }
-        cout << "Server: expansion done " << endl; 
-        if (expanded_query.size() != n_i) {
-            cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl; 
-        }    
-
-        // Transform expanded query to NTT, and ...
-        for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
-            evaluator_->transform_to_ntt_inplace(expanded_query[jj]);
-        }
+    // Transform expanded query to NTT, and ...
+    for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
+      evaluator_->transform_to_ntt_inplace(expanded_query[jj]);
+    }
 
-        // Transform plaintext to NTT. If database is pre-processed, can skip
-        if ((!is_db_preprocessed_) || i > 0) {
-            for (uint32_t jj = 0; jj < cur->size(); jj++) {
-                evaluator_->transform_to_ntt_inplace((*cur)[jj], context_->first_parms_id());
-            }
-        }
+    // Transform plaintext to NTT. If database is pre-processed, can skip
+    if ((!is_db_preprocessed_) || i > 0) {
+      for (uint32_t jj = 0; jj < cur->size(); jj++) {
+        evaluator_->transform_to_ntt_inplace((*cur)[jj],
+                                             context_->first_parms_id());
+      }
+    }
 
-        for (uint64_t k = 0; k < product; k++) {
-            if ((*cur)[k].is_zero()){
-                cout << k + 1 << "/ " << product <<  "-th ptxt = 0 " << endl;
-            }
-        }
+    for (uint64_t k = 0; k < product; k++) {
+      if ((*cur)[k].is_zero()) {
+        cout << k + 1 << "/ " << product << "-th ptxt = 0 " << endl;
+      }
+    }
 
-        product /= n_i;
+    product /= n_i;
 
-        vector<Ciphertext> intermediateCtxts(product);
-        Ciphertext temp;
+    vector<Ciphertext> intermediateCtxts(product);
+    Ciphertext temp;
 
-        for (uint64_t k = 0; k < product; k++) {
+    for (uint64_t k = 0; k < product; k++) {
 
-            evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediateCtxts[k]);
+      evaluator_->multiply_plain(expanded_query[0], (*cur)[k],
+                                 intermediateCtxts[k]);
 
-            for (uint64_t j = 1; j < n_i; j++) {
-                evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product], temp);
-                evaluator_->add_inplace(intermediateCtxts[k], temp); // Adds to first component.
-            }
-        }
+      for (uint64_t j = 1; j < n_i; j++) {
+        evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product],
+                                   temp);
+        evaluator_->add_inplace(intermediateCtxts[k],
+                                temp); // Adds to first component.
+      }
+    }
 
-        for (uint32_t jj = 0; jj < intermediateCtxts.size(); jj++) {
-            evaluator_->transform_from_ntt_inplace(intermediateCtxts[jj]);
-            // print intermediate ctxts? 
-            //cout << "const term of ctxt " << jj << " = " << intermediateCtxts[jj][0] << endl; 
-        }
+    for (uint32_t jj = 0; jj < intermediateCtxts.size(); jj++) {
+      evaluator_->transform_from_ntt_inplace(intermediateCtxts[jj]);
+      // print intermediate ctxts?
+      // cout << "const term of ctxt " << jj << " = " <<
+      // intermediateCtxts[jj][0] << endl;
+    }
 
-        if (i == nvec.size() - 1) {
-            return intermediateCtxts;
+    if (i == nvec.size() - 1) {
+      return intermediateCtxts;
+    } else {
+      intermediate_plain.clear();
+      intermediate_plain.reserve(pir_params_.expansion_ratio * product);
+      cur = &intermediate_plain;
+
+      for (uint64_t rr = 0; rr < product; rr++) {
+        EncryptionParameters parms;
+        if (pir_params_.enable_mswitching) {
+          evaluator_->mod_switch_to_inplace(intermediateCtxts[rr],
+                                            context_->last_parms_id());
+          parms = context_->last_context_data()->parms();
         } else {
-            intermediate_plain.clear();
-            intermediate_plain.reserve(pir_params_.expansion_ratio * product);
-            cur = &intermediate_plain;
-
-            for (uint64_t rr = 0; rr < product; rr++) {
-                EncryptionParameters parms;
-                if(pir_params_.enable_mswitching){
-                    evaluator_->mod_switch_to_inplace(intermediateCtxts[rr], context_->last_parms_id());
-                    parms = context_->last_context_data()->parms();
-                }
-                else{
-                    parms = context_->first_context_data()->parms();
-                }
-
-                vector<Plaintext> plains = decompose_to_plaintexts(parms,
-                    intermediateCtxts[rr]);
-
-                for (uint32_t jj = 0; jj < plains.size(); jj++) {
-                    intermediate_plain.emplace_back(plains[jj]);
-                }
-            }
-            product = intermediate_plain.size(); // multiply by expansion rate.
+          parms = context_->first_context_data()->parms();
         }
-        cout << "Server: " << i + 1 << "-th recursion level finished " << endl; 
-        cout << endl;
-    }
-    cout << "reply generated!  " << endl;
-    // This should never get here
-    assert(0);
-    vector<Ciphertext> fail(1);
-    return fail;
-}
 
-inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, uint32_t m,
-                                           uint32_t client_id) {
+        vector<Plaintext> plains =
+            decompose_to_plaintexts(parms, intermediateCtxts[rr]);
 
-    GaloisKeys &galkey = galoisKeys_[client_id];
-
-    // Assume that m is a power of 2. If not, round it to the next power of 2.
-    uint32_t logm = ceil(log2(m));
-    Plaintext two("2");
-
-    vector<int> galois_elts;
-    auto n = enc_params_.poly_modulus_degree();
-    if (logm > ceil(log2(n))){
-        throw logic_error("m > n is not allowed."); 
-    }
-    for (int i = 0; i < ceil(log2(n)); i++) {
-        galois_elts.push_back((n + exponentiate_uint(2, i)) / exponentiate_uint(2, i));
+        for (uint32_t jj = 0; jj < plains.size(); jj++) {
+          intermediate_plain.emplace_back(plains[jj]);
+        }
+      }
+      product = intermediate_plain.size(); // multiply by expansion rate.
     }
+    cout << "Server: " << i + 1 << "-th recursion level finished " << endl;
+    cout << endl;
+  }
+  cout << "reply generated!  " << endl;
+  // This should never get here
+  assert(0);
+  vector<Ciphertext> fail(1);
+  return fail;
+}
 
-    vector<Ciphertext> temp;
-    temp.push_back(encrypted);
-    Ciphertext tempctxt;
-    Ciphertext tempctxt_rotated;
-    Ciphertext tempctxt_shifted;
-    Ciphertext tempctxt_rotatedshifted;
-
-
-    for (uint32_t i = 0; i < logm - 1; i++) {
-        vector<Ciphertext> newtemp(temp.size() << 1);
-        // temp[a] = (j0 = a (mod 2**i) ? ) : Enc(x^{j0 - a}) else Enc(0).  With
-        // some scaling....
-        int index_raw = (n << 1) - (1 << i);
-        int index = (index_raw * galois_elts[i]) % (n << 1);
-
-        for (uint32_t a = 0; a < temp.size(); a++) {
-
-            evaluator_->apply_galois(temp[a], galois_elts[i], galkey, tempctxt_rotated);
+inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted,
+                                                  uint32_t m,
+                                                  uint32_t client_id) {
+
+  GaloisKeys &galkey = galoisKeys_[client_id];
+
+  // Assume that m is a power of 2. If not, round it to the next power of 2.
+  uint32_t logm = ceil(log2(m));
+  Plaintext two("2");
+
+  vector<int> galois_elts;
+  auto n = enc_params_.poly_modulus_degree();
+  if (logm > ceil(log2(n))) {
+    throw logic_error("m > n is not allowed.");
+  }
+  for (int i = 0; i < ceil(log2(n)); i++) {
+    galois_elts.push_back((n + exponentiate_uint(2, i)) /
+                          exponentiate_uint(2, i));
+  }
+
+  vector<Ciphertext> temp;
+  temp.push_back(encrypted);
+  Ciphertext tempctxt;
+  Ciphertext tempctxt_rotated;
+  Ciphertext tempctxt_shifted;
+  Ciphertext tempctxt_rotatedshifted;
+
+  for (uint32_t i = 0; i < logm - 1; i++) {
+    vector<Ciphertext> newtemp(temp.size() << 1);
+    // temp[a] = (j0 = a (mod 2**i) ? ) : Enc(x^{j0 - a}) else Enc(0).  With
+    // some scaling....
+    int index_raw = (n << 1) - (1 << i);
+    int index = (index_raw * galois_elts[i]) % (n << 1);
 
-            //cout << "rotate " << client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", "; 
+    for (uint32_t a = 0; a < temp.size(); a++) {
 
-            evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
-            multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
+      evaluator_->apply_galois(temp[a], galois_elts[i], galkey,
+                               tempctxt_rotated);
 
-            //cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", "; 
+      // cout << "rotate " <<
+      // client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", ";
 
+      evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
+      multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
 
-            multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
+      // cout << "mul by x^pow: " <<
+      // client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", ";
 
-            // cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) << ", "; 
+      multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
 
+      // cout << "mul by x^pow: " <<
+      // client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) <<
+      // ", ";
 
-            // Enc(2^i x^j) if j = 0 (mod 2**i).
-            evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
-        }
-        temp = newtemp;
-        /*
-        cout << "end: "; 
-        for (int h = 0; h < temp.size();h++){
-            cout << client.decryptor_->invariant_noise_budget(temp[h]) << ", "; 
-        }
-        cout << endl; 
-        */
+      // Enc(2^i x^j) if j = 0 (mod 2**i).
+      evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted,
+                      newtemp[a + temp.size()]);
     }
-    // Last step of the loop
-    vector<Ciphertext> newtemp(temp.size() << 1);
-    int index_raw = (n << 1) - (1 << (logm - 1));
-    int index = (index_raw * galois_elts[logm - 1]) % (n << 1);
-    for (uint32_t a = 0; a < temp.size(); a++) {
-        if (a >= (m - (1 << (logm - 1)))) {                       // corner case.
-            evaluator_->multiply_plain(temp[a], two, newtemp[a]); // plain multiplication by 2.
-            // cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", "; 
-        } else {
-            evaluator_->apply_galois(temp[a], galois_elts[logm - 1], galkey, tempctxt_rotated);
-            evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
-            multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
-            multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
-            evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
-        }
+    temp = newtemp;
+    /*
+    cout << "end: ";
+    for (int h = 0; h < temp.size();h++){
+        cout << client.decryptor_->invariant_noise_budget(temp[h]) << ", ";
     }
+    cout << endl;
+    */
+  }
+  // Last step of the loop
+  vector<Ciphertext> newtemp(temp.size() << 1);
+  int index_raw = (n << 1) - (1 << (logm - 1));
+  int index = (index_raw * galois_elts[logm - 1]) % (n << 1);
+  for (uint32_t a = 0; a < temp.size(); a++) {
+    if (a >= (m - (1 << (logm - 1)))) { // corner case.
+      evaluator_->multiply_plain(temp[a], two,
+                                 newtemp[a]); // plain multiplication by 2.
+      // cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", ";
+    } else {
+      evaluator_->apply_galois(temp[a], galois_elts[logm - 1], galkey,
+                               tempctxt_rotated);
+      evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
+      multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
+      multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
+      evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted,
+                      newtemp[a + temp.size()]);
+    }
+  }
 
-    vector<Ciphertext>::const_iterator first = newtemp.begin();
-    vector<Ciphertext>::const_iterator last = newtemp.begin() + m;
-    vector<Ciphertext> newVec(first, last);
+  vector<Ciphertext>::const_iterator first = newtemp.begin();
+  vector<Ciphertext>::const_iterator last = newtemp.begin() + m;
+  vector<Ciphertext> newVec(first, last);
 
-    return newVec;
+  return newVec;
 }
 
-inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext &destination,
-                                    uint32_t index) {
+inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted,
+                                           Ciphertext &destination,
+                                           uint32_t index) {
 
-    auto coeff_mod_count = enc_params_.coeff_modulus().size() - 1;
-    auto coeff_count = enc_params_.poly_modulus_degree();
-    auto encrypted_count = encrypted.size();
+  auto coeff_mod_count = enc_params_.coeff_modulus().size() - 1;
+  auto coeff_count = enc_params_.poly_modulus_degree();
+  auto encrypted_count = encrypted.size();
 
-    //cout << "coeff mod count for power of X = " << coeff_mod_count << endl; 
-    //cout << "coeff count for power of X = " << coeff_count << endl; 
+  // cout << "coeff mod count for power of X = " << coeff_mod_count << endl;
+  // cout << "coeff count for power of X = " << coeff_count << endl;
 
-    // First copy over.
-    destination = encrypted;
+  // First copy over.
+  destination = encrypted;
 
-    // Prepare for destination
-    // Multiply X^index for each ciphertext polynomial
-    for (int i = 0; i < encrypted_count; i++) {
-        for (int j = 0; j < coeff_mod_count; j++) {
-            negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
-                                           coeff_count, index,
-                                           enc_params_.coeff_modulus()[j],
-                                           destination.data(i) + (j * coeff_count));
-        }
+  // Prepare for destination
+  // Multiply X^index for each ciphertext polynomial
+  for (int i = 0; i < encrypted_count; i++) {
+    for (int j = 0; j < coeff_mod_count; j++) {
+      negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
+                                     coeff_count, index,
+                                     enc_params_.coeff_modulus()[j],
+                                     destination.data(i) + (j * coeff_count));
     }
+  }
 }
 
-void PIRServer::simple_set(uint64_t index, Plaintext pt){
-    if(is_db_preprocessed_){
-        evaluator_->transform_to_ntt_inplace(
-                pt, context_->first_parms_id());
-    }
-    db_->operator[](index) = pt;
+void PIRServer::simple_set(uint64_t index, Plaintext pt) {
+  if (is_db_preprocessed_) {
+    evaluator_->transform_to_ntt_inplace(pt, context_->first_parms_id());
+  }
+  db_->operator[](index) = pt;
 }
 
-Ciphertext PIRServer::simple_query(uint64_t index){
-    //There is no transform_from_ntt that takes a plaintext
-    Ciphertext ct;
-    Plaintext pt = db_->operator[](index);
-    evaluator_->multiply_plain(one_, pt, ct);
-    evaluator_->transform_from_ntt_inplace(ct);
-    return ct;
+Ciphertext PIRServer::simple_query(uint64_t index) {
+  // There is no transform_from_ntt that takes a plaintext
+  Ciphertext ct;
+  Plaintext pt = db_->operator[](index);
+  evaluator_->multiply_plain(one_, pt, ct);
+  evaluator_->transform_from_ntt_inplace(ct);
+  return ct;
 }
 
-void PIRServer::set_one_ct(Ciphertext one){
-    one_ = one;
-    evaluator_->transform_to_ntt_inplace(one_);
+void PIRServer::set_one_ct(Ciphertext one) {
+  one_ = one;
+  evaluator_->transform_to_ntt_inplace(one_);
 }

+ 46 - 43
src/pir_server.hpp

@@ -1,52 +1,55 @@
 #pragma once
 
 #include "pir.hpp"
+#include "pir_client.hpp"
 #include <map>
 #include <memory>
 #include <vector>
-#include "pir_client.hpp"
 
 class PIRServer {
-  public:
-    PIRServer(const seal::EncryptionParameters &enc_params, const PirParams &pir_params);
-
-    // NOTE: server takes over ownership of db and frees it when it exits.
-    // Caller cannot free db
-    void set_database(std::unique_ptr<std::vector<seal::Plaintext>> &&db);
-    void set_database(const std::unique_ptr<const std::uint8_t[]> &bytes, std::uint64_t ele_num, std::uint64_t ele_size);
-    void preprocess_database();
-
-    std::vector<seal::Ciphertext> expand_query(
-            const seal::Ciphertext &encrypted, std::uint32_t m, std::uint32_t client_id);
-
-    PirQuery deserialize_query(std::stringstream &stream);
-    PirReply generate_reply(PirQuery &query, std::uint32_t client_id);
-    // Serializes the reply into the provided stream and returns the number of bytes written
-    int serialize_reply(PirReply &reply, std::stringstream &stream);
-
-    void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);
-
-    void simple_set(std::uint64_t index, seal::Plaintext pt);
-    // This is used for querying an element of the database WITHOUT PIR.
-    seal::Ciphertext simple_query(std::uint64_t index);
-    //This is only used for simple_query
-    void set_one_ct(seal::Ciphertext one);
-   
-
-
-  private:
-    seal::EncryptionParameters enc_params_; // SEAL parameters
-    PirParams pir_params_;              // PIR parameters
-    std::unique_ptr<Database> db_;
-    bool is_db_preprocessed_;
-    std::map<int, seal::GaloisKeys> galoisKeys_;
-    std::unique_ptr<seal::Evaluator> evaluator_;
-    std::unique_ptr<seal::BatchEncoder> encoder_;
-    std::shared_ptr<seal::SEALContext> context_;
-
-    //This is only uesd for simple_query
-    seal::Ciphertext one_;
-
-    void multiply_power_of_X(const seal::Ciphertext &encrypted, seal::Ciphertext &destination,
-                             std::uint32_t index);
+public:
+  PIRServer(const seal::EncryptionParameters &enc_params,
+            const PirParams &pir_params);
+
+  // NOTE: server takes over ownership of db and frees it when it exits.
+  // Caller cannot free db
+  void set_database(std::unique_ptr<std::vector<seal::Plaintext>> &&db);
+  void set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
+                    std::uint64_t ele_num, std::uint64_t ele_size);
+  void preprocess_database();
+
+  std::vector<seal::Ciphertext> expand_query(const seal::Ciphertext &encrypted,
+                                             std::uint32_t m,
+                                             std::uint32_t client_id);
+
+  PirQuery deserialize_query(std::stringstream &stream);
+  PirReply generate_reply(PirQuery &query, std::uint32_t client_id);
+  // Serializes the reply into the provided stream and returns the number of
+  // bytes written
+  int serialize_reply(PirReply &reply, std::stringstream &stream);
+
+  void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);
+
+  // Below simple operations are for interacting with the database WITHOUT PIR.
+  // So they can be used to modify a particular element in the database or
+  // to query a particular element (without privacy guarantees).
+  void simple_set(std::uint64_t index, seal::Plaintext pt);
+  seal::Ciphertext simple_query(std::uint64_t index);
+  void set_one_ct(seal::Ciphertext one);
+
+private:
+  seal::EncryptionParameters enc_params_; // SEAL parameters
+  PirParams pir_params_;                  // PIR parameters
+  std::unique_ptr<Database> db_;
+  bool is_db_preprocessed_;
+  std::map<int, seal::GaloisKeys> galoisKeys_;
+  std::unique_ptr<seal::Evaluator> evaluator_;
+  std::unique_ptr<seal::BatchEncoder> encoder_;
+  std::shared_ptr<seal::SEALContext> context_;
+
+  // This is only used for simple_query
+  seal::Ciphertext one_;
+
+  void multiply_power_of_X(const seal::Ciphertext &encrypted,
+                           seal::Ciphertext &destination, std::uint32_t index);
 };