123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346 |
- #include "pirclient.h"
- #include "utils.h"
- EC_KEY *ENCLAVE_PUBLIC_KEY = NULL;
- class ZT_LSORAMClient : public PIRClient {
- public:
- ZT_LSORAMClient();
- // Create a PIR query. The plainquery must be exactly 32 bytes
- // long.
- virtual void create(const string &plainquery, const string ¶ms,
- void *&queryid, string &pirquery);
- //Helper functions for create()
- void setupEnclavePublicKey(const string ¶ms);
- int encryptLSORAMRequest(EC_KEY* target_public_key,
- unsigned char *serialized_request, uint32_t request_size, unsigned char
- **encrypted_request, unsigned char **client_pubkey, uint32_t *pubkey_size_x,
- uint32_t *pubkey_size_y, unsigned char **ecdh_aes_key, unsigned char **iv,
- unsigned char **tag);
- // Extract the plaintext response from a PIR response. Returns
- // true if successful, false if unsuccessful.
- virtual bool extract(void *&queryid, const string &pirresponse,
- string &plainresponse);
- // Helper functions for extract()
- int decryptLSORAMResponse(unsigned char *encrypted_response,
- uint32_t response_size, unsigned char *tag, unsigned char *aes_key,
- unsigned char *iv, unsigned char **response);
- };
- ZT_LSORAMClient::ZT_LSORAMClient() {
- }
- // Put anything you'll need to decrypt the response in here
- struct Decryptstate {
- //The AES-key used to encrypt query
- string decrypt_key;
- string iv;
- };
- /*
- Inputs: a target pub key, a seriailzed request and request size.
- Outputs: instantiates and populates:
- client_pubkey, aes_key (from target_pubkey and generated client_pubkey ECDH)
- iv, encrypted request and tag for the request
- */
- int ZT_LSORAMClient::encryptLSORAMRequest(EC_KEY* target_public_key, unsigned char *serialized_request,
- uint32_t request_size, unsigned char **encrypted_request, unsigned char **client_pubkey,
- uint32_t *pubkey_size_x, uint32_t *pubkey_size_y, unsigned char **ecdh_aes_key,
- unsigned char **iv, unsigned char **tag){
- //Generate a new key
- EC_KEY *ephemeral_key = NULL;
- BIGNUM *x, *y;
- x = BN_new();
- y = BN_new();
- BN_CTX *bn_ctx = BN_CTX_new();
- const EC_GROUP *curve = NULL;
- if(NULL == (curve = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1)))
- fprintf(stderr, "ZT_LSORAMClient: Setting EC_GROUP failed \n");
- ephemeral_key = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
- if(ephemeral_key==NULL)
- fprintf(stderr, "ZT_LSORAMClient: EC_KEY_new_by_curve_name Fail\n");
- int ret = EC_KEY_generate_key(ephemeral_key);
- if(ret!=1)
- fprintf(stderr, "ZT_LSORAMClient: EC_KEY_generate_key Fail\n");
- const EC_POINT *pub_point;
- pub_point = EC_KEY_get0_public_key((const EC_KEY *) ephemeral_key);
- if(pub_point == NULL)
- fprintf(stderr, "ZT_LSORAMClient: EC_KEY_get0_public_key Fail\n");
-
- ret = EC_POINT_get_affine_coordinates_GFp(curve, pub_point, x, y, bn_ctx);
- if(ret==0)
- fprintf(stderr, "ZT_LSORAMClient: EC_POINT_get_affine_coordinates_GFp Failed \n");
-
- unsigned char *bin_x, *bin_y;
- uint32_t size_bin_x = BN_num_bytes(x);
- uint32_t size_bin_y = BN_num_bytes(y);
- printf("(%d, %d)\n", size_bin_x, size_bin_y);
- bin_x = (unsigned char*) malloc(EC_KEY_SIZE);
- bin_y = (unsigned char*) malloc(EC_KEY_SIZE);
- BN_bn2bin(x, bin_x);
- BN_bn2bin(y, bin_y);
- *pubkey_size_x = size_bin_x;
- *pubkey_size_y = size_bin_y;
- *client_pubkey = (unsigned char*) malloc(size_bin_x + size_bin_y);
- memcpy(*client_pubkey, bin_x, size_bin_x);
- memcpy(*client_pubkey + size_bin_x, bin_y, size_bin_y);
- /*
- unsigned char *ptr = *client_pubkey;
- printf("Serialized Client's Public Key in encryptLSORAM :\n");
- for(int t = 0; t < size_bin_x; t++)
- printf("%02X", ptr[t]);
- printf("\n");
- printf("Serialized Client's Public Key in encryptLSORAM :\n");
- for(int t = 0; t < size_bin_y; t++)
- printf("%02X", ptr[size_bin_x + t]);
- printf("\n");
- */
- uint32_t field_size = EC_GROUP_get_degree(EC_KEY_get0_group(target_public_key));
- uint32_t secret_len = (field_size+7)/8;
- unsigned char *secret = (unsigned char*) malloc(secret_len);
- //Returns a 32 byte secret
- secret_len = ECDH_compute_key(secret, secret_len, EC_KEY_get0_public_key(target_public_key),
- ephemeral_key, NULL);
- //Sample IV;
- *ecdh_aes_key = (unsigned char*) malloc (KEY_LENGTH);
- *iv = (unsigned char*) malloc (IV_LENGTH);
-
- memcpy(*ecdh_aes_key, secret, KEY_LENGTH);
- memcpy(*iv, secret + KEY_LENGTH, IV_LENGTH);
- /*
- unsigned char *ecdh_ptr = (unsigned char *) *ecdh_aes_key;
- unsigned char *iv_ptr = (unsigned char *) *iv;
-
- printf("KEY_LENGTH = %d\n", KEY_LENGTH);
- printf("ecdh_key computed by Client :\n");
- for(int t = 0; t < KEY_LENGTH; t++)
- printf("%02X", ecdh_ptr[t]);
- printf("\n");
- printf("iv computed by Client :\n");
- for(int t = 0; t < IV_LENGTH; t++)
- printf("%02X", iv_ptr[t]);
- printf("\n");
- */
- BN_CTX_free(bn_ctx);
- *encrypted_request = (unsigned char*) malloc (request_size);
- *tag = (unsigned char*) malloc (TAG_SIZE);
- uint32_t encrypted_request_size;
- /*
- printf("Request bytes before encrypting: \n");
- for(int t = 0; t < request_size; t++)
- printf("%02X", serialized_request[t]);
- printf("\n");
- */
- encrypted_request_size = AES_GCM_128_encrypt(serialized_request, request_size,
- NULL, 0, (unsigned char*) *ecdh_aes_key, (unsigned char*) *iv,
- IV_LENGTH, *encrypted_request, *tag);
- /*
- unsigned char*tag_ptr = *tag;
- printf("Tag bytes after encryption: \n");
- for(uint32_t t = 0; t < TAG_SIZE; t++)
- printf("%02X", tag_ptr[t]);
- printf("\n");
-
- printf("Request_size = %d, Encrypted_request_size = %d,\n", request_size, encrypted_request_size);
- printf("Request bytes after encrypting: \n");
- unsigned char *encrypted_ptr = (unsigned char*) *encrypted_request;
- for(uint32_t t = 0; t < encrypted_request_size; t++)
- printf("%02X", encrypted_ptr[t]);
- printf("\n");
- */
- return encrypted_request_size;
- }
- void ZT_LSORAMClient::setupEnclavePublicKey(const string ¶ms){
- const char *serialized_key;
- unsigned char bin_x[PRIME256V1_KEY_SIZE];
- unsigned char bin_y[PRIME256V1_KEY_SIZE];
- BIGNUM *x, *y;
- EC_GROUP *curve;
- if(NULL == (curve = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1)))
- fprintf(stderr, "ZT_LSORAMClient:Setting EC_GROUP failed \n");
- EC_POINT *pub_point = EC_POINT_new(curve);
- serialized_key = params.c_str();
- memcpy(bin_x, serialized_key, PRIME256V1_KEY_SIZE);
- memcpy(bin_y, serialized_key + PRIME256V1_KEY_SIZE, PRIME256V1_KEY_SIZE);
- //Load the Enclave Public Key
- ENCLAVE_PUBLIC_KEY = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
-
- BN_CTX *bn_ctx = BN_CTX_new();
- x = BN_bin2bn(bin_x, PRIME256V1_KEY_SIZE, NULL);
- y = BN_bin2bn(bin_y, PRIME256V1_KEY_SIZE, NULL);
- if(EC_POINT_set_affine_coordinates_GFp(curve, pub_point, x, y, bn_ctx)==0)
- fprintf(stderr, "ZT_LSORAMClient: EC_POINT_set_affine_coordinates FAILED \n");
- if(EC_KEY_set_public_key(ENCLAVE_PUBLIC_KEY, pub_point)==0)
- fprintf(stderr, "ZT_LSORAMClient: EC_KEY_set_public_key FAILED \n");
- BN_CTX_free(bn_ctx);
- }
- void
- ZT_LSORAMClient::create(const string &plainquery, const string ¶ms,
- void *&queryid, string &pirquery)
- {
- // TODO: In ZT_LSORAMClient Lookupquery should be:
- // encrypted_query||tag_in||pk_x_size||pk_y_size||client_pubkey
- // client_pubkey (of size pk_x_size+pk_y_size)
- fprintf(stderr, "ZT_LSORAMClient: Starting create() ");
- if (plainquery.length() == 32) {
- //if (plainquery.length() == 32 && params.length() == 64) {
- //Setup Enclave_Pub_key
- setupEnclavePublicKey(params);
- fprintf(stderr, "ZT_LSORAMClient Done with setupEnclavePublicKey()\n");
-
- //Encrypt the query
- unsigned char *request = (unsigned char*) plainquery.c_str();
- unsigned char *tag, *encrypted_request, *ecdh_aes_key, *iv, *client_pubkey;
- uint32_t pubkey_size_x, pubkey_size_y;
-
- fprintf(stderr, "ZT_LSORAMClient: In create, HSDesc Key = ");
- for(int i = 0;i<BLINDED_KEY_SIZE; i++){
- fprintf(stderr, "%02X", request[i]);
- }
- fprintf(stderr, "\n");
- encryptLSORAMRequest(ENCLAVE_PUBLIC_KEY,(unsigned char*) request, BLINDED_KEY_SIZE,
- &encrypted_request, &client_pubkey, &pubkey_size_x, &pubkey_size_y,
- &ecdh_aes_key, &iv, &tag);
- // Format encrypted query into pirquery as :
- // encrypted_query||tag_in||pk_x_size||pk_y_size||client_pubkey
- // client_pubkey (of size pk_x_size+pk_y_size)
- uint32_t pirquery_cstr_size = BLINDED_KEY_SIZE + TAG_SIZE
- + (2 * sizeof(uint32_t)) + pubkey_size_x + pubkey_size_y;
- unsigned char *pirquery_cstr = (unsigned char*) malloc(pirquery_cstr_size);
- unsigned char *ptr = pirquery_cstr;
- memcpy(ptr, encrypted_request, BLINDED_KEY_SIZE);
- ptr+=BLINDED_KEY_SIZE;
- memcpy(ptr, tag, TAG_SIZE);
- ptr+=TAG_SIZE;
- memcpy(ptr, &pubkey_size_x, sizeof(uint32_t));
- ptr+=sizeof(uint32_t);
- memcpy(ptr, &pubkey_size_y, sizeof(uint32_t));
- ptr+=sizeof(uint32_t);
- memcpy(ptr, client_pubkey, pubkey_size_x+pubkey_size_y);
- pirquery.assign((const char*) pirquery_cstr, pirquery_cstr_size);
- //Store shared session key into DecryptState
- Decryptstate *ds = new Decryptstate();
- fprintf(stderr, "ZT_LSORAMClient: In create, AES_KEY = ");
- for(int i = 0;i<KEY_LENGTH; i++){
- fprintf(stderr, "%02X", ecdh_aes_key[i]);
- }
- ds->decrypt_key.assign((const char*) ecdh_aes_key, KEY_LENGTH);
- ds->iv.assign((const char*) iv, IV_LENGTH);
- queryid = ds;
- //Free all buffers
- free(pirquery_cstr);
- free(tag);
- free(client_pubkey);
- free(ecdh_aes_key);
- free(iv);
- free(encrypted_request);
- }
- else{
- fprintf(stderr, "Error ZT_LSORAMClient: plainquery.length or params.length!= 32");
- }
- }
- int ZT_LSORAMClient::decryptLSORAMResponse(unsigned char *encrypted_response, uint32_t response_size,
- unsigned char *tag, unsigned char *aes_key, unsigned char *iv, unsigned char **response) {
- *response = (unsigned char*) malloc (response_size);
- AES_GCM_128_decrypt(encrypted_response, response_size, NULL, 0, tag, aes_key, iv, IV_LENGTH, *response);
- return response_size;
- }
- bool
- ZT_LSORAMClient::extract(void *&queryid, const string &pirresponse,
- string &plainresponse)
- {
- fprintf(stderr, "ZT_LSORAMClient: Starting extract()\n");
- //pirresponse = encrypted_response||tag_out
- if(pirresponse.length()!=(DESCRIPTOR_MAX_SIZE+TAG_SIZE)){
- fprintf(stderr, "ZT_LSORAMClient: pirresponse size does not match expected value"
- "(DESCRIPTOR_MAX_SIZE + TAG_SIZE)\n");
- return 0;
- }
-
- Decryptstate *ds = (Decryptstate *)queryid;
-
- unsigned char *aes_key = (unsigned char*) ds->decrypt_key.c_str();
- unsigned char *iv = (unsigned char*) ds->iv.c_str();
- fprintf(stderr, "ZT_LSORAMClient: In extract, AES_KEY = ");
- for(int i = 0;i<KEY_LENGTH; i++){
- fprintf(stderr, "%02X", aes_key[i]);
- }
- string encrypted_response = pirresponse.substr(0, DESCRIPTOR_MAX_SIZE);
- string tag_out = pirresponse.substr(DESCRIPTOR_MAX_SIZE, TAG_SIZE);
- unsigned char *response;
-
- decryptLSORAMResponse((unsigned char *)encrypted_response.c_str(),
- DESCRIPTOR_MAX_SIZE, (unsigned char*) tag_out.c_str(), aes_key, iv,
- &response);
-
- delete ds;
- queryid = NULL;
- unsigned char *value_ptr = response;
- fprintf(stderr, "ZT_LSORAMClient: In extract, after decryption");
- fprintf(stderr, "ZT_LSORAMClient: (First 32 bytes of) HSDesc Value = ");
- for(int i = 0; i <32; i++){
- fprintf(stderr, "%02X", value_ptr[i]);
- }
- fprintf(stderr,"\n");
- //Populate plainresponse with response
- plainresponse.assign((const char*) response, DESCRIPTOR_MAX_SIZE);
- free(response);
- fprintf(stderr, "ZT_LSORAMClient:Finished extract()\n");
- return true;
-
- }
- int main(int argc, char **argv) {
- ZT_LSORAMClient client;
- client.mainloop();
- return 0;
- }
|