Browse Source

More cleaning up

Andrew Beams 3 years ago
parent
commit
4c2a028e87
4 changed files with 6 additions and 128 deletions
  1. 0 107
      encoding_test.cpp
  2. 1 1
      main.cpp
  3. 3 18
      pir_server.cpp
  4. 2 2
      pir_server.hpp

+ 0 - 107
encoding_test.cpp

@@ -1,107 +0,0 @@
-#include "pir.hpp"
-#include "pir_client.hpp"
-#include "pir_server.hpp"
-#include <seal/seal.h>
-#include <memory>
-#include <random>
-#include <cstdint>
-#include <cstddef>
-
-using namespace std;
-using namespace seal;
-
-int main(int argc, char *argv[]) {
-    uint64_t number_of_items = 1 << 12;
-    uint64_t size_per_item = 288; // in bytes
-    uint32_t N = 4096;
-
-    // Recommended values: (logt, d) = (12, 2) or (8, 1). 
-    uint32_t logt = 20; 
-    uint32_t d = 1;
-
-    EncryptionParameters params(scheme_type::bfv);
-    PirParams pir_params;
-
-    // Generates all parameters
-    cout << "Main: Generating all parameters" << endl;
-    gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
-
-    logt = floor(log2(params.plain_modulus().value()));
-
-    cout << "Main: Initializing the database (this may take some time) ..." << endl;
-
-    // Create test database
-    auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    // Copy of the database. We use this at the end to make sure we retrieved
-    // the correct element.
-    auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
-
-    random_device rd;
-    for (uint64_t i = 0; i < number_of_items; i++) {
-        for (uint64_t j = 0; j < size_per_item; j++) {
-            uint8_t val = rd() % 256;
-            db.get()[(i * size_per_item) + j] = val;
-            db_copy.get()[(i * size_per_item) + j] = val;
-        }
-    }
-
-    shared_ptr<SEALContext> context = make_shared<SEALContext>(params, true);
-    unique_ptr<KeyGenerator> keygen = make_unique<KeyGenerator>(*context);
-    
-    PublicKey public_key;
-    keygen->create_public_key(public_key);
-    
-    unique_ptr<Encryptor> encryptor = make_unique<Encryptor>(*context, public_key);
-
-    SecretKey secret_key = keygen->secret_key();
-    unique_ptr<Decryptor> decryptor = make_unique<Decryptor>(*context, secret_key);
-
-    unique_ptr<Evaluator> evaluator = make_unique<Evaluator>(*context);
-
-
-    uint64_t ele_per_ptxt = elements_per_ptxt(logt, N, size_per_item);
-    uint64_t bytes_per_ptxt = ele_per_ptxt * size_per_item;
-
-    uint64_t db_size = number_of_items * size_per_item;
-
-    uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, size_per_item);
-    assert(coeff_per_ptxt <= N);
-
-
-    vector<uint64_t> coefficients = bytes_to_coeffs(logt, db.get(), size_per_item);
-    uint64_t used = coefficients.size();
-
-    assert(used <= coeff_per_ptxt);
-
-    // Pad the rest with 1s
-    for (uint64_t j = 0; j < (N - used); j++) {
-        coefficients.push_back(1);
-    }
-
-    Plaintext plain;
-    vector_to_plaintext(coefficients, plain);
-
-    //cout << "Plaintext: " << plain.to_string() << endl;
-
-    vector<uint8_t> elems(N * logt / 8);
-    coeffs_to_bytes(logt, plain, elems.data(), (N * logt) / 8);
-
-    bool failed = false;
-    // Check that we retrieved the correct element
-    for (uint32_t i = 0; i < size_per_item; i++) {
-        if (elems[i] != db_copy.get()[i]) {
-            cout << "Main: elems " << (int)elems[i] << ", db "
-                 << (int) db_copy.get()[i] << endl;
-            cout << "Main: PIR result wrong at " << i <<  endl;
-            failed = true;
-        }
-    }
-    if(failed){
-        return -1;
-    }
-    else{
-        cout << "succeeded" << endl;
-    }
-
-}

+ 1 - 1
main.cpp

@@ -87,7 +87,7 @@ int main(int argc, char *argv[]) {
 
     // Measure query processing (including expansion)
     auto time_server_s = high_resolution_clock::now();
-    PirReply reply = server.generate_reply(query, 0, client);
+    PirReply reply = server.generate_reply(query, 0);
     auto time_server_e = high_resolution_clock::now();
     auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
 

+ 3 - 18
pir_server.cpp

@@ -132,7 +132,7 @@ void PIRServer::set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey)
     galoisKeys_[client_id] = galkey;
 }
 
-PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, const PIRClient& client) {
+PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id) {
 
     vector<uint64_t> nvec = pir_params_.nvec;
     uint64_t product = 1;
@@ -169,7 +169,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, const PIR
                 total = ((n_i - 1) % N) + 1; 
             }
             cout << "-- expanding one query ctxt into " << total  << " ctxts "<< endl;
-            vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id, client);
+            vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id);
             expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()), 
                     std::make_move_iterator(expanded_query_part.end()));
             expanded_query_part.clear(); 
@@ -179,21 +179,6 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, const PIR
             cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl; 
         }    
 
-        
-        cout << "Checking expanded query " << endl; 
-        Plaintext tempPt; 
-        for (int h = 0 ; h < expanded_query.size(); h++){
-            client.decryptor_->decrypt(expanded_query[h], tempPt); 
-            if(tempPt.is_zero()){
-                continue;
-            }
-            cout << "index: " << h << ", ";
-            cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", "; 
-            cout << tempPt.to_string()  << endl; 
-        }
-        cout << endl;
-        
-
         // Transform expanded query to NTT, and ...
         for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
             evaluator_->transform_to_ntt_inplace(expanded_query[jj]);
@@ -267,7 +252,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, const PIR
 }
 
 inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, uint32_t m,
-                                           uint32_t client_id, const PIRClient& client) {
+                                           uint32_t client_id) {
 
 #ifdef DEBUG
     uint64_t plainMod = params_.plain_modulus().value();

+ 2 - 2
pir_server.hpp

@@ -17,9 +17,9 @@ class PIRServer {
     void preprocess_database();
 
     std::vector<seal::Ciphertext> expand_query(
-            const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id, const PIRClient& client);
+            const seal::Ciphertext &encrypted, std::uint32_t m, uint32_t client_id);
 
-    PirReply generate_reply(PirQuery query, std::uint32_t client_id, const PIRClient& client);
+    PirReply generate_reply(PirQuery query, std::uint32_t client_id);
 
     void set_galois_key(std::uint32_t client_id, seal::GaloisKeys galkey);