Browse Source

Client and server now get num_of_plaintexts and elements_per_plaintext from pir_params

Andrew Beams 3 years ago
parent
commit
d10a17fecd
4 changed files with 15 additions and 23 deletions
  1. 2 2
      main.cpp
  2. 4 12
      pir_client.cpp
  3. 2 2
      pir_client.hpp
  4. 7 7
      pir_server.cpp

+ 2 - 2
main.cpp

@@ -83,8 +83,8 @@ int main(int argc, char *argv[]) {
 
     // Choose an index of an element in the DB
     uint64_t ele_index = rd() % number_of_items; // element in DB at random position
-    uint64_t index = client.get_fv_index(ele_index, size_per_item);   // index of FV plaintext
-    uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
+    uint64_t index = client.get_fv_index(ele_index);   // index of FV plaintext
+    uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
     cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
     cout << "Main: FV index = " << index << ", FV offset = " << offset << endl; 
 

+ 4 - 12
pir_client.cpp

@@ -61,20 +61,12 @@ PirQuery PIRClient::generate_query(uint64_t desiredIndex) {
     return result;
 }
 
-uint64_t PIRClient::get_fv_index(uint64_t element_idx, uint64_t ele_size) {
-    auto N = enc_params_.poly_modulus_degree();
-    auto logt = floor(log2(enc_params_.plain_modulus().value()));
-
-    auto ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
-    return static_cast<uint64_t>(element_idx / ele_per_ptxt);
+uint64_t PIRClient::get_fv_index(uint64_t element_index) {
+    return static_cast<uint64_t>(element_index / pir_params_.elements_per_plaintext);
 }
 
-uint64_t PIRClient::get_fv_offset(uint64_t element_idx, uint64_t ele_size) {
-    uint32_t N = enc_params_.poly_modulus_degree();
-    uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
-
-    uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
-    return element_idx % ele_per_ptxt;
+uint64_t PIRClient::get_fv_offset(uint64_t element_index) {
+    return element_index % pir_params_.elements_per_plaintext;
 }
 
 Plaintext PIRClient::decode_reply(PirReply reply) {

+ 2 - 2
pir_client.hpp

@@ -17,8 +17,8 @@ class PIRClient {
     seal::GaloisKeys generate_galois_keys();
 
     // Index and offset of an element in an FV plaintext
-    uint64_t get_fv_index(uint64_t element_idx, uint64_t ele_size);
-    uint64_t get_fv_offset(uint64_t element_idx, uint64_t ele_size);
+    uint64_t get_fv_index(uint64_t element_index);
+    uint64_t get_fv_offset(uint64_t element_index);
 
 
   private:

+ 7 - 7
pir_server.cpp

@@ -43,7 +43,7 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
     uint32_t N = enc_params_.poly_modulus_degree();
 
     // number of FV plaintexts needed to represent all elements
-    uint64_t total = plaintexts_per_db(logt, N, ele_num, ele_size);
+    uint64_t num_of_plaintexts = pir_params_.num_of_plaintexts;
 
     // number of FV plaintexts needed to create the d-dimensional matrix
     uint64_t prod = 1;
@@ -52,15 +52,15 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
     }
     uint64_t matrix_plaintexts = prod;
 
-    cout << "Total: " << total << endl;
+    cout << "Total: " << num_of_plaintexts << endl;
     cout << "Prod: " << prod << endl;
 
-    assert(total <= matrix_plaintexts);
+    assert(num_of_plaintexts <= matrix_plaintexts);
 
     auto result = make_unique<vector<Plaintext>>();
     result->reserve(matrix_plaintexts);
 
-    uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, ele_size);
+    uint64_t ele_per_ptxt = pir_params_.elements_per_plaintext;
     uint64_t bytes_per_ptxt = ele_per_ptxt * ele_size;
 
     uint64_t db_size = ele_num * ele_size;
@@ -68,12 +68,12 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
     uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, ele_size);
     assert(coeff_per_ptxt <= N);
 
-    cout << "Server: total number of FV plaintext = " << total << endl;
+    cout << "Server: num_of_plaintexts number of FV plaintext = " << num_of_plaintexts << endl;
     cout << "Server: elements packed into each plaintext " << ele_per_ptxt << endl; 
 
     uint32_t offset = 0;
 
-    for (uint64_t i = 0; i < total; i++) {
+    for (uint64_t i = 0; i < num_of_plaintexts; i++) {
 
         uint64_t process_bytes = 0;
 
@@ -106,7 +106,7 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
 
     // Add padding to make database a matrix
     uint64_t current_plaintexts = result->size();
-    assert(current_plaintexts <= total);
+    assert(current_plaintexts <= num_of_plaintexts);
 
 #ifdef DEBUG
     cout << "adding: " << matrix_plaintexts - current_plaintexts