ZT_LSORAMclient.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. #include "pirclient.h"
  2. #include "utils.h"
  3. EC_KEY *ENCLAVE_PUBLIC_KEY = NULL;
  4. class ZT_LSORAMClient : public PIRClient {
  5. public:
  6. ZT_LSORAMClient();
  7. // Create a PIR query. The plainquery must be exactly 32 bytes
  8. // long.
  9. virtual void create(const string &plainquery, const string &params,
  10. void *&queryid, string &pirquery);
  11. //Helper functions for create()
  12. void setupEnclavePublicKey(const string &params);
  13. int encryptLSORAMRequest(EC_KEY* target_public_key,
  14. unsigned char *serialized_request, uint32_t request_size, unsigned char
  15. **encrypted_request, unsigned char **client_pubkey, uint32_t *pubkey_size_x,
  16. uint32_t *pubkey_size_y, unsigned char **ecdh_aes_key, unsigned char **iv,
  17. unsigned char **tag);
  18. // Extract the plaintext response from a PIR response. Returns
  19. // true if successful, false if unsuccessful.
  20. virtual bool extract(void *&queryid, const string &pirresponse,
  21. string &plainresponse);
  22. // Helper functions for extract()
  23. int decryptLSORAMResponse(unsigned char *encrypted_response,
  24. uint32_t response_size, unsigned char *tag, unsigned char *aes_key,
  25. unsigned char *iv, unsigned char **response);
  26. };
  27. ZT_LSORAMClient::ZT_LSORAMClient() {
  28. }
  29. // Put anything you'll need to decrypt the response in here
  30. struct Decryptstate {
  31. //The AES-key used to encrypt query
  32. string decrypt_key;
  33. string iv;
  34. };
  35. /*
  36. Inputs: a target pub key, a seriailzed request and request size.
  37. Outputs: instantiates and populates:
  38. client_pubkey, aes_key (from target_pubkey and generated client_pubkey ECDH)
  39. iv, encrypted request and tag for the request
  40. */
  41. int ZT_LSORAMClient::encryptLSORAMRequest(EC_KEY* target_public_key, unsigned char *serialized_request,
  42. uint32_t request_size, unsigned char **encrypted_request, unsigned char **client_pubkey,
  43. uint32_t *pubkey_size_x, uint32_t *pubkey_size_y, unsigned char **ecdh_aes_key,
  44. unsigned char **iv, unsigned char **tag){
  45. //Generate a new key
  46. EC_KEY *ephemeral_key = NULL;
  47. BIGNUM *x, *y;
  48. x = BN_new();
  49. y = BN_new();
  50. BN_CTX *bn_ctx = BN_CTX_new();
  51. const EC_GROUP *curve = NULL;
  52. if(NULL == (curve = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1)))
  53. fprintf(stderr, "ZT_LSORAMClient: Setting EC_GROUP failed \n");
  54. ephemeral_key = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
  55. if(ephemeral_key==NULL)
  56. fprintf(stderr, "ZT_LSORAMClient: EC_KEY_new_by_curve_name Fail\n");
  57. int ret = EC_KEY_generate_key(ephemeral_key);
  58. if(ret!=1)
  59. fprintf(stderr, "ZT_LSORAMClient: EC_KEY_generate_key Fail\n");
  60. const EC_POINT *pub_point;
  61. pub_point = EC_KEY_get0_public_key((const EC_KEY *) ephemeral_key);
  62. if(pub_point == NULL)
  63. fprintf(stderr, "ZT_LSORAMClient: EC_KEY_get0_public_key Fail\n");
  64. ret = EC_POINT_get_affine_coordinates_GFp(curve, pub_point, x, y, bn_ctx);
  65. if(ret==0)
  66. fprintf(stderr, "ZT_LSORAMClient: EC_POINT_get_affine_coordinates_GFp Failed \n");
  67. unsigned char *bin_x, *bin_y;
  68. uint32_t size_bin_x = BN_num_bytes(x);
  69. uint32_t size_bin_y = BN_num_bytes(y);
  70. printf("(%d, %d)\n", size_bin_x, size_bin_y);
  71. bin_x = (unsigned char*) malloc(EC_KEY_SIZE);
  72. bin_y = (unsigned char*) malloc(EC_KEY_SIZE);
  73. BN_bn2bin(x, bin_x);
  74. BN_bn2bin(y, bin_y);
  75. *pubkey_size_x = size_bin_x;
  76. *pubkey_size_y = size_bin_y;
  77. *client_pubkey = (unsigned char*) malloc(size_bin_x + size_bin_y);
  78. memcpy(*client_pubkey, bin_x, size_bin_x);
  79. memcpy(*client_pubkey + size_bin_x, bin_y, size_bin_y);
  80. /*
  81. unsigned char *ptr = *client_pubkey;
  82. printf("Serialized Client's Public Key in encryptLSORAM :\n");
  83. for(int t = 0; t < size_bin_x; t++)
  84. printf("%02X", ptr[t]);
  85. printf("\n");
  86. printf("Serialized Client's Public Key in encryptLSORAM :\n");
  87. for(int t = 0; t < size_bin_y; t++)
  88. printf("%02X", ptr[size_bin_x + t]);
  89. printf("\n");
  90. */
  91. uint32_t field_size = EC_GROUP_get_degree(EC_KEY_get0_group(target_public_key));
  92. uint32_t secret_len = (field_size+7)/8;
  93. unsigned char *secret = (unsigned char*) malloc(secret_len);
  94. //Returns a 32 byte secret
  95. secret_len = ECDH_compute_key(secret, secret_len, EC_KEY_get0_public_key(target_public_key),
  96. ephemeral_key, NULL);
  97. //Sample IV;
  98. *ecdh_aes_key = (unsigned char*) malloc (KEY_LENGTH);
  99. *iv = (unsigned char*) malloc (IV_LENGTH);
  100. memcpy(*ecdh_aes_key, secret, KEY_LENGTH);
  101. memcpy(*iv, secret + KEY_LENGTH, IV_LENGTH);
  102. /*
  103. unsigned char *ecdh_ptr = (unsigned char *) *ecdh_aes_key;
  104. unsigned char *iv_ptr = (unsigned char *) *iv;
  105. printf("KEY_LENGTH = %d\n", KEY_LENGTH);
  106. printf("ecdh_key computed by Client :\n");
  107. for(int t = 0; t < KEY_LENGTH; t++)
  108. printf("%02X", ecdh_ptr[t]);
  109. printf("\n");
  110. printf("iv computed by Client :\n");
  111. for(int t = 0; t < IV_LENGTH; t++)
  112. printf("%02X", iv_ptr[t]);
  113. printf("\n");
  114. */
  115. BN_CTX_free(bn_ctx);
  116. *encrypted_request = (unsigned char*) malloc (request_size);
  117. *tag = (unsigned char*) malloc (TAG_SIZE);
  118. uint32_t encrypted_request_size;
  119. /*
  120. printf("Request bytes before encrypting: \n");
  121. for(int t = 0; t < request_size; t++)
  122. printf("%02X", serialized_request[t]);
  123. printf("\n");
  124. */
  125. encrypted_request_size = AES_GCM_128_encrypt(serialized_request, request_size,
  126. NULL, 0, (unsigned char*) *ecdh_aes_key, (unsigned char*) *iv,
  127. IV_LENGTH, *encrypted_request, *tag);
  128. /*
  129. unsigned char*tag_ptr = *tag;
  130. printf("Tag bytes after encryption: \n");
  131. for(uint32_t t = 0; t < TAG_SIZE; t++)
  132. printf("%02X", tag_ptr[t]);
  133. printf("\n");
  134. printf("Request_size = %d, Encrypted_request_size = %d,\n", request_size, encrypted_request_size);
  135. printf("Request bytes after encrypting: \n");
  136. unsigned char *encrypted_ptr = (unsigned char*) *encrypted_request;
  137. for(uint32_t t = 0; t < encrypted_request_size; t++)
  138. printf("%02X", encrypted_ptr[t]);
  139. printf("\n");
  140. */
  141. return encrypted_request_size;
  142. }
  143. void ZT_LSORAMClient::setupEnclavePublicKey(const string &params){
  144. const char *serialized_key;
  145. unsigned char bin_x[PRIME256V1_KEY_SIZE];
  146. unsigned char bin_y[PRIME256V1_KEY_SIZE];
  147. BIGNUM *x, *y;
  148. EC_GROUP *curve;
  149. if(NULL == (curve = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1)))
  150. fprintf(stderr, "ZT_LSORAMClient:Setting EC_GROUP failed \n");
  151. EC_POINT *pub_point = EC_POINT_new(curve);
  152. serialized_key = params.c_str();
  153. memcpy(bin_x, serialized_key, PRIME256V1_KEY_SIZE);
  154. memcpy(bin_y, serialized_key + PRIME256V1_KEY_SIZE, PRIME256V1_KEY_SIZE);
  155. //Load the Enclave Public Key
  156. ENCLAVE_PUBLIC_KEY = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
  157. BN_CTX *bn_ctx = BN_CTX_new();
  158. x = BN_bin2bn(bin_x, PRIME256V1_KEY_SIZE, NULL);
  159. y = BN_bin2bn(bin_y, PRIME256V1_KEY_SIZE, NULL);
  160. if(EC_POINT_set_affine_coordinates_GFp(curve, pub_point, x, y, bn_ctx)==0)
  161. fprintf(stderr, "ZT_LSORAMClient: EC_POINT_set_affine_coordinates FAILED \n");
  162. if(EC_KEY_set_public_key(ENCLAVE_PUBLIC_KEY, pub_point)==0)
  163. fprintf(stderr, "ZT_LSORAMClient: EC_KEY_set_public_key FAILED \n");
  164. BN_CTX_free(bn_ctx);
  165. }
  166. void
  167. ZT_LSORAMClient::create(const string &plainquery, const string &params,
  168. void *&queryid, string &pirquery)
  169. {
  170. // TODO: In ZT_LSORAMClient Lookupquery should be:
  171. // encrypted_query||tag_in||pk_x_size||pk_y_size||client_pubkey
  172. // client_pubkey (of size pk_x_size+pk_y_size)
  173. fprintf(stderr, "ZT_LSORAMClient: Starting create() ");
  174. if (plainquery.length() == 32) {
  175. //if (plainquery.length() == 32 && params.length() == 64) {
  176. //Setup Enclave_Pub_key
  177. setupEnclavePublicKey(params);
  178. fprintf(stderr, "ZT_LSORAMClient Done with setupEnclavePublicKey()\n");
  179. //Encrypt the query
  180. unsigned char *request = (unsigned char*) plainquery.c_str();
  181. unsigned char *tag, *encrypted_request, *ecdh_aes_key, *iv, *client_pubkey;
  182. uint32_t pubkey_size_x, pubkey_size_y;
  183. fprintf(stderr, "ZT_LSORAMClient: In create, HSDesc Key = ");
  184. for(int i = 0;i<BLINDED_KEY_SIZE; i++){
  185. fprintf(stderr, "%02X", request[i]);
  186. }
  187. fprintf(stderr, "\n");
  188. encryptLSORAMRequest(ENCLAVE_PUBLIC_KEY,(unsigned char*) request, BLINDED_KEY_SIZE,
  189. &encrypted_request, &client_pubkey, &pubkey_size_x, &pubkey_size_y,
  190. &ecdh_aes_key, &iv, &tag);
  191. // Format encrypted query into pirquery as :
  192. // encrypted_query||tag_in||pk_x_size||pk_y_size||client_pubkey
  193. // client_pubkey (of size pk_x_size+pk_y_size)
  194. uint32_t pirquery_cstr_size = BLINDED_KEY_SIZE + TAG_SIZE
  195. + (2 * sizeof(uint32_t)) + pubkey_size_x + pubkey_size_y;
  196. unsigned char *pirquery_cstr = (unsigned char*) malloc(pirquery_cstr_size);
  197. unsigned char *ptr = pirquery_cstr;
  198. memcpy(ptr, encrypted_request, BLINDED_KEY_SIZE);
  199. ptr+=BLINDED_KEY_SIZE;
  200. memcpy(ptr, tag, TAG_SIZE);
  201. ptr+=TAG_SIZE;
  202. memcpy(ptr, &pubkey_size_x, sizeof(uint32_t));
  203. ptr+=sizeof(uint32_t);
  204. memcpy(ptr, &pubkey_size_y, sizeof(uint32_t));
  205. ptr+=sizeof(uint32_t);
  206. memcpy(ptr, client_pubkey, pubkey_size_x+pubkey_size_y);
  207. pirquery.assign((const char*) pirquery_cstr, pirquery_cstr_size);
  208. //Store shared session key into DecryptState
  209. Decryptstate *ds = new Decryptstate();
  210. fprintf(stderr, "ZT_LSORAMClient: In create, AES_KEY = ");
  211. for(int i = 0;i<KEY_LENGTH; i++){
  212. fprintf(stderr, "%02X", ecdh_aes_key[i]);
  213. }
  214. ds->decrypt_key.assign((const char*) ecdh_aes_key, KEY_LENGTH);
  215. ds->iv.assign((const char*) iv, IV_LENGTH);
  216. queryid = ds;
  217. //Free all buffers
  218. free(pirquery_cstr);
  219. free(tag);
  220. free(client_pubkey);
  221. free(ecdh_aes_key);
  222. free(iv);
  223. free(encrypted_request);
  224. }
  225. else{
  226. fprintf(stderr, "Error ZT_LSORAMClient: plainquery.length or params.length!= 32");
  227. }
  228. }
  229. int ZT_LSORAMClient::decryptLSORAMResponse(unsigned char *encrypted_response, uint32_t response_size,
  230. unsigned char *tag, unsigned char *aes_key, unsigned char *iv, unsigned char **response) {
  231. *response = (unsigned char*) malloc (response_size);
  232. AES_GCM_128_decrypt(encrypted_response, response_size, NULL, 0, tag, aes_key, iv, IV_LENGTH, *response);
  233. return response_size;
  234. }
  235. bool
  236. ZT_LSORAMClient::extract(void *&queryid, const string &pirresponse,
  237. string &plainresponse)
  238. {
  239. fprintf(stderr, "ZT_LSORAMClient: Starting extract()\n");
  240. //pirresponse = encrypted_response||tag_out
  241. if(pirresponse.length()!=(DESCRIPTOR_MAX_SIZE+TAG_SIZE)){
  242. fprintf(stderr, "ZT_LSORAMClient: pirresponse size does not match expected value"
  243. "(DESCRIPTOR_MAX_SIZE + TAG_SIZE)\n");
  244. return 0;
  245. }
  246. Decryptstate *ds = (Decryptstate *)queryid;
  247. unsigned char *aes_key = (unsigned char*) ds->decrypt_key.c_str();
  248. unsigned char *iv = (unsigned char*) ds->iv.c_str();
  249. fprintf(stderr, "ZT_LSORAMClient: In extract, AES_KEY = ");
  250. for(int i = 0;i<KEY_LENGTH; i++){
  251. fprintf(stderr, "%02X", aes_key[i]);
  252. }
  253. string encrypted_response = pirresponse.substr(0, DESCRIPTOR_MAX_SIZE);
  254. string tag_out = pirresponse.substr(DESCRIPTOR_MAX_SIZE, TAG_SIZE);
  255. unsigned char *response;
  256. decryptLSORAMResponse((unsigned char *)encrypted_response.c_str(),
  257. DESCRIPTOR_MAX_SIZE, (unsigned char*) tag_out.c_str(), aes_key, iv,
  258. &response);
  259. delete ds;
  260. queryid = NULL;
  261. unsigned char *value_ptr = response;
  262. fprintf(stderr, "ZT_LSORAMClient: In extract, after decryption");
  263. fprintf(stderr, "ZT_LSORAMClient: (First 32 bytes of) HSDesc Value = ");
  264. for(int i = 0; i <32; i++){
  265. fprintf(stderr, "%02X", value_ptr[i]);
  266. }
  267. fprintf(stderr,"\n");
  268. //Populate plainresponse with response
  269. plainresponse.assign((const char*) response, DESCRIPTOR_MAX_SIZE);
  270. free(response);
  271. fprintf(stderr, "ZT_LSORAMClient:Finished extract()\n");
  272. return true;
  273. }
  274. int main(int argc, char **argv) {
  275. ZT_LSORAMClient client;
  276. client.mainloop();
  277. return 0;
  278. }