ZT_LSORAMclient.cc 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. #include "pirclient.h"
  2. #include "utils.h"
  3. EC_KEY *ENCLAVE_PUBLIC_KEY = NULL;
  4. class ZT_LSORAMClient : public PIRClient {
  5. public:
  6. ZT_LSORAMClient();
  7. // Create a PIR query. The plainquery must be exactly 32 bytes
  8. // long.
  9. virtual void create(const string &plainquery, const string &params,
  10. void *&queryid, string &pirquery);
  11. //Helper functions for create()
  12. void setupEnclavePublicKey(const string &params);
  13. int encryptLSORAMRequest(EC_KEY* target_public_key,
  14. unsigned char *serialized_request, uint32_t request_size, unsigned char
  15. **encrypted_request, unsigned char **client_pubkey, uint32_t *pubkey_size_x,
  16. uint32_t *pubkey_size_y, unsigned char **ecdh_aes_key, unsigned char **iv,
  17. unsigned char **tag);
  18. // Extract the plaintext response from a PIR response. Returns
  19. // true if successful, false if unsuccessful.
  20. virtual bool extract(void *&queryid, const string &pirresponse,
  21. string &plainresponse);
  22. // Helper functions for extract()
  23. int decryptLSORAMResponse(unsigned char *encrypted_response,
  24. uint32_t response_size, unsigned char *tag, unsigned char *aes_key,
  25. unsigned char *iv, unsigned char **response);
  26. };
  27. ZT_LSORAMClient::ZT_LSORAMClient() {
  28. }
  29. // Put anything you'll need to decrypt the response in here
  30. struct Decryptstate {
  31. //The AES-key used to encrypt query
  32. string decrypt_key;
  33. string iv;
  34. };
  35. /*
  36. Inputs: a target pub key, a seriailzed request and request size.
  37. Outputs: instantiates and populates:
  38. client_pubkey, aes_key (from target_pubkey and generated client_pubkey ECDH)
  39. iv, encrypted request and tag for the request
  40. */
  41. int ZT_LSORAMClient::encryptLSORAMRequest(EC_KEY* target_public_key, unsigned char *serialized_request,
  42. uint32_t request_size, unsigned char **encrypted_request, unsigned char **client_pubkey,
  43. uint32_t *pubkey_size_x, uint32_t *pubkey_size_y, unsigned char **ecdh_aes_key,
  44. unsigned char **iv, unsigned char **tag){
  45. //Generate a new key
  46. EC_KEY *ephemeral_key = NULL;
  47. BIGNUM *x, *y;
  48. x = BN_new();
  49. y = BN_new();
  50. BN_CTX *bn_ctx = BN_CTX_new();
  51. const EC_GROUP *curve = NULL;
  52. if(NULL == (curve = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1)))
  53. printf("Setting EC_GROUP failed \n");
  54. ephemeral_key = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
  55. if(ephemeral_key==NULL)
  56. printf("Client: EC_KEY_new_by_curve_name Fail\n");
  57. int ret = EC_KEY_generate_key(ephemeral_key);
  58. if(ret!=1)
  59. printf("Client: EC_KEY_generate_key Fail\n");
  60. const EC_POINT *pub_point;
  61. pub_point = EC_KEY_get0_public_key((const EC_KEY *) ephemeral_key);
  62. if(pub_point == NULL)
  63. printf("Client: EC_KEY_get0_public_key Fail\n");
  64. ret = EC_POINT_get_affine_coordinates_GFp(curve, pub_point, x, y, bn_ctx);
  65. if(ret==0)
  66. printf("Client: EC_POINT_get_affine_coordinates_GFp Failed \n");
  67. unsigned char *bin_x, *bin_y;
  68. uint32_t size_bin_x = BN_num_bytes(x);
  69. uint32_t size_bin_y = BN_num_bytes(y);
  70. printf("(%d, %d)\n", size_bin_x, size_bin_y);
  71. bin_x = (unsigned char*) malloc(EC_KEY_SIZE);
  72. bin_y = (unsigned char*) malloc(EC_KEY_SIZE);
  73. BN_bn2bin(x, bin_x);
  74. BN_bn2bin(y, bin_y);
  75. *pubkey_size_x = size_bin_x;
  76. *pubkey_size_y = size_bin_y;
  77. *client_pubkey = (unsigned char*) malloc(size_bin_x + size_bin_y);
  78. memcpy(*client_pubkey, bin_x, size_bin_x);
  79. memcpy(*client_pubkey + size_bin_x, bin_y, size_bin_y);
  80. /*
  81. unsigned char *ptr = *client_pubkey;
  82. printf("Serialized Client's Public Key in encryptLSORAM :\n");
  83. for(int t = 0; t < size_bin_x; t++)
  84. printf("%02X", ptr[t]);
  85. printf("\n");
  86. printf("Serialized Client's Public Key in encryptLSORAM :\n");
  87. for(int t = 0; t < size_bin_y; t++)
  88. printf("%02X", ptr[size_bin_x + t]);
  89. printf("\n");
  90. */
  91. uint32_t field_size = EC_GROUP_get_degree(EC_KEY_get0_group(target_public_key));
  92. uint32_t secret_len = (field_size+7)/8;
  93. unsigned char *secret = (unsigned char*) malloc(secret_len);
  94. //Returns a 32 byte secret
  95. secret_len = ECDH_compute_key(secret, secret_len, EC_KEY_get0_public_key(target_public_key),
  96. ephemeral_key, NULL);
  97. //Sample IV;
  98. *ecdh_aes_key = (unsigned char*) malloc (KEY_LENGTH);
  99. *iv = (unsigned char*) malloc (IV_LENGTH);
  100. memcpy(*ecdh_aes_key, secret, KEY_LENGTH);
  101. memcpy(*iv, secret + KEY_LENGTH, IV_LENGTH);
  102. /*
  103. unsigned char *ecdh_ptr = (unsigned char *) *ecdh_aes_key;
  104. unsigned char *iv_ptr = (unsigned char *) *iv;
  105. printf("KEY_LENGTH = %d\n", KEY_LENGTH);
  106. printf("ecdh_key computed by Client :\n");
  107. for(int t = 0; t < KEY_LENGTH; t++)
  108. printf("%02X", ecdh_ptr[t]);
  109. printf("\n");
  110. printf("iv computed by Client :\n");
  111. for(int t = 0; t < IV_LENGTH; t++)
  112. printf("%02X", iv_ptr[t]);
  113. printf("\n");
  114. */
  115. BN_CTX_free(bn_ctx);
  116. *encrypted_request = (unsigned char*) malloc (request_size);
  117. *tag = (unsigned char*) malloc (TAG_SIZE);
  118. uint32_t encrypted_request_size;
  119. /*
  120. printf("Request bytes before encrypting: \n");
  121. for(int t = 0; t < request_size; t++)
  122. printf("%02X", serialized_request[t]);
  123. printf("\n");
  124. */
  125. encrypted_request_size = AES_GCM_128_encrypt(serialized_request, request_size,
  126. NULL, 0, (unsigned char*) *ecdh_aes_key, (unsigned char*) *iv,
  127. IV_LENGTH, *encrypted_request, *tag);
  128. /*
  129. unsigned char*tag_ptr = *tag;
  130. printf("Tag bytes after encryption: \n");
  131. for(uint32_t t = 0; t < TAG_SIZE; t++)
  132. printf("%02X", tag_ptr[t]);
  133. printf("\n");
  134. printf("Request_size = %d, Encrypted_request_size = %d,\n", request_size, encrypted_request_size);
  135. printf("Request bytes after encrypting: \n");
  136. unsigned char *encrypted_ptr = (unsigned char*) *encrypted_request;
  137. for(uint32_t t = 0; t < encrypted_request_size; t++)
  138. printf("%02X", encrypted_ptr[t]);
  139. printf("\n");
  140. */
  141. return encrypted_request_size;
  142. }
  143. void ZT_LSORAMClient::setupEnclavePublicKey(const string &params){
  144. const char *serialized_key;
  145. unsigned char bin_x[PRIME256V1_KEY_SIZE];
  146. unsigned char bin_y[PRIME256V1_KEY_SIZE];
  147. BIGNUM *x, *y;
  148. EC_GROUP *curve;
  149. if(NULL == (curve = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1)))
  150. printf("Setting EC_GROUP failed \n");
  151. EC_POINT *pub_point = EC_POINT_new(curve);
  152. serialized_key = params.c_str();
  153. memcpy(bin_x, serialized_key, PRIME256V1_KEY_SIZE);
  154. memcpy(bin_y, serialized_key + PRIME256V1_KEY_SIZE, PRIME256V1_KEY_SIZE);
  155. //Load the Enclave Public Key
  156. ENCLAVE_PUBLIC_KEY = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
  157. BN_CTX *bn_ctx = BN_CTX_new();
  158. x = BN_bin2bn(bin_x, PRIME256V1_KEY_SIZE, NULL);
  159. y = BN_bin2bn(bin_y, PRIME256V1_KEY_SIZE, NULL);
  160. if(EC_POINT_set_affine_coordinates_GFp(curve, pub_point, x, y, bn_ctx)==0)
  161. printf("EC_POINT_set_affine_coordinates FAILED \n");
  162. if(EC_KEY_set_public_key(ENCLAVE_PUBLIC_KEY, pub_point)==0)
  163. printf("EC_KEY_set_public_key FAILED \n");
  164. BN_CTX_free(bn_ctx);
  165. }
  166. void
  167. ZT_LSORAMClient::create(const string &plainquery, const string &params,
  168. void *&queryid, string &pirquery)
  169. {
  170. // TODO: In ZT_LSORAMClient Lookupquery should be:
  171. // encrypted_query||tag_in||pk_x_size||pk_y_size||client_pubkey
  172. // client_pubkey (of size pk_x_size+pk_y_size)
  173. if (plainquery.length() == 32 && params.length() == 32) {
  174. //Setup Enclave_Pub_key
  175. setupEnclavePublicKey(params);
  176. //Encrypt the query
  177. const char *request = plainquery.c_str();
  178. unsigned char *tag, *encrypted_request, *ecdh_aes_key, *iv, *client_pubkey;
  179. uint32_t pubkey_size_x, pubkey_size_y;
  180. encryptLSORAMRequest(ENCLAVE_PUBLIC_KEY,(unsigned char*) request, BLINDED_KEY_SIZE,
  181. &encrypted_request, &client_pubkey, &pubkey_size_x, &pubkey_size_y,
  182. &ecdh_aes_key, &iv, &tag);
  183. // Format encrypted query into pirquery as :
  184. // encrypted_query||tag_in||pk_x_size||pk_y_size||client_pubkey
  185. // client_pubkey (of size pk_x_size+pk_y_size)
  186. uint32_t pirquery_cstr_size = BLINDED_KEY_SIZE + TAG_SIZE
  187. + (2 * sizeof(uint32_t)) + pubkey_size_x + pubkey_size_y;
  188. unsigned char *pirquery_cstr = (unsigned char*) malloc(pirquery_cstr_size);
  189. unsigned char *ptr = pirquery_cstr;
  190. memcpy(ptr, encrypted_request, BLINDED_KEY_SIZE);
  191. ptr+=BLINDED_KEY_SIZE;
  192. memcpy(ptr, tag, TAG_SIZE);
  193. ptr+=TAG_SIZE;
  194. memcpy(ptr, &pubkey_size_x, sizeof(uint32_t));
  195. ptr+=sizeof(uint32_t);
  196. memcpy(ptr, &pubkey_size_y, sizeof(uint32_t));
  197. ptr+=sizeof(uint32_t);
  198. memcpy(ptr, client_pubkey, pubkey_size_x+pubkey_size_y);
  199. pirquery.assign((const char*) pirquery_cstr, pirquery_cstr_size);
  200. //Store shared session key into DecryptState
  201. Decryptstate *ds = new Decryptstate();
  202. ds->decrypt_key.assign((const char*) ecdh_aes_key, KEY_LENGTH);
  203. ds->iv.assign((const char*) iv, IV_LENGTH);
  204. queryid = ds;
  205. //Free all buffers
  206. free(pirquery_cstr);
  207. free(tag);
  208. free(client_pubkey);
  209. free(ecdh_aes_key);
  210. free(iv);
  211. free(encrypted_request);
  212. }
  213. }
  214. int ZT_LSORAMClient::decryptLSORAMResponse(unsigned char *encrypted_response, uint32_t response_size,
  215. unsigned char *tag, unsigned char *aes_key, unsigned char *iv, unsigned char **response) {
  216. *response = (unsigned char*) malloc (response_size);
  217. AES_GCM_128_decrypt(encrypted_response, response_size, NULL, 0, tag, aes_key, iv, IV_LENGTH, *response);
  218. return response_size;
  219. }
  220. bool
  221. ZT_LSORAMClient::extract(void *&queryid, const string &pirresponse,
  222. string &plainresponse)
  223. {
  224. //pirresponse = encrypted_response||tag_out
  225. if(pirresponse.length()!=(DESCRIPTOR_MAX_SIZE+TAG_SIZE)){
  226. printf("pirresponse size does not match expected value"
  227. "(DESCRIPTOR_MAX_SIZE + TAG_SIZE)\n");
  228. return 0;
  229. }
  230. Decryptstate *ds = (Decryptstate *)queryid;
  231. unsigned char *aes_key = (unsigned char*) ds->decrypt_key.c_str();
  232. unsigned char *iv = (unsigned char*) ds->iv.c_str();
  233. string encrypted_response = pirresponse.substr(0, DESCRIPTOR_MAX_SIZE);
  234. string tag_out = pirresponse.substr(DESCRIPTOR_MAX_SIZE, TAG_SIZE);
  235. unsigned char *response;
  236. decryptLSORAMResponse((unsigned char *)encrypted_response.c_str(),
  237. DESCRIPTOR_MAX_SIZE, (unsigned char*) tag_out.c_str(), aes_key, iv,
  238. &response);
  239. delete ds;
  240. queryid = NULL;
  241. //Populate plainresponse with response
  242. plainresponse.assign((const char*) response, DESCRIPTOR_MAX_SIZE);
  243. free(response);
  244. return true;
  245. }
  246. int main(int argc, char **argv) {
  247. ZT_LSORAMClient client;
  248. client.mainloop();
  249. return 0;
  250. }