Browse Source

Major refactoring of pir.cpp

Andrew Beams 3 years ago
parent
commit
3b4ec66832
4 changed files with 138 additions and 28 deletions
  1. 19 6
      main.cpp
  2. 84 10
      pir.cpp
  3. 32 12
      pir.hpp
  4. 3 0
      pir_server.cpp

+ 19 - 6
main.cpp

@@ -22,12 +22,25 @@ int main(int argc, char *argv[]) {
     uint32_t logt = 20; 
     uint32_t d = 2;
 
-    EncryptionParameters params(scheme_type::bfv);
+    EncryptionParameters enc_params(scheme_type::bfv);
     PirParams pir_params;
 
     // Generates all parameters
-    cout << "Main: Generating all parameters" << endl;
-    gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
+    
+    cout << "Main: Generating SEAL parameters" << endl;
+    gen_encryption_params(N, logt, enc_params);
+    
+    cout << "Main: Verifying SEAL parameters" << endl;
+    verify_encryption_params(enc_params);
+    cout << "Main: SEAL parameters are good" << endl;
+
+    cout << "Main: Generating PIR parameters" << endl;
+    gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params);
+    
+    
+    
+    //gen_params(number_of_items, size_per_item, N, logt, d, enc_params, pir_params);
+    print_pir_params(pir_params);
 
     cout << "Main: Initializing the database (this may take some time) ..." << endl;
 
@@ -49,10 +62,10 @@ int main(int argc, char *argv[]) {
 
     // Initialize PIR Server
     cout << "Main: Initializing server and client" << endl;
-    PIRServer server(params, pir_params);
+    PIRServer server(enc_params, pir_params);
 
     // Initialize PIR client....
-    PIRClient client(params, pir_params);
+    PIRClient client(enc_params, pir_params);
     GaloisKeys galois_keys = client.generate_galois_keys();
 
     // Set galois key for client with id 0
@@ -97,7 +110,7 @@ int main(int argc, char *argv[]) {
     auto time_decode_e = chrono::high_resolution_clock::now();
     auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
 
-    logt = floor(log2(params.plain_modulus().value()));
+    logt = floor(log2(enc_params.plain_modulus().value()));
 
     // Convert from FV plaintext (polynomial) to database element at the client
     vector<uint8_t> elems(N * logt / 8);

+ 84 - 10
pir.cpp

@@ -27,19 +27,93 @@ std::vector<std::uint64_t> get_dimensions(std::uint64_t num_of_plaintexts, std::
     return dimensions;
 }
 
+void gen_encryption_params(std::uint32_t N, std::uint32_t logt,
+                           seal::EncryptionParameters &enc_params){
+    
+    enc_params.set_poly_modulus_degree(N);
+    enc_params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
+    enc_params.set_plain_modulus(PlainModulus::Batching(N, logt));
+}
+
+void verify_encryption_params(const seal::EncryptionParameters &enc_params){
+    SEALContext context(enc_params, true);
+    if(!context.parameters_set()){
+        throw invalid_argument("SEAL parameters not valid.");
+    }
+    if(!context.using_keyswitching()){
+        throw invalid_argument("SEAL parameters do not support key switching.");
+    }
+    if(!context.first_context_data()->qualifiers().using_batching){
+        throw invalid_argument("SEAL parameters do not support batching.");
+    }
+    return;
+}
+
+void gen_pir_params(uint64_t ele_num, uint64_t ele_size, uint32_t d,
+                    const EncryptionParameters &enc_params, PirParams &pir_params,
+                    bool enable_symmetric, bool enable_batching){
+    std::uint32_t N = enc_params.poly_modulus_degree();
+    Modulus t = enc_params.plain_modulus();
+    std::uint32_t logt = floor(log2(t.value()));
+
+    cout << "logt: " << logt << endl << "N: " << N << endl <<
+    "ele_num: " << ele_num << endl << "ele_size: " << ele_size << endl;
+
+    std::uint64_t elements_per_plaintext;
+    std::uint64_t num_of_plaintexts;
+
+    if(enable_batching){
+        elements_per_plaintext = elements_per_ptxt(logt, N, ele_size);
+        num_of_plaintexts = plaintexts_per_db(logt, N, ele_num, ele_size);
+    }
+    else{
+        elements_per_plaintext = 1;
+        num_of_plaintexts = ele_num;
+    }
+
+    vector<uint64_t> nvec = get_dimensions(num_of_plaintexts, d);
+
+    uint32_t expansion_ratio = 0;
+    for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
+        double logqi = log2(enc_params.coeff_modulus()[i].value());
+        cout << "PIR: logqi = " << logqi << endl;
+        expansion_ratio += ceil(logqi / logt);
+    }
+
+    if(!enable_symmetric){
+        expansion_ratio = expansion_ratio << 1;
+    }
+
+    pir_params.enable_symmetric = enable_symmetric;
+    pir_params.enable_batching = enable_batching;
+    pir_params.ele_num = ele_num;
+    pir_params.ele_size = ele_size;
+    pir_params.elements_per_plaintext = elements_per_plaintext;
+    pir_params.num_of_plaintexts = num_of_plaintexts;
+    pir_params.d = d;                 
+    pir_params.expansion_ratio = expansion_ratio;           
+    pir_params.nvec = nvec;
+    pir_params.dbc = 6;
+    pir_params.n = num_of_plaintexts;
+}
+
+
+void print_pir_params(const PirParams &pir_params){
+    cout << "Pir Params: " << endl;
+    cout << "num_of_elements: " << pir_params.ele_num << endl;
+    cout << "ele_size: " << pir_params.ele_size << endl;
+    cout << "elements_per_plaintext: " << pir_params.elements_per_plaintext << endl;
+    cout << "num_of_plaintexts: " << pir_params.num_of_plaintexts << endl;
+    cout << "dimension: " << pir_params.d << endl;
+    cout << "expansion ratio: " << pir_params.expansion_ratio << endl;
+    cout << "dbc: " << pir_params.dbc << endl;
+    cout << "n: " << pir_params.n << endl;
+}
+
 void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
                 uint32_t d, EncryptionParameters &params,
                 PirParams &pir_params) {
     
-    // Determine the maximum size of each dimension
-
-    // plain modulus = a power of 2 plus 1
-    uint64_t plain_mod = (static_cast<uint64_t>(1) << logt) + 1;
-
-#ifdef DEBUG
-    cout << "log(plain mod) before expand = " << logt << endl;
-    cout << "number of FV plaintexts = " << plaintext_num << endl;
-#endif
 
     params.set_poly_modulus_degree(N);
     params.set_coeff_modulus(CoeffModulus::BFVDefault(N));
@@ -57,7 +131,7 @@ void gen_params(uint64_t ele_num, uint64_t ele_size, uint32_t N, uint32_t logt,
     uint32_t expansion_ratio = 0;
     for (uint32_t i = 0; i < params.coeff_modulus().size(); ++i) {
         double logqi = log2(params.coeff_modulus()[i].value());
-        cout << "PIR: logqi = " << logqi << endl; 
+        cout << "PIR: logqi = " << logqi << endl;
         expansion_ratio += ceil(logqi / logt);
     }
 

+ 32 - 12
pir.hpp

@@ -7,28 +7,48 @@
 #include <string>
 #include <vector>
 
-#define CIPHER_SIZE 32841
-
 typedef std::vector<seal::Plaintext> Database;
 typedef std::vector<std::vector<seal::Ciphertext>> PirQuery;
 typedef std::vector<seal::Ciphertext> PirReply;
 
 struct PirParams {
-    std::uint64_t n;                 // number of plaintexts in database
-    std::uint32_t d;                 // number of dimensions for the database (1 or 2)
-    std::uint32_t expansion_ratio;   // ratio of ciphertext to plaintext
-    std::uint32_t dbc;               // decomposition bit count (used by relinearization)
-    std::vector<std::uint64_t> nvec; // size of each of the d dimensions
+    bool enable_symmetric;
+    bool enable_batching;
+    std::uint64_t ele_num;
+    std::uint64_t ele_size;
+    std::uint64_t elements_per_plaintext;
+    std::uint64_t num_of_plaintexts;         // number of plaintexts in database
+    std::uint32_t d;                 // number of dimensions for the database
+    std::uint32_t expansion_ratio;           // ratio of ciphertext to plaintext
+    std::vector<std::uint64_t> nvec;         // size of each of the d dimensions
+    std::uint32_t dbc;
+    std::uint32_t n;
 };
 
-void gen_params(std::uint64_t ele_num,  // number of elements (not FV plaintexts) in database
-                std::uint64_t ele_size, // size of each element
-                std::uint32_t N,        // degree of polynomial
-                std::uint32_t logt,     // bits of plaintext coefficient
-                std::uint32_t d,        // dimension of database
+void gen_encryption_params(std::uint32_t N,        // degree of polynomial
+                           std::uint32_t logt,     // bits of plaintext coefficient
+                           seal::EncryptionParameters &enc_params);
+
+void gen_pir_params(uint64_t ele_num,
+                    uint64_t ele_size,
+                    uint32_t d,
+                    const seal::EncryptionParameters &enc_params,
+                    PirParams &pir_params,
+                    bool enable_symmetric = false,
+                    bool enable_batching = true);
+
+void gen_params(uint64_t ele_num,
+                uint64_t ele_size,
+                uint32_t N,
+                uint32_t logt,
+                uint32_t d,
                 seal::EncryptionParameters &params,
                 PirParams &pir_params);
 
+void verify_encryption_params(const seal::EncryptionParameters &enc_params);
+
+void print_pir_params(const PirParams &pir_params);
+
 // returns the plaintext modulus after expansion
 std::uint32_t plainmod_after_expansion(std::uint32_t logt, std::uint32_t N, 
                                        std::uint32_t d, std::uint64_t ele_num,

+ 3 - 0
pir_server.cpp

@@ -52,6 +52,9 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
     }
     uint64_t matrix_plaintexts = prod;
 
+    cout << "Total: " << total << endl;
+    cout << "Prod: " << prod << endl;
+
     assert(total <= matrix_plaintexts);
 
     auto result = make_unique<vector<Plaintext>>();