Browse Source

fixed race condition with flow access

cecylia 7 years ago
parent
commit
599ac29ef1
5 changed files with 140 additions and 101 deletions
  1. 28 16
      relay_station/crypto.c
  2. 31 2
      relay_station/flow.c
  3. 3 0
      relay_station/flow.h
  4. 74 80
      relay_station/slitheen-proxy.c
  5. 4 3
      relay_station/slitheen.h

+ 28 - 16
relay_station/crypto.c

@@ -582,7 +582,6 @@ int verify_finish_hash(flow *f, uint8_t *hs, int32_t incoming){
 
 	if(update_finish_hash(f, old_finished)){
 		fprintf(stderr, "Error updating finish hash with FINISHED msg\n");
-		remove_flow(f);
 		goto err;
 	}
 
@@ -1450,27 +1449,40 @@ void check_handshake(struct packet_info *info){
 			printf(")\n");
 #endif
 
-			/* Save flow in table */
-			flow *flow_ptr = add_flow(info);
+			/* If flow is not in table, save it */
+			flow *flow_ptr = check_flow(info);
 			if(flow_ptr == NULL){
-				fprintf(stderr, "Memory failure\n");
-				return;
-			}
+				flow_ptr = add_flow(info);
+				if(flow_ptr == NULL){
+					fprintf(stderr, "Memory failure\n");
+					return;
+				}
 
-			for(int i=0; i<16; i++){
-				flow_ptr->key[i] = key[i];
-			}
+				for(int i=0; i<16; i++){
+					flow_ptr->key[i] = key[i];
+				}
 
-			memcpy(flow_ptr->client_random, hello_rand, SSL3_RANDOM_SIZE);
+				memcpy(flow_ptr->client_random, hello_rand, SSL3_RANDOM_SIZE);
 #ifdef DEBUG
-			for(int i=0; i< SSL3_RANDOM_SIZE; i++){
-				printf("%02x ", hello_rand[i]);
-			}
-			printf("\n");
-			
-			printf("Saved new flow\n");
+				for(int i=0; i< SSL3_RANDOM_SIZE; i++){
+					printf("%02x ", hello_rand[i]);
+				}
+				printf("\n");
+				
+				printf("Saved new flow\n");
 #endif
 
+				flow_ptr->ref_ctr--;
+
+			} else { /* else update saved flow with new key and random nonce */
+				for(int i=0; i<16; i++){
+					flow_ptr->key[i] = key[i];
+				}
+
+				memcpy(flow_ptr->client_random, hello_rand, SSL3_RANDOM_SIZE);
+				flow_ptr->ref_ctr--;
+			}
+
 		}
 	}
 }

+ 31 - 2
relay_station/flow.c

@@ -68,6 +68,9 @@ flow *add_flow(struct packet_info *info) {
 	new_flow->src_port = info->tcp_hdr->src_port;
 	new_flow->dst_port = info->tcp_hdr->dst_port;
 
+	new_flow->ref_ctr = 1;
+	new_flow->removed = 0;
+
 	new_flow->upstream_app_data = emalloc(sizeof(app_data_queue));
 	new_flow->upstream_app_data->first_packet = NULL;
 	new_flow->downstream_app_data = emalloc(sizeof(app_data_queue));
@@ -350,7 +353,11 @@ int update_flow(flow *f, uint8_t *record, uint8_t incoming) {
 #ifdef DEBUG_HS
 					printf("Received finished (%d) (%x:%d -> %x:%d)\n", incoming, f->src_ip.s_addr, ntohs(f->src_port), f->dst_ip.s_addr, ntohs(f->dst_port));
 #endif
-					verify_finish_hash(f,p, incoming);
+					if(verify_finish_hash(f,p, incoming)){
+						fprintf(stderr, "Error verifying finished hash\n");
+						remove_flow(f);
+						goto err;
+					}
 					
 					//re-encrypt finished message
 					int32_t n =  encrypt(f, record+RECORD_HEADER_LEN, record+RECORD_HEADER_LEN, record_len - (RECORD_HEADER_LEN+16), incoming, 0x16, 1, 1);
@@ -450,6 +457,19 @@ err:
  */
 int remove_flow(flow *f) {
 
+	sem_wait(&flow_table_lock);
+	//decrement reference counter
+	f->ref_ctr--;
+	if(f->ref_ctr){ //if there are still references to f, wait to free it
+		printf("Cannot free, still %d reference(s)\n", f->ref_ctr);
+		f->removed = 1; 
+		sem_post(&flow_table_lock);
+		return 0;
+	}
+
+	if(f->removed)
+		printf("Trying again to free\n");
+
 	//Empty application data queues
 	packet *tmp = f->upstream_app_data->first_packet;
 	while(tmp != NULL){
@@ -540,7 +560,6 @@ int remove_flow(flow *f) {
 		}
 	}
 
-	sem_wait(&flow_table_lock);
 	flow_entry *entry = table->first_entry;
 	if(entry->f == f){
 		table->first_entry = entry->next;
@@ -635,8 +654,18 @@ flow *check_flow(struct packet_info *info){
 		}
 		entry = entry->next;
 	}
+
+	if(found != NULL){
+		found->ref_ctr++;
+	}
+
 	sem_post(&flow_table_lock);
 
+	if(found != NULL && found->removed){
+		remove_flow(found);
+		found=NULL;
+	}
+
 	return found;
 }
 

+ 3 - 0
relay_station/flow.h

@@ -86,6 +86,9 @@ typedef struct session_cache_st {
 typedef struct flow_st {
 	sem_t flow_lock;
 
+	uint32_t ref_ctr;
+	uint8_t removed;
+
 	struct in_addr src_ip, dst_ip; /* Source (client) and Destination (server) addresses */
 	uint16_t src_port, dst_port;	/* Source and Destination ports */
 

+ 74 - 80
relay_station/slitheen-proxy.c

@@ -43,8 +43,8 @@ int main(int argc, char *argv[]){
 	dev1 = argv[1];
 	dev2 = argv[2];
 
-	snprintf(filter1, 33, "ether src host %s", macaddr);
-	snprintf(filter2, 33, "ether dst host %s", macaddr);
+	snprintf(filter1, 33, "ether src host %s", macaddr1);
+	snprintf(filter2, 33, "ether src host %s", macaddr2);
 
 	if(init_tables()){
 		exit(1);
@@ -342,8 +342,6 @@ void process_packet(struct packet_info *info){
 
 		if(data_to_process){
 
-			uint8_t removed = 0;
-
 			if(p != info->app_data){
 				printf("UH OH something weird might happen\n");
 			}
@@ -354,7 +352,7 @@ void process_packet(struct packet_info *info){
 
 				/* Pass data to packet chain */
 				if(add_packet(observed, info)){//removed_flow
-					removed = 1;
+					return;
 				}
 			}
 
@@ -362,112 +360,108 @@ void process_packet(struct packet_info *info){
 			if(info->tcp_hdr->flags & (FIN | RST) ){
 				/* Remove flow from table, connection ended */
 				remove_flow(observed);
-			} else {
-				/* add packet to application data queue */
+				return;
+			}
+			/* add packet to application data queue */
 
-				//check if flow was removed
-				if(removed){
-					return;
-				}
+			//add new app block
+			packet *new_block = ecalloc(1, sizeof(packet));
+			new_block->seq_num = seq_num;
+			new_block->data = ecalloc(1, info->app_data_len);
+			memcpy(new_block->data, info->app_data, info->app_data_len);
+			new_block->len = info->app_data_len;
+			new_block->next = NULL;
+			new_block->expiration = 0;
 
-				//add new app block
-				packet *new_block = ecalloc(1, sizeof(packet));
-				new_block->seq_num = seq_num;
-				new_block->data = ecalloc(1, info->app_data_len);
-				memcpy(new_block->data, info->app_data, info->app_data_len);
-				new_block->len = info->app_data_len;
-				new_block->next = NULL;
-				new_block->expiration = 0;
+			packet *saved_data = (incoming)? observed->downstream_app_data->first_packet :
+				observed->upstream_app_data->first_packet;
+
+			//put app data block in queue
+			if(saved_data == NULL){
+				if(incoming){
+					observed->downstream_app_data->first_packet = new_block;
+					if(new_block->seq_num ==
+							observed->downstream_seq_num){
+						observed->downstream_seq_num += new_block->len;
+#ifdef DEBUG
+						printf("Updated downstream expected seqnum to %u\n",
+								observed->downstream_seq_num );
+#endif
+					}
+				} else {
+					observed->upstream_app_data->first_packet = new_block;
+					if(new_block->seq_num ==
+							observed->upstream_seq_num){
+						observed->upstream_seq_num += new_block->len;
+#ifdef DEBUG
+						printf("Updated upstream expected seqnum to %u\n",
+								observed->upstream_seq_num );
+#endif
+					}
+				}
 
-				packet *saved_data = (incoming)? observed->downstream_app_data->first_packet :
-					observed->upstream_app_data->first_packet;
+			}
+			else{
+				uint8_t saved = 0;
+				while(saved_data->next != NULL){
+					if(!saved && (saved_data->next->seq_num > seq_num)){
+						new_block->next = saved_data->next;
+						saved_data->next = new_block;
+						saved = 1;
+					}
 
-				//put app data block in queue
-				if(saved_data == NULL){
+					//update expected sequence number
 					if(incoming){
-						observed->downstream_app_data->first_packet = new_block;
-						if(new_block->seq_num ==
+						if(saved_data->next->seq_num ==
 								observed->downstream_seq_num){
-							observed->downstream_seq_num += new_block->len;
+							observed->downstream_seq_num += saved_data->next->len;
 #ifdef DEBUG
 							printf("Updated downstream expected seqnum to %u\n",
 									observed->downstream_seq_num );
 #endif
 						}
-					} else {
-						observed->upstream_app_data->first_packet = new_block;
-						if(new_block->seq_num ==
+					} else {//outgoing
+						if(saved_data->next->seq_num ==
 								observed->upstream_seq_num){
-							observed->upstream_seq_num += new_block->len;
+							observed->upstream_seq_num += saved_data->next->len;
 #ifdef DEBUG
 							printf("Updated upstream expected seqnum to %u\n",
 									observed->upstream_seq_num );
 #endif
 						}
 					}
+						
+					saved_data = saved_data->next;
 
 				}
-				else{
-					uint8_t saved = 0;
-					while(saved_data->next != NULL){
-						if(!saved && (saved_data->next->seq_num > seq_num)){
-							new_block->next = saved_data->next;
-							saved_data->next = new_block;
-							saved = 1;
-						}
-
-						//update expected sequence number
-						if(incoming){
-							if(saved_data->next->seq_num ==
-									observed->downstream_seq_num){
-								observed->downstream_seq_num += saved_data->next->len;
-#ifdef DEBUG
-								printf("Updated downstream expected seqnum to %u\n",
-										observed->downstream_seq_num );
-#endif
-							}
-						} else {//outgoing
-							if(saved_data->next->seq_num ==
-									observed->upstream_seq_num){
-								observed->upstream_seq_num += saved_data->next->len;
+				if(!saved){
+					saved_data->next = new_block;
+					//update expected sequence number
+					if(incoming){
+						if(saved_data->next->seq_num ==
+								observed->downstream_seq_num){
+							observed->downstream_seq_num += saved_data->next->len;
 #ifdef DEBUG
-								printf("Updated upstream expected seqnum to %u\n",
-										observed->upstream_seq_num );
+							printf("Updated downstream expected seqnum to %u\n",
+									observed->downstream_seq_num );
 #endif
-							}
 						}
-							
-						saved_data = saved_data->next;
-
-					}
-					if(!saved){
-						saved_data->next = new_block;
-						//update expected sequence number
-						if(incoming){
-							if(saved_data->next->seq_num ==
-									observed->downstream_seq_num){
-								observed->downstream_seq_num += saved_data->next->len;
-#ifdef DEBUG
-								printf("Updated downstream expected seqnum to %u\n",
-										observed->downstream_seq_num );
-#endif
-							}
-						} else {//outgoing
-							if(saved_data->next->seq_num ==
-									observed->upstream_seq_num){
-								observed->upstream_seq_num += saved_data->next->len;
+					} else {//outgoing
+						if(saved_data->next->seq_num ==
+								observed->upstream_seq_num){
+							observed->upstream_seq_num += saved_data->next->len;
 #ifdef DEBUG
-								printf("Updated upstream expected seqnum to %u\n",
-										observed->upstream_seq_num );
+							printf("Updated upstream expected seqnum to %u\n",
+									observed->upstream_seq_num );
 #endif
-							}
 						}
-
 					}
+
 				}
 			}
 		}
 
+		observed->ref_ctr--;
 	}
 
 

+ 4 - 3
relay_station/slitheen.h

@@ -4,10 +4,11 @@
 #include <netinet/in.h>
 #include <pcap.h>
 
-//#define macaddr1 "00:25:90:5a:26:99"
-//#define macaddr2 "00:25:90:c9:5a:09"
+#define macaddr1 "00:25:90:5a:26:99"
+#define macaddr2 "00:25:90:c9:5a:09"
 
-#define macaddr "08:00:27:0e:89:ea"
+//#define macaddr1 "08:00:27:0e:89:ea"
+//#define macaddr2 "08:00:27:0e:89:ea"
 
 /* Ethernet addresses are 6 bytes */
 #define ETHER_ADDR_LEN	6