Browse Source

SealPIR v0.1

Sebastian Angel 6 years ago
parent
commit
dd7ffc3e4a
8 changed files with 1067 additions and 944 deletions
  1. 11 12
      Makefile
  2. 102 166
      main.cpp
  3. 213 637
      pir.cpp
  4. 74 129
      pir.hpp
  5. 193 0
      pir_client.cpp
  6. 35 0
      pir_client.hpp
  7. 398 0
      pir_server.cpp
  8. 41 0
      pir_server.hpp

+ 11 - 12
Makefile

@@ -1,27 +1,26 @@
-CPP=g++
+CXX=g++
 
-IDIR = ../SEAL/SEAL/
-LDIR = ../SEAL/bin/
-ODIR=obj
-BDIR=bin
+IDIR =../SEAL/SEAL/
+LDIR =../SEAL/bin/
 
 CFLAGS=-std=c++11 -I. -I$(IDIR) -O3
+ODIR=obj
+BDIR=bin
 LIBS=-L$(LDIR) -lseal
 
-_DEPS = pir.hpp
-DEPS = $(patsubst %,$(IDIR)/%,$(_DEPS))
+DEPS = pir.hpp pir_server.hpp pir_client.hpp
 
-_OBJ = pir.o main.o 
+_OBJ = pir.o main.o pir_server.o pir_client.o 
 OBJ = $(patsubst %,$(ODIR)/%,$(_OBJ))
 
 
-$(ODIR)/%.o: %.cpp
+$(ODIR)/%.o: %.cpp $(DEPS)
 	@mkdir -p $(@D)
-	$(CPP) -c -o $@ $< $(CFLAGS)
+	$(CXX) -c -o $@ $< $(CFLAGS)
 
-$(BDIR)/main: $(OBJ)
+$(BDIR)/main: $(OBJ) $(DEPS) 
 	@mkdir -p $(@D)
-	$(CPP) -o $@ $^ $(CFLAGS) $(LIBS)
+	$(CXX) -o $@ $(OBJ) $(CFLAGS) $(LIBS)
 
 all: main
 

+ 102 - 166
main.cpp

@@ -1,185 +1,121 @@
 #include "pir.hpp"
-#include <time.h>
-#define BILLION 1000000000L
-#define MILLION (1.0*1000000L)
-#define KILO (1.0*1024L)
-#include <fstream>
-#include <vector>
-#include <sstream>
-#include <algorithm>
+#include "pir_client.hpp"
+#include "pir_server.hpp"
 #include <chrono>
 #include <random>
 
-#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"
-#define PBWIDTH 60
-#define NUM_SLOT 64
-#define NUM_THREAD 2
-
+using namespace chrono;
+using namespace seal;
 
 int main(int argc, char *argv[]) {
 
-  uint64_t number_of_items = 1 << 22;
-  uint64_t size_per_item = 288 << 3; // 288 B. 
-
-
-  int n = 2048;
-  int logt = 21;
-  uint64_t plainMod = static_cast<uint64_t> (1) << logt;
-  double hao_const =  0.5 * log2(number_of_items *size_per_item) - 0.5 * log2(n);
-
-  int logtprime = logt; 
-  while(true){
-    if (logtprime + ceil(hao_const - 0.5*log2(logtprime)) == logt) break;
-    logtprime--; 
-  }
-
-  int number_of_plaintexts = ceil (((double)(number_of_items)* size_per_item / n) / logtprime );
-
-  EncryptionParameters parms;
-  parms.set_poly_modulus("1x^" + std::to_string(n) + " + 1");
-  vector<SmallModulus> coeff_mod_array;
-  int logq = 0;
-
-  for (int i = 0; i < 1; ++i)
-  {
-    coeff_mod_array.emplace_back(SmallModulus());
-    coeff_mod_array[i] = small_mods_60bit(i);
-    logq += coeff_mod_array[i].bit_count();
-  }
-
-  parms.set_coeff_modulus(coeff_mod_array);
-  parms.set_plain_modulus(plainMod);
-
-  pirParams pirparms;
-
-  uint64_t newplainMod = 1 << logtprime;
-
-
-  int item_per_plaintext = floor((double)get_power_of_two(newplainMod) *n / size_per_item);
+    // uint64_t number_of_items = 1 << 13;
+    // uint64_t number_of_items = 4096;
+     uint64_t number_of_items = 1 << 16;
 
+    uint64_t size_per_item = 288; // in bytes
+    // uint64_t size_per_item = 1 << 10; // 1 KB.
+    // uint64_t size_per_item = 10 << 10; // 10 KB.
 
-  pirparms.d = 2;
-  pirparms.alpha = 1;
-  pirparms.dbc = 8;
-  pirparms.N = number_of_plaintexts;
+    uint32_t N = 2048;
+    uint32_t logt = 20;
+    uint32_t d = 2;
 
-  int sqrt_items = ceil(sqrt(number_of_plaintexts));
-  int bound1 = ceil((double) number_of_plaintexts / sqrt_items);
-  int bound2 = sqrt_items;
+    EncryptionParameters params;
+    EncryptionParameters expanded_params;
+    PirParams pir_params;
 
-  vector<int> Nvec = { bound1, bound2 };
-  pirparms.Nvec = Nvec;
+    // Generates all parameters
+    cout << "Generating all parameters" << endl;
+    gen_params(number_of_items, size_per_item, N, logt, d, params, expanded_params, pir_params);
 
+    // Create test database
+    uint8_t *db = (uint8_t *)malloc(number_of_items * size_per_item);
 
-  // Initialize PIR client....
-  PIRClient client(parms, pirparms);
-
-  GaloisKeys galois_keys = client.generate_galois_keys();
-
-
-  EncryptionParameters newparms = client.get_new_parms();
-  galois_keys.mutable_hash_block() = newparms.hash_block();
-  PIRServer server(client.get_new_parms(), client.get_pir_parms());
-
-  server.set_galois_key(0, galois_keys);
-
-  int index = 3; // we want to obtain the 3rd item. 
-
-
-  random_device rd;
-
-  vector<uint64_t> no_choose(n+1);
-  vector<uint64_t> choose(n+1);
-
-
-  for (int i = 0; i < n+1; i++) {
-    no_choose[i] = rd() % newplainMod;
-    choose[i] = rd() % newplainMod;
-    if (i == n) {
-      choose[i] = 0; 
-      no_choose[i] = 0; 
+    random_device rd;
+    for (uint64_t i = 0; i < number_of_items; i++) {
+        for (uint64_t j = 0; j < size_per_item; j++) {
+            *(db + (i * size_per_item) + j) = rd() % 256;
+        }
     }
-  }
-
-  unique_ptr<uint64_t> items_anchor(new uint64_t[bound1*bound2*(n + 1)]); 
-  vector<Plaintext> items;
 
-  uint64_t * items_ptr = items_anchor.get();
-
-  for (int i = 0; i < bound1*bound2; i++) {
-    items.emplace_back(n + 1, items_ptr); 
-    if (i != index) {
-      util::set_uint_uint(no_choose.data(), n+1, items_ptr);
-    } else {
-      util::set_uint_uint(choose.data(), n+1, items_ptr);
+    // Initialize PIR Server
+    cout << "Initializing server and client" << endl;
+    PIRServer server(expanded_params, pir_params);
+
+    // Initialize PIR client....
+    PIRClient client(params, expanded_params, pir_params);
+    GaloisKeys galois_keys = client.generate_galois_keys();
+
+    // Set galois key
+    cout << "Setting Galois keys" << endl;
+    server.set_galois_key(0, galois_keys);
+
+
+    // The following can be used to update parameters rather than creating new instances
+    // (here it doesn't do anything).
+    cout << "Updating database size to: " << number_of_items << " elements" << endl;
+    update_params(number_of_items, size_per_item, d, params, expanded_params, pir_params);
+
+    uint32_t logtp = ceil(log2(expanded_params.plain_modulus().value()));
+
+    cout << "logtp: " << logtp << endl;
+
+    client.update_parameters(expanded_params, pir_params);
+    server.update_parameters(expanded_params, pir_params);
+
+    // Measure database setup
+    auto time_pre_s = high_resolution_clock::now();
+    server.set_database(db, number_of_items, size_per_item);
+    server.preprocess_database();
+    auto time_pre_e = high_resolution_clock::now();
+    auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
+
+    // Choose an index of an element in the DB
+    uint64_t ele_index = rd() % number_of_items; // element in DB at random position
+    uint64_t index = client.get_fv_index(ele_index, size_per_item);   // index of FV plaintext
+    uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
+
+    // Measure query generation
+    auto time_query_s = high_resolution_clock::now();
+    PirQuery query = client.generate_query(index);
+    auto time_query_e = high_resolution_clock::now();
+    auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
+
+    // Measure query processing (including expansion)
+    auto time_server_s = high_resolution_clock::now();
+    PirQuery query_ser = deserialize_ciphertexts(d, serialize_ciphertexts(query), CIPHER_SIZE);
+    PirReply reply = server.generate_reply(query_ser, 0);
+    auto time_server_e = high_resolution_clock::now();
+    auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
+
+    // Measure response extraction
+    auto time_decode_s = chrono::high_resolution_clock::now();
+    Plaintext result = client.decode_reply(reply);
+    auto time_decode_e = chrono::high_resolution_clock::now();
+    auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+
+    // Convert to elements
+    vector<uint8_t> elems(N * logtp / 8);
+    coeffs_to_bytes(logtp, result, elems.data(), (N * logtp) / 8);
+
+    // Check that we retrieved the correct element
+    for (uint32_t i = 0; i < size_per_item; i++) {
+        if (elems[(offset * size_per_item) + i] != db[(ele_index * size_per_item) + i]) {
+            cout << "elems " << (int)elems[(offset * size_per_item) + i] << ", db "
+                 << (int)db[(ele_index * size_per_item) + i] << endl;
+            cout << "PIR result wrong!" << endl;
+            return -1;
+        }
     }
-    items_ptr += n + 1; 
-  }
-  server.set_database(&items);
-
-  auto time_querygen_start = chrono::high_resolution_clock::now();
-
-  pirQuery query = client.generate_query(index);
-
-  for (int i = 0; i < query.size(); i++) {
-    query[i].mutable_hash_block() = newparms.hash_block();
-  }
-
-  auto time_querygen_end = chrono::high_resolution_clock::now();
-
-  cout << "PIRClient query generation time : " << chrono::duration_cast<chrono::microseconds>(time_querygen_end - time_querygen_start).count() / 1000
-    << " ms" << endl;
-  cout << "Query size = " << (double) n * 2 * logq * pirparms.d / (1024 * 8) << "KB" << endl;
-
-  auto time_pre_start = chrono::high_resolution_clock::now();
-
-  server.preprocess_database();
-
-  auto time_pre_end = chrono::high_resolution_clock::now();
-  cout << "pre-processing time = " << chrono::duration_cast<chrono::microseconds>(time_pre_end - time_pre_start).count() / 1000
-    << " ms" << endl;
-
-  pirQuery query_ser = deserialize_ciphertexts(2, serialize_ciphertexts(query), 32828);
-
-  auto time_server_start = chrono::high_resolution_clock::now();
-
-  pirReply reply = server.generate_reply(query_ser, 0);
 
+    // Output results
+    cout << "PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
+    cout << "PIRServer query processing generation time: " << time_server_us / 1000 << " ms"
+         << endl;
+    cout << "PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
+    cout << "PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
+    cout << "Reply num ciphertexts: " << reply.size() << endl;
 
-  auto time_server_end = chrono::high_resolution_clock::now();
-
-
-  cout << "Server reply generation time : " << chrono::duration_cast<chrono::microseconds>(time_server_end - time_server_start).count() / 1000
-    << " ms" << endl;
-
-  cout<<"Reply ciphertexts"<<reply.size()<<endl;
-
-
-  cout << "Reply size = " << (double) reply.size() * n * 2 * logq  / (1024 * 8) << "KB" << endl;
-
-  auto time_decode_start = chrono::high_resolution_clock::now();
-
-  Plaintext result = client.decode_reply(reply);
-
-  auto time_decode_end = chrono::high_resolution_clock::now();
-
-  cout << "PIRClient decoding time : " << chrono::duration_cast<chrono::microseconds>(time_decode_end - time_decode_start).count() / 1000
-    << " ms" << endl;
-
-  cout << "Result = ";
-  bool pircorrect = true;
-  for (int i = 0; i < n; i++) {
-    if (result[i] != choose[i]) {
-      pircorrect = false;
-      break;
-    }
-  }
-  if (pircorrect) {
-    cout << "PIR result correct!!" << endl;
-  }
-  else {
-    cout << "PIR result wrong!" << endl;
-  }
-
-  return 0;
+    return 0;
 }

+ 213 - 637
pir.cpp

@@ -1,717 +1,293 @@
 #include "pir.hpp"
+
 using namespace std;
-#include <vector>
 using namespace seal;
 using namespace seal::util;
 
-PIRClient::PIRClient(const seal::EncryptionParameters &parms, pirParams & pirparms) {
-  parms_ = parms; 
-  SEALContext context(parms);
-  keygen_.reset(new KeyGenerator(context));
-
-  encryptor_.reset(new Encryptor(context, keygen_->public_key()));
-
-  uint64_t plainMod = parms.plain_modulus().value();
-
-  int N = pirparms.Nvec[0];
-  int logN = ceil(log(N) / log(2)); 
-
-  EncryptionParameters newparms = parms;
-  newparms.set_plain_modulus(plainMod >> logN); 
-  newparms_ = newparms;
-  SEALContext newcontext(newparms);
-  SecretKey secret_key = keygen_->secret_key();
-  secret_key.mutable_hash_block() = newparms.hash_block();
-  decryptor_.reset(new Decryptor(newcontext, secret_key));
-  evaluator_.reset(new Evaluator(newcontext));
-
-  int expansion_ratio = 0;
-  for (int i = 0; i < parms.coeff_modulus().size(); ++i)
-  {
-    double logqi = log(parms.coeff_modulus()[i].value());
-    expansion_ratio += ceil(logqi / log(newparms.plain_modulus().value()));
-  }
-  pirparms.expansion_ratio_ = expansion_ratio << 1;
-  pirparms_ = pirparms;
-}
-
-pirQuery PIRClient::generate_query(int desiredIndex) { 
-  vector<int> indices = compute_indices(desiredIndex, pirparms_.Nvec); 
-  vector<Ciphertext> result;
-  for (int i = 0; i < indices.size(); i++) {
-    Ciphertext dest;
-    encryptor_->encrypt(Plaintext("1x^" + std::to_string(indices[i])), dest);
-    result.push_back(dest); 
-  }
-  return result;
-}
-
-Plaintext PIRClient::decode_reply(pirReply reply) {
-  int exp_ratio = pirparms_.expansion_ratio_;
-  vector<Ciphertext> temp = reply;
-  int recursion_level = pirparms_.d;
-  for (int i = 0; i < recursion_level; i++) {
-    vector<Ciphertext> newtemp;
-    vector<Plaintext> tempplain;
-    for (int j = 0; j < temp.size(); j++) {
-      Plaintext ptxt;
-      decryptor_->decrypt(temp[j], ptxt);
-      tempplain.push_back(ptxt);  
-      if ( (j + 1) % exp_ratio == 0 && j > 0) {
-        // Combine into one ciphertext. 
-        Ciphertext combined = compose_to_ciphertext(tempplain); 
-        newtemp.push_back(combined);
-      }
-    }	
-    if (i == recursion_level - 1) {
-      if (temp.size() != 1) throw;
-      return tempplain[0];
-    }
-    else {
-      tempplain.clear();
-      temp = newtemp;
-    }
-  }
-
-}
-
-GaloisKeys PIRClient::generate_galois_keys() {
-  vector<uint64_t> galois_elts;
-  int n = parms_.poly_modulus().coeff_count() - 1;
-  int logn = get_power_of_two(n);
-
-  for (int i = 0; i < logn; i++)
-  {
-    galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
-  }
-
-  GaloisKeys galois_keys;
-  keygen_->generate_galois_keys(pirparms_.dbc, galois_elts, galois_keys);
-  return galois_keys;
-}
+vector<uint64_t> get_dimensions(uint64_t plaintext_num, uint32_t d) {
 
-void PIRClient::print_info(Ciphertext & encrypted)
-{
-  Plaintext ptxt;
-  decryptor_->decrypt(encrypted, ptxt);
-}
+    assert(d > 0);
+    assert(plaintext_num > 0);
 
-// Given a vector N1, ..., Nd and a number desired index j between 0 and prod(N_i). 
-// Return j indices j1, ..., jd such that  j = j1 (N/N1) + j2 (N/N1N2) + ..... 
-vector<int> compute_indices(int desiredIndex, vector<int> Nvec) {
-  int d = Nvec.size();
-  int product = 1;
-  for (int i = 0; i < Nvec.size(); i++) {
-    product *= Nvec[i];
-  }
-
-  int j = desiredIndex;
-  vector<int> result;
-  for (int i = 0; i < d; i++) {
-    product /= Nvec[i]; 
-    int ji = j / product; 
-    result.push_back(ji); 
-    j -= ji*product; 
-  }
-  return result;
-}
+    vector<uint64_t> dimensions(d);
 
-PIRServer::PIRServer(const seal::EncryptionParameters & parms, const pirParams &pirparams) {
-  parms_ = parms;
-  pirparams_ = pirparams;
-  SEALContext context(parms);
-  evaluator_.reset(new Evaluator(context));
-  is_db_preprocessed_ = false;
-}
+    for (uint32_t i = 0; i < d; i++) {
+        dimensions[i] = std::max((uint32_t) 2, (uint32_t) floor(pow(plaintext_num, 1.0/d)));
+    }
 
-void PIRServer::preprocess_database() {
-  if (!is_db_preprocessed_) {
-    for (int i = 0; i < dataBase_->size(); i++) {
-      evaluator_->transform_to_ntt(dataBase_->operator[](i));
+    uint32_t product = 1;
+    uint32_t j = 0;
+
+    // if plaintext_num is not a d-power
+    if ((double) dimensions[0] != pow(plaintext_num, 1.0 / d)) {
+        while  (product < plaintext_num && j < d) {
+            product = 1;
+            dimensions[j++]++;
+            for (uint32_t i = 0; i < d; i++) {
+                product *= dimensions[i];
+            }
+        }
     }
-    is_db_preprocessed_ = true;
-  }
-}
 
-void PIRServer::set_database(vector<Plaintext> *db) {
-  if (db == nullptr) {
-    throw invalid_argument("db cannot be null");
-  }
-  dataBase_ = db;
+    return dimensions;
 }
 
-pirReply PIRServer::generate_reply(pirQuery query, int client_id) {
-  vector<int> Nvec = pirparams_.Nvec;
-  uint64_t product = 1; 
-  for (int i = 0; i < Nvec.size(); i++) {
-    product *= Nvec[i]; 
-  }
-  int coeff_count = parms_.poly_modulus().coeff_count(); 
-
-  vector<Plaintext> *cur = dataBase_;
-  vector<Plaintext> intermediate_plain; // decompose.... 
+void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
+                uint32_t d, EncryptionParameters &params, EncryptionParameters &expanded_params,
+                PirParams &pir_params) {
+    
+    // Determine the maximum size of each dimension
+    uint32_t logtp = plainmod_after_expansion(logt, N, d, ele_num, ele_size);
 
-  auto my_pool = MemoryPoolHandle::New();
+    uint64_t plain_mod = static_cast<uint64_t>(1) << logt;
+    uint64_t expanded_plain_mod = static_cast<uint64_t>(1) << logtp;
+    uint64_t plaintext_num = plaintexts_per_db(logtp, N, ele_num, ele_size);
 
-
-  for (int i = 0; i < Nvec.size(); i++) {
-    int Ni = Nvec[i];
-    vector<Ciphertext> expanded_query = expand_query(query[i], Ni, galoisKeys_[client_id]);
 #ifdef DEBUG
-    cout << "query ciphertext check: " << endl;
-    for (int tt = 0; tt < expanded_query.size(); tt++) {
-      client.print_info(expanded_query[tt]);
-    }
+    cout << "log(plain mod) before expand = " << logt << endl;
+    cout << "log(plain mod) after expand = " << logtp << endl;
+    cout << "number of FV plaintexts = " << plaintext_num << endl;
 #endif
 
-    // Transform expanded query to NTT, and ... 
-    for (int jj = 0; jj < expanded_query.size(); jj++) {
-      evaluator_->transform_to_ntt(expanded_query[jj]);
-    }
+    vector<SmallModulus> coeff_mod_array;
+    uint32_t logq = 0;
 
-    // Transform plaintext to NTT. If database is pre-processed, can skip
-    if ((!is_db_preprocessed_) || i > 0) {
-      for (int jj = 0; jj < cur->size(); jj++) {
-        evaluator_->transform_to_ntt((*cur)[jj]);
-      }
+    for (uint32_t i = 0; i < 1; i++) {
+        coeff_mod_array.emplace_back(SmallModulus());
+        coeff_mod_array[i] = small_mods_60bit(i);
+        logq += coeff_mod_array[i].bit_count();
     }
 
-    product /= Ni;
-    vector<Ciphertext> intermediate(product);
-    Ciphertext temp1;
+    params.set_poly_modulus("1x^" + to_string(N) + " + 1");
+    params.set_coeff_modulus(coeff_mod_array);
+    params.set_plain_modulus(plain_mod);
 
-    for (int k = 0; k < product; k++) {
-      evaluator_->multiply_plain_ntt(expanded_query[0], (*cur)[k], intermediate[k]);
-      for (int j = 1; j < Ni; j++) {
-        evaluator_->multiply_plain_ntt(expanded_query[j], (*cur)[k + j*product], temp1);
-        evaluator_->add(intermediate[k], temp1); // Adds to the first component.
-      }
-    }
-    for (int jj = 0; jj < intermediate.size(); jj++) {
-      evaluator_->transform_from_ntt(intermediate[jj]);
-    }
+    expanded_params.set_poly_modulus("1x^" + to_string(N) + " + 1");
+    expanded_params.set_coeff_modulus(coeff_mod_array);
+    expanded_params.set_plain_modulus(expanded_plain_mod);
 
-#ifdef DEBUG
-    cout << "intermediate ciphertext check: " << endl;
-    for (int tt = 0; tt < intermediate.size(); tt++) {
-      cout << tt + 1 << " / " << intermediate.size() << " ";
-      client.print_info(intermediate[tt]);
+    vector<uint64_t> nvec = get_dimensions(plaintext_num, d);
+
+    uint32_t expansion_ratio = 0;
+    for (uint32_t i = 0; i < params.coeff_modulus().size(); ++i) {
+        double logqi = log2(params.coeff_modulus()[i].value());
+        expansion_ratio += ceil(logqi / logtp);
     }
-#endif
 
-    if (i == Nvec.size() - 1) {
-      return intermediate;
-    } else {
-      intermediate_plain.clear();
-      intermediate_plain.reserve(pirparams_.expansion_ratio_ * product);
-      cur = &intermediate_plain;
+    pir_params.d = d;
+    pir_params.dbc = 6;
+    pir_params.n = plaintext_num;
+    pir_params.nvec = nvec;
+    pir_params.expansion_ratio = expansion_ratio << 1;
+}
 
-      util::Pointer tempplain_ptr(allocate_zero_poly(pirparams_.expansion_ratio_ * product, coeff_count, my_pool));
+void update_params(uint64_t ele_num, uint64_t ele_size, uint32_t d, 
+                   const EncryptionParameters &old_params, EncryptionParameters &expanded_params, 
+                   PirParams &pir_params) {
 
-      for (int rr = 0; rr < product; rr++) {
-        decompose_to_plaintexts_ptr(intermediate[rr], tempplain_ptr.get() + rr * pirparams_.expansion_ratio_* coeff_count);
-#ifdef DEBUG
-        cout << "compose decompose check: " << endl;
-        client.print_info(evaluator_->compose_to_ciphertext(tempplain));
-#endif              
-        for (int jj = 0; jj < pirparams_.expansion_ratio_; jj++){
-          int offset = rr * pirparams_.expansion_ratio_* coeff_count + jj * coeff_count;
-          intermediate_plain.emplace_back(coeff_count, tempplain_ptr.get() + offset);
-        }
-      }
-      product *= pirparams_.expansion_ratio_; // multiply by expansion rate.
-    }
-  }
-}
+    uint32_t logt = ceil(log2(old_params.plain_modulus().value()));
+    uint32_t N = old_params.poly_modulus().coeff_count() - 1;
 
-vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, int d, const GaloisKeys &galkey) {
+    // Determine the maximum size of each dimension
+    uint32_t logtp = plainmod_after_expansion(logt, N, d, ele_num, ele_size);
+
+    uint64_t expanded_plain_mod = static_cast<uint64_t>(1) << logtp;
+    uint64_t plaintext_num = plaintexts_per_db(logtp, N, ele_num, ele_size);
 
-  uint64_t plainMod = parms_.plain_modulus().value();
 #ifdef DEBUG
-  cout << "PIRServer side plain modulus = " << plainMod << endl;
+    cout << "log(plain mod) before expand = " << logt << endl;
+    cout << "log(plain mod) after expand = " << logtp << endl;
+    cout << "number of FV plaintexts = " << plaintext_num << endl;
 #endif
-  
-  // Assume that d is a power of 2. If not, round it to the next power of 2. 
-  int logd = ceil(log(d) / log(2));
-  Plaintext two("2");
-  vector<int> galois_elts;
-  int n = parms_.poly_modulus().coeff_count() - 1;
-  for (int i = 0; i < logd; i++) {
-    galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
-  }
-  vector<Ciphertext> temp;
-  temp.push_back(encrypted);
-  Ciphertext tempctxt;
-  Ciphertext tempctxt_rotated;
-  Ciphertext tempctxt_shifted;
-  Ciphertext tempctxt_rotatedshifted;
-
-  int shift = 1;
-  for (int i = 0; i < logd -1; i++) {
-    vector<Ciphertext> newtemp(temp.size() << 1);
-    int index_raw = (n << 1) - (1 << i);
-    int index = (index_raw * galois_elts[i]) % (n << 1);
-    for (int a = 0; a < temp.size(); a++) {
-      evaluator_->apply_galois(temp[a], galois_elts[i], galkey, tempctxt_rotated); // Can be done in-place
-      evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
-      multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
-
-      multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
-      evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a+temp.size()]); // Enc(2^i x^j) if j = 0 (mod 2**i).
-    }
-    temp = newtemp;
-  }
-
-  // Last iteration of the loop 
-  vector<Ciphertext> newtemp(temp.size() << 1);
-  int index_raw = (n << 1) - (1 << (logd - 1));
-  int index = (index_raw * galois_elts[logd - 1]) % (n << 1);
-  for (int a = 0; a < temp.size(); a++) {
-    if(a >= (d - (1 << (logd - 1)))) { // corner case. 
-      evaluator_->multiply_plain(temp[a], two, newtemp[a]);// plain multiplication by 2.
-    }
-    else {
-      evaluator_->apply_galois(temp[a], galois_elts[logd-1], galkey, tempctxt_rotated); // Can be done in-place
-      evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
-      multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
-      multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
-      evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]); // Enc(2^i x^j) if j = 0 (mod 2**i).
-    }
-  }
 
-  vector<Ciphertext>::const_iterator first = newtemp.begin();
-  vector<Ciphertext>::const_iterator last = newtemp.begin() + d;
-  vector<Ciphertext> newVec(first, last);
-  return newVec;
-}
+    expanded_params.set_poly_modulus(old_params.poly_modulus());
+    expanded_params.set_coeff_modulus(old_params.coeff_modulus());
+    expanded_params.set_plain_modulus(expanded_plain_mod);
 
+    // Assumes dimension of database is 2
+    vector<uint64_t> nvec = get_dimensions(plaintext_num, d);
 
-void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext & destination, int index)
-{
-  // Extract parameter
-  int coeff_mod_count = parms_.coeff_modulus().size();
-  int coeff_count = parms_.poly_modulus().coeff_count();
-  int coeff_bit_count = coeff_mod_count * bits_per_uint64;
-  int encrypted_ptr_increment = coeff_count * coeff_mod_count;
-  int encrypted_count = encrypted.size();
-  
-  // First copy over. 
-  destination = encrypted;
-  
-  // Prepare for destination
-  // Multiply X^index for each ciphertext polynomial
-  for (int i = 0; i < encrypted_count; i++)
-  {
-    for (int j = 0; j < coeff_mod_count; j++)
-    {
-      negacyclic_shift_poly_coeffmod(encrypted.pointer(i) + (j * coeff_count), coeff_count - 1, index, parms_.coeff_modulus()[j], destination.mutable_pointer(i) + (j * coeff_count));
+    uint32_t expansion_ratio = 0;
+    for (uint32_t i = 0; i < old_params.coeff_modulus().size(); ++i) {
+        double logqi = log2(old_params.coeff_modulus()[i].value());
+        expansion_ratio += ceil(logqi / logtp);
     }
-  }
+
+    pir_params.d = d;
+    pir_params.dbc = 6;
+    pir_params.n = plaintext_num;
+    pir_params.nvec = nvec;
+    pir_params.expansion_ratio = expansion_ratio << 1;
 }
 
+uint32_t plainmod_after_expansion(uint32_t logt, uint32_t N, uint32_t d, 
+        uint64_t ele_num, uint64_t ele_size) {
 
-Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
-  Ciphertext result;
-  int encrypted_count = 2;
-
-
-  int coeff_count = newparms_.poly_modulus().coeff_count();
-  int coeff_mod_count = newparms_.coeff_modulus().size();
-  int array_poly_uint64_count = coeff_count * coeff_mod_count;
-
-  result.reserve(newparms_, encrypted_count);
-  int plain_bit_count = newparms_.plain_modulus().bit_count();
-  uint64_t plainMod = newparms_.plain_modulus().value();
-
-
-  // A triple for loop. Going over polys, moduli, and decomposed index.
-  for (int i = 0; i < encrypted_count; i++) {
-    uint64_t *encrypted_pointer = result.mutable_pointer(i);
-    for (int j = 0; j < coeff_mod_count; j++)
-    {
-      // populate one poly at a time.
-      // create a polynomial to store the current decomposition value which will be copied into the array to populate it at the current index.
-      double logqj = log(newparms_.coeff_modulus()[j].value());
-      int expansion_ratio = ceil(logqj / log(plainMod));
-      uint64_t cur = 1;
-      for (int k = 0; k < expansion_ratio; k++)
-      {
-        // Compose here
-        const uint64_t *plain_coeff = plains[k + j*(expansion_ratio)+i*(coeff_mod_count*expansion_ratio)].pointer();
-        for (int m = 0; m < coeff_count - 1; m++)
-        {
-          if (k == 0) {
-            *(encrypted_pointer + m + j*coeff_count) = *(plain_coeff + m) * cur;
-          }
-          else {
-            *(encrypted_pointer + m + j*coeff_count) += *(plain_coeff + m) * cur;
-          }
-        }
-        *(encrypted_pointer + coeff_count - 1 + j*coeff_count) = 0;
-        cur *= plainMod;
-      }
-
-      // Reduction modulo qj. This is needed? 
-      for (int m = 0; m < coeff_count; m++)
-      {
-        *(encrypted_pointer + m + j*coeff_count) %= newparms_.coeff_modulus()[j].value();
-      }
-    }
-  }
-  result.mutable_hash_block() = newparms_.hash_block();
-  return result;
-}
+    // Goal: find max logtp such that logtp + ceil(log(ceil(d_root(n)))) <= logt
+    // where n = ceil(ele_num / floor(N*logtp / ele_size *8))
+    for (uint32_t logtp = logt; logtp >= 2; logtp--) {
 
+        uint64_t n = plaintexts_per_db(logtp, N, ele_num, ele_size);
 
-void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, uint64_t* plain_ptr) {
-  vector<Plaintext> result;
-  int coeff_count = parms_.poly_modulus().coeff_count();
-  int coeff_mod_count = parms_.coeff_modulus().size();
-  int array_poly_uint64_count = coeff_count * coeff_mod_count;
-
-  int plain_bit_count = parms_.plain_modulus().bit_count();
-
-  int encrypted_count = encrypted.size();
-
-  // Generate powers of t.
-  uint64_t plainModMinusOne = parms_.plain_modulus().value() -1;
-  int exp = ceil(log2(plainModMinusOne + 1)); 
-
-  for (int i = 0; i < encrypted_count; i++) {
-    const uint64_t * encrypted_pointer = encrypted.pointer(i);
-    for (int j = 0; j < coeff_mod_count; j++)
-    {
-      // populate one poly at a time.
-      // create a polynomial to store the current decomposition value which will be copied into the array to populate it at the current index.
-      int shift = 0;
-      int logqj = log2(parms_.coeff_modulus()[j].value());
-      int expansion_ratio = (logqj + exp -1) / exp;
-      uint64_t curexp = 0;
-      for (int k = 0; k < expansion_ratio; k++)
-      {
-        // Decompose here
-        for (int m = 0; m < coeff_count; m++)
-        {
-          *plain_ptr = (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & plainModMinusOne;
-          plain_ptr++;
+        if (logtp == logt && n == 1) {
+            return logtp - 1;
         }
-        curexp += exp;
-      }
-    }
-  }
-  return;
-}
-
 
-std::vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted) {
-  vector<Plaintext> result;
-  int coeff_count = parms_.poly_modulus().coeff_count();
-  int coeff_mod_count = parms_.coeff_modulus().size();
-  int array_poly_uint64_count = coeff_count * coeff_mod_count;
-
-  int plain_bit_count = parms_.plain_modulus().bit_count();
-
-  int encrypted_count = encrypted.size();
-
-  // Generate powers of t.
-  uint64_t plainMod = parms_.plain_modulus().value();
-
-  for (int i = 0; i < encrypted_count; i++) {
-    const uint64_t * encrypted_pointer = encrypted.pointer(i);
-    for (int j = 0; j < coeff_mod_count; j++)
-    {
-      // populate one poly at a time.
-      // create a polynomial to store the current decomposition value which will be copied into the array to populate it at the current index.
-      int shift = 0;
-      int logqj = log(parms_.coeff_modulus()[j].value());
-      int expansion_ratio = ceil(logqj / log(plainMod));
-      uint64_t cur = 1;
-      for (int k = 0; k < expansion_ratio; k++)
-      {
-        // Decompose here
-        BigPoly temp;
-        temp.resize(coeff_count, plain_bit_count);
-        temp.set_zero();
-        uint64_t *plain_coeff = temp.pointer();
-        for (int m = 0; m < coeff_count; m++)
-        {
-          *(plain_coeff + m) = (*(encrypted_pointer + m + (j * coeff_count)) / cur) % plainMod;
+        if ((double)logtp + ceil(log2(ceil(pow(n, 1.0/(double)d)))) <= logt) {
+            return logtp;
         }
-        result.push_back(Plaintext(temp));
-        cur *= plainMod;
-      }
     }
-  }
-  return result;
-}
 
-void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs, Plaintext &plain)
-{
-  int coeff_count = coeffs.size();
-  plain.resize(coeff_count);
-util:set_uint_uint(coeffs.data(), coeff_count, plain.pointer());
+    assert(0); // this should never happen
+    return logt;
 }
 
-
-string serialize_ciphertext(Ciphertext c) {
-  std::stringstream output(std::ios::binary|std::ios::out);
-  c.save(output);
-  return output.str();
+// Number of coefficients needed to represent a database element
+uint64_t coefficients_per_element(uint32_t logtp, uint64_t ele_size) {
+    return ceil(8 * ele_size / (double)logtp);
 }
 
-string serialize_ciphertexts(vector<Ciphertext> c) {
-  string s;
-  for(int i=0; i<c.size(); i++) {
-    s.append(serialize_ciphertext(c[i]));
-  }
-  return s;
+// Number of database elements that can fit in a single FV plaintext
+uint64_t elements_per_ptxt(uint32_t logtp, uint64_t N, uint64_t ele_size) {
+    uint64_t coeff_per_ele = coefficients_per_element(logtp, ele_size);
+    uint64_t ele_per_ptxt = N / coeff_per_ele;
+    assert(ele_per_ptxt > 0);
+    return ele_per_ptxt;
 }
 
-Ciphertext* deserialize_ciphertext(string s) {
-  Ciphertext *c = new Ciphertext();
-  std::stringstream input(std::ios::binary|std::ios::in);
-  input.str(s);
-  c->load(input);
-  return c;
+// Number of FV plaintexts needed to represent the database
+uint64_t plaintexts_per_db(uint32_t logtp, uint64_t N, uint64_t ele_num, uint64_t ele_size) {
+    uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
+    return ceil((double)ele_num / ele_per_ptxt);
 }
 
-vector<Ciphertext> deserialize_ciphertexts(int count, string s, int len_ciphertext) {
-  vector<Ciphertext> c;
-  for(int i=0; i<count; i++) {
-    c.push_back(*(deserialize_ciphertext(s.substr(i*len_ciphertext, len_ciphertext))));
-  }
-  return c;
-}
+vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes, uint64_t size) {
 
-string serialize_plaintext(Plaintext p) {
-  std::stringstream output(std::ios::binary|std::ios::out);
-  p.save(output);
-  return output.str();
-}
+    uint64_t size_out = coefficients_per_element(limit, size);
+    vector<uint64_t> output(size_out);
 
-string serialize_plaintexts(vector<Plaintext> p) {
-  string s;
-  for(int i=0; i<p.size(); i++) {
-    s.append(serialize_plaintext(p[i]));
-  }
-  return s;
-}
+    uint32_t room = limit;
+    uint64_t *target = &output[0];
 
-Plaintext* deserialize_plaintext(string s) {
-  Plaintext *c = new Plaintext();
-  std::stringstream input(std::ios::binary|std::ios::in);
-  input.str(s);
-  c->load(input);
-  return c;
-}
-
-vector<Plaintext> deserialize_plaintexts(int count, string s, int len_plaintext) {
-  vector<Plaintext> p;
-  for(int i=0; i<count; i++) {
-    p.push_back(*(deserialize_plaintext(s.substr(i*len_plaintext, len_plaintext))));
-  }
-  return p;
-}
-
-string serialize_galoiskeys(GaloisKeys g) {
-  std::stringstream output(std::ios::binary|std::ios::out);
-  g.save(output);
-  return output.str();
-}
+    for (uint32_t i = 0; i < size; i++) {
+        uint8_t src = bytes[i];
+        uint32_t rest = 8;
+        while (rest) {
+            if (room == 0) {
+                target++;
+                room = limit;
+            }
+            uint32_t shift = rest;
+            if (room < rest) {
+                shift = room;
+            }
+            *target = *target << shift;
+            *target = *target | (src >> (8 - shift));
+            src = src << shift;
+            room -= shift;
+            rest -= shift;
+        }
+    }
 
-GaloisKeys* deserialize_galoiskeys(string s) {
-  GaloisKeys *g = new GaloisKeys();
-  std::stringstream input(std::ios::binary|std::ios::in);
-  input.str(s);
-  g->load(input);
-  return g;
+    *target = *target << room;
+    return output;
+}
+
+void coeffs_to_bytes(uint32_t limit, const Plaintext &coeffs, uint8_t *output, uint32_t size_out) {
+    uint32_t room = 8;
+    uint32_t j = 0;
+    uint8_t *target = output;
+
+    for (uint32_t i = 0; i < coeffs.coeff_count(); i++) {
+        uint64_t src = coeffs[i];
+        uint32_t rest = limit;
+        while (rest && j < size_out) {
+            uint32_t shift = rest;
+            if (room < rest) {
+                shift = room;
+            }
+            target[j] = target[j] << shift;
+            target[j] = target[j] | (src >> (limit - shift));
+            src = src << shift;
+            room -= shift;
+            rest -= shift;
+            if (room == 0) {
+                j++;
+                room = 8;
+            }
+        }
+    }
 }
 
-void 
-cpp_buffer_free(char *buf) {
-  free(buf);
+void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
+    uint32_t coeff_count = coeffs.size();
+    plain.resize(coeff_count);
+    util::set_uint_uint(coeffs.data(), coeff_count, plain.pointer());
 }
 
-void* 
-cpp_client_setup(uint64_t len_total_bytes, uint64_t num_db_entries) {
-
-  uint64_t number_of_items = num_db_entries;
-  uint64_t size_per_item = (len_total_bytes/num_db_entries) << 3;
-
-  int n = 2048;
-  int logt = 22;
-  uint64_t plainMod = static_cast<uint64_t> (1) << logt;
-
-  int number_of_plaintexts = ceil (((double)(number_of_items)* size_per_item / n) / logt );
+vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
+    uint32_t num = Nvec.size();
+    uint64_t product = 1;
 
-  EncryptionParameters parms;
-  parms.set_poly_modulus("1x^" + std::to_string(n) + " + 1");
-  vector<SmallModulus> coeff_mod_array;
-  int logq = 0;
-
-  for (int i = 0; i < 1; ++i)
-  {
-    coeff_mod_array.emplace_back(SmallModulus());
-    coeff_mod_array[i] = small_mods_60bit(i);
-    logq += coeff_mod_array[i].bit_count();
-  }
-
-  parms.set_coeff_modulus(coeff_mod_array);
-  parms.set_plain_modulus(plainMod);
-
-  pirParams pirparms;
-
-  int item_per_plaintext = floor((double)get_power_of_two(plainMod) *n / size_per_item);
-
-  pirparms.d = 2;
-  pirparms.alpha = 1;
-  pirparms.dbc = 8;
-  pirparms.N = number_of_plaintexts;
-  int sqrt_items = ceil(sqrt(number_of_plaintexts));
-
-  int bound1 = number_of_plaintexts / sqrt_items;
-  int bound2 = sqrt_items;
-
-  vector<int> Nvec = { bound1, bound2 };
-  pirparms.Nvec = Nvec;
-
-  PIRClient *client = new PIRClient(parms, pirparms);
-  return (void*) client;
-}
-
-char* 
-cpp_client_generate_query(void* pir, uint64_t chosen_idx, uint64_t* rlen_total_bytes, uint64_t* rnum_logical_entries) {
+    for (uint32_t i = 0; i < num; i++) {
+        product *= Nvec[i];
+    }
 
-  pirQuery query = ((PIRClient*) pir)->generate_query(chosen_idx);
+    uint64_t j = desiredIndex;
+    vector<uint64_t> result;
 
-  string s = serialize_ciphertexts(query);
+    for (uint32_t i = 0; i < num; i++) {
 
-  *rlen_total_bytes = s.length();
-  *rnum_logical_entries = query.size();
+        product /= Nvec[i];
+        uint64_t ji = j / product;
 
-  char *outptr, *result; 
-  result = (char*)calloc(*rlen_total_bytes, sizeof(char));
-  memcpy(result, s.c_str(), s.length());
-  return result;
-}
+        result.push_back(ji);
+        j -= ji * product;
+    }
 
-char*
-cpp_client_generate_galois_keys(void *pir, uint64_t *rlen_total_bytes) {
-  GaloisKeys g = ((PIRClient*) pir)->generate_galois_keys();
-  string s = serialize_galoiskeys(g); //.c_str();
-  char *outptr, *result; 
-  result = (char*)calloc(s.length(), sizeof(char));
-  memcpy(result, s.c_str(), s.length());
-  *rlen_total_bytes = s.length();
-  return result;
+    return result;
 }
 
-  char*
-cpp_client_process_reply(void* pir, char* r, uint64_t len_total_bytes, uint64_t num_logical_entries, uint64_t* rlen_total_bytes)
-{
-  string s(r);
-  vector<Ciphertext> reply = deserialize_ciphertexts(num_logical_entries, s, 32828);
-  Plaintext p = ((PIRClient*) pir)->decode_reply(reply);
-
-  string resp = serialize_plaintext(p);
-  *rlen_total_bytes = resp.length();
-  char *result = (char*)calloc(*rlen_total_bytes, sizeof(char));
-  memcpy(result, resp.c_str(), resp.length());
-  return result;
+inline Ciphertext deserialize_ciphertext(string s) {
+    Ciphertext c;
+    std::stringstream input(std::ios::binary | std::ios::in);
+    input.str(s);
+    c.load(input);
+    return c;
 }
 
-  void 
-cpp_client_free(void *pir)
-{
-  delete (PIRClient*) pir;
+vector<Ciphertext> deserialize_ciphertexts(uint32_t count, string s, uint32_t len_ciphertext) {
+    vector<Ciphertext> c;
+    for (uint32_t i = 0; i < count; i++) {
+        c.push_back(deserialize_ciphertext(s.substr(i * len_ciphertext, len_ciphertext)));
+    }
+    return c;
 }
 
-  void* 
-cpp_server_setup(uint64_t len_total_bytes, char *db, uint64_t num_logical_entries) 
-{
-  uint64_t max_entry_size_bytes = len_total_bytes/num_logical_entries;
-  uint64_t number_of_items = num_logical_entries;
-  uint64_t size_per_item = max_entry_size_bytes << 3; // 288 B. 
-
-  int n = 2048;
-  int logt = 22;
-  uint64_t plainMod = static_cast<uint64_t> (1) << logt;
-
-  int number_of_plaintexts = ceil (((double)(number_of_items)* size_per_item / n) / logt );
-
-  EncryptionParameters parms;
-  parms.set_poly_modulus("1x^" + std::to_string(n) + " + 1");
-  vector<SmallModulus> coeff_mod_array;
-  int logq = 0;
-
-  for (int i = 0; i < 1; ++i)
-  {
-    coeff_mod_array.emplace_back(SmallModulus());
-    coeff_mod_array[i] = small_mods_60bit(i);
-    logq += coeff_mod_array[i].bit_count();
-  }
-
-  parms.set_coeff_modulus(coeff_mod_array);
-  parms.set_plain_modulus(plainMod);
-
-  pirParams pirparms;
-
-  int item_per_plaintext = floor((double)get_power_of_two(plainMod) *n / size_per_item);
-
-  pirparms.d = 2;
-  pirparms.alpha = 1;
-
-  pirparms.dbc = 8;
-
-  pirparms.N = number_of_plaintexts;
-
-  int sqrt_items = ceil(sqrt(number_of_plaintexts));
-
-  int bound1 = number_of_plaintexts / sqrt_items;
-  int bound2 = sqrt_items;
-
-  vector<int> Nvec = { bound1, bound2 };
-  pirparms.Nvec = Nvec;
-
-  PIRServer *server = new PIRServer(parms, pirparms);
-
-  string d(db);
-  vector<Plaintext> items = deserialize_plaintexts(num_logical_entries, d, max_entry_size_bytes);
-  server->set_database(&items);
-  server->preprocess_database();
-  return (void*) server;
+inline string serialize_ciphertext(Ciphertext c) {
+    std::stringstream output(std::ios::binary | std::ios::out);
+    c.save(output);
+    return output.str();
 }
 
-  void
-cpp_server_set_galois_keys(void *pir, char *q, uint64_t len_total_bytes, int client_id)
-{
-  string s(q);
-  GaloisKeys *g = deserialize_galoiskeys(s);
-  ((PIRServer*)pir)->set_galois_key(client_id, *g);
+string serialize_ciphertexts(vector<Ciphertext> c) {
+    string s;
+    for (uint32_t i = 0; i < c.size(); i++) {
+        s.append(serialize_ciphertext(c[i]));
+    }
+    return s;
 }
 
-  char* 
-cpp_server_process_query(void* pir, char* q, uint64_t len_total_bytes, uint64_t num_logical_entries, uint64_t* rlen_total_bytes, uint64_t* rnum_logical_entries, int client_id)
-{
-  string str(q);
-  pirQuery query = deserialize_ciphertexts(num_logical_entries, str, len_total_bytes/num_logical_entries);
-
-  pirReply reply = ((PIRServer*) pir)->generate_reply(query, client_id);
-
-  string s = serialize_ciphertexts(reply);
-
-  *rlen_total_bytes = s.length();
-  *rnum_logical_entries = reply.size();
-
-  char *outptr, *result; 
-  result = (char*)calloc(*rlen_total_bytes, sizeof(char));
-  memcpy(result, s.c_str(), s.length());
-  return result;
+string serialize_galoiskeys(GaloisKeys g) {
+    std::stringstream output(std::ios::binary | std::ios::out);
+    g.save(output);
+    return output.str();
 }
 
-
-  void 
-cpp_server_free(void *pir)
-{
-  delete (PIRServer*) pir;
+GaloisKeys *deserialize_galoiskeys(string s) {
+    GaloisKeys *g = new GaloisKeys();
+    std::stringstream input(std::ios::binary | std::ios::in);
+    input.str(s);
+    g->load(input);
+    return g;
 }

+ 74 - 129
pir.hpp

@@ -1,133 +1,78 @@
-#ifndef SEAL_PIR_H
-#define SEAL_PIR_H
-
-#include <iostream>
-#include <iomanip>
-#include <math.h>
-#include <chrono>
-#include "seal/memorypoolhandle.h"
-#include "seal/encryptor.h"
-#include "seal/decryptor.h"
-#include "seal/encryptionparams.h"
-#include "seal/publickey.h"
-#include "seal/secretkey.h"
-#include "seal/evaluationkeys.h"
-#include "seal/galoiskeys.h"
+#pragma once
+
 #include "seal/seal.h"
-#include "math.h"
-#include "seal/util/polyarith.h"
-#include "seal/util/uintarith.h"
 #include "seal/util/polyarithsmallmod.h"
-
-using namespace std;
-using namespace seal;
-using namespace seal::util;
-typedef std::vector<Plaintext> dataBase;
-typedef std::vector<Ciphertext> pirQuery;
-typedef std::vector<Ciphertext> pirReply;
-
-vector<Ciphertext> deserialize_ciphertexts(int count, string s, int len_ciphertext);
-string serialize_ciphertexts(vector<Ciphertext> c);
-string serialize_plaintext(Plaintext p);
-string serialize_plaintexts(vector<Plaintext> p); 
-
-struct pirParams {
-  int N;
-  int size;
-  int alpha; 
-  int d; 
-  vector<int> Nvec;
-  int expansion_ratio_;
-  int dbc;
-};
-
-void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs, Plaintext &plain); 
-
-
-vector<int> compute_indices(int desiredIndex, vector<int> Nvec); 
-
-class PIRClient {
-  public:
-    PIRClient(const seal::EncryptionParameters & parms, pirParams & pirparms);
-    pirQuery generate_query(int desiredIndex);
-    Plaintext decode_reply(pirReply reply);
-
-    GaloisKeys generate_galois_keys();
-
-    void print_info(Ciphertext &encrypted);
-
-    EncryptionParameters get_new_parms() {
-      return newparms_;
-    }
-
-    pirParams get_pir_parms() {
-      return pirparms_;
-    }
-
-    Ciphertext compose_to_ciphertext(vector<Plaintext> plains);
-
-
-  private:
-    EncryptionParameters parms_;
-    EncryptionParameters newparms_;
-
-    pirParams pirparms_;
-    unique_ptr<Encryptor> encryptor_;
-    unique_ptr<Decryptor> decryptor_;
-    unique_ptr<Evaluator> evaluator_;
-    unique_ptr<KeyGenerator> keygen_;
-};
-
-class PIRServer
-{
-  public: 
-    PIRServer(const seal::EncryptionParameters &parms, const pirParams &pirparams);
-
-    // Reads the database from file.
-    void read_database_from_file(string file); 
-
-    // Preprocess the databse 
-    void preprocess_database();
-
-    void set_database(vector<Plaintext> *db);
-
-    pirReply generate_reply(pirQuery query, int client_id);
-    vector<Ciphertext> expand_query(const Ciphertext & encrypted, int d, const GaloisKeys &galkey);
-
-    void set_galois_key(int client_id, GaloisKeys galkey) {
-      galoisKeys_[client_id] = galkey;
-    }
-
-    void multiply_power_of_X(const Ciphertext &encrypted, Ciphertext & destination, int index);
-
-    void decompose_to_plaintexts_ptr(const Ciphertext & encrypted, uint64_t* plain_ptr);
-
-    vector<Plaintext> decompose_to_plaintexts(const Ciphertext &encrypted);
-
-  private:
-    EncryptionParameters parms_;
-    pirParams pirparams_;
-    map<int, GaloisKeys> galoisKeys_;
-    dataBase *dataBase_ = nullptr;
-    unique_ptr<Evaluator> evaluator_;
-    bool is_db_preprocessed_;
+#include <cassert>
+#include <cmath>
+#include <string>
+#include <vector>
+
+#define CIPHER_SIZE 32828
+
+typedef std::vector<seal::Plaintext> Database;
+typedef std::vector<seal::Ciphertext> PirQuery;
+typedef std::vector<seal::Ciphertext> PirReply;
+
+struct PirParams {
+    std::uint64_t n;                 // number of plaintexts in database
+    std::uint32_t d;                 // number of dimensions for the database (usually 2)
+    std::uint32_t expansion_ratio;   // ratio of plaintext to ciphertext
+    std::uint32_t dbc;               // decomposition bit count (used by relinearization)
+    std::vector<std::uint64_t> nvec; // size of each of the d dimensions
 };
 
-extern "C" {
-  void cpp_buffer_free(char* buf);
-
-  // client-specific methods
-  void* cpp_client_setup(uint64_t len_db_total_bytes, uint64_t num_db_entries);
-  char* cpp_client_generate_galois_keys(void *pir); 
-  char* cpp_client_generate_query(void* pir, uint64_t chosen_idx, uint64_t* rlen_query_total_bytes, uint64_t* rnum_query_slots);
-  char* cpp_client_process_reply(void* pir, char* r, uint64_t len_response_total_bytes, uint64_t num_response_slots, uint64_t* rlen_answer_total_bytes);
-  void cpp_client_update_db_params(void* pir, uint64_t len_db_total_bytes, uint64_t num_db_entries, uint64_t alpha, uint64_t d);
-  void cpp_client_free(void* pir);
-
-  // server-specific methods
-  void* cpp_server_setup(uint64_t len_db_total_bytes, char *db, uint64_t num_db_entries); 
-  char* cpp_server_process_query(void* pir, char* q, uint64_t len_query_total_bytes, uint64_t num_query_slots, uint64_t* rlen_response_total_bytes, uint64_t* rnum_response_slots);
-  void cpp_server_set_galois_keys(void *pir, char *q, uint64_t len_total_bytes, int client_id);
-  void cpp_server_free(void* pir);
-}
-#endif
+void gen_params(std::uint64_t ele_num,  // number of elements (not FV plaintexts) in database
+                std::uint64_t ele_size, // size of each element
+                std::uint32_t N,        // degree of polynomial
+                std::uint32_t logt,     // bits of plaintext coefficient
+                std::uint32_t d,        // dimension of database
+                seal::EncryptionParameters &params, seal::EncryptionParameters &expanded_params,
+                PirParams &pir_params);
+
+void update_params(std::uint64_t ele_num, 
+                   std::uint64_t ele_size,
+                   std::uint32_t d,
+                   const seal::EncryptionParameters &old_params,
+                   seal::EncryptionParameters &expanded_params, PirParams &pir_params);
+
+// returns the plaintext modulus after expansion
+std::uint32_t plainmod_after_expansion(std::uint32_t logt, std::uint32_t N, 
+                                       std::uint32_t d, std::uint64_t ele_num,
+                                       std::uint64_t ele_size);
+
+// returns the number of plaintexts that the database can hold
+std::uint64_t plaintexts_per_db(std::uint32_t logtp, std::uint64_t N, std::uint64_t ele_num,
+                                std::uint64_t ele_size);
+
+// returns the number of elements that a single FV plaintext can hold
+std::uint64_t elements_per_ptxt(std::uint32_t logtp, std::uint64_t N, std::uint64_t ele_size);
+
+// returns the number of coefficients needed to store one element
+std::uint64_t coefficients_per_element(std::uint32_t logtp, std::uint64_t ele_size);
+
+// Converts an array of bytes to a vector of coefficients, each of which is less
+// than the plaintext modulus
+std::vector<std::uint64_t> bytes_to_coeffs(std::uint32_t limit, const std::uint8_t *bytes,
+                                           std::uint64_t size);
+
+// Converts an array of coefficients into an array of bytes
+void coeffs_to_bytes(std::uint32_t logtp, const seal::Plaintext &coeffs, std::uint8_t *output,
+                     std::uint32_t size_out);
+
+// Takes a vector of coefficients and returns the corresponding FV plaintext
+void vector_to_plaintext(const vector<std::uint64_t> &coeffs, seal::Plaintext &plain);
+
+// Since the database has d dimensions, and an item is a particular cell
+// in the d-dimensional hypercube, this function computes the corresponding
+// index for each of the d dimensions
+std::vector<std::uint64_t> compute_indices(std::uint64_t desiredIndex,
+                                           std::vector<std::uint64_t> nvec);
+
+// Serialize and deserialize ciphertexts to send them over the network
+std::vector<seal::Ciphertext> deserialize_ciphertexts(std::uint32_t count, std::string s,
+                                                      std::uint32_t len_ciphertext);
+std::string serialize_ciphertexts(vector<seal::Ciphertext> c);
+
+// Serialize and deserialize galois keys to send them over the network
+std::string serialize_galoiskeys(seal::GaloisKeys g);
+seal::GaloisKeys *deserialize_galoiskeys(std::string s);

+ 193 - 0
pir_client.cpp

@@ -0,0 +1,193 @@
+#include "pir_client.hpp"
+
+using namespace std;
+using namespace seal;
+using namespace seal::util;
+
+PIRClient::PIRClient(const EncryptionParameters &params,
+                     const EncryptionParameters &expanded_params, const PirParams &pir_parms) {
+
+    params_ = params;
+    SEALContext context(params);
+
+    expanded_params_ = expanded_params;
+    SEALContext newcontext(expanded_params);
+
+    pir_params_ = pir_parms;
+
+    keygen_.reset(new KeyGenerator(context));
+    encryptor_.reset(new Encryptor(context, keygen_->public_key()));
+
+    SecretKey secret_key = keygen_->secret_key();
+    secret_key.mutable_hash_block() = expanded_params.hash_block();
+
+    decryptor_.reset(new Decryptor(newcontext, secret_key));
+    evaluator_.reset(new Evaluator(newcontext));
+}
+
+void PIRClient::update_parameters(const EncryptionParameters &expanded_params,
+                                  const PirParams &pir_params) {
+
+    // The only thing that can change is the plaintext modulus and pir_params
+    assert(expanded_params.poly_modulus() == expanded_params_.poly_modulus());
+    assert(expanded_params.coeff_modulus() == expanded_params_.coeff_modulus());
+
+    expanded_params_ = expanded_params;
+    pir_params_ = pir_params;
+    SEALContext newcontext(expanded_params);
+
+    SecretKey secret_key = keygen_->secret_key();
+    secret_key.mutable_hash_block() = expanded_params.hash_block();
+
+    decryptor_.reset(new Decryptor(newcontext, secret_key));
+    evaluator_.reset(new Evaluator(newcontext));
+}
+
+PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
+
+    vector<uint64_t> indices = compute_indices(desiredIndex, pir_params_.nvec);
+    vector<Ciphertext> result;
+
+    for (uint32_t i = 0; i < indices.size(); i++) {
+        Ciphertext dest;
+        encryptor_->encrypt(Plaintext("1x^" + std::to_string(indices[i])), dest);
+        dest.mutable_hash_block() = expanded_params_.hash_block();
+        result.push_back(dest);
+    }
+
+    return result;
+}
+
+uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
+    uint32_t N = params_.poly_modulus().coeff_count() - 1;
+    uint32_t logtp = ceil(log2(expanded_params_.plain_modulus().value()));
+
+    uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
+    return element_idx / ele_per_ptxt;
+}
+
+uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
+    uint32_t N = params_.poly_modulus().coeff_count() - 1;
+    uint32_t logtp = ceil(log2(expanded_params_.plain_modulus().value()));
+
+    uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
+    return element_idx % ele_per_ptxt;
+}
+
+Plaintext PIRClient::decode_reply(PirReply reply) {
+    uint32_t exp_ratio = pir_params_.expansion_ratio;
+    uint32_t recursion_level = pir_params_.d;
+
+    vector<Ciphertext> temp = reply;
+
+    for (uint32_t i = 0; i < recursion_level; i++) {
+
+        vector<Ciphertext> newtemp;
+        vector<Plaintext> tempplain;
+
+        for (uint32_t j = 0; j < temp.size(); j++) {
+            Plaintext ptxt;
+            decryptor_->decrypt(temp[j], ptxt);
+            tempplain.push_back(ptxt);
+
+#ifdef DEBUG
+            cout << "recursion level : " << i << " noise budget :  ";
+            cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
+#endif
+
+            if ((j + 1) % exp_ratio == 0 && j > 0) {
+                // Combine into one ciphertext.
+                Ciphertext combined = compose_to_ciphertext(tempplain);
+                newtemp.push_back(combined);
+            }
+        }
+
+        if (i == recursion_level - 1) {
+            assert(temp.size() == 1);
+            return tempplain[0];
+        } else {
+            tempplain.clear();
+            temp = newtemp;
+        }
+    }
+
+    // This should never be called
+    assert(0);
+    Plaintext fail;
+    return fail;
+}
+
+GaloisKeys PIRClient::generate_galois_keys() {
+    // Generate the Galois keys needed for coeff_select.
+    vector<uint64_t> galois_elts;
+    int N = params_.poly_modulus().coeff_count() - 1;
+    int logN = get_power_of_two(N);
+
+    for (int i = 0; i < logN; i++) {
+        galois_elts.push_back((N + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
+#ifdef DEBUG
+        cout << galois_elts.back() << ", ";
+#endif
+    }
+
+    GaloisKeys galois_keys;
+    keygen_->generate_galois_keys(pir_params_.dbc, galois_elts, galois_keys);
+    return galois_keys;
+}
+
+Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
+    int encrypted_count = 2;
+    int coeff_count = expanded_params_.poly_modulus().coeff_count();
+    int coeff_mod_count = expanded_params_.coeff_modulus().size();
+    uint64_t plainMod = expanded_params_.plain_modulus().value();
+
+    Ciphertext result;
+    result.reserve(expanded_params_, encrypted_count);
+
+    // A triple for loop. Going over polys, moduli, and decomposed index.
+    for (int i = 0; i < encrypted_count; i++) {
+        uint64_t *encrypted_pointer = result.mutable_pointer(i);
+
+        for (int j = 0; j < coeff_mod_count; j++) {
+            // populate one poly at a time.
+            // create a polynomial to store the current decomposition value
+            // which will be copied into the array to populate it at the current
+            // index.
+            double logqj = log2(expanded_params_.coeff_modulus()[j].value());
+            int expansion_ratio = ceil(logqj / log2(plainMod));
+
+            // cout << "expansion ratio = " << expansion_ratio << endl;
+            uint64_t cur = 1;
+
+            for (int k = 0; k < expansion_ratio; k++) {
+
+                // Compose here
+                const uint64_t *plain_coeff =
+                    plains[k + j * (expansion_ratio) + i * (coeff_mod_count * expansion_ratio)]
+                        .pointer();
+
+                for (int m = 0; m < coeff_count - 1; m++) {
+                    if (k == 0) {
+                        *(encrypted_pointer + m + j * coeff_count) = *(plain_coeff + m) * cur;
+                    } else {
+                        *(encrypted_pointer + m + j * coeff_count) += *(plain_coeff + m) * cur;
+                    }
+                }
+
+                *(encrypted_pointer + coeff_count - 1 + j * coeff_count) = 0;
+                cur *= plainMod;
+            }
+
+            // XXX: Reduction modulo qj. This is needed?
+            /*
+            for (int m = 0; m < coeff_count; m++) {
+                *(encrypted_pointer + m + j * coeff_count) %=
+                    expanded_params_.coeff_modulus()[j].value();
+            }
+            */
+        }
+    }
+
+    result.mutable_hash_block() = expanded_params_.hash_block();
+    return result;
+}

+ 35 - 0
pir_client.hpp

@@ -0,0 +1,35 @@
+#pragma once
+
+#include "pir.hpp"
+#include <memory>
+
+class PIRClient {
+  public:
+    PIRClient(const seal::EncryptionParameters &parms,
+              const seal::EncryptionParameters &expandedParams, const PirParams &pirparms);
+
+    void update_parameters(const seal::EncryptionParameters &expandedParams,
+                           const PirParams &pirparms);
+
+    PirQuery generate_query(std::uint64_t desiredIndex);
+    seal::Plaintext decode_reply(PirReply reply);
+
+    seal::GaloisKeys generate_galois_keys();
+
+    // Index and offset of an element in an FV plaintext
+    uint64_t get_fv_index(uint64_t element_idx, uint64_t ele_size);
+    uint64_t get_fv_offset(uint64_t element_idx, uint64_t ele_size);
+
+  private:
+    // Should we store a decryptor and an encryptor?
+    seal::EncryptionParameters params_;
+    seal::EncryptionParameters expanded_params_;
+    PirParams pir_params_;
+
+    std::unique_ptr<seal::Encryptor> encryptor_;
+    std::unique_ptr<seal::Decryptor> decryptor_;
+    std::unique_ptr<seal::Evaluator> evaluator_;
+    std::unique_ptr<seal::KeyGenerator> keygen_;
+
+    seal::Ciphertext compose_to_ciphertext(std::vector<seal::Plaintext> plains);
+};

+ 398 - 0
pir_server.cpp

@@ -0,0 +1,398 @@
+#include "pir_server.hpp"
+
+using namespace std;
+using namespace seal;
+using namespace seal::util;
+
+PIRServer::PIRServer(const EncryptionParameters &expanded_params, const PirParams &pir_params) {
+    expanded_params_ = expanded_params;
+    pir_params_ = pir_params;
+    SEALContext context(expanded_params);
+    evaluator_.reset(new Evaluator(context));
+    is_db_preprocessed_ = false;
+}
+
+PIRServer::~PIRServer() {
+    delete db_;
+}
+
+void PIRServer::update_parameters(const EncryptionParameters &expanded_params,
+                                  const PirParams &pir_params) {
+
+    // The only thing that can change is the plaintext modulus and pir_params
+    assert(expanded_params.poly_modulus() == expanded_params_.poly_modulus());
+    assert(expanded_params.coeff_modulus() == expanded_params_.coeff_modulus());
+
+    expanded_params_ = expanded_params;
+    pir_params_ = pir_params;
+    SEALContext context(expanded_params);
+    evaluator_.reset(new Evaluator(context));
+    is_db_preprocessed_ = false;
+
+    // Update all the galois keys
+    for (std::pair<const int, GaloisKeys> &key : galoisKeys_) {
+        key.second.mutable_hash_block() = expanded_params_.hash_block();
+    }
+}
+
+void PIRServer::preprocess_database() {
+    if (!is_db_preprocessed_) {
+
+        for (uint32_t i = 0; i < db_->size(); i++) {
+            evaluator_->transform_to_ntt(db_->operator[](i));
+        }
+
+        is_db_preprocessed_ = true;
+    }
+}
+
+// Server takes over ownership of db and will free it when it exits
+void PIRServer::set_database(vector<Plaintext> *db) {
+    if (db == nullptr) {
+        throw invalid_argument("db cannot be null");
+    }
+
+    db_ = db;
+    is_db_preprocessed_ = false;
+}
+
+void PIRServer::set_database(const uint8_t *bytes, uint64_t ele_num, uint64_t ele_size) {
+
+    uint32_t logtp = ceil(log2(expanded_params_.plain_modulus().value()));
+    uint32_t N = expanded_params_.poly_modulus().coeff_count() - 1;
+
+    // number of FV plaintexts needed to represent all elements
+    uint64_t total = plaintexts_per_db(logtp, N, ele_num, ele_size);
+
+    // number of FV plaintexts needed to create the d-dimensional matrix
+    uint64_t prod = 1;
+    for (uint32_t i = 0; i < pir_params_.nvec.size(); i++) {
+        prod *= pir_params_.nvec[i];
+    }
+    uint64_t matrix_plaintexts = prod;
+    assert(total <= matrix_plaintexts);
+
+    vector<Plaintext> *result = new vector<Plaintext>();
+    result->reserve(matrix_plaintexts);
+
+    uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
+    uint64_t bytes_per_ptxt = ele_per_ptxt * ele_size;
+
+    uint64_t db_size = ele_num * ele_size;
+
+    uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logtp, ele_size);
+    assert(coeff_per_ptxt <= N);
+
+    uint32_t offset = 0;
+
+    for (uint64_t i = 0; i < total; i++) {
+
+        uint64_t process_bytes = 0;
+
+        if (db_size <= offset) {
+            break;
+        } else if (db_size < offset + bytes_per_ptxt) {
+            process_bytes = db_size - offset;
+        } else {
+            process_bytes = bytes_per_ptxt;
+        }
+
+        // Get the coefficients of the elements that will be packed in plaintext i
+        vector<uint64_t> coefficients = bytes_to_coeffs(logtp, bytes + offset, process_bytes);
+        offset += process_bytes;
+
+        uint64_t used = coefficients.size();
+
+        assert(used <= coeff_per_ptxt);
+
+        // Pad the rest with 1s
+        for (uint64_t j = 0; j < (N - used); j++) {
+            coefficients.push_back(1);
+        }
+
+        Plaintext plain;
+        vector_to_plaintext(coefficients, plain);
+        result->push_back(plain);
+    }
+
+    // Add padding to make database a matrix
+    uint64_t current_plaintexts = result->size();
+    assert(current_plaintexts <= total);
+
+#ifdef DEBUG
+    cout << "adding: " << matrix_plaintexts - current_plaintexts
+         << " FV plaintexts of padding (equivalent to: "
+         << (matrix_plaintexts - current_plaintexts) * elements_per_ptxt(logtp, N, ele_size)
+         << " elements)" << endl;
+#endif
+
+    vector<uint64_t> padding(N, 1);
+
+    for (uint64_t i = 0; i < (matrix_plaintexts - current_plaintexts); i++) {
+        Plaintext plain;
+        vector_to_plaintext(padding, plain);
+        result->push_back(plain);
+    }
+
+    set_database(result);
+}
+
+void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey) {
+    galkey.mutable_hash_block() = expanded_params_.hash_block();
+    galoisKeys_[client_id] = galkey;
+}
+
+PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
+
+    vector<uint64_t> nvec = pir_params_.nvec;
+    uint64_t product = 1;
+
+    for (uint32_t i = 0; i < nvec.size(); i++) {
+        product *= nvec[i];
+    }
+
+    int coeff_count = expanded_params_.poly_modulus().coeff_count();
+
+    vector<Plaintext> *cur = db_;
+    vector<Plaintext> intermediate_plain; // decompose....
+
+    auto my_pool = MemoryPoolHandle::New();
+
+    for (uint32_t i = 0; i < nvec.size(); i++) {
+        uint64_t n_i = nvec[i];
+        vector<Ciphertext> expanded_query = expand_query(query[i], n_i, client_id);
+
+        // Transform expanded query to NTT, and ...
+        for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
+            evaluator_->transform_to_ntt(expanded_query[jj]);
+        }
+
+        // Transform plaintext to NTT. If database is pre-processed, can skip
+        if ((!is_db_preprocessed_) || i > 0) {
+            for (uint32_t jj = 0; jj < cur->size(); jj++) {
+                evaluator_->transform_to_ntt((*cur)[jj]);
+            }
+        }
+
+        product /= n_i;
+
+        vector<Ciphertext> intermediate(product);
+        Ciphertext temp;
+
+        for (uint64_t k = 0; k < product; k++) {
+            evaluator_->multiply_plain_ntt(expanded_query[0], (*cur)[k], intermediate[k]);
+
+            for (uint64_t j = 1; j < n_i; j++) {
+                evaluator_->multiply_plain_ntt(expanded_query[j], (*cur)[k + j * product], temp);
+                evaluator_->add(intermediate[k],
+                                temp); // Adds to first component.
+            }
+        }
+
+        for (uint32_t jj = 0; jj < intermediate.size(); jj++) {
+            evaluator_->transform_from_ntt(intermediate[jj]);
+        }
+
+        if (i == nvec.size() - 1) {
+            return intermediate;
+        } else {
+            intermediate_plain.clear();
+            intermediate_plain.reserve(pir_params_.expansion_ratio * product);
+            cur = &intermediate_plain;
+
+            util::Pointer tempplain_ptr(allocate_zero_poly(
+                pir_params_.expansion_ratio * product, coeff_count, my_pool));
+
+            for (uint64_t rr = 0; rr < product; rr++) {
+
+                decompose_to_plaintexts_ptr(intermediate[rr],
+                                            tempplain_ptr.get() +
+                                                rr * pir_params_.expansion_ratio * coeff_count);
+
+                for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
+                    int offset = rr * pir_params_.expansion_ratio * coeff_count + jj * coeff_count;
+                    intermediate_plain.emplace_back(coeff_count, tempplain_ptr.get() + offset);
+                }
+            }
+
+            product *= pir_params_.expansion_ratio; // multiply by expansion rate.
+        }
+    }
+
+    // This should never get here
+    assert(0);
+    vector<Ciphertext> fail(1);
+    return fail;
+}
+
+inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, uint32_t m,
+                                           uint32_t client_id) {
+
+#ifdef DEBUG
+    uint64_t plainMod = expanded_params_.plain_modulus().value();
+    cout << "PIRServer side plain modulus = " << plainMod << endl;
+#endif
+
+    GaloisKeys &galkey = galoisKeys_[client_id];
+
+    // Assume that m is a power of 2. If not, round it to the next power of 2.
+    uint32_t logm = ceil(log2(m));
+    Plaintext two("2");
+
+    vector<int> galois_elts;
+    int n = expanded_params_.poly_modulus().coeff_count() - 1;
+
+    for (uint32_t i = 0; i < logm; i++) {
+        galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
+    }
+
+    vector<Ciphertext> temp;
+    temp.push_back(encrypted);
+    Ciphertext tempctxt;
+    Ciphertext tempctxt_rotated;
+    Ciphertext tempctxt_shifted;
+    Ciphertext tempctxt_rotatedshifted;
+
+    for (uint32_t i = 0; i < logm - 1; i++) {
+        vector<Ciphertext> newtemp(temp.size() << 1);
+        // temp[a] = (j0 = a (mod 2**i) ? ) : Enc(x^{j0 - a}) else Enc(0).  With
+        // some scaling....
+        int index_raw = (n << 1) - (1 << i);
+        int index = (index_raw * galois_elts[i]) % (n << 1);
+
+        for (uint32_t a = 0; a < temp.size(); a++) {
+
+            evaluator_->apply_galois(temp[a], galois_elts[i], galkey, tempctxt_rotated);
+            evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
+            multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
+            multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
+            // Enc(2^i x^j) if j = 0 (mod 2**i).
+            evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
+        }
+        temp = newtemp;
+    }
+
+    // Last step of the loop
+    vector<Ciphertext> newtemp(temp.size() << 1);
+    int index_raw = (n << 1) - (1 << (logm - 1));
+    int index = (index_raw * galois_elts[logm - 1]) % (n << 1);
+    for (uint32_t a = 0; a < temp.size(); a++) {
+        if (a >= (m - (1 << (logm - 1)))) {                       // corner case.
+            evaluator_->multiply_plain(temp[a], two, newtemp[a]); // plain multiplication by 2.
+        } else {
+            evaluator_->apply_galois(temp[a], galois_elts[logm - 1], galkey, tempctxt_rotated);
+            evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
+            multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
+            multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
+            evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
+        }
+    }
+
+    vector<Ciphertext>::const_iterator first = newtemp.begin();
+    vector<Ciphertext>::const_iterator last = newtemp.begin() + m;
+    vector<Ciphertext> newVec(first, last);
+    return newVec;
+}
+
+inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext &destination,
+                                    uint32_t index) {
+
+    int coeff_mod_count = expanded_params_.coeff_modulus().size();
+    int coeff_count = expanded_params_.poly_modulus().coeff_count();
+    int encrypted_count = encrypted.size();
+
+    // First copy over.
+    destination = encrypted;
+
+    // Prepare for destination
+    // Multiply X^index for each ciphertext polynomial
+    for (int i = 0; i < encrypted_count; i++) {
+        for (int j = 0; j < coeff_mod_count; j++) {
+            negacyclic_shift_poly_coeffmod(encrypted.pointer(i) + (j * coeff_count),
+                                           coeff_count - 1, index,
+                                           expanded_params_.coeff_modulus()[j],
+                                           destination.mutable_pointer(i) + (j * coeff_count));
+        }
+    }
+}
+
+inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, uint64_t *plain_ptr) {
+
+    vector<Plaintext> result;
+    int coeff_count = expanded_params_.poly_modulus().coeff_count();
+    int coeff_mod_count = expanded_params_.coeff_modulus().size();
+    int encrypted_count = encrypted.size();
+
+    // Generate powers of t.
+    uint64_t plainModMinusOne = expanded_params_.plain_modulus().value() - 1;
+    int exp = ceil(log2(plainModMinusOne + 1));
+
+    // A triple for loop. Going over polys, moduli, and decomposed index.
+
+    for (int i = 0; i < encrypted_count; i++) {
+        const uint64_t *encrypted_pointer = encrypted.pointer(i);
+        for (int j = 0; j < coeff_mod_count; j++) {
+            // populate one poly at a time.
+            // create a polynomial to store the current decomposition value
+            // which will be copied into the array to populate it at the current
+            // index.
+            int logqj = log2(expanded_params_.coeff_modulus()[j].value());
+            int expansion_ratio = ceil(logqj + exp - 1) / exp;
+
+            // cout << "expansion ratio = " << expansion_ratio << endl;
+            uint64_t curexp = 0;
+            for (int k = 0; k < expansion_ratio; k++) {
+                // Decompose here
+                for (int m = 0; m < coeff_count; m++) {
+                    *plain_ptr =
+                        (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & plainModMinusOne;
+                    plain_ptr++;
+                }
+                curexp += exp;
+            }
+        }
+    }
+}
+
+vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted) {
+    vector<Plaintext> result;
+    int coeff_count = expanded_params_.poly_modulus().coeff_count();
+    int coeff_mod_count = expanded_params_.coeff_modulus().size();
+    int plain_bit_count = expanded_params_.plain_modulus().bit_count();
+    int encrypted_count = encrypted.size();
+
+    // Generate powers of t.
+    uint64_t plainMod = expanded_params_.plain_modulus().value();
+
+    // A triple for loop. Going over polys, moduli, and decomposed index.
+    for (int i = 0; i < encrypted_count; i++) {
+        const uint64_t *encrypted_pointer = encrypted.pointer(i);
+        for (int j = 0; j < coeff_mod_count; j++) {
+            // populate one poly at a time.
+            // create a polynomial to store the current decomposition value
+            // which will be copied into the array to populate it at the current
+            // index.
+            int logqj = log2(expanded_params_.coeff_modulus()[j].value());
+            int expansion_ratio = ceil(logqj / log2(plainMod));
+
+            // cout << "expansion ratio = " << expansion_ratio << endl;
+            uint64_t cur = 1;
+            for (int k = 0; k < expansion_ratio; k++) {
+                // Decompose here
+                BigPoly temp;
+                temp.resize(coeff_count, plain_bit_count);
+                temp.set_zero();
+                uint64_t *plain_coeff = temp.pointer();
+                for (int m = 0; m < coeff_count; m++) {
+                    *(plain_coeff + m) =
+                        (*(encrypted_pointer + m + (j * coeff_count)) / cur) % plainMod;
+                }
+
+                result.push_back(Plaintext(temp));
+                cur *= plainMod;
+            }
+        }
+    }
+
+    return result;
+}

+ 41 - 0
pir_server.hpp

@@ -0,0 +1,41 @@
+#pragma once
+
+#include "pir.hpp"
+#include <map>
+#include <memory>
+#include <vector>
+
+class PIRServer {
+  public:
+    PIRServer(const seal::EncryptionParameters &expanded_params, const PirParams &pir_params);
+    ~PIRServer();
+
+    void update_parameters(const seal::EncryptionParameters &expanded_params,
+                           const PirParams &pir_params);
+
+    // NOTE: server takes over ownership of db and frees it when it exits.
+    // Caller cannot free db
+    void set_database(std::vector<seal::Plaintext> *db);
+    void set_database(const std::uint8_t *bytes, std::uint64_t ele_num, std::uint64_t ele_size);
+    void preprocess_database();
+
+    std::vector<seal::Ciphertext> expand_query(const seal::Ciphertext &encrypted, std::uint32_t m,
+                                               uint32_t client_id);
+
+    PirReply generate_reply(PirQuery query, std::uint32_t client_id);
+
+    void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);
+
+  private:
+    seal::EncryptionParameters expanded_params_; // SEAL parameters
+    PirParams pir_params_;                       // PIR parameters
+    Database *db_ = nullptr;
+    bool is_db_preprocessed_;
+    std::map<int, seal::GaloisKeys> galoisKeys_;
+    std::unique_ptr<seal::Evaluator> evaluator_;
+
+    void decompose_to_plaintexts_ptr(const seal::Ciphertext &encrypted, std::uint64_t *plain_ptr);
+    std::vector<seal::Plaintext> decompose_to_plaintexts(const seal::Ciphertext &encrypted);
+    void multiply_power_of_X(const seal::Ciphertext &encrypted, seal::Ciphertext &destination,
+                             std::uint32_t index);
+};