Browse Source

Merge pull request #12 from abeams/master

Updated SealPIR to work with SEAL v3.6
Kim Laine 2 years ago
parent
commit
12f6c0ff34
25 changed files with 2337 additions and 1275 deletions
  1. 5 0
      .gitignore
  2. 4 15
      CMakeLists.txt
  3. 65 12
      README.md
  4. 0 124
      main.cpp
  5. 0 272
      pir.cpp
  6. 0 74
      pir.hpp
  7. 0 248
      pir_client.cpp
  8. 0 41
      pir_client.hpp
  9. 0 451
      pir_server.cpp
  10. 0 38
      pir_server.hpp
  11. 8 0
      src/CMakeLists.txt
  12. 185 0
      src/main.cpp
  13. 386 0
      src/pir.cpp
  14. 100 0
      src/pir.hpp
  15. 289 0
      src/pir_client.cpp
  16. 58 0
      src/pir_client.hpp
  17. 442 0
      src/pir_server.cpp
  18. 55 0
      src/pir_server.hpp
  19. 25 0
      test/CMakeLists.txt
  20. 47 0
      test/coefficient_conversion_test.cpp
  21. 79 0
      test/decomposition_test.cpp
  22. 107 0
      test/expand_test.cpp
  23. 167 0
      test/query_test.cpp
  24. 172 0
      test/replace_test.cpp
  25. 143 0
      test/simple_query_test.cpp

+ 5 - 0
.gitignore

@@ -329,11 +329,16 @@ ASALocalRun/
 # MFractors (Xamarin productivity tool) working folder 
 .mfractor/
 
+*.vscode
+
 # CMake files.
+*/CMakeFiles/
 /CMakeCache.txt
 /CMakeFiles/
 /Makefile
 /cmake_install.cmake
+*/Makefile
+*/*.cmake
 
 # Built targets.
 libsealpir.a

+ 4 - 15
CMakeLists.txt

@@ -1,21 +1,10 @@
 cmake_minimum_required(VERSION 3.10)
-
 set(CMAKE_CXX_STANDARD 17)
 set(CMAKE_CXX_STANDARD_REQUIRED ON)
+project(SealPIR VERSION 2.2 LANGUAGES CXX)
 
-project(SealPIR VERSION 2.1 LANGUAGES CXX)
 set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/bin)
+add_subdirectory(src)
 
-add_executable(main 
-	main.cpp
-)
-
-add_library(sealpir STATIC
-  pir.cpp
-  pir_client.cpp
-  pir_server.cpp
-)
-
-find_package(SEAL 3.2.0 EXACT REQUIRED)
-
-target_link_libraries(main sealpir SEAL::seal)
+enable_testing()
+add_subdirectory(test)

+ 65 - 12
README.md

@@ -2,29 +2,82 @@
 
 SealPIR is a research library and should not be used in production systems. 
 SealPIR allows a client to download an element from a database stored by a server without
-revealing which element was downloaded. SealPIR was introduced at 
-the Symposium on Security and Privacy (Oakland) in 2018. You can find
+revealing to the server which element was downloaded. SealPIR was introduced at 
+the IEEE Symposium on Security and Privacy (Oakland) in 2018. You can find
 a copy of the paper [here](https://eprint.iacr.org/2017/1142.pdf).
 
 # Compiling SEAL
 
-SealPIR depends on [Microsoft SEAL version 3.2.0](https://github.com/microsoft/SEAL/tree/3.2.0).
-Install SEAL before compiling SealPIR.
+SealPIR depends on [Microsoft SEAL version 4.0.0](https://github.com/microsoft/SEAL/tree/4.0.0).
+
+Download and install SEAL (follow the instructions in the above link) before before compiling SealPIR.
 
 # Compiling SealPIR
 
-Once Microsoft SEAL 3.2.0 is installed, to build SealPIR simply run:
+Once Microsoft SEAL 4.0.0 is installed, to build SealPIR simply run:
+
+```
+cmake .
+make
+```
+
+This should produce a binary file ``bin/main``.
+
+# Testing SealPIR
+
+Once you have compiled SealPIR, you can run our battery of unit tests with:
 
-	cmake .
-	make
-	
-This should produce a binary file ``bin/sealpir``.
+```
+ctest .
+```
 
 # Using SealPIR
 
-Take a look at the example in main.cpp for how to use SealPIR. 
-Note: the parameter "d" stands for recursion levels, and for the current configuration, the 
-server-to-client reply has size (pow(10, d-1) * 32) KB. Therefore we recommend using d <= 3.  
+Take a look at the example in `src/main.cpp` for how to use SealPIR. 
+You can also look at the tests in the `test` folder.
+
+
+## Default parameters
+
+*N* indicates the degree of the BFV polynomials.  Default is 4096.
+
+*t* indicates the plaintext modulus, but we specify *log t* instead. Default is 20.
+
+Each BFV ciphertext can encrypt log t * N, which is approximately 10 KB bits of information.
+
+This means that if your database has, say, 1 KB elements, then you can pack 10 
+such elements into a single BFV plaintext. 
+On the other hand, if your database has, say, 20 KB elements, then you will 
+need two BFV plaintexts to represent each of your elements.
+
+*d* represents the recursion level.  When the number of BFV plaintexts needed
+to represent your database (see above for how to map the number of database
+elements of a given size to the number of BFV plaintexts) is smaller than N,
+then setting *d = 1* minimizes communication costs. However, you can also set
+*d = 2* which doubles the size of the query and increases the size of the
+response by roughly a factor of 4, but in some cases might reduce computational
+costs a little bit (because the oblivious expansion procedure is cheaper). 
+
+When the number of BFV plaintexts is much greater than N, then *d = 2*
+minimizes communication costs. You can read the paper to understand how *d*
+affects communication costs. In general, the query consists of *d* BFV
+ciphertexts and can index a database with *N^d* BFV plaintexts;  the response
+consists of *F^(d-1)* ciphertexts, where *F* is the ciphertext
+expansion factor. In the current implementation which uses recursive
+modulo swithcing, *F* is around 4. We have not identified any setting where
+*d > 2* is beneficial.
+
+
+# Changelog
+
+This implementation of SealPIR uses the latest version of SEAL, fixes several bugs,
+and provides better serialization/deserialization of queries and responses,
+and a more streamlined code base.
+
+If you wish to use the **original** version of SealPIR which corresponds to the
+numbers reported in the paper and which uses an older version  of SEAL, check
+out [this](https://github.com/microsoft/SealPIR/tree/ccf86c50fd3291) branch in
+the git repository.
 
 # Contributing
 

+ 0 - 124
main.cpp

@@ -1,124 +0,0 @@
-#include "pir.hpp"
-#include "pir_client.hpp"
-#include "pir_server.hpp"
-#include <seal/seal.h>
-#include <chrono>
-#include <memory>
-#include <random>
-#include <cstdint>
-#include <cstddef>
-
-using namespace std::chrono;
-using namespace std;
-using namespace seal;
-
-int main(int argc, char *argv[]) {
-
-    uint64_t number_of_items = 1 << 12;
-    uint64_t size_per_item = 288; // in bytes
-    uint32_t N = 2048;
-
-    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
-    uint32_t logt = 12; 
-    uint32_t d = 2;
-
-    EncryptionParameters params(scheme_type::BFV);
-    PirParams pir_params;
-
-    // Generates all parameters
-    cout << "Main: Generating all parameters" << endl;
-    gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
-
-    cout << "Main: Initializing the database (this may take some time) ..." << endl;
-
-    // Create test database
-    auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    // Copy of the database. We use this at the end to make sure we retrieved
-    // the correct element.
-    auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    random_device rd;
-    for (uint64_t i = 0; i < number_of_items; i++) {
-        for (uint64_t j = 0; j < size_per_item; j++) {
-            auto val = rd() % 256;
-            db.get()[(i * size_per_item) + j] = val;
-            db_copy.get()[(i * size_per_item) + j] = val;
-        }
-    }
-
-    // Initialize PIR Server
-    cout << "Main: Initializing server and client" << endl;
-    PIRServer server(params, pir_params);
-
-    // Initialize PIR client....
-    PIRClient client(params, pir_params);
-    GaloisKeys galois_keys = client.generate_galois_keys();
-
-    // Set galois key for client with id 0
-    cout << "Main: Setting Galois keys...";
-    server.set_galois_key(0, galois_keys);
-
-    // Measure database setup
-    auto time_pre_s = high_resolution_clock::now();
-    server.set_database(move(db), number_of_items, size_per_item);
-    server.preprocess_database();
-    cout << "Main: database pre processed " << endl;
-    auto time_pre_e = high_resolution_clock::now();
-    auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
-
-    // Choose an index of an element in the DB
-    uint64_t ele_index = rd() % number_of_items; // element in DB at random position
-    uint64_t index = client.get_fv_index(ele_index, size_per_item);   // index of FV plaintext
-    uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
-    cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
-    cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
-
-    // Measure query generation
-    auto time_query_s = high_resolution_clock::now();
-    PirQuery query = client.generate_query(index);
-    auto time_query_e = high_resolution_clock::now();
-    auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
-    cout << "Main: query generated" << endl;
-
-    //To marshall query to send over the network, you can use serialize/deserialize:
-    //std::string query_ser = serialize_query(query);
-    //PirQuery query2 = deserialize_query(d, 1, query_ser, CIPHER_SIZE);
-
-    // Measure query processing (including expansion)
-    auto time_server_s = high_resolution_clock::now();
-    PirReply reply = server.generate_reply(query, 0);
-    auto time_server_e = high_resolution_clock::now();
-    auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
-
-    // Measure response extraction
-    auto time_decode_s = chrono::high_resolution_clock::now();
-    Plaintext result = client.decode_reply(reply);
-    auto time_decode_e = chrono::high_resolution_clock::now();
-    auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
-
-    // Convert from FV plaintext (polynomial) to database element at the client
-    vector<uint8_t> elems(N * logt / 8);
-    coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
-
-    // Check that we retrieved the correct element
-    for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[(offset * size_per_item) + i] != db_copy.get()[(ele_index * size_per_item) + i]) {
-            cout << "Main: elems " << (int)elems[(offset * size_per_item) + i] << ", db "
-                 << (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
-            cout << "Main: PIR result wrong!" << endl;
-            return -1;
-        }
-    }
-
-    // Output results
-    cout << "Main: PIR result correct!" << endl;
-    cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
-    cout << "Main: PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
-    cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms"
-         << endl;
-    cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
-    cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
-
-    return 0;
-}

+ 0 - 272
pir.cpp

@@ -1,272 +0,0 @@
-#include "pir.hpp"
-
-using namespace std;
-using namespace seal;
-using namespace seal::util;
-
-vector<uint64_t> get_dimensions(uint64_t plaintext_num, uint32_t d) {
-
-    assert(d > 0);
-    assert(plaintext_num > 0);
-
-    vector<uint64_t> dimensions(d);
-
-    for (uint32_t i = 0; i < d; i++) {
-        dimensions[i] = std::max((uint32_t) 2, (uint32_t) floor(pow(plaintext_num, 1.0/d)));
-    }
-
-    uint32_t product = 1;
-    uint32_t j = 0;
-
-    // if plaintext_num is not a d-power
-    if ((double) dimensions[0] != pow(plaintext_num, 1.0 / d)) {
-        while  (product < plaintext_num && j < d) {
-            product = 1;
-            dimensions[j++]++;
-            for (uint32_t i = 0; i < d; i++) {
-                product *= dimensions[i];
-            }
-        }
-    }
-
-    return dimensions;
-}
-
-void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
-                uint32_t d, EncryptionParameters &params,
-                PirParams &pir_params) {
-    
-    // Determine the maximum size of each dimension
-
-    // plain modulus = a power of 2 plus 1
-    uint64_t plain_mod = (static_cast<uint64_t>(1) << logt) + 1;
-    uint64_t plaintext_num = plaintexts_per_db(logt, N, ele_num, ele_size);
-
-#ifdef DEBUG
-    cout << "log(plain mod) before expand = " << logt << endl;
-    cout << "number of FV plaintexts = " << plaintext_num << endl;
-#endif
-
-    vector<SmallModulus> coeff_mod_array;
-    uint32_t logq = 0;
-
-    for (uint32_t i = 0; i < 1; i++) {
-        coeff_mod_array.emplace_back(SmallModulus());
-        coeff_mod_array[i] = DefaultParams::small_mods_60bit(i);
-        logq += coeff_mod_array[i].bit_count();
-    }
-
-    params.set_poly_modulus_degree(N);
-    params.set_coeff_modulus(coeff_mod_array);
-    params.set_plain_modulus(plain_mod);
-
-    vector<uint64_t> nvec = get_dimensions(plaintext_num, d);
-
-    uint32_t expansion_ratio = 0;
-    for (uint32_t i = 0; i < params.coeff_modulus().size(); ++i) {
-        double logqi = log2(params.coeff_modulus()[i].value());
-        cout << "PIR: logqi = " << logqi << endl; 
-        expansion_ratio += ceil(logqi / logt);
-    }
-
-    pir_params.d = d;
-    pir_params.dbc = 6;
-    pir_params.n = plaintext_num;
-    pir_params.nvec = nvec;
-    pir_params.expansion_ratio = expansion_ratio << 1; // because one ciphertext = two polys
-}
-
-
-uint32_t plainmod_after_expansion(uint32_t logt, uint32_t N, uint32_t d, 
-        uint64_t ele_num, uint64_t ele_size) {
-
-    // Goal: find max logtp such that logtp + ceil(log(ceil(d_root(n)))) <= logt
-    // where n = ceil(ele_num / floor(N*logtp / ele_size *8))
-    for (uint32_t logtp = logt; logtp >= 2; logtp--) {
-
-        uint64_t n = plaintexts_per_db(logtp, N, ele_num, ele_size);
-
-        if (logtp == logt && n == 1) {
-            return logtp - 1;
-        }
-
-        if ((double)logtp + ceil(log2(ceil(pow(n, 1.0/(double)d)))) <= logt) {
-            return logtp;
-        }
-    }
-
-    assert(0); // this should never happen
-    return logt;
-}
-
-// Number of coefficients needed to represent a database element
-uint64_t coefficients_per_element(uint32_t logtp, uint64_t ele_size) {
-    return ceil(8 * ele_size / (double)logtp);
-}
-
-// Number of database elements that can fit in a single FV plaintext
-uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
-    uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
-    uint64_t ele_per_ptxt = N / coeff_per_ele;
-    assert(ele_per_ptxt > 0);
-    return ele_per_ptxt;
-}
-
-// Number of FV plaintexts needed to represent the database
-uint64_t plaintexts_per_db(uint32_t logtp, uint64_t N, uint64_t ele_num, uint64_t ele_size) {
-    uint64_t ele_per_ptxt = elements_per_ptxt(logtp, N, ele_size);
-    return ceil((double)ele_num / ele_per_ptxt);
-}
-
-vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes, uint64_t size) {
-
-    uint64_t size_out = coefficients_per_element(limit, size);
-    vector<uint64_t> output(size_out);
-
-    uint32_t room = limit;
-    uint64_t *target = &output[0];
-
-    for (uint32_t i = 0; i < size; i++) {
-        uint8_t src = bytes[i];
-        uint32_t rest = 8;
-        while (rest) {
-            if (room == 0) {
-                target++;
-                room = limit;
-            }
-            uint32_t shift = rest;
-            if (room < rest) {
-                shift = room;
-            }
-            *target = *target << shift;
-            *target = *target | (src >> (8 - shift));
-            src = src << shift;
-            room -= shift;
-            rest -= shift;
-        }
-    }
-
-    *target = *target << room;
-    return output;
-}
-
-void coeffs_to_bytes(uint32_t limit, const Plaintext &coeffs, uint8_t *output, uint32_t size_out) {
-    uint32_t room = 8;
-    uint32_t j = 0;
-    uint8_t *target = output;
-
-    for (uint32_t i = 0; i < coeffs.coeff_count(); i++) {
-        uint64_t src = coeffs[i];
-        uint32_t rest = limit;
-        while (rest && j < size_out) {
-            uint32_t shift = rest;
-            if (room < rest) {
-                shift = room;
-            }
-            target[j] = target[j] << shift;
-            target[j] = target[j] | (src >> (limit - shift));
-            src = src << shift;
-            room -= shift;
-            rest -= shift;
-            if (room == 0) {
-                j++;
-                room = 8;
-            }
-        }
-    }
-}
-
-void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
-    uint32_t coeff_count = coeffs.size();
-    plain.resize(coeff_count);
-    util::set_uint_uint(coeffs.data(), coeff_count, plain.data());
-}
-
-vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
-    uint32_t num = Nvec.size();
-    uint64_t product = 1;
-
-    for (uint32_t i = 0; i < num; i++) {
-        product *= Nvec[i];
-    }
-
-    uint64_t j = desiredIndex;
-    vector<uint64_t> result;
-
-    for (uint32_t i = 0; i < num; i++) {
-
-        product /= Nvec[i];
-        uint64_t ji = j / product;
-
-        result.push_back(ji);
-        j -= ji * product;
-    }
-
-    return result;
-}
-
-inline Ciphertext deserialize_ciphertext(string s) {
-    Ciphertext c;
-    std::istringstream input(s);
-    c.unsafe_load(input);
-    return c;
-}
-
-
-vector<Ciphertext> deserialize_ciphertexts(uint32_t count, string s, uint32_t len_ciphertext) {
-    vector<Ciphertext> c;
-    for (uint32_t i = 0; i < count; i++) {
-        c.push_back(deserialize_ciphertext(s.substr(i * len_ciphertext, len_ciphertext)));
-    }
-    return c;
-}
-
-PirQuery deserialize_query(uint32_t d, uint32_t count, string s, uint32_t len_ciphertext) {
-    vector<vector<Ciphertext>> c;
-    for (uint32_t i = 0; i < d; i++) {
-        c.push_back(deserialize_ciphertexts(
-              count, 
-              s.substr(i * count * len_ciphertext, count * len_ciphertext),
-              len_ciphertext)
-        );
-    }
-    return c;
-}
-
-
-inline string serialize_ciphertext(Ciphertext c) {
-    std::ostringstream output;
-    c.save(output);
-    return output.str();
-}
-
-string serialize_ciphertexts(vector<Ciphertext> c) {
-    string s;
-    for (uint32_t i = 0; i < c.size(); i++) {
-        s.append(serialize_ciphertext(c[i]));
-    }
-    return s;
-}
-
-string serialize_query(vector<vector<Ciphertext>> c) {
-    string s;
-    for (uint32_t i = 0; i < c.size(); i++) {
-      for (uint32_t j = 0; j < c[i].size(); j++) {
-        s.append(serialize_ciphertext(c[i][j]));
-      }
-    }
-    return s;
-}
-
-string serialize_galoiskeys(GaloisKeys g) {
-    std::ostringstream output;
-    g.save(output);
-    return output.str();
-}
-
-GaloisKeys *deserialize_galoiskeys(string s) {
-    GaloisKeys *g = new GaloisKeys();
-    std::istringstream input(s);
-    g->unsafe_load(input);
-    return g;
-}

+ 0 - 74
pir.hpp

@@ -1,74 +0,0 @@
-#pragma once
-
-#include "seal/seal.h"
-#include "seal/util/polyarithsmallmod.h"
-#include <cassert>
-#include <cmath>
-#include <string>
-#include <vector>
-
-#define CIPHER_SIZE 32841
-
-typedef std::vector<seal::Plaintext> Database;
-typedef std::vector<std::vector<seal::Ciphertext>> PirQuery;
-typedef std::vector<seal::Ciphertext> PirReply;
-
-struct PirParams {
-    std::uint64_t n;                 // number of plaintexts in database
-    std::uint32_t d;                 // number of dimensions for the database (1 or 2)
-    std::uint32_t expansion_ratio;   // ratio of ciphertext to plaintext
-    std::uint32_t dbc;               // decomposition bit count (used by relinearization)
-    std::vector<std::uint64_t> nvec; // size of each of the d dimensions
-};
-
-void gen_params(std::uint64_t ele_num,  // number of elements (not FV plaintexts) in database
-                std::uint64_t ele_size, // size of each element
-                std::uint32_t N,        // degree of polynomial
-                std::uint32_t logt,     // bits of plaintext coefficient
-                std::uint32_t d,        // dimension of database
-                seal::EncryptionParameters &params,
-                PirParams &pir_params);
-
-// returns the plaintext modulus after expansion
-std::uint32_t plainmod_after_expansion(std::uint32_t logt, std::uint32_t N, 
-                                       std::uint32_t d, std::uint64_t ele_num,
-                                       std::uint64_t ele_size);
-
-// returns the number of plaintexts that the database can hold
-std::uint64_t plaintexts_per_db(std::uint32_t logtp, std::uint64_t N, std::uint64_t ele_num,
-                                std::uint64_t ele_size);
-
-// returns the number of elements that a single FV plaintext can hold
-std::uint64_t elements_per_ptxt(std::uint32_t logtp, std::uint64_t N, std::uint64_t ele_size);
-
-// returns the number of coefficients needed to store one element
-std::uint64_t coefficients_per_element(std::uint32_t logtp, std::uint64_t ele_size);
-
-// Converts an array of bytes to a vector of coefficients, each of which is less
-// than the plaintext modulus
-std::vector<std::uint64_t> bytes_to_coeffs(std::uint32_t limit, const std::uint8_t *bytes,
-                                           std::uint64_t size);
-
-// Converts an array of coefficients into an array of bytes
-void coeffs_to_bytes(std::uint32_t logtp, const seal::Plaintext &coeffs, std::uint8_t *output,
-                     std::uint32_t size_out);
-
-// Takes a vector of coefficients and returns the corresponding FV plaintext
-void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs, seal::Plaintext &plain);
-
-// Since the database has d dimensions, and an item is a particular cell
-// in the d-dimensional hypercube, this function computes the corresponding
-// index for each of the d dimensions
-std::vector<std::uint64_t> compute_indices(std::uint64_t desiredIndex,
-                                           std::vector<std::uint64_t> nvec);
-
-// Serialize and deserialize ciphertexts to send them over the network
-PirQuery deserialize_query(std::uint32_t d, uint32_t count, std::string s, std::uint32_t len_ciphertext);
-std::vector<seal::Ciphertext> deserialize_ciphertexts(std::uint32_t count, std::string s,
-                                                      std::uint32_t len_ciphertext);
-std::string serialize_ciphertexts(std::vector<seal::Ciphertext> c);
-std::string serialize_query(std::vector<std::vector<seal::Ciphertext>> c);
-
-// Serialize and deserialize galois keys to send them over the network
-std::string serialize_galoiskeys(seal::GaloisKeys g);
-seal::GaloisKeys *deserialize_galoiskeys(std::string s);

+ 0 - 248
pir_client.cpp

@@ -1,248 +0,0 @@
-#include "pir_client.hpp"
-
-using namespace std;
-using namespace seal;
-using namespace seal::util;
-
-PIRClient::PIRClient(const EncryptionParameters &params,
-                     const PirParams &pir_parms) :
-    params_(params){
-
-    newcontext_ = SEALContext::Create(params_);
-
-    pir_params_ = pir_parms;
-
-    keygen_ = make_unique<KeyGenerator>(newcontext_);
-    encryptor_ = make_unique<Encryptor>(newcontext_, keygen_->public_key());
-
-    SecretKey secret_key = keygen_->secret_key();
-
-    decryptor_ = make_unique<Decryptor>(newcontext_, secret_key);
-    evaluator_ = make_unique<Evaluator>(newcontext_);
-}
-
-
-PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
-
-    indices_ = compute_indices(desiredIndex, pir_params_.nvec);
-
-    compute_inverse_scales(); 
-
-    vector<vector<Ciphertext> > result(pir_params_.d);
-    int N = params_.poly_modulus_degree(); 
-
-    Plaintext pt(params_.poly_modulus_degree());
-    for (uint32_t i = 0; i < indices_.size(); i++) {
-        uint32_t num_ptxts = ceil( (pir_params_.nvec[i] + 0.0) / N);
-        // initialize result. 
-        cout << "Client: index " << i + 1  <<  "/ " <<  indices_.size() << " = " << indices_[i] << endl; 
-        cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
-        for (uint32_t j =0; j < num_ptxts; j++){
-            pt.set_zero();
-            if (indices_[i] > N*(j+1) || indices_[i] < N*j){
-#ifdef DEBUG
-                cout << "Client: coming here: so just encrypt zero." << endl; 
-#endif 
-                // just encrypt zero
-            } else{
-#ifdef DEBUG
-                cout << "Client: encrypting a real thing " << endl; 
-#endif 
-                uint64_t real_index = indices_[i] - N*j; 
-                pt[real_index] = 1;
-            }
-            Ciphertext dest;
-            encryptor_->encrypt(pt, dest);
-            dest.parms_id() = params_.parms_id();
-            result[i].push_back(dest);
-        }   
-    }
-
-    return result;
-}
-
-uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
-    auto N = params_.poly_modulus_degree();
-    auto logt = floor(log2(params_.plain_modulus().value()));
-
-    auto ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
-    return static_cast<uint64_t>(element_idx / ele_per_ptxt);
-}
-
-uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
-    uint32_t N = params_.poly_modulus_degree();
-    uint32_t logt = floor(log2(params_.plain_modulus().value()));
-
-    uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
-    return element_idx % ele_per_ptxt;
-}
-
-Plaintext PIRClient::decode_reply(PirReply reply) {
-    uint32_t exp_ratio = pir_params_.expansion_ratio;
-    uint32_t recursion_level = pir_params_.d;
-
-    vector<Ciphertext> temp = reply;
-
-    uint64_t t = params_.plain_modulus().value();
-
-    for (uint32_t i = 0; i < recursion_level; i++) {
-        cout << "Client: " << i + 1 << "/ " << recursion_level << "-th decryption layer started." << endl; 
-        vector<Ciphertext> newtemp;
-        vector<Plaintext> tempplain;
-
-        for (uint32_t j = 0; j < temp.size(); j++) {
-            Plaintext ptxt;
-            decryptor_->decrypt(temp[j], ptxt);
-#ifdef DEBUG
-            cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl; 
-#endif
-            // multiply by inverse_scale for every coefficient of ptxt
-            for(int h = 0; h < ptxt.coeff_count(); h++){
-                ptxt[h] *= inverse_scales_[recursion_level -  1 - i]; 
-                ptxt[h] %= t; 
-            }
-            //cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
-            tempplain.push_back(ptxt);
-
-#ifdef DEBUG
-            cout << "recursion level : " << i << " noise budget :  ";
-            cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
-#endif
-
-            if ((j + 1) % exp_ratio == 0 && j > 0) {
-                // Combine into one ciphertext.
-                Ciphertext combined = compose_to_ciphertext(tempplain);
-                newtemp.push_back(combined);
-                tempplain.clear();
-                // cout << "Client: const term of ciphertext = " << combined[0] << endl; 
-            }
-        }
-        cout << "Client: done." << endl; 
-        cout << endl; 
-        if (i == recursion_level - 1) {
-            assert(temp.size() == 1);
-            return tempplain[0];
-        } else {
-            tempplain.clear();
-            temp = newtemp;
-        }
-    }
-
-    // This should never be called
-    assert(0);
-    Plaintext fail;
-    return fail;
-}
-
-GaloisKeys PIRClient::generate_galois_keys() {
-    // Generate the Galois keys needed for coeff_select.
-    vector<uint64_t> galois_elts;
-    int N = params_.poly_modulus_degree();
-    int logN = get_power_of_two(N);
-
-    //cout << "printing galois elements...";
-    for (int i = 0; i < logN; i++) {
-        galois_elts.push_back((N + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
-//#ifdef DEBUG
-        // cout << galois_elts.back() << ", ";
-//#endif
-    }
-
-    return keygen_->galois_keys(pir_params_.dbc, galois_elts);
-}
-
-Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
-    size_t encrypted_count = 2;
-    auto coeff_count = params_.poly_modulus_degree();
-    auto coeff_mod_count = params_.coeff_modulus().size();
-    uint64_t plainMod = params_.plain_modulus().value();
-    int logt = floor(log2(plainMod)); 
-
-    Ciphertext result(newcontext_);
-    result.resize(encrypted_count);
-
-    // A triple for loop. Going over polys, moduli, and decomposed index.
-    for (int i = 0; i < encrypted_count; i++) {
-        uint64_t *encrypted_pointer = result.data(i);
-
-        for (int j = 0; j < coeff_mod_count; j++) {
-            // populate one poly at a time.
-            // create a polynomial to store the current decomposition value
-            // which will be copied into the array to populate it at the current
-            // index.
-            double logqj = log2(params_.coeff_modulus()[j].value());
-            int expansion_ratio = ceil(logqj / logt);
-            uint64_t cur = 1;
-            // cout << "Client: expansion_ratio = " << expansion_ratio << endl; 
-
-            for (int k = 0; k < expansion_ratio; k++) {
-                // Compose here
-                const uint64_t *plain_coeff =
-                    plains[k + j * (expansion_ratio) + i * (coeff_mod_count * expansion_ratio)]
-                        .data();
-
-                for (int m = 0; m < coeff_count; m++) {
-                    if (k == 0) {
-                        *(encrypted_pointer + m + j * coeff_count) = *(plain_coeff + m) * cur;
-                    } else {
-                        *(encrypted_pointer + m + j * coeff_count) += *(plain_coeff + m) * cur;
-                    }
-                }
-                // *(encrypted_pointer + coeff_count - 1 + j * coeff_count) = 0;
-                cur <<= logt;
-            }
-
-            // XXX: Reduction modulo qj. This is needed?
-            /*
-            for (int m = 0; m < coeff_count; m++) {
-                *(encrypted_pointer + m + j * coeff_count) %=
-                    params_.coeff_modulus()[j].value();
-            }
-            */
-        }
-    }
-
-    result.parms_id() = params_.parms_id();
-    return result;
-}
-
-
-void PIRClient::compute_inverse_scales(){
-    if (indices_.size() != pir_params_.nvec.size()){
-        throw invalid_argument("size mismatch"); 
-    }
-    int logt = floor(log2(params_.plain_modulus().value())); 
-
-    uint64_t N = params_.poly_modulus_degree(); 
-    uint64_t t = params_.plain_modulus().value();
-    int logN = log2(N);
-    int logm = logN;
-
-    inverse_scales_.clear(); 
-
-    for(int i = 0; i < pir_params_.nvec.size(); i++){
-        uint64_t index_modN = indices_[i] % N; 
-        uint64_t numCtxt = ceil ( (pir_params_.nvec[i] + 0.0) / N);  // number of query ciphertexts. 
-        uint64_t batchId = indices_[i] / N;  
-        if (batchId == numCtxt - 1) {
-            cout << "Client: adjusting the logm value..." << endl; 
-            logm = ceil(log2((pir_params_.nvec[i] % N)));
-        }
-
-        uint64_t inverse_scale; 
- 
-
-        int quo = logm / logt; 
-        int mod = logm % logt; 
-        inverse_scale = pow(2, logt - mod); 
-        if ((quo +1) %2 != 0){
-            inverse_scale =  params_.plain_modulus().value() - pow(2, logt - mod); 
-        }
-        inverse_scales_.push_back(inverse_scale); 
-        if ( (inverse_scale << logm)  % t != 1){
-            throw logic_error("something wrong"); 
-        }
-        cout << "Client: logm, inverse scale, t = " << logm << ", " << inverse_scale << ", " << t << endl; 
-    }
-}
-

+ 0 - 41
pir_client.hpp

@@ -1,41 +0,0 @@
-#pragma once
-
-#include "pir.hpp"
-#include <memory>
-#include <vector>
-
-using namespace std; 
-
-class PIRClient {
-  public:
-    PIRClient(const seal::EncryptionParameters &parms,
-               const PirParams &pirparms);
-
-    PirQuery generate_query(std::uint64_t desiredIndex);
-    seal::Plaintext decode_reply(PirReply reply);
-
-    seal::GaloisKeys generate_galois_keys();
-
-    // Index and offset of an element in an FV plaintext
-    uint64_t get_fv_index(uint64_t element_idx, uint64_t ele_size);
-    uint64_t get_fv_offset(uint64_t element_idx, uint64_t ele_size);
-
-    void compute_inverse_scales(); 
-
-  private:
-    seal::EncryptionParameters params_;
-    PirParams pir_params_;
-
-    std::unique_ptr<seal::Encryptor> encryptor_;
-    std::unique_ptr<seal::Decryptor> decryptor_;
-    std::unique_ptr<seal::Evaluator> evaluator_;
-    std::unique_ptr<seal::KeyGenerator> keygen_;
-    std::shared_ptr<seal::SEALContext> newcontext_;
-
-    vector<uint64_t> indices_; // the indices for retrieval. 
-    vector<uint64_t> inverse_scales_; 
-
-    seal::Ciphertext compose_to_ciphertext(std::vector<seal::Plaintext> plains);
-
-    friend class PIRServer;
-};

+ 0 - 451
pir_server.cpp

@@ -1,451 +0,0 @@
-#include "pir_server.hpp"
-#include "pir_client.hpp"
-
-using namespace std;
-using namespace seal;
-using namespace seal::util;
-
-PIRServer::PIRServer(const EncryptionParameters &params, const PirParams &pir_params) :
-    params_(params), 
-    pir_params_(pir_params),
-    is_db_preprocessed_(false)
-{
-    auto context = SEALContext::Create(params, false);
-    evaluator_ = make_unique<Evaluator>(context);
-}
-
-void PIRServer::preprocess_database() {
-    if (!is_db_preprocessed_) {
-
-        for (uint32_t i = 0; i < db_->size(); i++) {
-            evaluator_->transform_to_ntt_inplace(
-                db_->operator[](i), params_.parms_id());
-        }
-
-        is_db_preprocessed_ = true;
-    }
-}
-
-// Server takes over ownership of db and will free it when it exits
-void PIRServer::set_database(unique_ptr<vector<Plaintext>> &&db) {
-    if (!db) {
-        throw invalid_argument("db cannot be null");
-    }
-
-    db_ = move(db);
-    is_db_preprocessed_ = false;
-}
-
-void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes, 
-    uint64_t ele_num, uint64_t ele_size) {
-
-    uint32_t logt = floor(log2(params_.plain_modulus().value()));
-    uint32_t N = params_.poly_modulus_degree();
-
-    // number of FV plaintexts needed to represent all elements
-    uint64_t total = plaintexts_per_db(logt, N, ele_num, ele_size);
-
-    // number of FV plaintexts needed to create the d-dimensional matrix
-    uint64_t prod = 1;
-    for (uint32_t i = 0; i < pir_params_.nvec.size(); i++) {
-        prod *= pir_params_.nvec[i];
-    }
-    uint64_t matrix_plaintexts = prod;
-    assert(total <= matrix_plaintexts);
-
-    auto result = make_unique<vector<Plaintext>>();
-    result->reserve(matrix_plaintexts);
-
-    uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
-    uint64_t bytes_per_ptxt = ele_per_ptxt * ele_size;
-
-    uint64_t db_size = ele_num * ele_size;
-
-    uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, ele_size);
-    assert(coeff_per_ptxt <= N);
-
-    cout << "Server: total number of FV plaintext = " << total << endl;
-    cout << "Server: elements packed into each plaintext " << ele_per_ptxt << endl; 
-
-    uint32_t offset = 0;
-
-    for (uint64_t i = 0; i < total; i++) {
-
-        uint64_t process_bytes = 0;
-
-        if (db_size <= offset) {
-            break;
-        } else if (db_size < offset + bytes_per_ptxt) {
-            process_bytes = db_size - offset;
-        } else {
-            process_bytes = bytes_per_ptxt;
-        }
-
-        // Get the coefficients of the elements that will be packed in plaintext i
-        vector<uint64_t> coefficients = bytes_to_coeffs(logt, bytes.get() + offset, process_bytes);
-        offset += process_bytes;
-
-        uint64_t used = coefficients.size();
-
-        assert(used <= coeff_per_ptxt);
-
-        // Pad the rest with 1s
-        for (uint64_t j = 0; j < (N - used); j++) {
-            coefficients.push_back(1);
-        }
-
-        Plaintext plain;
-        vector_to_plaintext(coefficients, plain);
-        // cout << i << "-th encoded plaintext = " << plain.to_string() << endl; 
-        result->push_back(move(plain));
-    }
-
-    // Add padding to make database a matrix
-    uint64_t current_plaintexts = result->size();
-    assert(current_plaintexts <= total);
-
-#ifdef DEBUG
-    cout << "adding: " << matrix_plaintexts - current_plaintexts
-         << " FV plaintexts of padding (equivalent to: "
-         << (matrix_plaintexts - current_plaintexts) * elements_per_ptxt(logtp, N, ele_size)
-         << " elements)" << endl;
-#endif
-
-    vector<uint64_t> padding(N, 1);
-
-    for (uint64_t i = 0; i < (matrix_plaintexts - current_plaintexts); i++) {
-        Plaintext plain;
-        vector_to_plaintext(padding, plain);
-        result->push_back(plain);
-    }
-
-    set_database(move(result));
-}
-
-void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey) {
-    galkey.parms_id() = params_.parms_id();
-    galoisKeys_[client_id] = galkey;
-}
-
-PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
-
-    vector<uint64_t> nvec = pir_params_.nvec;
-    uint64_t product = 1;
-
-    for (uint32_t i = 0; i < nvec.size(); i++) {
-        product *= nvec[i];
-    }
-
-    auto coeff_count = params_.poly_modulus_degree();
-
-    vector<Plaintext> *cur = db_.get();
-    vector<Plaintext> intermediate_plain; // decompose....
-
-    auto pool = MemoryManager::GetPool();
-
-
-    int N = params_.poly_modulus_degree();
-
-    int logt = floor(log2(params_.plain_modulus().value()));
-
-    cout << "expansion ratio = " << pir_params_.expansion_ratio << endl; 
-    for (uint32_t i = 0; i < nvec.size(); i++) {
-        cout << "Server: " << i + 1 << "-th recursion level started " << endl; 
-
-
-        vector<Ciphertext> expanded_query; 
-
-        uint64_t n_i = nvec[i];
-        cout << "Server: n_i = " << n_i << endl; 
-        cout << "Server: expanding " << query[i].size() << " query ctxts" << endl;
-        for (uint32_t j = 0; j < query[i].size(); j++){
-            uint64_t total = N; 
-            if (j == query[i].size() - 1){
-                total = n_i % N; 
-            }
-            cout << "-- expanding one query ctxt into " << total  << " ctxts "<< endl;
-            vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id);
-            expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()), 
-                    std::make_move_iterator(expanded_query_part.end()));
-            expanded_query_part.clear(); 
-        }
-        cout << "Server: expansion done " << endl; 
-        if (expanded_query.size() != n_i) {
-            cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl; 
-        }    
-
-        /*
-        cout << "Checking expanded query " << endl; 
-        Plaintext tempPt; 
-        for (int h = 0 ; h < expanded_query.size(); h++){
-            cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", "; 
-            client.decryptor_->decrypt(expanded_query[h], tempPt); 
-            cout << tempPt.to_string()  << endl; 
-        }
-        cout << endl;
-        */
-
-        // Transform expanded query to NTT, and ...
-        for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
-            evaluator_->transform_to_ntt_inplace(expanded_query[jj]);
-        }
-
-        // Transform plaintext to NTT. If database is pre-processed, can skip
-        if ((!is_db_preprocessed_) || i > 0) {
-            for (uint32_t jj = 0; jj < cur->size(); jj++) {
-                evaluator_->transform_to_ntt_inplace((*cur)[jj], params_.parms_id());
-            }
-        }
-
-        for (uint64_t k = 0; k < product; k++) {
-            if ((*cur)[k].is_zero()){
-                cout << k + 1 << "/ " << product <<  "-th ptxt = 0 " << endl; 
-            }
-        }
-
-        product /= n_i;
-
-        vector<Ciphertext> intermediateCtxts(product);
-        Ciphertext temp;
-
-        for (uint64_t k = 0; k < product; k++) {
-
-            evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediateCtxts[k]);
-
-            for (uint64_t j = 1; j < n_i; j++) {
-                evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product], temp);
-                evaluator_->add_inplace(intermediateCtxts[k], temp); // Adds to first component.
-            }
-        }
-
-        for (uint32_t jj = 0; jj < intermediateCtxts.size(); jj++) {
-            evaluator_->transform_from_ntt_inplace(intermediateCtxts[jj]);
-            // print intermediate ctxts? 
-            //cout << "const term of ctxt " << jj << " = " << intermediateCtxts[jj][0] << endl; 
-        }
-
-        if (i == nvec.size() - 1) {
-            return intermediateCtxts;
-        } else {
-            intermediate_plain.clear();
-            intermediate_plain.reserve(pir_params_.expansion_ratio * product);
-            cur = &intermediate_plain;
-
-            auto tempplain = util::allocate<Plaintext>(
-                pir_params_.expansion_ratio * product,
-                pool, coeff_count);
-
-            for (uint64_t rr = 0; rr < product; rr++) {
-
-                decompose_to_plaintexts_ptr(intermediateCtxts[rr],
-                    tempplain.get() + rr * pir_params_.expansion_ratio, logt);
-
-                for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
-                    auto offset = rr * pir_params_.expansion_ratio + jj;
-                    intermediate_plain.emplace_back(tempplain[offset]);
-                }
-            }
-            product *= pir_params_.expansion_ratio; // multiply by expansion rate.
-        }
-        cout << "Server: " << i + 1 << "-th recursion level finished " << endl; 
-        cout << endl;
-    }
-    cout << "reply generated!  " << endl;
-    // This should never get here
-    assert(0);
-    vector<Ciphertext> fail(1);
-    return fail;
-}
-
-inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, uint32_t m,
-                                           uint32_t client_id) {
-
-#ifdef DEBUG
-    uint64_t plainMod = params_.plain_modulus().value();
-    cout << "PIRServer side plain modulus = " << plainMod << endl;
-#endif
-
-    GaloisKeys &galkey = galoisKeys_[client_id];
-
-    // Assume that m is a power of 2. If not, round it to the next power of 2.
-    uint32_t logm = ceil(log2(m));
-    Plaintext two("2");
-
-    vector<int> galois_elts;
-    auto n = params_.poly_modulus_degree();
-    if (logm > ceil(log2(n))){
-        throw logic_error("m > n is not allowed."); 
-    }
-    for (int i = 0; i < ceil(log2(n)); i++) {
-        galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
-    }
-
-    vector<Ciphertext> temp;
-    temp.push_back(encrypted);
-    Ciphertext tempctxt;
-    Ciphertext tempctxt_rotated;
-    Ciphertext tempctxt_shifted;
-    Ciphertext tempctxt_rotatedshifted;
-
-
-    for (uint32_t i = 0; i < logm - 1; i++) {
-        vector<Ciphertext> newtemp(temp.size() << 1);
-        // temp[a] = (j0 = a (mod 2**i) ? ) : Enc(x^{j0 - a}) else Enc(0).  With
-        // some scaling....
-        int index_raw = (n << 1) - (1 << i);
-        int index = (index_raw * galois_elts[i]) % (n << 1);
-
-        for (uint32_t a = 0; a < temp.size(); a++) {
-
-            evaluator_->apply_galois(temp[a], galois_elts[i], galkey, tempctxt_rotated);
-
-            //cout << "rotate " << client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", "; 
-
-            evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
-            multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
-
-            //cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", "; 
-
-
-            multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
-
-            // cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) << ", "; 
-
-
-            // Enc(2^i x^j) if j = 0 (mod 2**i).
-            evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
-        }
-        temp = newtemp;
-        /*
-        cout << "end: "; 
-        for (int h = 0; h < temp.size();h++){
-            cout << client.decryptor_->invariant_noise_budget(temp[h]) << ", "; 
-        }
-        cout << endl; 
-        */
-    }
-    // Last step of the loop
-    vector<Ciphertext> newtemp(temp.size() << 1);
-    int index_raw = (n << 1) - (1 << (logm - 1));
-    int index = (index_raw * galois_elts[logm - 1]) % (n << 1);
-    for (uint32_t a = 0; a < temp.size(); a++) {
-        if (a >= (m - (1 << (logm - 1)))) {                       // corner case.
-            evaluator_->multiply_plain(temp[a], two, newtemp[a]); // plain multiplication by 2.
-            // cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", "; 
-        } else {
-            evaluator_->apply_galois(temp[a], galois_elts[logm - 1], galkey, tempctxt_rotated);
-            evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
-            multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
-            multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
-            evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
-        }
-    }
-
-    vector<Ciphertext>::const_iterator first = newtemp.begin();
-    vector<Ciphertext>::const_iterator last = newtemp.begin() + m;
-    vector<Ciphertext> newVec(first, last);
-    return newVec;
-}
-
-inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Ciphertext &destination,
-                                    uint32_t index) {
-
-    auto coeff_mod_count = params_.coeff_modulus().size();
-    auto coeff_count = params_.poly_modulus_degree();
-    auto encrypted_count = encrypted.size();
-
-    //cout << "coeff mod count for power of X = " << coeff_mod_count << endl; 
-    //cout << "coeff count for power of X = " << coeff_count << endl; 
-
-    // First copy over.
-    destination = encrypted;
-
-    // Prepare for destination
-    // Multiply X^index for each ciphertext polynomial
-    for (int i = 0; i < encrypted_count; i++) {
-        for (int j = 0; j < coeff_mod_count; j++) {
-            negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
-                                           coeff_count, index,
-                                           params_.coeff_modulus()[j],
-                                           destination.data(i) + (j * coeff_count));
-        }
-    }
-}
-
-inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr, int logt) {
-
-    vector<Plaintext> result;
-    auto coeff_count = params_.poly_modulus_degree();
-    auto coeff_mod_count = params_.coeff_modulus().size();
-    auto encrypted_count = encrypted.size();
-
-    uint64_t t1 = 1 << logt;  //  t1 <= t. 
-
-    uint64_t t1minusone =  t1 -1; 
-    // A triple for loop. Going over polys, moduli, and decomposed index.
-
-    for (int i = 0; i < encrypted_count; i++) {
-        const uint64_t *encrypted_pointer = encrypted.data(i);
-        for (int j = 0; j < coeff_mod_count; j++) {
-            // populate one poly at a time.
-            // create a polynomial to store the current decomposition value
-            // which will be copied into the array to populate it at the current
-            // index.
-            double logqj = log2(params_.coeff_modulus()[j].value());
-            //int expansion_ratio = ceil(logqj + exponent - 1) / exponent;
-            int expansion_ratio =  ceil(logqj / logt); 
-            // cout << "local expansion ratio = " << expansion_ratio << endl;
-            uint64_t curexp = 0;
-            for (int k = 0; k < expansion_ratio; k++) {
-                // Decompose here
-                for (int m = 0; m < coeff_count; m++) {
-                    plain_ptr[i * coeff_mod_count * expansion_ratio
-                        + j * expansion_ratio + k][m] =
-                        (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & t1minusone;
-                }
-                curexp += logt;
-            }
-        }
-    }
-}
-
-vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted) {
-    vector<Plaintext> result;
-    auto coeff_count = params_.poly_modulus_degree();
-    auto coeff_mod_count = params_.coeff_modulus().size();
-    auto plain_bit_count = params_.plain_modulus().bit_count();
-    auto encrypted_count = encrypted.size();
-
-    // Generate powers of t.
-    uint64_t plainMod = params_.plain_modulus().value();
-
-    // A triple for loop. Going over polys, moduli, and decomposed index.
-    for (int i = 0; i < encrypted_count; i++) {
-        const uint64_t *encrypted_pointer = encrypted.data(i);
-        for (int j = 0; j < coeff_mod_count; j++) {
-            // populate one poly at a time.
-            // create a polynomial to store the current decomposition value
-            // which will be copied into the array to populate it at the current
-            // index.
-            int logqj = log2(params_.coeff_modulus()[j].value());
-            int expansion_ratio = ceil(logqj / log2(plainMod));
-
-            // cout << "expansion ratio = " << expansion_ratio << endl;
-            uint64_t cur = 1;
-            for (int k = 0; k < expansion_ratio; k++) {
-                // Decompose here
-                Plaintext temp(coeff_count);
-                transform(encrypted_pointer + (j * coeff_count), 
-                        encrypted_pointer + ((j + 1) * coeff_count), 
-                        temp.data(),
-                        [cur, &plainMod](auto &in) { return (in / cur) % plainMod; }
-                );
-
-                result.emplace_back(move(temp));
-                cur *= plainMod;
-            }
-        }
-    }
-
-    return result;
-}

+ 0 - 38
pir_server.hpp

@@ -1,38 +0,0 @@
-#pragma once
-
-#include "pir.hpp"
-#include <map>
-#include <memory>
-#include <vector>
-#include "pir_client.hpp"
-
-class PIRServer {
-  public:
-    PIRServer(const seal::EncryptionParameters &params, const PirParams &pir_params);
-
-    // NOTE: server takes over ownership of db and frees it when it exits.
-    // Caller cannot free db
-    void set_database(std::unique_ptr<std::vector<seal::Plaintext>> &&db);
-    void set_database(const std::unique_ptr<const std::uint8_t[]> &bytes, std::uint64_t ele_num, std::uint64_t ele_size);
-    void preprocess_database();
-
-    std::vector<seal::Ciphertext> expand_query(
-            const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id);
-
-    PirReply generate_reply(PirQuery query, std::uint32_t client_id);
-
-    void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);
-
-  private:
-    seal::EncryptionParameters params_; // SEAL parameters
-    PirParams pir_params_;              // PIR parameters
-    std::unique_ptr<Database> db_;
-    bool is_db_preprocessed_;
-    std::map<int, seal::GaloisKeys> galoisKeys_;
-    std::unique_ptr<seal::Evaluator> evaluator_;
-
-    void decompose_to_plaintexts_ptr(const seal::Ciphertext &encrypted, seal::Plaintext *plain_ptr, int logt);
-    std::vector<seal::Plaintext> decompose_to_plaintexts(const seal::Ciphertext &encrypted);
-    void multiply_power_of_X(const seal::Ciphertext &encrypted, seal::Ciphertext &destination,
-                             std::uint32_t index);
-};

+ 8 - 0
src/CMakeLists.txt

@@ -0,0 +1,8 @@
+find_package(SEAL 4.0 REQUIRED)
+
+add_library(sealpir pir.hpp pir.cpp pir_client.hpp pir_client.cpp pir_server.hpp
+  pir_server.cpp)
+target_link_libraries(sealpir SEAL::seal)
+
+add_executable(main main.cpp)
+target_link_libraries(main sealpir)

+ 185 - 0
src/main.cpp

@@ -0,0 +1,185 @@
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include "pir_server.hpp"
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <random>
+#include <seal/seal.h>
+
+using namespace std::chrono;
+using namespace std;
+using namespace seal;
+
+int main(int argc, char *argv[]) {
+
+  uint64_t number_of_items = 1 << 16;
+  uint64_t size_per_item = 1024; // in bytes
+  uint32_t N = 4096;
+
+  // Recommended values: (logt, d) = (20, 2).
+  uint32_t logt = 20;
+  uint32_t d = 2;
+  bool use_symmetric = true; // use symmetric encryption instead of public key
+                             // (recommended for smaller query)
+  bool use_batching = true;  // pack as many elements as possible into a BFV
+                             // plaintext (recommended)
+  bool use_recursive_mod_switching = true;
+
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
+
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params,
+                 use_symmetric, use_batching, use_recursive_mod_switching);
+
+  print_seal_params(enc_params);
+  print_pir_params(pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  cout << "Main: Generating galois keys for client" << endl;
+
+  GaloisKeys galois_keys = client.generate_galois_keys();
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Server maps the galois key to client 0. We only have 1 client,
+  // which is why we associate it with 0. If there are multiple PIR
+  // clients, you should have each client generate a galois key,
+  // and assign each client an index or id, then call the procedure below.
+  server.set_galois_key(0, galois_keys);
+
+  cout << "Main: Creating the database with random data (this may take some "
+          "time) ..."
+       << endl;
+
+  // Create test database
+  auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  // Copy of the database. We use this at the end to make sure we retrieved
+  // the correct element.
+  auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  random_device rd;
+  for (uint64_t i = 0; i < number_of_items; i++) {
+    for (uint64_t j = 0; j < size_per_item; j++) {
+      uint8_t val = rd() % 256;
+      db.get()[(i * size_per_item) + j] = val;
+      db_copy.get()[(i * size_per_item) + j] = val;
+    }
+  }
+
+  // Measure database setup
+  auto time_pre_s = high_resolution_clock::now();
+  server.set_database(move(db), number_of_items, size_per_item);
+  server.preprocess_database();
+  auto time_pre_e = high_resolution_clock::now();
+  auto time_pre_us =
+      duration_cast<microseconds>(time_pre_e - time_pre_s).count();
+  cout << "Main: database pre processed " << endl;
+
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Measure query generation
+  auto time_query_s = high_resolution_clock::now();
+  PirQuery query = client.generate_query(index);
+  auto time_query_e = high_resolution_clock::now();
+  auto time_query_us =
+      duration_cast<microseconds>(time_query_e - time_query_s).count();
+  cout << "Main: query generated" << endl;
+
+  // Measure serialized query generation (useful for sending over the network)
+  stringstream client_stream;
+  stringstream server_stream;
+  auto time_s_query_s = high_resolution_clock::now();
+  int query_size = client.generate_serialized_query(index, client_stream);
+  auto time_s_query_e = high_resolution_clock::now();
+  auto time_s_query_us =
+      duration_cast<microseconds>(time_s_query_e - time_s_query_s).count();
+  cout << "Main: serialized query generated" << endl;
+
+  // Measure query deserialization (useful for receiving over the network)
+  auto time_deserial_s = high_resolution_clock::now();
+  PirQuery query2 = server.deserialize_query(client_stream);
+  auto time_deserial_e = high_resolution_clock::now();
+  auto time_deserial_us =
+      duration_cast<microseconds>(time_deserial_e - time_deserial_s).count();
+  cout << "Main: query deserialized" << endl;
+
+  // Measure query processing (including expansion)
+  auto time_server_s = high_resolution_clock::now();
+  // Answer PIR query from client 0. If there are multiple clients,
+  // enter the id of the client (to use the associated galois key).
+  PirReply reply = server.generate_reply(query2, 0);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+  cout << "Main: reply generated" << endl;
+
+  // Serialize reply (useful for sending over the network)
+  int reply_size = server.serialize_reply(reply, server_stream);
+
+  // Measure response extraction
+  auto time_decode_s = chrono::high_resolution_clock::now();
+  vector<uint8_t> elems = client.decode_reply(reply, offset);
+  auto time_decode_e = chrono::high_resolution_clock::now();
+  auto time_decode_us =
+      duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+  cout << "Main: reply decoded" << endl;
+
+  assert(elems.size() == size_per_item);
+
+  bool failed = false;
+  // Check that we retrieved the correct element
+  for (uint32_t i = 0; i < size_per_item; i++) {
+    if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
+      cout << "Main: elems " << (int)elems[i] << ", db "
+           << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
+      cout << "Main: PIR result wrong at " << i << endl;
+      failed = true;
+    }
+  }
+  if (failed) {
+    return -1;
+  }
+
+  // Output results
+  cout << "Main: PIR result correct!" << endl;
+  cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
+       << endl;
+  cout << "Main: PIRClient query generation time: " << time_query_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient serialized query generation time: "
+       << time_s_query_us / 1000 << " ms" << endl;
+  cout << "Main: PIRServer query deserialization time: " << time_deserial_us
+       << " us" << endl;
+  cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
+       << " ms" << endl;
+  cout << "Main: Query size: " << query_size << " bytes" << endl;
+  cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
+  cout << "Main: Reply size: " << reply_size << " bytes" << endl;
+
+  return 0;
+}

+ 386 - 0
src/pir.cpp

@@ -0,0 +1,386 @@
+#include "pir.hpp"
+
+using namespace std;
+using namespace seal;
+using namespace seal::util;
+
+std::vector<std::uint64_t> get_dimensions(std::uint64_t num_of_plaintexts,
+                                          std::uint32_t d) {
+
+  assert(d > 0);
+  assert(num_of_plaintexts > 0);
+
+  std::uint64_t root =
+      max(static_cast<uint32_t>(2),
+          static_cast<uint32_t>(floor(pow(num_of_plaintexts, 1.0 / d))));
+
+  std::vector<std::uint64_t> dimensions(d, root);
+
+  for (int i = 0; i < d; i++) {
+    if (accumulate(dimensions.begin(), dimensions.end(), 1,
+                   multiplies<uint64_t>()) > num_of_plaintexts) {
+      break;
+    }
+    dimensions[i] += 1;
+  }
+
+  std::uint32_t prod = accumulate(dimensions.begin(), dimensions.end(), 1,
+                                  multiplies<uint64_t>());
+  assert(prod >= num_of_plaintexts);
+  return dimensions;
+}
+
+void gen_encryption_params(std::uint32_t N, std::uint32_t logt,
+                           seal::EncryptionParameters &enc_params) {
+
+  enc_params.set_poly_modulus_degree(N);
+  enc_params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
+  enc_params.set_plain_modulus(PlainModulus::Batching(N, logt + 1));
+  // the +1 above ensures we get logt bits for each plaintext coefficient.
+  // Otherwise the coefficient modulus t will be logt bits, but only floor(t) =
+  // logt-1 (whp) will be usable (since we need to ensure that all data in the
+  // coefficient is < t).
+}
+
+void verify_encryption_params(const seal::EncryptionParameters &enc_params) {
+  SEALContext context(enc_params, true);
+  if (!context.parameters_set()) {
+    throw invalid_argument("SEAL parameters not valid.");
+  }
+  if (!context.using_keyswitching()) {
+    throw invalid_argument("SEAL parameters do not support key switching.");
+  }
+  if (!context.first_context_data()->qualifiers().using_batching) {
+    throw invalid_argument("SEAL parameters do not support batching.");
+  }
+
+  BatchEncoder batch_encoder(context);
+  size_t slot_count = batch_encoder.slot_count();
+  if (slot_count != enc_params.poly_modulus_degree()) {
+    throw invalid_argument("Slot count not equal to poly modulus degree - this "
+                           "will cause issues downstream.");
+  }
+
+  return;
+}
+
+void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
+                    const EncryptionParameters &enc_params,
+                    PirParams &pir_params, bool enable_symmetric,
+                    bool enable_batching, bool enable_mswitching) {
+  std::uint32_t N = enc_params.poly_modulus_degree();
+  Modulus t = enc_params.plain_modulus();
+  std::uint32_t logt = floor(log2(t.value())); // # of usable bits
+  std::uint64_t elements_per_plaintext;
+  std::uint64_t num_of_plaintexts;
+
+  if (enable_batching) {
+    elements_per_plaintext = elements_per_ptxt(logt, N, ele_size);
+    num_of_plaintexts = plaintexts_per_db(logt, N, ele_num, ele_size);
+  } else {
+    elements_per_plaintext = 1;
+    num_of_plaintexts = ele_num;
+  }
+
+  vector<uint64_t> nvec = get_dimensions(num_of_plaintexts, d);
+
+  uint32_t expansion_ratio = 0;
+  for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
+    double logqi = log2(enc_params.coeff_modulus()[i].value());
+    expansion_ratio += ceil(logqi / logt);
+  }
+
+  pir_params.enable_symmetric = enable_symmetric;
+  pir_params.enable_batching = enable_batching;
+  pir_params.enable_mswitching = enable_mswitching;
+  pir_params.ele_num = ele_num;
+  pir_params.ele_size = ele_size;
+  pir_params.elements_per_plaintext = elements_per_plaintext;
+  pir_params.num_of_plaintexts = num_of_plaintexts;
+  pir_params.d = d;
+  pir_params.expansion_ratio = expansion_ratio << 1;
+  pir_params.nvec = nvec;
+  pir_params.slot_count = N;
+}
+
+void print_pir_params(const PirParams &pir_params) {
+  std::uint32_t prod =
+      accumulate(pir_params.nvec.begin(), pir_params.nvec.end(), 1,
+                 multiplies<uint64_t>());
+
+  cout << "PIR Parameters" << endl;
+  cout << "number of elements: " << pir_params.ele_num << endl;
+  cout << "element size: " << pir_params.ele_size << endl;
+  cout << "elements per BFV plaintext: " << pir_params.elements_per_plaintext
+       << endl;
+  cout << "dimensions for d-dimensional hyperrectangle: " << pir_params.d
+       << endl;
+  cout << "number of BFV plaintexts (before padding): "
+       << pir_params.num_of_plaintexts << endl;
+  cout << "Number of BFV plaintexts after padding (to fill d-dimensional "
+          "hyperrectangle): "
+       << prod << endl;
+  cout << "expansion ratio: " << pir_params.expansion_ratio << endl;
+  cout << "Using symmetric encryption: " << pir_params.enable_symmetric << endl;
+  cout << "Using recursive mod switching: " << pir_params.enable_mswitching
+       << endl;
+  cout << "slot count: " << pir_params.slot_count << endl;
+  cout << "==============================" << endl;
+}
+
+void print_seal_params(const EncryptionParameters &enc_params) {
+  std::uint32_t N = enc_params.poly_modulus_degree();
+  Modulus t = enc_params.plain_modulus();
+  std::uint32_t logt = floor(log2(t.value()));
+
+  cout << "SEAL encryption parameters" << endl;
+  cout << "Degree of polynomial modulus (N): " << N << endl;
+  cout << "Size of plaintext modulus (log t):" << logt << endl;
+  cout << "There are " << enc_params.coeff_modulus().size()
+       << " coefficient modulus:" << endl;
+
+  for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
+    double logqi = log2(enc_params.coeff_modulus()[i].value());
+    cout << "Size of coefficient modulus " << i << " (log q_" << i
+         << "): " << logqi << endl;
+  }
+  cout << "==============================" << endl;
+}
+
+// Number of coefficients needed to represent a database element
+uint64_t coefficients_per_element(uint32_t logt, uint64_t ele_size) {
+  return ceil(8 * ele_size / (double)logt);
+}
+
+// Number of database elements that can fit in a single FV plaintext
+uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
+  uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);
+  uint64_t ele_per_ptxt = N / coeff_per_ele;
+  assert(ele_per_ptxt > 0);
+  return ele_per_ptxt;
+}
+
+// Number of FV plaintexts needed to represent the database
+uint64_t plaintexts_per_db(uint32_t logt, uint64_t N, uint64_t ele_num,
+                           uint64_t ele_size) {
+  uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
+  return ceil((double)ele_num / ele_per_ptxt);
+}
+
+vector<uint64_t> bytes_to_coeffs(uint32_t limit, const uint8_t *bytes,
+                                 uint64_t size) {
+
+  uint64_t size_out = coefficients_per_element(limit, size);
+  vector<uint64_t> output(size_out);
+
+  uint32_t room = limit;
+  uint64_t *target = &output[0];
+
+  for (uint32_t i = 0; i < size; i++) {
+    uint8_t src = bytes[i];
+    uint32_t rest = 8;
+    while (rest) {
+      if (room == 0) {
+        target++;
+        room = limit;
+      }
+      uint32_t shift = rest;
+      if (room < rest) {
+        shift = room;
+      }
+      *target = *target << shift;
+      *target = *target | (src >> (8 - shift));
+      src = src << shift;
+      room -= shift;
+      rest -= shift;
+    }
+  }
+
+  *target = *target << room;
+  return output;
+}
+
+void coeffs_to_bytes(uint32_t limit, const vector<uint64_t> &coeffs,
+                     uint8_t *output, uint32_t size_out, uint32_t ele_size) {
+  uint32_t room = 8;
+  uint32_t j = 0;
+  uint8_t *target = output;
+  uint32_t bits_left = ele_size * 8;
+  for (uint32_t i = 0; i < coeffs.size(); i++) {
+    if (bits_left == 0) {
+      bits_left = ele_size * 8;
+    }
+    uint64_t src = coeffs[i];
+    uint32_t rest = min(limit, bits_left);
+    while (rest && j < size_out) {
+      uint32_t shift = rest;
+      if (room < rest) {
+        shift = room;
+      }
+
+      target[j] = target[j] << shift;
+      target[j] = target[j] | (src >> (limit - shift));
+      src = src << shift;
+      room -= shift;
+      rest -= shift;
+      bits_left -= shift;
+      if (room == 0) {
+        j++;
+        room = 8;
+      }
+    }
+  }
+}
+
+void vector_to_plaintext(const vector<uint64_t> &coeffs, Plaintext &plain) {
+  uint32_t coeff_count = coeffs.size();
+  plain.resize(coeff_count);
+  util::set_uint(coeffs.data(), coeff_count, plain.data());
+}
+
+vector<uint64_t> compute_indices(uint64_t desiredIndex, vector<uint64_t> Nvec) {
+  uint32_t num = Nvec.size();
+  uint64_t product = 1;
+
+  for (uint32_t i = 0; i < num; i++) {
+    product *= Nvec[i];
+  }
+
+  uint64_t j = desiredIndex;
+  vector<uint64_t> result;
+
+  for (uint32_t i = 0; i < num; i++) {
+
+    product /= Nvec[i];
+    uint64_t ji = j / product;
+
+    result.push_back(ji);
+    j -= ji * product;
+  }
+
+  return result;
+}
+
+uint64_t invert_mod(uint64_t m, const seal::Modulus &mod) {
+  if (mod.uint64_count() > 1) {
+    cout << "Mod too big to invert";
+  }
+  uint64_t inverse = 0;
+  if (!seal::util::try_invert_uint_mod(m, mod.value(), inverse)) {
+    cout << "Could not invert value";
+  }
+  return inverse;
+}
+
+uint32_t compute_expansion_ratio(EncryptionParameters params) {
+  uint32_t expansion_ratio = 0;
+  uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
+  for (size_t i = 0; i < params.coeff_modulus().size(); ++i) {
+    double coeff_bit_size = log2(params.coeff_modulus()[i].value());
+    expansion_ratio += ceil(coeff_bit_size / pt_bits_per_coeff);
+  }
+  return expansion_ratio;
+}
+
+vector<Plaintext> decompose_to_plaintexts(EncryptionParameters params,
+                                          const Ciphertext &ct) {
+  const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
+  const auto coeff_count = params.poly_modulus_degree();
+  const auto coeff_mod_count = params.coeff_modulus().size();
+  const uint64_t pt_bitmask = (1 << pt_bits_per_coeff) - 1;
+
+  vector<Plaintext> result(compute_expansion_ratio(params) * ct.size());
+  auto pt_iter = result.begin();
+  for (size_t poly_index = 0; poly_index < ct.size(); ++poly_index) {
+    for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count;
+         ++coeff_mod_index) {
+      const double coeff_bit_size =
+          log2(params.coeff_modulus()[coeff_mod_index].value());
+      const size_t local_expansion_ratio =
+          ceil(coeff_bit_size / pt_bits_per_coeff);
+      size_t shift = 0;
+      for (size_t i = 0; i < local_expansion_ratio; ++i) {
+        pt_iter->resize(coeff_count);
+        for (size_t c = 0; c < coeff_count; ++c) {
+          (*pt_iter)[c] =
+              (ct.data(poly_index)[coeff_mod_index * coeff_count + c] >>
+               shift) &
+              pt_bitmask;
+        }
+        ++pt_iter;
+        shift += pt_bits_per_coeff;
+      }
+    }
+  }
+  return result;
+}
+
+void compose_to_ciphertext(EncryptionParameters params,
+                           vector<Plaintext>::const_iterator pt_iter,
+                           const size_t ct_poly_count, Ciphertext &ct) {
+  const uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
+  const auto coeff_count = params.poly_modulus_degree();
+  const auto coeff_mod_count = params.coeff_modulus().size();
+
+  ct.resize(ct_poly_count);
+  for (size_t poly_index = 0; poly_index < ct_poly_count; ++poly_index) {
+    for (size_t coeff_mod_index = 0; coeff_mod_index < coeff_mod_count;
+         ++coeff_mod_index) {
+      const double coeff_bit_size =
+          log2(params.coeff_modulus()[coeff_mod_index].value());
+      const size_t local_expansion_ratio =
+          ceil(coeff_bit_size / pt_bits_per_coeff);
+      size_t shift = 0;
+      for (size_t i = 0; i < local_expansion_ratio; ++i) {
+        for (size_t c = 0; c < pt_iter->coeff_count(); ++c) {
+          if (shift == 0) {
+            ct.data(poly_index)[coeff_mod_index * coeff_count + c] =
+                (*pt_iter)[c];
+          } else {
+            ct.data(poly_index)[coeff_mod_index * coeff_count + c] +=
+                ((*pt_iter)[c] << shift);
+          }
+        }
+        ++pt_iter;
+        shift += pt_bits_per_coeff;
+      }
+    }
+  }
+}
+
+void compose_to_ciphertext(EncryptionParameters params,
+                           const vector<Plaintext> &pts, Ciphertext &ct) {
+  return compose_to_ciphertext(
+      params, pts.begin(), pts.size() / compute_expansion_ratio(params), ct);
+}
+
+PirQuery deserialize_query(uint32_t d, uint32_t count, string s,
+                           uint32_t len_ciphertext,
+                           shared_ptr<SEALContext> context) {
+  vector<vector<Ciphertext>> q;
+  std::istringstream input(s);
+
+  for (uint32_t i = 0; i < d; i++) {
+    vector<Ciphertext> cs;
+    for (uint32_t i = 0; i < count; i++) {
+      Ciphertext c;
+      c.load(*context, input);
+      cs.push_back(c);
+    }
+    q.push_back(cs);
+  }
+  return q;
+}
+
+string serialize_galoiskeys(Serializable<GaloisKeys> g) {
+  std::ostringstream output;
+  g.save(output);
+  return output.str();
+}
+
+GaloisKeys *deserialize_galoiskeys(string s, shared_ptr<SEALContext> context) {
+  GaloisKeys *g = new GaloisKeys();
+  std::istringstream input(s);
+  g->load(*context, input);
+  return g;
+}

+ 100 - 0
src/pir.hpp

@@ -0,0 +1,100 @@
+#pragma once
+
+#include "seal/seal.h"
+#include "seal/util/polyarithsmallmod.h"
+#include <cassert>
+#include <cmath>
+#include <string>
+#include <vector>
+
+typedef std::vector<seal::Plaintext> Database;
+typedef std::vector<std::vector<seal::Ciphertext>> PirQuery;
+typedef std::vector<seal::Ciphertext> PirReply;
+
+struct PirParams {
+  bool enable_symmetric;
+  bool enable_batching;
+  bool enable_mswitching;
+  std::uint64_t ele_num;
+  std::uint64_t ele_size;
+  std::uint64_t elements_per_plaintext;
+  std::uint64_t num_of_plaintexts; // number of plaintexts in database
+  std::uint32_t d;                 // number of dimensions for the database
+  std::uint32_t expansion_ratio;   // ratio of ciphertext to plaintext
+  std::vector<std::uint64_t> nvec; // size of each of the d dimensions
+  std::uint32_t slot_count;
+};
+
+void gen_encryption_params(std::uint32_t N,    // degree of polynomial
+                           std::uint32_t logt, // bits of plaintext coefficient
+                           seal::EncryptionParameters &enc_params);
+
+void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
+                    const seal::EncryptionParameters &enc_params,
+                    PirParams &pir_params, bool enable_symmetric = false,
+                    bool enable_batching = true, bool enable_mswitching = true);
+
+void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
+                uint32_t d, seal::EncryptionParameters &params,
+                PirParams &pir_params);
+
+void verify_encryption_params(const seal::EncryptionParameters &enc_params);
+
+void print_pir_params(const PirParams &pir_params);
+void print_seal_params(const seal::EncryptionParameters &enc_params);
+
+// returns the number of plaintexts that the database can hold
+std::uint64_t plaintexts_per_db(std::uint32_t logt, std::uint64_t N,
+                                std::uint64_t ele_num, std::uint64_t ele_size);
+
+// returns the number of elements that a single FV plaintext can hold
+std::uint64_t elements_per_ptxt(std::uint32_t logt, std::uint64_t N,
+                                std::uint64_t ele_size);
+
+// returns the number of coefficients needed to store one element
+std::uint64_t coefficients_per_element(std::uint32_t logt,
+                                       std::uint64_t ele_size);
+
+// Converts an array of bytes to a vector of coefficients, each of which is less
+// than the plaintext modulus
+std::vector<std::uint64_t> bytes_to_coeffs(std::uint32_t limit,
+                                           const std::uint8_t *bytes,
+                                           std::uint64_t size);
+
+// Converts an array of coefficients into an array of bytes
+void coeffs_to_bytes(std::uint32_t limit,
+                     const std::vector<std::uint64_t> &coeffs,
+                     std::uint8_t *output, std::uint32_t size_out,
+                     std::uint32_t ele_size);
+
+// Takes a vector of coefficients and returns the corresponding FV plaintext
+void vector_to_plaintext(const std::vector<std::uint64_t> &coeffs,
+                         seal::Plaintext &plain);
+
+// Since the database has d dimensions, and an item is a particular cell
+// in the d-dimensional hypercube, this function computes the corresponding
+// index for each of the d dimensions
+std::vector<std::uint64_t> compute_indices(std::uint64_t desiredIndex,
+                                           std::vector<std::uint64_t> nvec);
+
+uint64_t invert_mod(uint64_t m, const seal::Modulus &mod);
+
+uint32_t compute_expansion_ratio(seal::EncryptionParameters params);
+std::vector<seal::Plaintext>
+decompose_to_plaintexts(seal::EncryptionParameters params,
+                        const seal::Ciphertext &ct);
+
+// We need the returned ciphertext to be initialized by Context so the caller
+// will pass it in
+void compose_to_ciphertext(seal::EncryptionParameters params,
+                           const std::vector<seal::Plaintext> &pts,
+                           seal::Ciphertext &ct);
+void compose_to_ciphertext(seal::EncryptionParameters params,
+                           std::vector<seal::Plaintext>::const_iterator pt_iter,
+                           seal::Ciphertext &ct);
+
+// Serialize and deserialize galois keys to send them over the network
+std::string serialize_galoiskeys(seal::Serializable<seal::GaloisKeys> g);
+seal::GaloisKeys *
+deserialize_galoiskeys(std::string s,
+                       std::shared_ptr<seal::SEALContext> context);

+ 289 - 0
src/pir_client.cpp

@@ -0,0 +1,289 @@
+#include "pir_client.hpp"
+
+using namespace std;
+using namespace seal;
+using namespace seal::util;
+
+PIRClient::PIRClient(const EncryptionParameters &enc_params,
+                     const PirParams &pir_params)
+    : enc_params_(enc_params), pir_params_(pir_params) {
+
+  context_ = make_shared<SEALContext>(enc_params, true);
+
+  keygen_ = make_unique<KeyGenerator>(*context_);
+
+  PublicKey public_key;
+  keygen_->create_public_key(public_key);
+  SecretKey secret_key = keygen_->secret_key();
+
+  if (pir_params_.enable_symmetric) {
+    encryptor_ = make_unique<Encryptor>(*context_, secret_key);
+  } else {
+    encryptor_ = make_unique<Encryptor>(*context_, public_key);
+  }
+
+  decryptor_ = make_unique<Decryptor>(*context_, secret_key);
+  evaluator_ = make_unique<Evaluator>(*context_);
+  encoder_ = make_unique<BatchEncoder>(*context_);
+}
+
+int PIRClient::generate_serialized_query(uint64_t desiredIndex,
+                                         std::stringstream &stream) {
+
+  int N = enc_params_.poly_modulus_degree();
+  int output_size = 0;
+  indices_ = compute_indices(desiredIndex, pir_params_.nvec);
+  Plaintext pt(enc_params_.poly_modulus_degree());
+
+  for (uint32_t i = 0; i < indices_.size(); i++) {
+    uint32_t num_ptxts = ceil((pir_params_.nvec[i] + 0.0) / N);
+    // initialize result.
+    cout << "Client: index " << i + 1 << "/ " << indices_.size() << " = "
+         << indices_[i] << endl;
+    cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
+
+    for (uint32_t j = 0; j < num_ptxts; j++) {
+      pt.set_zero();
+      if (indices_[i] >= N * j && indices_[i] <= N * (j + 1)) {
+        uint64_t real_index = indices_[i] - N * j;
+        uint64_t n_i = pir_params_.nvec[i];
+        uint64_t total = N;
+        if (j == num_ptxts - 1) {
+          total = n_i % N;
+        }
+        uint64_t log_total = ceil(log2(total));
+
+        cout << "Client: Inverting " << pow(2, log_total) << endl;
+        pt[real_index] =
+            invert_mod(pow(2, log_total), enc_params_.plain_modulus());
+      }
+
+      if (pir_params_.enable_symmetric) {
+        output_size += encryptor_->encrypt_symmetric(pt).save(stream);
+      } else {
+        output_size += encryptor_->encrypt(pt).save(stream);
+      }
+    }
+  }
+
+  return output_size;
+}
+
+PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
+
+  indices_ = compute_indices(desiredIndex, pir_params_.nvec);
+
+  PirQuery result(pir_params_.d);
+  int N = enc_params_.poly_modulus_degree();
+
+  Plaintext pt(enc_params_.poly_modulus_degree());
+  for (uint32_t i = 0; i < indices_.size(); i++) {
+    uint32_t num_ptxts = ceil((pir_params_.nvec[i] + 0.0) / N);
+    // initialize result.
+    cout << "Client: index " << i + 1 << "/ " << indices_.size() << " = "
+         << indices_[i] << endl;
+    cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
+
+    for (uint32_t j = 0; j < num_ptxts; j++) {
+      pt.set_zero();
+      if (indices_[i] >= N * j && indices_[i] <= N * (j + 1)) {
+        uint64_t real_index = indices_[i] - N * j;
+        uint64_t n_i = pir_params_.nvec[i];
+        uint64_t total = N;
+        if (j == num_ptxts - 1) {
+          total = n_i % N;
+        }
+        uint64_t log_total = ceil(log2(total));
+
+        cout << "Client: Inverting " << pow(2, log_total) << endl;
+        pt[real_index] =
+            invert_mod(pow(2, log_total), enc_params_.plain_modulus());
+      }
+      Ciphertext dest;
+      if (pir_params_.enable_symmetric) {
+        encryptor_->encrypt_symmetric(pt, dest);
+      } else {
+        encryptor_->encrypt(pt, dest);
+      }
+      result[i].push_back(dest);
+    }
+  }
+
+  return result;
+}
+
+uint64_t PIRClient::get_fv_index(uint64_t element_index) {
+  return static_cast<uint64_t>(element_index /
+                               pir_params_.elements_per_plaintext);
+}
+
+uint64_t PIRClient::get_fv_offset(uint64_t element_index) {
+  return element_index % pir_params_.elements_per_plaintext;
+}
+
+Plaintext PIRClient::decrypt(Ciphertext ct) {
+  Plaintext pt;
+  decryptor_->decrypt(ct, pt);
+  return pt;
+}
+
+vector<uint8_t> PIRClient::decode_reply(PirReply &reply, uint64_t offset) {
+  Plaintext result = decode_reply(reply);
+  return extract_bytes(result, offset);
+}
+
+vector<uint64_t> PIRClient::extract_coeffs(Plaintext pt) {
+  vector<uint64_t> coeffs;
+  encoder_->decode(pt, coeffs);
+  return coeffs;
+}
+
+std::vector<uint64_t> PIRClient::extract_coeffs(seal::Plaintext pt,
+                                                uint64_t offset) {
+  vector<uint64_t> coeffs;
+  encoder_->decode(pt, coeffs);
+
+  uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+
+  uint64_t coeffs_per_element =
+      coefficients_per_element(logt, pir_params_.ele_size);
+
+  return std::vector<uint64_t>(coeffs.begin() + offset * coeffs_per_element,
+                               coeffs.begin() +
+                                   (offset + 1) * coeffs_per_element);
+}
+
+std::vector<uint8_t> PIRClient::extract_bytes(seal::Plaintext pt,
+                                              uint64_t offset) {
+  uint32_t N = enc_params_.poly_modulus_degree();
+  uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+  uint32_t bytes_per_ptxt =
+      pir_params_.elements_per_plaintext * pir_params_.ele_size;
+
+  // Convert from FV plaintext (polynomial) to database element at the client
+  vector<uint8_t> elems(bytes_per_ptxt);
+  vector<uint64_t> coeffs;
+  encoder_->decode(pt, coeffs);
+  coeffs_to_bytes(logt, coeffs, elems.data(), bytes_per_ptxt,
+                  pir_params_.ele_size);
+  return std::vector<uint8_t>(elems.begin() + offset * pir_params_.ele_size,
+                              elems.begin() +
+                                  (offset + 1) * pir_params_.ele_size);
+}
+
+Plaintext PIRClient::decode_reply(PirReply &reply) {
+  EncryptionParameters parms;
+  parms_id_type parms_id;
+  if (pir_params_.enable_mswitching) {
+    parms = context_->last_context_data()->parms();
+    parms_id = context_->last_parms_id();
+  } else {
+    parms = context_->first_context_data()->parms();
+    parms_id = context_->first_parms_id();
+  }
+  uint32_t exp_ratio = compute_expansion_ratio(parms);
+  uint32_t recursion_level = pir_params_.d;
+
+  vector<Ciphertext> temp = reply;
+  uint32_t ciphertext_size = temp[0].size();
+
+  uint64_t t = enc_params_.plain_modulus().value();
+
+  for (uint32_t i = 0; i < recursion_level; i++) {
+    cout << "Client: " << i + 1 << "/ " << recursion_level
+         << "-th decryption layer started." << endl;
+    vector<Ciphertext> newtemp;
+    vector<Plaintext> tempplain;
+
+    for (uint32_t j = 0; j < temp.size(); j++) {
+      Plaintext ptxt;
+      decryptor_->decrypt(temp[j], ptxt);
+#ifdef DEBUG
+      cout << "Client: reply noise budget = "
+           << decryptor_->invariant_noise_budget(temp[j]) << endl;
+#endif
+
+      // cout << "decoded (and scaled) plaintext = " << ptxt.to_string() <<
+      // endl;
+      tempplain.push_back(ptxt);
+
+#ifdef DEBUG
+      cout << "recursion level : " << i << " noise budget :  ";
+      cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
+#endif
+
+      if ((j + 1) % (exp_ratio * ciphertext_size) == 0 && j > 0) {
+        // Combine into one ciphertext.
+        Ciphertext combined(*context_, parms_id);
+        compose_to_ciphertext(parms, tempplain, combined);
+        newtemp.push_back(combined);
+        tempplain.clear();
+        // cout << "Client: const term of ciphertext = " << combined[0] << endl;
+      }
+    }
+    cout << "Client: done." << endl;
+    cout << endl;
+    if (i == recursion_level - 1) {
+      assert(temp.size() == 1);
+      return tempplain[0];
+    } else {
+      tempplain.clear();
+      temp = newtemp;
+    }
+  }
+
+  // This should never be called
+  assert(0);
+  Plaintext fail;
+  return fail;
+}
+
+GaloisKeys PIRClient::generate_galois_keys() {
+  // Generate the Galois keys needed for coeff_select.
+  vector<uint32_t> galois_elts;
+  int N = enc_params_.poly_modulus_degree();
+  int logN = get_power_of_two(N);
+
+  // cout << "printing galois elements...";
+  for (int i = 0; i < logN; i++) {
+    galois_elts.push_back((N + exponentiate_uint(2, i)) /
+                          exponentiate_uint(2, i));
+    //#ifdef DEBUG
+    // cout << galois_elts.back() << ", ";
+    //#endif
+  }
+  GaloisKeys gal_keys;
+  keygen_->create_galois_keys(galois_elts, gal_keys);
+  return gal_keys;
+}
+
+Plaintext PIRClient::replace_element(Plaintext pt, vector<uint64_t> new_element,
+                                     uint64_t offset) {
+  vector<uint64_t> coeffs = extract_coeffs(pt);
+
+  uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+  uint64_t coeffs_per_element =
+      coefficients_per_element(logt, pir_params_.ele_size);
+
+  assert(new_element.size() == coeffs_per_element);
+
+  for (uint64_t i = 0; i < coeffs_per_element; i++) {
+    coeffs[i + offset * coeffs_per_element] = new_element[i];
+  }
+
+  Plaintext new_pt;
+
+  encoder_->encode(coeffs, new_pt);
+  return new_pt;
+}
+
+Ciphertext PIRClient::get_one() {
+  Plaintext pt("1");
+  Ciphertext ct;
+  if (pir_params_.enable_symmetric) {
+    encryptor_->encrypt_symmetric(pt, ct);
+  } else {
+    encryptor_->encrypt(pt, ct);
+  }
+  return ct;
+}

+ 58 - 0
src/pir_client.hpp

@@ -0,0 +1,58 @@
+#pragma once
+
+#include "pir.hpp"
+#include <memory>
+#include <vector>
+
+using namespace std;
+
+class PIRClient {
+public:
+  PIRClient(const seal::EncryptionParameters &encparms,
+            const PirParams &pirparams);
+
+  PirQuery generate_query(std::uint64_t desiredIndex);
+  // Serializes the query into the provided stream and returns number of bytes
+  // written
+  int generate_serialized_query(std::uint64_t desiredIndex,
+                                std::stringstream &stream);
+  seal::Plaintext decode_reply(PirReply &reply);
+
+  std::vector<uint64_t> extract_coeffs(seal::Plaintext pt);
+  std::vector<uint64_t> extract_coeffs(seal::Plaintext pt,
+                                       std::uint64_t offset);
+  std::vector<uint8_t> extract_bytes(seal::Plaintext pt, std::uint64_t offset);
+
+  std::vector<uint8_t> decode_reply(PirReply &reply, uint64_t offset);
+
+  seal::Plaintext decrypt(seal::Ciphertext ct);
+
+  seal::GaloisKeys generate_galois_keys();
+
+  // Index and offset of an element in an FV plaintext
+  uint64_t get_fv_index(uint64_t element_index);
+  uint64_t get_fv_offset(uint64_t element_index);
+
+  // Only used for simple_query
+  seal::Ciphertext get_one();
+
+  seal::Plaintext replace_element(seal::Plaintext pt,
+                                  std::vector<std::uint64_t> new_element,
+                                  std::uint64_t offset);
+
+private:
+  seal::EncryptionParameters enc_params_;
+  PirParams pir_params_;
+
+  std::unique_ptr<seal::Encryptor> encryptor_;
+  std::unique_ptr<seal::Decryptor> decryptor_;
+  std::unique_ptr<seal::Evaluator> evaluator_;
+  std::unique_ptr<seal::KeyGenerator> keygen_;
+  std::unique_ptr<seal::BatchEncoder> encoder_;
+  std::shared_ptr<seal::SEALContext> context_;
+
+  vector<uint64_t> indices_; // the indices for retrieval.
+  vector<uint64_t> inverse_scales_;
+
+  friend class PIRServer;
+};

+ 442 - 0
src/pir_server.cpp

@@ -0,0 +1,442 @@
+#include "pir_server.hpp"
+#include "pir_client.hpp"
+
+using namespace std;
+using namespace seal;
+using namespace seal::util;
+
+PIRServer::PIRServer(const EncryptionParameters &enc_params,
+                     const PirParams &pir_params)
+    : enc_params_(enc_params), pir_params_(pir_params),
+      is_db_preprocessed_(false) {
+  context_ = make_shared<SEALContext>(enc_params, true);
+  evaluator_ = make_unique<Evaluator>(*context_);
+  encoder_ = make_unique<BatchEncoder>(*context_);
+}
+
+void PIRServer::preprocess_database() {
+  if (!is_db_preprocessed_) {
+
+    for (uint32_t i = 0; i < db_->size(); i++) {
+      evaluator_->transform_to_ntt_inplace(db_->operator[](i),
+                                           context_->first_parms_id());
+    }
+
+    is_db_preprocessed_ = true;
+  }
+}
+
+// Server takes over ownership of db and will free it when it exits
+void PIRServer::set_database(unique_ptr<vector<Plaintext>> &&db) {
+  if (!db) {
+    throw invalid_argument("db cannot be null");
+  }
+
+  db_ = move(db);
+  is_db_preprocessed_ = false;
+}
+
+void PIRServer::set_database(const std::unique_ptr<const uint8_t[]> &bytes,
+                             uint64_t ele_num, uint64_t ele_size) {
+
+  uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+  uint32_t N = enc_params_.poly_modulus_degree();
+
+  // number of FV plaintexts needed to represent all elements
+  uint64_t num_of_plaintexts = pir_params_.num_of_plaintexts;
+
+  // number of FV plaintexts needed to create the d-dimensional matrix
+  uint64_t prod = 1;
+  for (uint32_t i = 0; i < pir_params_.nvec.size(); i++) {
+    prod *= pir_params_.nvec[i];
+  }
+  uint64_t matrix_plaintexts = prod;
+
+  assert(num_of_plaintexts <= matrix_plaintexts);
+
+  auto result = make_unique<vector<Plaintext>>();
+  result->reserve(matrix_plaintexts);
+
+  uint64_t ele_per_ptxt = pir_params_.elements_per_plaintext;
+  uint64_t bytes_per_ptxt = ele_per_ptxt * ele_size;
+
+  uint64_t db_size = ele_num * ele_size;
+
+  uint64_t coeff_per_ptxt =
+      ele_per_ptxt * coefficients_per_element(logt, ele_size);
+  assert(coeff_per_ptxt <= N);
+
+  cout << "Elements per plaintext: " << ele_per_ptxt << endl;
+  cout << "Coeff per ptxt: " << coeff_per_ptxt << endl;
+  cout << "Bytes per plaintext: " << bytes_per_ptxt << endl;
+
+  uint32_t offset = 0;
+
+  for (uint64_t i = 0; i < num_of_plaintexts; i++) {
+
+    uint64_t process_bytes = 0;
+
+    if (db_size <= offset) {
+      break;
+    } else if (db_size < offset + bytes_per_ptxt) {
+      process_bytes = db_size - offset;
+    } else {
+      process_bytes = bytes_per_ptxt;
+    }
+    assert(process_bytes % ele_size == 0);
+    uint64_t ele_in_chunk = process_bytes / ele_size;
+
+    // Get the coefficients of the elements that will be packed in plaintext i
+    vector<uint64_t> coefficients(coeff_per_ptxt);
+    for (uint64_t ele = 0; ele < ele_in_chunk; ele++) {
+      vector<uint64_t> element_coeffs = bytes_to_coeffs(
+          logt, bytes.get() + offset + (ele_size * ele), ele_size);
+      std::copy(element_coeffs.begin(), element_coeffs.end(),
+                coefficients.begin() +
+                    (coefficients_per_element(logt, ele_size) * ele));
+    }
+
+    offset += process_bytes;
+
+    uint64_t used = coefficients.size();
+
+    assert(used <= coeff_per_ptxt);
+
+    // Pad the rest with 1s
+    for (uint64_t j = 0; j < (pir_params_.slot_count - used); j++) {
+      coefficients.push_back(1);
+    }
+
+    Plaintext plain;
+    encoder_->encode(coefficients, plain);
+    // cout << i << "-th encoded plaintext = " << plain.to_string() << endl;
+    result->push_back(move(plain));
+  }
+
+  // Add padding to make database a matrix
+  uint64_t current_plaintexts = result->size();
+  assert(current_plaintexts <= num_of_plaintexts);
+
+#ifdef DEBUG
+  cout << "adding: " << matrix_plaintexts - current_plaintexts
+       << " FV plaintexts of padding (equivalent to: "
+       << (matrix_plaintexts - current_plaintexts) *
+              elements_per_ptxt(logtp, N, ele_size)
+       << " elements)" << endl;
+#endif
+
+  vector<uint64_t> padding(N, 1);
+
+  for (uint64_t i = 0; i < (matrix_plaintexts - current_plaintexts); i++) {
+    Plaintext plain;
+    vector_to_plaintext(padding, plain);
+    result->push_back(plain);
+  }
+
+  set_database(move(result));
+}
+
+void PIRServer::set_galois_key(uint32_t client_id, seal::GaloisKeys galkey) {
+  galoisKeys_[client_id] = galkey;
+}
+
+PirQuery PIRServer::deserialize_query(stringstream &stream) {
+  PirQuery q;
+
+  for (uint32_t i = 0; i < pir_params_.d; i++) {
+    // number of ciphertexts needed to encode the index for dimension i
+    // keeping into account that each ciphertext can encode up to
+    // poly_modulus_degree indexes In most cases this is usually 1.
+    uint32_t ctx_per_dimension =
+        ceil((pir_params_.nvec[i] + 0.0) / enc_params_.poly_modulus_degree());
+
+    vector<Ciphertext> cs;
+    for (uint32_t j = 0; j < ctx_per_dimension; j++) {
+      Ciphertext c;
+      c.load(*context_, stream);
+      cs.push_back(c);
+    }
+
+    q.push_back(cs);
+  }
+
+  return q;
+}
+
+int PIRServer::serialize_reply(PirReply &reply, stringstream &stream) {
+  int output_size = 0;
+  for (int i = 0; i < reply.size(); i++) {
+    evaluator_->mod_switch_to_inplace(reply[i], context_->last_parms_id());
+    output_size += reply[i].save(stream);
+  }
+  return output_size;
+}
+
+PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id) {
+
+  vector<uint64_t> nvec = pir_params_.nvec;
+  uint64_t product = 1;
+
+  for (uint32_t i = 0; i < nvec.size(); i++) {
+    product *= nvec[i];
+  }
+
+  auto coeff_count = enc_params_.poly_modulus_degree();
+
+  vector<Plaintext> *cur = db_.get();
+  vector<Plaintext> intermediate_plain; // decompose....
+
+  auto pool = MemoryManager::GetPool();
+
+  int N = enc_params_.poly_modulus_degree();
+
+  int logt = floor(log2(enc_params_.plain_modulus().value()));
+
+  for (uint32_t i = 0; i < nvec.size(); i++) {
+    cout << "Server: " << i + 1 << "-th recursion level started " << endl;
+
+    vector<Ciphertext> expanded_query;
+
+    uint64_t n_i = nvec[i];
+    cout << "Server: n_i = " << n_i << endl;
+    cout << "Server: expanding " << query[i].size() << " query ctxts" << endl;
+    for (uint32_t j = 0; j < query[i].size(); j++) {
+      uint64_t total = N;
+      if (j == query[i].size() - 1) {
+        total = n_i % N;
+      }
+      cout << "-- expanding one query ctxt into " << total << " ctxts " << endl;
+      vector<Ciphertext> expanded_query_part =
+          expand_query(query[i][j], total, client_id);
+      expanded_query.insert(
+          expanded_query.end(),
+          std::make_move_iterator(expanded_query_part.begin()),
+          std::make_move_iterator(expanded_query_part.end()));
+      expanded_query_part.clear();
+    }
+    cout << "Server: expansion done " << endl;
+    if (expanded_query.size() != n_i) {
+      cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i
+           << endl;
+    }
+
+    // Transform expanded query to NTT, and ...
+    for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
+      evaluator_->transform_to_ntt_inplace(expanded_query[jj]);
+    }
+
+    // Transform plaintext to NTT. If database is pre-processed, can skip
+    if ((!is_db_preprocessed_) || i > 0) {
+      for (uint32_t jj = 0; jj < cur->size(); jj++) {
+        evaluator_->transform_to_ntt_inplace((*cur)[jj],
+                                             context_->first_parms_id());
+      }
+    }
+
+    for (uint64_t k = 0; k < product; k++) {
+      if ((*cur)[k].is_zero()) {
+        cout << k + 1 << "/ " << product << "-th ptxt = 0 " << endl;
+      }
+    }
+
+    product /= n_i;
+
+    vector<Ciphertext> intermediateCtxts(product);
+    Ciphertext temp;
+
+    for (uint64_t k = 0; k < product; k++) {
+
+      evaluator_->multiply_plain(expanded_query[0], (*cur)[k],
+                                 intermediateCtxts[k]);
+
+      for (uint64_t j = 1; j < n_i; j++) {
+        evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product],
+                                   temp);
+        evaluator_->add_inplace(intermediateCtxts[k],
+                                temp); // Adds to first component.
+      }
+    }
+
+    for (uint32_t jj = 0; jj < intermediateCtxts.size(); jj++) {
+      evaluator_->transform_from_ntt_inplace(intermediateCtxts[jj]);
+      // print intermediate ctxts?
+      // cout << "const term of ctxt " << jj << " = " <<
+      // intermediateCtxts[jj][0] << endl;
+    }
+
+    if (i == nvec.size() - 1) {
+      return intermediateCtxts;
+    } else {
+      intermediate_plain.clear();
+      intermediate_plain.reserve(pir_params_.expansion_ratio * product);
+      cur = &intermediate_plain;
+
+      for (uint64_t rr = 0; rr < product; rr++) {
+        EncryptionParameters parms;
+        if (pir_params_.enable_mswitching) {
+          evaluator_->mod_switch_to_inplace(intermediateCtxts[rr],
+                                            context_->last_parms_id());
+          parms = context_->last_context_data()->parms();
+        } else {
+          parms = context_->first_context_data()->parms();
+        }
+
+        vector<Plaintext> plains =
+            decompose_to_plaintexts(parms, intermediateCtxts[rr]);
+
+        for (uint32_t jj = 0; jj < plains.size(); jj++) {
+          intermediate_plain.emplace_back(plains[jj]);
+        }
+      }
+      product = intermediate_plain.size(); // multiply by expansion rate.
+    }
+    cout << "Server: " << i + 1 << "-th recursion level finished " << endl;
+    cout << endl;
+  }
+  cout << "reply generated!  " << endl;
+  // This should never get here
+  assert(0);
+  vector<Ciphertext> fail(1);
+  return fail;
+}
+
+inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted,
+                                                  uint32_t m,
+                                                  uint32_t client_id) {
+
+  GaloisKeys &galkey = galoisKeys_[client_id];
+
+  // Assume that m is a power of 2. If not, round it to the next power of 2.
+  uint32_t logm = ceil(log2(m));
+  Plaintext two("2");
+
+  vector<int> galois_elts;
+  auto n = enc_params_.poly_modulus_degree();
+  if (logm > ceil(log2(n))) {
+    throw logic_error("m > n is not allowed.");
+  }
+  for (int i = 0; i < ceil(log2(n)); i++) {
+    galois_elts.push_back((n + exponentiate_uint(2, i)) /
+                          exponentiate_uint(2, i));
+  }
+
+  vector<Ciphertext> temp;
+  temp.push_back(encrypted);
+  Ciphertext tempctxt;
+  Ciphertext tempctxt_rotated;
+  Ciphertext tempctxt_shifted;
+  Ciphertext tempctxt_rotatedshifted;
+
+  for (uint32_t i = 0; i < logm - 1; i++) {
+    vector<Ciphertext> newtemp(temp.size() << 1);
+    // temp[a] = (j0 = a (mod 2**i) ? ) : Enc(x^{j0 - a}) else Enc(0).  With
+    // some scaling....
+    int index_raw = (n << 1) - (1 << i);
+    int index = (index_raw * galois_elts[i]) % (n << 1);
+
+    for (uint32_t a = 0; a < temp.size(); a++) {
+
+      evaluator_->apply_galois(temp[a], galois_elts[i], galkey,
+                               tempctxt_rotated);
+
+      // cout << "rotate " <<
+      // client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", ";
+
+      evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
+      multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
+
+      // cout << "mul by x^pow: " <<
+      // client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", ";
+
+      multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
+
+      // cout << "mul by x^pow: " <<
+      // client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) <<
+      // ", ";
+
+      // Enc(2^i x^j) if j = 0 (mod 2**i).
+      evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted,
+                      newtemp[a + temp.size()]);
+    }
+    temp = newtemp;
+    /*
+    cout << "end: ";
+    for (int h = 0; h < temp.size();h++){
+        cout << client.decryptor_->invariant_noise_budget(temp[h]) << ", ";
+    }
+    cout << endl;
+    */
+  }
+  // Last step of the loop
+  vector<Ciphertext> newtemp(temp.size() << 1);
+  int index_raw = (n << 1) - (1 << (logm - 1));
+  int index = (index_raw * galois_elts[logm - 1]) % (n << 1);
+  for (uint32_t a = 0; a < temp.size(); a++) {
+    if (a >= (m - (1 << (logm - 1)))) { // corner case.
+      evaluator_->multiply_plain(temp[a], two,
+                                 newtemp[a]); // plain multiplication by 2.
+      // cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", ";
+    } else {
+      evaluator_->apply_galois(temp[a], galois_elts[logm - 1], galkey,
+                               tempctxt_rotated);
+      evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
+      multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
+      multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
+      evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted,
+                      newtemp[a + temp.size()]);
+    }
+  }
+
+  vector<Ciphertext>::const_iterator first = newtemp.begin();
+  vector<Ciphertext>::const_iterator last = newtemp.begin() + m;
+  vector<Ciphertext> newVec(first, last);
+
+  return newVec;
+}
+
+inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted,
+                                           Ciphertext &destination,
+                                           uint32_t index) {
+
+  auto coeff_mod_count = enc_params_.coeff_modulus().size() - 1;
+  auto coeff_count = enc_params_.poly_modulus_degree();
+  auto encrypted_count = encrypted.size();
+
+  // cout << "coeff mod count for power of X = " << coeff_mod_count << endl;
+  // cout << "coeff count for power of X = " << coeff_count << endl;
+
+  // First copy over.
+  destination = encrypted;
+
+  // Prepare for destination
+  // Multiply X^index for each ciphertext polynomial
+  for (int i = 0; i < encrypted_count; i++) {
+    for (int j = 0; j < coeff_mod_count; j++) {
+      negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
+                                     coeff_count, index,
+                                     enc_params_.coeff_modulus()[j],
+                                     destination.data(i) + (j * coeff_count));
+    }
+  }
+}
+
+void PIRServer::simple_set(uint64_t index, Plaintext pt) {
+  if (is_db_preprocessed_) {
+    evaluator_->transform_to_ntt_inplace(pt, context_->first_parms_id());
+  }
+  db_->operator[](index) = pt;
+}
+
+Ciphertext PIRServer::simple_query(uint64_t index) {
+  // There is no transform_from_ntt that takes a plaintext
+  Ciphertext ct;
+  Plaintext pt = db_->operator[](index);
+  evaluator_->multiply_plain(one_, pt, ct);
+  evaluator_->transform_from_ntt_inplace(ct);
+  return ct;
+}
+
+void PIRServer::set_one_ct(Ciphertext one) {
+  one_ = one;
+  evaluator_->transform_to_ntt_inplace(one_);
+}

+ 55 - 0
src/pir_server.hpp

@@ -0,0 +1,55 @@
+#pragma once
+
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include <map>
+#include <memory>
+#include <vector>
+
+class PIRServer {
+public:
+  PIRServer(const seal::EncryptionParameters &enc_params,
+            const PirParams &pir_params);
+
+  // NOTE: server takes over ownership of db and frees it when it exits.
+  // Caller cannot free db
+  void set_database(std::unique_ptr<std::vector<seal::Plaintext>> &&db);
+  void set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
+                    std::uint64_t ele_num, std::uint64_t ele_size);
+  void preprocess_database();
+
+  std::vector<seal::Ciphertext> expand_query(const seal::Ciphertext &encrypted,
+                                             std::uint32_t m,
+                                             std::uint32_t client_id);
+
+  PirQuery deserialize_query(std::stringstream &stream);
+  PirReply generate_reply(PirQuery &query, std::uint32_t client_id);
+  // Serializes the reply into the provided stream and returns the number of
+  // bytes written
+  int serialize_reply(PirReply &reply, std::stringstream &stream);
+
+  void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);
+
+  // Below simple operations are for interacting with the database WITHOUT PIR.
+  // So they can be used to modify a particular element in the database or
+  // to query a particular element (without privacy guarantees).
+  void simple_set(std::uint64_t index, seal::Plaintext pt);
+  seal::Ciphertext simple_query(std::uint64_t index);
+  void set_one_ct(seal::Ciphertext one);
+
+private:
+  seal::EncryptionParameters enc_params_; // SEAL parameters
+  PirParams pir_params_;                  // PIR parameters
+  std::unique_ptr<Database> db_;
+  bool is_db_preprocessed_;
+  std::map<int, seal::GaloisKeys> galoisKeys_;
+  std::unique_ptr<seal::Evaluator> evaluator_;
+  std::unique_ptr<seal::BatchEncoder> encoder_;
+  std::shared_ptr<seal::SEALContext> context_;
+
+  // This is only used for simple_query
+  seal::Ciphertext one_;
+
+  void multiply_power_of_X(const seal::Ciphertext &encrypted,
+                           seal::Ciphertext &destination, std::uint32_t index);
+};

+ 25 - 0
test/CMakeLists.txt

@@ -0,0 +1,25 @@
+include_directories (${SealPIR_SOURCE_DIR}/src)
+
+add_executable(coefficient_conversion_test coefficient_conversion_test.cpp)
+target_link_libraries(coefficient_conversion_test sealpir)
+add_test(NAME coefficient_conversion_test COMMAND coefficient_conversion_test)
+
+add_executable(expand_test expand_test.cpp)
+target_link_libraries(expand_test sealpir)
+add_test(NAME expand_test COMMAND expand_test)
+
+add_executable(query_test query_test.cpp)
+target_link_libraries(query_test sealpir)
+add_test(NAME query_test COMMAND query_test)
+
+add_executable(simple_query_test simple_query_test.cpp)
+target_link_libraries(simple_query_test sealpir)
+add_test(NAME simple_query_test COMMAND simple_query_test)
+
+add_executable(replace_test replace_test.cpp)
+target_link_libraries(replace_test sealpir)
+add_test(NAME replace_test COMMAND replace_test)
+
+add_executable(decomposition_test decomposition_test.cpp)
+target_link_libraries(decomposition_test sealpir)
+add_test(NAME decomposition_test COMMAND decomposition_test)

+ 47 - 0
test/coefficient_conversion_test.cpp

@@ -0,0 +1,47 @@
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include "pir_server.hpp"
+#include <bitset>
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <random>
+#include <seal/seal.h>
+
+using namespace std::chrono;
+using namespace std;
+using namespace seal;
+
+int main(int argc, char *argv[]) {
+
+  const uint32_t logt = 16;
+  const uint32_t ele_size = 3;
+  const uint32_t num_ele = 3;
+  uint8_t bytes[ele_size * num_ele] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+
+  vector<uint64_t> coeffs;
+
+  cout << "Coeffs: " << endl;
+  for (int i = 0; i < num_ele; i++) {
+    vector<uint64_t> ele_coeffs =
+        bytes_to_coeffs(logt, bytes + (i * ele_size), ele_size);
+    for (int j = 0; j < ele_coeffs.size(); j++) {
+      cout << ele_coeffs[j] << endl;
+      cout << std::bitset<logt>(ele_coeffs[j]) << endl;
+      coeffs.push_back(ele_coeffs[j]);
+    }
+  }
+
+  cout << "Num of Coeffs: " << coeffs.size() << endl;
+
+  uint8_t output[ele_size * num_ele];
+  coeffs_to_bytes(logt, coeffs, output, ele_size * num_ele, ele_size);
+
+  cout << "Bytes: " << endl;
+  for (int i = 0; i < ele_size * num_ele; i++) {
+    cout << (int)output[i] << endl;
+  }
+
+  return 0;
+}

+ 79 - 0
test/decomposition_test.cpp

@@ -0,0 +1,79 @@
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include "pir_server.hpp"
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <random>
+#include <seal/seal.h>
+
+using namespace std::chrono;
+using namespace std;
+using namespace seal;
+
+int main(int argc, char *argv[]) {
+
+  uint64_t number_of_items = 2048;
+  uint64_t size_per_item = 288; // in bytes
+  uint32_t N = 8192;
+
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = 20;
+
+  EncryptionParameters enc_params(scheme_type::bfv);
+
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  SEALContext context(enc_params, true);
+  KeyGenerator keygen(context);
+
+  SecretKey secret_key = keygen.secret_key();
+  Encryptor encryptor(context, secret_key);
+  Decryptor decryptor(context, secret_key);
+  Evaluator evaluator(context);
+  BatchEncoder encoder(context);
+  logt = floor(log2(enc_params.plain_modulus().value()));
+
+  uint32_t plain_modulus = enc_params.plain_modulus().value();
+
+  size_t slot_count = encoder.slot_count();
+
+  vector<uint64_t> coefficients(slot_count, 0ULL);
+  for (uint32_t i = 0; i < coefficients.size(); i++) {
+    coefficients[i] = rand() % plain_modulus;
+  }
+  Plaintext pt;
+  encoder.encode(coefficients, pt);
+  Ciphertext ct;
+  encryptor.encrypt_symmetric(pt, ct);
+  std::cout << "Encrypting" << std::endl;
+  auto context_data = context.last_context_data();
+  auto parms_id = context.last_parms_id();
+
+  evaluator.mod_switch_to_inplace(ct, parms_id);
+
+  EncryptionParameters params = context_data->parms();
+  std::cout << "Encoding" << std::endl;
+  vector<Plaintext> encoded = decompose_to_plaintexts(params, ct);
+  std::cout << "Expansion Factor: " << encoded.size() << std::endl;
+  std::cout << "Decoding" << std::endl;
+  Ciphertext decoded(context, parms_id);
+  compose_to_ciphertext(params, encoded, decoded);
+  std::cout << "Checking" << std::endl;
+  Plaintext pt2;
+  decryptor.decrypt(decoded, pt2);
+
+  assert(pt == pt2);
+
+  std::cout << "Worked" << std::endl;
+
+  return 0;
+}

+ 107 - 0
test/expand_test.cpp

@@ -0,0 +1,107 @@
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include "pir_server.hpp"
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <random>
+#include <seal/seal.h>
+
+using namespace std::chrono;
+using namespace std;
+using namespace seal;
+
+// For this test, we need the parameters to be such that the number of
+// compressed ciphertexts needed is 1.
+int main(int argc, char *argv[]) {
+
+  uint64_t number_of_items = 2048;
+  uint64_t size_per_item = 288; // in bytes
+  uint32_t N = 4096;
+
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = 20;
+  uint32_t d = 1;
+
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
+
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
+
+  // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
+  // pir_params);
+  print_pir_params(pir_params);
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server and client" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  GaloisKeys galois_keys = client.generate_galois_keys();
+
+  // Set galois key for client with id 0
+  cout << "Main: Setting Galois keys...";
+  server.set_galois_key(0, galois_keys);
+
+  random_device rd;
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Measure query generation
+  auto time_query_s = high_resolution_clock::now();
+  PirQuery query = client.generate_query(index);
+  auto time_query_e = high_resolution_clock::now();
+  auto time_query_us =
+      duration_cast<microseconds>(time_query_e - time_query_s).count();
+  cout << "Main: query generated" << endl;
+
+  // Measure query processing (including expansion)
+  auto time_server_s = high_resolution_clock::now();
+  uint64_t n_i = pir_params.nvec[0];
+  vector<Ciphertext> expanded_query = server.expand_query(query[0][0], n_i, 0);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+  cout << "Main: query expanded" << endl;
+
+  assert(expanded_query.size() == n_i);
+
+  cout << "Main: checking expansion" << endl;
+  for (size_t i = 0; i < expanded_query.size(); i++) {
+    Plaintext decryption = client.decrypt(expanded_query.at(i));
+
+    if (decryption.is_zero() && index != i) {
+      continue;
+    } else if (decryption.is_zero()) {
+      cout << "Found zero where index should be" << endl;
+      return -1;
+    } else if (std::stoi(decryption.to_string()) != 1) {
+      cout << "Query vector at index " << index
+           << " should be 1 but is instead " << decryption.to_string() << endl;
+      return -1;
+    } else {
+      cout << "Query vector at index " << index << " is "
+           << decryption.to_string() << endl;
+    }
+  }
+
+  return 0;
+}

+ 167 - 0
test/query_test.cpp

@@ -0,0 +1,167 @@
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include "pir_server.hpp"
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <random>
+#include <seal/seal.h>
+
+using namespace std::chrono;
+using namespace std;
+using namespace seal;
+
+int query_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+               uint32_t lt, uint32_t dim);
+
+int main(int argc, char *argv[]) {
+  // Quick check
+  assert(query_test(1 << 10, 288, 4096, 20, 1) == 0);
+
+  assert(query_test(1 << 10, 288, 4096, 20, 2) == 0);
+
+  assert(query_test(1 << 10, 288, 4096, 20, 3) == 0);
+
+  assert(query_test(1 << 10, 288, 8192, 20, 2) == 0);
+
+  // Forces ciphertext expansion to be the same as the degree
+  assert(query_test(1 << 20, 288, 4096, 20, 1) == 0);
+
+  assert(query_test(1 << 20, 288, 4096, 20, 2) == 0);
+}
+
+int query_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+               uint32_t lt, uint32_t dim) {
+  uint64_t number_of_items = num_items;
+  uint64_t size_per_item = item_size; // in bytes
+  uint32_t N = degree;
+
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = lt;
+  uint32_t d = dim;
+
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
+
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
+
+  // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
+  // pir_params);
+  print_pir_params(pir_params);
+
+  cout << "Main: Initializing the database (this may take some time) ..."
+       << endl;
+
+  // Create test database
+  auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  // Copy of the database. We use this at the end to make sure we retrieved
+  // the correct element.
+  auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  random_device rd;
+  for (uint64_t i = 0; i < number_of_items; i++) {
+    for (uint64_t j = 0; j < size_per_item; j++) {
+      uint8_t val = rd() % 256;
+      db.get()[(i * size_per_item) + j] = val;
+      db_copy.get()[(i * size_per_item) + j] = val;
+    }
+  }
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server and client" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  GaloisKeys galois_keys = client.generate_galois_keys();
+
+  // Set galois key for client with id 0
+  cout << "Main: Setting Galois keys...";
+  server.set_galois_key(0, galois_keys);
+
+  // Measure database setup
+  auto time_pre_s = high_resolution_clock::now();
+  server.set_database(move(db), number_of_items, size_per_item);
+  server.preprocess_database();
+  cout << "Main: database pre processed " << endl;
+  auto time_pre_e = high_resolution_clock::now();
+  auto time_pre_us =
+      duration_cast<microseconds>(time_pre_e - time_pre_s).count();
+
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Measure query generation
+  auto time_query_s = high_resolution_clock::now();
+  PirQuery query = client.generate_query(index);
+  auto time_query_e = high_resolution_clock::now();
+  auto time_query_us =
+      duration_cast<microseconds>(time_query_e - time_query_s).count();
+  cout << "Main: query generated" << endl;
+
+  // To marshall query to send over the network, you can use
+  // serialize/deserialize: std::string query_ser = serialize_query(query);
+  // PirQuery query2 = deserialize_query(d, 1, query_ser, CIPHER_SIZE);
+
+  // Measure query processing (including expansion)
+  auto time_server_s = high_resolution_clock::now();
+  PirReply reply = server.generate_reply(query, 0);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+
+  // Measure response extraction
+  auto time_decode_s = chrono::high_resolution_clock::now();
+  vector<uint8_t> elems = client.decode_reply(reply, offset);
+  auto time_decode_e = chrono::high_resolution_clock::now();
+  auto time_decode_us =
+      duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+
+  assert(elems.size() == size_per_item);
+
+  bool failed = false;
+  // Check that we retrieved the correct element
+  for (uint32_t i = 0; i < size_per_item; i++) {
+    if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
+      cout << "Main: elems " << (int)elems[i] << ", db "
+           << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
+      cout << "Main: PIR result wrong at " << i << endl;
+      failed = true;
+    }
+  }
+  if (failed) {
+    return -1;
+  }
+
+  // Output results
+  cout << "Main: PIR result correct!" << endl;
+  cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
+       << endl;
+  cout << "Main: PIRClient query generation time: " << time_query_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
+       << " ms" << endl;
+  cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
+
+  return 0;
+}

+ 172 - 0
test/replace_test.cpp

@@ -0,0 +1,172 @@
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include "pir_server.hpp"
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <random>
+#include <seal/seal.h>
+
+using namespace std::chrono;
+using namespace std;
+using namespace seal;
+
+int replace_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+                 uint32_t lt, uint32_t dim);
+
+int main(int argc, char *argv[]) {
+  // Quick check
+  assert(replace_test(1 << 13, 1, 4096, 20, 1) == 0);
+
+  // Forces ciphertext expansion to be the same as the degree
+  assert(replace_test(1 << 20, 288, 4096, 20, 1) == 0);
+
+  assert(replace_test(1 << 20, 288, 4096, 20, 2) == 0);
+}
+
+int replace_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+                 uint32_t lt, uint32_t dim) {
+  uint64_t number_of_items = num_items;
+  uint64_t size_per_item = item_size; // in bytes
+  uint32_t N = degree;
+
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = lt;
+  uint32_t d = dim;
+
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
+
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
+
+  // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
+  // pir_params);
+  print_pir_params(pir_params);
+
+  cout << "Main: Initializing the database (this may take some time) ..."
+       << endl;
+
+  // Create test database
+  auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  // Copy of the database. We use this at the end to make sure we retrieved
+  // the correct element.
+  auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  random_device rd;
+  for (uint64_t i = 0; i < number_of_items; i++) {
+    for (uint64_t j = 0; j < size_per_item; j++) {
+      uint8_t val = rd() % 256;
+      db.get()[(i * size_per_item) + j] = val;
+      db_copy.get()[(i * size_per_item) + j] = val;
+    }
+  }
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server and client" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  Ciphertext one_ct = client.get_one();
+  GaloisKeys galois_keys = client.generate_galois_keys();
+
+  // Set galois key for client with id 0
+  cout << "Main: Setting Galois keys...";
+  server.set_galois_key(0, galois_keys);
+
+  // Measure database setup
+  auto time_pre_s = high_resolution_clock::now();
+  server.set_database(move(db), number_of_items, size_per_item);
+  server.preprocess_database();
+  server.set_one_ct(one_ct);
+  cout << "Main: database pre processed " << endl;
+  auto time_pre_e = high_resolution_clock::now();
+  auto time_pre_us =
+      duration_cast<microseconds>(time_pre_e - time_pre_s).count();
+
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Generate a new element
+  vector<uint8_t> new_element(size_per_item);
+  vector<uint8_t> new_element_copy(size_per_item);
+  for (uint64_t i = 0; i < size_per_item; i++) {
+    uint8_t val = rd() % 256;
+    new_element[i] = val;
+    new_element_copy[i] = val;
+  }
+
+  // Get element to replace
+  auto time_server_s = high_resolution_clock::now();
+  Ciphertext reply = server.simple_query(index);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+  auto time_decode_s = chrono::high_resolution_clock::now();
+  Plaintext old_pt = client.decrypt(reply);
+  auto time_decode_e = chrono::high_resolution_clock::now();
+  auto time_decode_us =
+      duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+
+  // Replace element
+  Modulus t = enc_params.plain_modulus();
+  logt = floor(log2(t.value()));
+  vector<uint64_t> new_coeffs =
+      bytes_to_coeffs(logt, new_element.data(), size_per_item);
+  Plaintext new_pt = client.replace_element(old_pt, new_coeffs, offset);
+  server.simple_set(index, new_pt);
+
+  // Get the replaced element
+  PirQuery query = client.generate_query(index);
+  PirReply server_reply = server.generate_reply(query, 0);
+  vector<uint8_t> elems = client.decode_reply(server_reply, offset);
+  // vector<uint8_t> elems =
+  // client.extract_bytes(client.decrypt(server.simple_query(index)), offset);
+  vector<uint8_t> old_elems = client.extract_bytes(old_pt, offset);
+
+  assert(elems.size() == size_per_item);
+
+  bool failed = false;
+  // Check that we retrieved the correct element
+  for (uint32_t i = 0; i < size_per_item; i++) {
+    if (elems[i] != new_element_copy[i]) {
+      cout << "Main: elems " << (int)elems[i] << ", new "
+           << (int)new_element_copy[i] << ", old "
+           << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
+      cout << "Main: PIR result wrong at " << i << endl;
+      failed = true;
+    }
+  }
+  if (failed) {
+    return -1;
+  }
+
+  // Output results
+  cout << "Main: PIR result correct!" << endl;
+  cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
+       << endl;
+  cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
+       << " ms" << endl;
+
+  return 0;
+}

+ 143 - 0
test/simple_query_test.cpp

@@ -0,0 +1,143 @@
+#include "pir.hpp"
+#include "pir_client.hpp"
+#include "pir_server.hpp"
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <random>
+#include <seal/seal.h>
+
+using namespace std::chrono;
+using namespace std;
+using namespace seal;
+
+int simple_query_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+                      uint32_t lt, uint32_t dim);
+
+int main(int argc, char *argv[]) {
+  // Quick check
+  assert(simple_query_test(1 << 10, 288, 4096, 20, 1) == 0);
+
+  // Forces ciphertext expansion to be the same as the degree
+  assert(simple_query_test(1 << 20, 288, 4096, 20, 1) == 0);
+
+  assert(simple_query_test(1 << 20, 288, 4096, 20, 2) == 0);
+}
+
+int simple_query_test(uint64_t num_items, uint64_t item_size, uint32_t degree,
+                      uint32_t lt, uint32_t dim) {
+  uint64_t number_of_items = num_items;
+  uint64_t size_per_item = item_size; // in bytes
+  uint32_t N = degree;
+
+  // Recommended values: (logt, d) = (12, 2) or (8, 1).
+  uint32_t logt = lt;
+  uint32_t d = dim;
+
+  EncryptionParameters enc_params(scheme_type::bfv);
+  PirParams pir_params;
+
+  // Generates all parameters
+
+  cout << "Main: Generating SEAL parameters" << endl;
+  gen_encryption_params(N, logt, enc_params);
+
+  cout << "Main: Verifying SEAL parameters" << endl;
+  verify_encryption_params(enc_params);
+  cout << "Main: SEAL parameters are good" << endl;
+
+  cout << "Main: Generating PIR parameters" << endl;
+  gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
+
+  // gen_params(number_of_items, size_per_item, N, logt, d, enc_params,
+  // pir_params);
+  print_pir_params(pir_params);
+
+  cout << "Main: Initializing the database (this may take some time) ..."
+       << endl;
+
+  // Create test database
+  auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  // Copy of the database. We use this at the end to make sure we retrieved
+  // the correct element.
+  auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
+
+  random_device rd;
+  for (uint64_t i = 0; i < number_of_items; i++) {
+    for (uint64_t j = 0; j < size_per_item; j++) {
+      uint8_t val = rd() % 256;
+      db.get()[(i * size_per_item) + j] = val;
+      db_copy.get()[(i * size_per_item) + j] = val;
+    }
+  }
+
+  // Initialize PIR Server
+  cout << "Main: Initializing server and client" << endl;
+  PIRServer server(enc_params, pir_params);
+
+  // Initialize PIR client....
+  PIRClient client(enc_params, pir_params);
+  Ciphertext one_ct = client.get_one();
+
+  // Measure database setup
+  auto time_pre_s = high_resolution_clock::now();
+  server.set_database(move(db), number_of_items, size_per_item);
+  server.preprocess_database();
+  server.set_one_ct(one_ct);
+  cout << "Main: database pre processed " << endl;
+  auto time_pre_e = high_resolution_clock::now();
+  auto time_pre_us =
+      duration_cast<microseconds>(time_pre_e - time_pre_s).count();
+
+  // Choose an index of an element in the DB
+  uint64_t ele_index =
+      rd() % number_of_items; // element in DB at random position
+  uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+  uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
+  cout << "Main: element index = " << ele_index << " from [0, "
+       << number_of_items - 1 << "]" << endl;
+  cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
+
+  // Measure query processing (including expansion)
+  auto time_server_s = high_resolution_clock::now();
+  Ciphertext reply = server.simple_query(index);
+  auto time_server_e = high_resolution_clock::now();
+  auto time_server_us =
+      duration_cast<microseconds>(time_server_e - time_server_s).count();
+
+  // Measure response extraction
+  auto time_decode_s = chrono::high_resolution_clock::now();
+  vector<uint8_t> elems = client.extract_bytes(client.decrypt(reply), offset);
+  auto time_decode_e = chrono::high_resolution_clock::now();
+  auto time_decode_us =
+      duration_cast<microseconds>(time_decode_e - time_decode_s).count();
+
+  assert(elems.size() == size_per_item);
+
+  bool failed = false;
+  // Check that we retrieved the correct element
+  for (uint32_t i = 0; i < size_per_item; i++) {
+    if (elems[i] != db_copy.get()[(ele_index * size_per_item) + i]) {
+      cout << "Main: elems " << (int)elems[i] << ", db "
+           << (int)db_copy.get()[(ele_index * size_per_item) + i] << endl;
+      cout << "Main: PIR result wrong at " << i << endl;
+      failed = true;
+    }
+  }
+  if (failed) {
+    return -1;
+  }
+
+  // Output results
+  cout << "Main: PIR result correct!" << endl;
+  cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms"
+       << endl;
+  cout << "Main: PIRServer reply generation time: " << time_server_us / 1000
+       << " ms" << endl;
+  cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000
+       << " ms" << endl;
+
+  return 0;
+}