Ver código fonte

Handshake complete, and encrypted data can be exchanged

Ian Goldberg 1 ano atrás
pai
commit
85b51f0e57
1 arquivos alterados com 80 adições e 20 exclusões
  1. 80 20
      Enclave/comms.cpp

+ 80 - 20
Enclave/comms.cpp

@@ -1,7 +1,6 @@
 #include <vector>
 #include <functional>
 #include <cstring>
-#include <stdio.h>
 
 #include "sgx_tcrypto.h"
 #include "sgx_tseal.h"
@@ -31,6 +30,9 @@ struct NodeCommState {
     sgx_ec256_private_t handshake_dh_privkey;
     sgx_ec256_public_t handshake_dh_pubkey;
 
+    // The server keeps this state between handshake messages 1 and 3
+    uint8_t handshake_cli_srv_mac[16];
+
     // The outgoing and incoming AES keys after the handshake
     sgx_aes_gcm_128bit_key_t out_aes_key, in_aes_key;
 
@@ -99,9 +101,9 @@ struct NodeCommState {
         memmove(&pubkey, conf_pubkey, sizeof(pubkey));
     }
 
-    void message_start(uint32_t plaintext_len);
+    void message_start(uint32_t plaintext_len, bool encrypt=true);
 
-    void message_data(uint8_t *data, uint32_t len);
+    void message_data(uint8_t *data, uint32_t len, bool encrypt=true);
 
     // Start the handshake (as the client)
     void handshake_start();
@@ -134,7 +136,8 @@ static uint8_t* default_in_msg_get_buf(NodeCommState &commst,
 static void default_in_msg_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t)
 {
-    printf("Received handshake_3 message of %u bytes:\n", plaintext_len);
+    printf("Received message of %u bytes from node %lu:\n",
+        plaintext_len, nodest.node_num);
     for (uint32_t i=0;i<plaintext_len;++i) {
         printf("%02x", data[i]);
     }
@@ -247,15 +250,23 @@ static void handshake_1_msg_received(NodeCommState &nodest,
 
     sgx_ecc256_close_context(ecc_handle);
 
+    // Save the state we'll need to process handshake message 3
+    memmove(&nodest.in_aes_key, h2, 16);
+    memmove(&nodest.out_aes_key, ((uint8_t*)h2)+16, 16);
+    memmove(&nodest.handshake_cli_srv_mac, cli_srv_mac, 16);
+
     // Get us ready to receive handshake message 3
     nodest.in_msg_get_buf = default_in_msg_get_buf;
     nodest.in_msg_received = handshake_3_msg_received;
     nodest.handshake_step = HANDSHAKE_S_SENT_2;
 
     // Send handshake message 2
-    nodest.message_start(sizeof(our_dh_pubkey) + sizeof(srv_cli_sig));
-    nodest.message_data((uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey));
-    nodest.message_data((uint8_t*)&srv_cli_sig, sizeof(srv_cli_sig));
+    nodest.message_start(sizeof(our_dh_pubkey) + sizeof(srv_cli_sig),
+        false);
+    nodest.message_data((uint8_t*)&our_dh_pubkey, sizeof(our_dh_pubkey),
+        false);
+    nodest.message_data((uint8_t*)&srv_cli_sig, sizeof(srv_cli_sig),
+        false);
 }
 
 // Receive (at the client) the secong handshake message
@@ -367,36 +378,82 @@ static void handshake_2_msg_received(NodeCommState &nodest,
     // Our side of the handshake is complete
     memmove(&nodest.out_aes_key, h2, 16);
     memmove(&nodest.in_aes_key, ((uint8_t*)h2)+16, 16);
+    memset(&nodest.out_aes_iv, 0, SGX_AESGCM_IV_SIZE);
+    memset(&nodest.in_aes_iv, 0, SGX_AESGCM_IV_SIZE);
     nodest.handshake_step = HANDSHAKE_COMPLETE;
     nodest.in_msg_get_buf = default_in_msg_get_buf;
     nodest.in_msg_received = default_in_msg_received;
 
     // Send handshake message 3
-    nodest.message_start(sizeof(cli_srv_sig));
-    nodest.message_data((uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig));
+    nodest.message_start(sizeof(cli_srv_sig), false);
+    nodest.message_data((uint8_t*)&cli_srv_sig, sizeof(cli_srv_sig),
+        false);
+
+    // Send a test message
+    nodest.message_start(12);
+    unsigned char buf[13];
+    memmove(buf, "Hello, world", 13);
+    nodest.message_data(buf, 12);
 }
 
 static void handshake_3_msg_received(NodeCommState &nodest,
     uint8_t *data, uint32_t plaintext_len, uint32_t)
 {
+    /*
     printf("Received handshake_3 message of %u bytes:\n", plaintext_len);
     for (uint32_t i=0;i<plaintext_len;++i) {
         printf("%02x", data[i]);
     }
     printf("\n");
+    */
 
+    if (plaintext_len != sizeof(sgx_ec256_signature_t)) {
+        printf("Received handshake_3 message of incorrect size %u\n",
+            plaintext_len);
+        return;
+    }
+    sgx_ecc_state_handle_t ecc_handle;
+    sgx_ec256_signature_t peer_sig;
+    memmove(&peer_sig, data, sizeof(peer_sig));
     delete[] data;
+    sgx_ecc256_open_context(&ecc_handle);
+
+    // Verify the signature on the client-to-server MAC
+    uint8_t result;
+    if (sgx_ecdsa_verify(nodest.handshake_cli_srv_mac, 16,
+            &nodest.pubkey, &peer_sig, &result, ecc_handle)
+            || result != SGX_EC_VALID) {
+        printf("Invalid signature received from node %hu\n",
+            nodest.node_num);
+        sgx_ecc256_close_context(ecc_handle);
+        return;
+    }
+
+    printf("Valid signature received from node %hu\n", nodest.node_num);
+
+    // Our side of the handshake is complete
+    memset(&nodest.out_aes_iv, 0, SGX_AESGCM_IV_SIZE);
+    memset(&nodest.in_aes_iv, 0, SGX_AESGCM_IV_SIZE);
+    nodest.handshake_step = HANDSHAKE_COMPLETE;
+    nodest.in_msg_get_buf = default_in_msg_get_buf;
+    nodest.in_msg_received = default_in_msg_received;
+
+    // Send a test message
+    nodest.message_start(12);
+    unsigned char buf[13];
+    memmove(buf, "Hello, world", 13);
+    nodest.message_data(buf, 12);
 }
 
 // Start a new outgoing message.  Pass the number of _plaintext_ bytes
 // the message will be.
-void NodeCommState::message_start(uint32_t plaintext_len)
+void NodeCommState::message_start(uint32_t plaintext_len, bool encrypt)
 {
     uint32_t ciphertext_len = plaintext_len;
 
     // If the handshake is complete, add SGX_AESGCM_MAC_SIZE bytes for
     // every FRAME_SIZE-SGX_AESGCM_MAC_SIZE bytes of plaintext.
-    if (handshake_step == HANDSHAKE_COMPLETE) {
+    if (encrypt) {
         uint32_t num_chunks = (plaintext_len +
             FRAME_SIZE - SGX_AESGCM_MAC_SIZE - 1) /
             (FRAME_SIZE - SGX_AESGCM_MAC_SIZE);
@@ -418,14 +475,16 @@ void NodeCommState::message_start(uint32_t plaintext_len)
         printf("Received NULL back from ocall_message\n");
     }
     if (msg_plaintext_chunk_remain > 0) {
-        *(size_t*)out_aes_iv += 1;
-        sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv, SGX_AESGCM_IV_SIZE,
-            NULL, 0, &out_aes_gcm_state);
+        if (encrypt) {
+            *(size_t*)out_aes_iv += 1;
+            sgx_aes_gcm128_enc_init(out_aes_key, out_aes_iv,
+                SGX_AESGCM_IV_SIZE, NULL, 0, &out_aes_gcm_state);
+        }
     }
 }
 
 // Process len bytes of plaintext data into the current message.
-void NodeCommState::message_data(uint8_t *data, uint32_t len)
+void NodeCommState::message_data(uint8_t *data, uint32_t len, bool encrypt)
 {
     while (len > 0) {
         if (msg_plaintext_chunk_remain == 0) {
@@ -440,7 +499,7 @@ void NodeCommState::message_data(uint8_t *data, uint32_t len)
             printf("frame is NULL when queueing message data\n");
             return;
         }
-        if (handshake_step == HANDSHAKE_COMPLETE) {
+        if (encrypt) {
             // Encrypt the data
             sgx_aes_gcm128_enc_update(data, bytes_to_process,
                 frame+frame_offset, out_aes_gcm_state);
@@ -455,7 +514,7 @@ void NodeCommState::message_data(uint8_t *data, uint32_t len)
         data += bytes_to_process;
         if (msg_plaintext_chunk_remain == 0) {
             // Complete and send this chunk
-            if (handshake_step == HANDSHAKE_COMPLETE) {
+            if (encrypt) {
                 sgx_aes_gcm128_enc_get_mac(frame+frame_offset,
                     out_aes_gcm_state);
                 frame_offset += SGX_AESGCM_MAC_SIZE;
@@ -472,7 +531,7 @@ void NodeCommState::message_data(uint8_t *data, uint32_t len)
                 msg_plaintext_chunk_remain =
                     FRAME_SIZE - SGX_AESGCM_MAC_SIZE;
             }
-            if (handshake_step == HANDSHAKE_COMPLETE) {
+            if (encrypt) {
                 sgx_aes_gcm_close(out_aes_gcm_state);
                 if (msg_plaintext_chunk_remain > 0) {
                     *(size_t*)out_aes_iv += 1;
@@ -715,9 +774,10 @@ void NodeCommState::handshake_start()
     handshake_step = HANDSHAKE_C_SENT_1;
 
     // Send the public key as the first message
-    message_start(sizeof(handshake_dh_pubkey));
+    message_start(sizeof(handshake_dh_pubkey), false);
 
-    message_data((uint8_t*)&handshake_dh_pubkey, sizeof(handshake_dh_pubkey));
+    message_data((uint8_t*)&handshake_dh_pubkey,
+        sizeof(handshake_dh_pubkey), false);
 }
 
 // Start all handshakes for which we are the client