|  | @@ -83,6 +83,12 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
 | 
	
		
			
				|  |  |      uint64_t coeff_per_ptxt = ele_per_ptxt * coefficients_per_element(logt, ele_size);
 | 
	
		
			
				|  |  |      assert(coeff_per_ptxt <= N);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    cout << "Server: total number of FV plaintext = " << total << endl;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    cout << "Server: elements packed into each plaintext " << ele_per_ptxt << endl; 
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      uint32_t offset = 0;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      for (uint64_t i = 0; i < total; i++) {
 | 
	
	
		
			
				|  | @@ -96,7 +102,6 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
 | 
	
		
			
				|  |  |          } else {
 | 
	
		
			
				|  |  |              process_bytes = bytes_per_ptxt;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |          // Get the coefficients of the elements that will be packed in plaintext i
 | 
	
		
			
				|  |  |          vector<uint64_t> coefficients = bytes_to_coeffs(logt, bytes.get() + offset, process_bytes);
 | 
	
		
			
				|  |  |          offset += process_bytes;
 | 
	
	
		
			
				|  | @@ -112,7 +117,7 @@ void PIRServer::set_database(const std::unique_ptr<const std::uint8_t[]> &bytes,
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          Plaintext plain;
 | 
	
		
			
				|  |  |          vector_to_plaintext(coefficients, plain);
 | 
	
		
			
				|  |  | -        cout << i << "-th encoded plaintext = " << plain.to_string() << endl; 
 | 
	
		
			
				|  |  | +        // cout << i << "-th encoded plaintext = " << plain.to_string() << endl; 
 | 
	
		
			
				|  |  |          result->push_back(move(plain));
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -159,10 +164,39 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      auto pool = MemoryManager::GetPool();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    int N = params_.poly_modulus_degree();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    int logt = floor(log2(params_.plain_modulus().value()));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    cout << "expansion ratio = " << pir_params_.expansion_ratio << endl; 
 | 
	
		
			
				|  |  |      for (uint32_t i = 0; i < nvec.size(); i++) {
 | 
	
		
			
				|  |  | +        cout << "Server: " << i + 1 << "-th recursion level started " << endl; 
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        vector<Ciphertext> expanded_query; 
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          uint64_t n_i = nvec[i];
 | 
	
		
			
				|  |  | -        vector<Ciphertext> expanded_query = expand_query(query[i], n_i, client_id, client);
 | 
	
		
			
				|  |  | -        cout << "Checking expanded query "; 
 | 
	
		
			
				|  |  | +        cout << "Server: n_i = " << n_i << endl; 
 | 
	
		
			
				|  |  | +        cout << "Server: expanding " << query[i].size() << "query ctxts" << endl;
 | 
	
		
			
				|  |  | +        for (uint32_t j = 0; j < query[i].size(); j++){
 | 
	
		
			
				|  |  | +            uint64_t total = N; 
 | 
	
		
			
				|  |  | +            if (j == query[i].size() - 1){
 | 
	
		
			
				|  |  | +                total = n_i % N; 
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            cout << "-- expanding one query ctxt into " << total  << " ctxts "<< endl;
 | 
	
		
			
				|  |  | +            vector<Ciphertext> expanded_query_part = expand_query(query[i][j], total, client_id, client);
 | 
	
		
			
				|  |  | +            expanded_query.insert(expanded_query.end(), std::make_move_iterator(expanded_query_part.begin()), 
 | 
	
		
			
				|  |  | +                    std::make_move_iterator(expanded_query_part.end()));
 | 
	
		
			
				|  |  | +            expanded_query_part.clear(); 
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        cout << "Server: expansion done " << endl; 
 | 
	
		
			
				|  |  | +        if (expanded_query.size() != n_i) {
 | 
	
		
			
				|  |  | +            cout << " size mismatch!!! " << expanded_query.size() << ", " << n_i << endl; 
 | 
	
		
			
				|  |  | +        }    
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /*
 | 
	
		
			
				|  |  | +        cout << "Checking expanded query " << endl; 
 | 
	
		
			
				|  |  |          Plaintext tempPt; 
 | 
	
		
			
				|  |  |          for (int h = 0 ; h < expanded_query.size(); h++){
 | 
	
		
			
				|  |  |              cout << "noise budget = " << client.decryptor_->invariant_noise_budget(expanded_query[h]) << ", "; 
 | 
	
	
		
			
				|  | @@ -170,6 +204,7 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
 | 
	
		
			
				|  |  |              cout << tempPt.to_string()  << endl; 
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |          cout << endl;
 | 
	
		
			
				|  |  | +        */
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          // Transform expanded query to NTT, and ...
 | 
	
		
			
				|  |  |          for (uint32_t jj = 0; jj < expanded_query.size(); jj++) {
 | 
	
	
		
			
				|  | @@ -183,26 +218,40 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        for (uint64_t k = 0; k < product; k++) {
 | 
	
		
			
				|  |  | +            if ((*cur)[k].is_zero()){
 | 
	
		
			
				|  |  | +                cout << k + 1 << "/ " << product <<  "-th ptxt = 0 " << endl; 
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          product /= n_i;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        vector<Ciphertext> intermediate(product);
 | 
	
		
			
				|  |  | +        vector<Ciphertext> intermediateCtxts(product);
 | 
	
		
			
				|  |  |          Ciphertext temp;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          for (uint64_t k = 0; k < product; k++) {
 | 
	
		
			
				|  |  | -            evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediate[k]);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            evaluator_->multiply_plain(expanded_query[0], (*cur)[k], intermediateCtxts[k]);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              for (uint64_t j = 1; j < n_i; j++) {
 | 
	
		
			
				|  |  |                  evaluator_->multiply_plain(expanded_query[j], (*cur)[k + j * product], temp);
 | 
	
		
			
				|  |  | -                evaluator_->add_inplace(intermediate[k], temp); // Adds to first component.
 | 
	
		
			
				|  |  | +                evaluator_->add_inplace(intermediateCtxts[k], temp); // Adds to first component.
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        for (uint32_t jj = 0; jj < intermediate.size(); jj++) {
 | 
	
		
			
				|  |  | -            evaluator_->transform_from_ntt_inplace(intermediate[jj]);
 | 
	
		
			
				|  |  | +        for (uint32_t jj = 0; jj < intermediateCtxts.size(); jj++) {
 | 
	
		
			
				|  |  | +            evaluator_->transform_from_ntt_inplace(intermediateCtxts[jj]);
 | 
	
		
			
				|  |  | +            // print intermediate ctxts? 
 | 
	
		
			
				|  |  | +            //cout << "const term of ctxt " << jj << " = " << intermediateCtxts[jj][0] << endl; 
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          if (i == nvec.size() - 1) {
 | 
	
		
			
				|  |  | -            return intermediate;
 | 
	
		
			
				|  |  | +            return intermediateCtxts;
 | 
	
		
			
				|  |  |          } else {
 | 
	
		
			
				|  |  |              intermediate_plain.clear();
 | 
	
		
			
				|  |  |              intermediate_plain.reserve(pir_params_.expansion_ratio * product);
 | 
	
	
		
			
				|  | @@ -214,19 +263,20 @@ PirReply PIRServer::generate_reply(PirQuery query, uint32_t client_id, PIRClient
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              for (uint64_t rr = 0; rr < product; rr++) {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -                decompose_to_plaintexts_ptr(intermediate[rr],
 | 
	
		
			
				|  |  | -                    tempplain.get() + rr * pir_params_.expansion_ratio);
 | 
	
		
			
				|  |  | +                decompose_to_plaintexts_ptr(intermediateCtxts[rr],
 | 
	
		
			
				|  |  | +                    tempplain.get() + rr * pir_params_.expansion_ratio, logt);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |                  for (uint32_t jj = 0; jj < pir_params_.expansion_ratio; jj++) {
 | 
	
		
			
				|  |  |                      auto offset = rr * pir_params_.expansion_ratio + jj;
 | 
	
		
			
				|  |  |                      intermediate_plain.emplace_back(tempplain[offset]);
 | 
	
		
			
				|  |  |                  }
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |              product *= pir_params_.expansion_ratio; // multiply by expansion rate.
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | +        cout << "Server: " << i + 1 << "-th recursion level finished " << endl; 
 | 
	
		
			
				|  |  | +        cout << endl;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +    cout << "reply generated!  " << endl;
 | 
	
		
			
				|  |  |      // This should never get here
 | 
	
		
			
				|  |  |      assert(0);
 | 
	
		
			
				|  |  |      vector<Ciphertext> fail(1);
 | 
	
	
		
			
				|  | @@ -252,13 +302,9 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
 | 
	
		
			
				|  |  |      if (logm > ceil(log2(n))){
 | 
	
		
			
				|  |  |          throw logic_error("m > n is not allowed."); 
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    cout << "galois elts at server: "; 
 | 
	
		
			
				|  |  | -    for (uint32_t i = 0; i < logm; i++) {
 | 
	
		
			
				|  |  | +    for (int i = 0; i < ceil(log2(n)); i++) {
 | 
	
		
			
				|  |  |          galois_elts.push_back((n + exponentiate_uint64(2, i)) / exponentiate_uint64(2, i));
 | 
	
		
			
				|  |  | -        cout << galois_elts.back() << ", "; 
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  | -    cout << endl;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      vector<Ciphertext> temp;
 | 
	
		
			
				|  |  |      temp.push_back(encrypted);
 | 
	
	
		
			
				|  | @@ -267,44 +313,45 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
 | 
	
		
			
				|  |  |      Ciphertext tempctxt_shifted;
 | 
	
		
			
				|  |  |      Ciphertext tempctxt_rotatedshifted;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      for (uint32_t i = 0; i < logm - 1; i++) {
 | 
	
		
			
				|  |  |          vector<Ciphertext> newtemp(temp.size() << 1);
 | 
	
		
			
				|  |  |          // temp[a] = (j0 = a (mod 2**i) ? ) : Enc(x^{j0 - a}) else Enc(0).  With
 | 
	
		
			
				|  |  |          // some scaling....
 | 
	
		
			
				|  |  |          int index_raw = (n << 1) - (1 << i);
 | 
	
		
			
				|  |  | +        // TODO: galois elements. 
 | 
	
		
			
				|  |  |          int index = (index_raw * galois_elts[i]) % (n << 1);
 | 
	
		
			
				|  |  | -        cout << i << "-th expansion round, noise budget = " << endl; 
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          for (uint32_t a = 0; a < temp.size(); a++) {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              evaluator_->apply_galois(temp[a], galois_elts[i], galkey, tempctxt_rotated);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            cout << "rotate " << client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", "; 
 | 
	
		
			
				|  |  | +            //cout << "rotate " << client.decryptor_->invariant_noise_budget(tempctxt_rotated) << ", "; 
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
 | 
	
		
			
				|  |  |              multiply_power_of_X(temp[a], tempctxt_shifted, index_raw);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", "; 
 | 
	
		
			
				|  |  | +            //cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_shifted) << ", "; 
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              multiply_power_of_X(tempctxt_rotated, tempctxt_rotatedshifted, index);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) << ", "; 
 | 
	
		
			
				|  |  | +            // cout << "mul by x^pow: " << client.decryptor_->invariant_noise_budget(tempctxt_rotatedshifted) << ", "; 
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              // Enc(2^i x^j) if j = 0 (mod 2**i).
 | 
	
		
			
				|  |  |              evaluator_->add(tempctxt_shifted, tempctxt_rotatedshifted, newtemp[a + temp.size()]);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |          temp = newtemp;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        /*
 | 
	
		
			
				|  |  |          cout << "end: "; 
 | 
	
		
			
				|  |  |          for (int h = 0; h < temp.size();h++){
 | 
	
		
			
				|  |  |              cout << client.decryptor_->invariant_noise_budget(temp[h]) << ", "; 
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |          cout << endl; 
 | 
	
		
			
				|  |  | +        */
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      // Last step of the loop
 | 
	
		
			
				|  |  |      vector<Ciphertext> newtemp(temp.size() << 1);
 | 
	
		
			
				|  |  |      int index_raw = (n << 1) - (1 << (logm - 1));
 | 
	
	
		
			
				|  | @@ -312,7 +359,7 @@ inline vector<Ciphertext> PIRServer::expand_query(const Ciphertext &encrypted, u
 | 
	
		
			
				|  |  |      for (uint32_t a = 0; a < temp.size(); a++) {
 | 
	
		
			
				|  |  |          if (a >= (m - (1 << (logm - 1)))) {                       // corner case.
 | 
	
		
			
				|  |  |              evaluator_->multiply_plain(temp[a], two, newtemp[a]); // plain multiplication by 2.
 | 
	
		
			
				|  |  | -            cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", "; 
 | 
	
		
			
				|  |  | +            // cout << client.decryptor_->invariant_noise_budget(newtemp[a]) << ", "; 
 | 
	
		
			
				|  |  |          } else {
 | 
	
		
			
				|  |  |              evaluator_->apply_galois(temp[a], galois_elts[logm - 1], galkey, tempctxt_rotated);
 | 
	
		
			
				|  |  |              evaluator_->add(temp[a], tempctxt_rotated, newtemp[a]);
 | 
	
	
		
			
				|  | @@ -353,17 +400,16 @@ inline void PIRServer::multiply_power_of_X(const Ciphertext &encrypted, Cipherte
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr) {
 | 
	
		
			
				|  |  | +inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted, Plaintext *plain_ptr, int logt) {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      vector<Plaintext> result;
 | 
	
		
			
				|  |  |      auto coeff_count = params_.poly_modulus_degree();
 | 
	
		
			
				|  |  |      auto coeff_mod_count = params_.coeff_modulus().size();
 | 
	
		
			
				|  |  |      auto encrypted_count = encrypted.size();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    // Generate powers of t.
 | 
	
		
			
				|  |  | -    uint64_t plainModMinusOne = params_.plain_modulus().value() - 1;
 | 
	
		
			
				|  |  | -    int exp = ceil(log2(plainModMinusOne + 1));
 | 
	
		
			
				|  |  | +    uint64_t t1 = 1 << logt;  //  t1 <= t. 
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    uint64_t t1minusone =  t1 -1; 
 | 
	
		
			
				|  |  |      // A triple for loop. Going over polys, moduli, and decomposed index.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      for (int i = 0; i < encrypted_count; i++) {
 | 
	
	
		
			
				|  | @@ -373,19 +419,19 @@ inline void PIRServer::decompose_to_plaintexts_ptr(const Ciphertext &encrypted,
 | 
	
		
			
				|  |  |              // create a polynomial to store the current decomposition value
 | 
	
		
			
				|  |  |              // which will be copied into the array to populate it at the current
 | 
	
		
			
				|  |  |              // index.
 | 
	
		
			
				|  |  | -            int logqj = log2(params_.coeff_modulus()[j].value());
 | 
	
		
			
				|  |  | -            int expansion_ratio = ceil(logqj + exp - 1) / exp;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            // cout << "expansion ratio = " << expansion_ratio << endl;
 | 
	
		
			
				|  |  | +            double logqj = log2(params_.coeff_modulus()[j].value());
 | 
	
		
			
				|  |  | +            //int expansion_ratio = ceil(logqj + exponent - 1) / exponent;
 | 
	
		
			
				|  |  | +            int expansion_ratio =  ceil(logqj / logt); 
 | 
	
		
			
				|  |  | +            // cout << "local expansion ratio = " << expansion_ratio << endl;
 | 
	
		
			
				|  |  |              uint64_t curexp = 0;
 | 
	
		
			
				|  |  |              for (int k = 0; k < expansion_ratio; k++) {
 | 
	
		
			
				|  |  |                  // Decompose here
 | 
	
		
			
				|  |  |                  for (int m = 0; m < coeff_count; m++) {
 | 
	
		
			
				|  |  |                      plain_ptr[i * coeff_mod_count * expansion_ratio
 | 
	
		
			
				|  |  |                          + j * expansion_ratio + k][m] =
 | 
	
		
			
				|  |  | -                        (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & plainModMinusOne;
 | 
	
		
			
				|  |  | +                        (*(encrypted_pointer + m + (j * coeff_count)) >> curexp) & t1minusone;
 | 
	
		
			
				|  |  |                  }
 | 
	
		
			
				|  |  | -                curexp += exp;
 | 
	
		
			
				|  |  | +                curexp += logt;
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 |