Browse Source

Started on replace functionality

Andrew Beams 2 years ago
parent
commit
e90710931c
3 changed files with 40 additions and 2 deletions
  1. 20 0
      src/pir_client.cpp
  2. 16 2
      src/pir_server.cpp
  3. 4 0
      test/CMakeLists.txt

+ 20 - 0
src/pir_client.cpp

@@ -276,6 +276,26 @@ Ciphertext PIRClient::compose_to_ciphertext(vector<Plaintext> plains) {
     return result;
 }
 
+Plaintext PIRClient::replace_element(Plaintext pt, vector<uint64_t> new_element, uint64_t offset){
+    vector<uint64_t> coeffs = extract_coeffs(pt);
+    
+    uint32_t logt = floor(log2(enc_params_.plain_modulus().value()));
+    uint64_t coeffs_per_element = coefficients_per_element(logt, pir_params_.ele_size);
+
+    assert(new_element.size() == coeffs_per_element);
+
+    for(uint64_t i = 0; i < coeffs_per_element; i++){
+        cout << "Replacing " << coeffs[i + offset * coeffs_per_element];
+        cout << " with " << new_element[i] << endl;
+        coeffs[i + offset * coeffs_per_element] = new_element[i];
+    }
+    
+    Plaintext new_pt;
+
+    encoder_->encode(coeffs, new_pt);
+    return new_pt;
+}
+
 Ciphertext PIRClient::get_one(){
     Plaintext pt("1");
     Ciphertext ct;

+ 16 - 2
src/pir_server.cpp

@@ -61,6 +61,7 @@ void PIRServer::set_database(const std::unique_ptr<const uint8_t[]> &bytes,
     uint64_t ele_per_ptxt = pir_params_.elements_per_plaintext;
     uint64_t bytes_per_ptxt = ele_per_ptxt * ele_size;
 
+    cout << "Bytes per ptxt: " << bytes_per_ptxt << endl;
     uint64_t db_size = ele_num * ele_size;
 
     uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, ele_size);
@@ -79,13 +80,18 @@ void PIRServer::set_database(const std::unique_ptr<const uint8_t[]> &bytes,
         } else {
             process_bytes = bytes_per_ptxt;
         }
+        cout << "Process bytes: " << process_bytes << endl;
 
         // Get the coefficients of the elements that will be packed in plaintext i
-        vector<uint64_t> coefficients = bytes_to_coeffs(logt, bytes.get() + offset, process_bytes);
+        vector<uint64_t> coefficients;
+        for(int j = 0; j < process_bytes; j += ele_size){
+            vector<uint64_t> to_add = bytes_to_coeffs(logt, bytes.get() + offset + j, ele_size);
+            coefficients.insert(coefficients.end(),to_add.begin(),to_add.end());
+        } 
         offset += process_bytes;
 
         uint64_t used = coefficients.size();
-
+        cout << "Used: " << used << endl;
         assert(used <= coeff_per_ptxt);
 
         // Pad the rest with 1s
@@ -462,6 +468,14 @@ vector<Plaintext> PIRServer::decompose_to_plaintexts(const Ciphertext &encrypted
     return result;
 }
 
+void PIRServer::simple_set(uint64_t index, Plaintext pt){
+    if(is_db_preprocessed_){
+        evaluator_->transform_to_ntt_inplace(
+                pt, context_->first_parms_id());
+    }
+    db_->operator[](index) = pt;
+}
+
 Ciphertext PIRServer::simple_query(uint64_t index){
     //There is no transform_from_ntt that takes a plaintext
     Ciphertext ct;

+ 4 - 0
test/CMakeLists.txt

@@ -11,3 +11,7 @@ add_test(NAME query_test COMMAND query_test)
 add_executable(simple_query_test simple_query_test.cpp)
 target_link_libraries(simple_query_test sealpir)
 add_test(NAME simple_query_test COMMAND simple_query_test)
+
+add_executable(replace_test replace_test.cpp)
+target_link_libraries(replace_test sealpir)
+add_test(NAME replace_test COMMAND replace_test)