Browse Source

refactored process_packet function for readability

cecylia 7 years ago
parent
commit
f1c1745d0d
1 changed files with 255 additions and 206 deletions
  1. 255 206
      relay_station/slitheen-proxy.c

+ 255 - 206
relay_station/slitheen-proxy.c

@@ -21,6 +21,11 @@
 #include "crypto.h"
 #include "cryptothread.h"
 
+
+void save_packet(flow *f, struct packet_info *info);
+void update_window_expiration(flow *f, struct packet_info *info);
+void retransmit(flow *f, struct packet_info *info, uint32_t data_to_fill);
+
 void usage(void){
 	printf("Usage: slitheen-proxy [internal network interface] [NAT interface]\n");
 }
@@ -144,7 +149,7 @@ end:
 #ifdef DEBUG
 	fprintf(stderr, "injected the following packet:\n");
 	for(int i=0; i< header->len; i++){
-		fprintf(stderr, "%02x ", packet[i]);
+		fprintf(stderr, "%02x ", tmp_packet[i]);
 	}
 	fprintf(stderr, "\n");
 
@@ -178,71 +183,25 @@ void process_packet(struct packet_info *info){
 	flow *observed;
 	if((observed = check_flow(info)) != NULL){
 	
-#ifdef DEBUG
+//#ifdef DEBUG
 		/*Check sequence number and replay application data if necessary*/
 		fprintf(stdout,"Flow: %x:%d > %x:%d (%s)\n", info->ip_hdr->src.s_addr, ntohs(info->tcp_hdr->src_port), info->ip_hdr->dst.s_addr, ntohs(info->tcp_hdr->dst_port), (info->ip_hdr->src.s_addr != observed->src_ip.s_addr)? "incoming":"outgoing");
 		fprintf(stdout,"ID number: %u\n", htonl(info->ip_hdr->id));
 		fprintf(stdout,"Sequence number: %u\n", htonl(info->tcp_hdr->sequence_num));
 		fprintf(stdout,"Acknowledgement number: %u\n", htonl(info->tcp_hdr->ack_num));
-#endif
+//#endif
 
 		uint8_t incoming = (info->ip_hdr->src.s_addr != observed->src_ip.s_addr)? 1 : 0;
 		uint32_t seq_num = htonl(info->tcp_hdr->sequence_num);
 		uint32_t expected_seq = (incoming)? observed->downstream_seq_num : observed->upstream_seq_num;
-#ifdef DEBUG
+//#ifdef DEBUG
 		fprintf(stdout,"Expected sequence number: %u\n", expected_seq);
-#endif
-
-
-		/* Remove acknknowledged data from queue after TCP window is exceeded */
-		uint32_t ack_num = htonl(info->tcp_hdr->ack_num);
-		uint32_t end_seq = seq_num + info->app_data_len - 1;
-		uint32_t window = ack_num + htons(info->tcp_hdr->win_size);
-
-#ifdef DEBUG
-		printf("Received sequence number %u\n", seq_num);
-		printf("Acknowledged up to %u with window expiring at %u\n", ack_num, window);
-		printf("Removing all packets up to %u\n", end_seq);
-#endif
-
-		packet *saved_data = (incoming)? observed->downstream_app_data->first_packet :
-			observed->upstream_app_data->first_packet;
-		while((saved_data != NULL) && (saved_data->expiration != 0) && (end_seq > saved_data->expiration)){
-			//remove entire block
-			if(incoming){
-				observed->downstream_app_data->first_packet = saved_data->next;
-			} else {
-				observed->upstream_app_data->first_packet = saved_data->next;
-			}
+//#endif
 
-			free(saved_data->data);
-			free(saved_data);
-			saved_data = (incoming)? observed->downstream_app_data->first_packet :
-				observed->upstream_app_data->first_packet;
+		/* Remove acknowledged data from queue after TCP window is exceeded */
+        update_window_expiration(observed, info);
 
-#ifdef DEBUG
-			if(saved_data != NULL){
-				printf("Currently saved seq_num is now %u\n", saved_data->seq_num);
-			} else {
-				printf("Acked all data, queue is empty\n");
-			}
-#endif
-
-		}
-
-		/* Update expiration for packets based on TCP window size */
-		saved_data = (incoming)? observed->upstream_app_data->first_packet :
-			observed->downstream_app_data->first_packet;
-		while((saved_data != NULL) && (ack_num > saved_data->seq_num)){
-			//update window
-			if(ack_num >= saved_data->seq_num + saved_data->len){
-				//remove entire block
-				saved_data->expiration = window;
-			}
-			saved_data = saved_data->next;
-		}
-
-		//fill with retransmit data, process new data
+		/* fill with retransmit data, process new data */
 		uint32_t data_to_fill;
 		uint32_t data_to_process;
 
@@ -260,72 +219,10 @@ void process_packet(struct packet_info *info){
 		uint8_t *p = info->app_data;
 
 		if(data_to_fill){ //retransmit
-			packet *saved_data = (incoming)? observed->downstream_app_data->first_packet :
-				observed->upstream_app_data->first_packet;
-
-
-			while(data_to_fill > 0){
-				if(saved_data == NULL){
-					//have already acked all data
-					p += data_to_fill;
-					seq_num += data_to_fill;
-					data_to_fill -= data_to_fill;
-					continue;
-				}
-
-				if(seq_num < saved_data->seq_num){
-					//we are missing a block. Use what was given
-					if(saved_data->seq_num - seq_num > data_to_fill){
-						//skip the rest
-						p += data_to_fill;
-						seq_num += data_to_fill;
-						data_to_fill -= data_to_fill;
-					} else {
-						p += saved_data->seq_num - seq_num;
-						data_to_fill -= saved_data->seq_num - seq_num;
-						seq_num += saved_data->seq_num - seq_num;
-					}
-				} else if ( seq_num == saved_data->seq_num) {
-
-					if(data_to_fill >= saved_data->len){
-						//exhaust this block and move onto next one
-						memcpy(p, saved_data->data, saved_data->len);
-						p += saved_data->len;
-						seq_num += saved_data->len;
-						data_to_fill -= saved_data->len;
-						saved_data = saved_data->next;
-					} else {
-						//fill with partial block
-						memcpy(p, saved_data->data, data_to_fill);
-						p += data_to_fill;
-						seq_num += data_to_fill;
-						data_to_fill -= data_to_fill;
-					}
-				} else { //seq_num > saved_data->seq_num
-					uint32_t offset = seq_num - saved_data->seq_num;
-					
-					if(offset > saved_data->len){
-						saved_data = saved_data->next;
-						offset -= saved_data->len;
-					} else {
-						if(data_to_fill > saved_data->len - offset){
-							memcpy(p, saved_data->data + offset, saved_data->len - offset);
-							p += saved_data->len - offset;
-							seq_num += saved_data->len - offset;
-							data_to_fill -= saved_data->len - offset;
-							saved_data = saved_data->next;
-						} else {
-							memcpy(p, saved_data->data + offset, data_to_fill);
-							p += data_to_fill;
-							seq_num += data_to_fill;
-							data_to_fill -= data_to_fill;
-						}
-					}
-				}
-			}
-
+            retransmit(observed, info, data_to_fill);
 		}
-		tcp_checksum(info);//update checksum
+
+        p += data_to_fill;
 
 		if(data_to_process){
 
@@ -334,10 +231,21 @@ void process_packet(struct packet_info *info){
 			}
 
 			if(observed->application){
+                if(seq_num > expected_seq){
+                    //For now, enters into FORFEIT state
+                    //TODO: change upstream behaviour to try to mask slitheen hdr
+                    printf("ERROR: future packet in app data, forfeiting flow\n");
+                    remove_flow(observed);
+                    return;
+                }
+
 				replace_packet(observed, info);
 			} else {
 
 				/* Pass data to packet chain */
+				if(observed->stall){
+
+				}
 				if(add_packet(observed, info)){//removed_flow
 					return;
 				}
@@ -349,109 +257,252 @@ void process_packet(struct packet_info *info){
 				remove_flow(observed);
 				return;
 			}
+
 			/* add packet to application data queue */
+            save_packet(observed, info);
+
+		}
+
+		observed->ref_ctr--;
+	}
+
+
+}
 
-			//add new app block
-			packet *new_block = ecalloc(1, sizeof(packet));
-			new_block->seq_num = seq_num;
-			new_block->data = ecalloc(1, info->app_data_len);
-			memcpy(new_block->data, info->app_data, info->app_data_len);
-			new_block->len = info->app_data_len;
-			new_block->next = NULL;
-			new_block->expiration = 0;
-
-			packet *saved_data = (incoming)? observed->downstream_app_data->first_packet :
-				observed->upstream_app_data->first_packet;
-
-			//put app data block in queue
-			if(saved_data == NULL){
-				if(incoming){
-					observed->downstream_app_data->first_packet = new_block;
-					if(new_block->seq_num ==
-							observed->downstream_seq_num){
-						observed->downstream_seq_num += new_block->len;
+void save_packet(flow *f, struct packet_info *info){
+
+    uint8_t incoming = (info->ip_hdr->src.s_addr != f->src_ip.s_addr)? 1 : 0;
+    uint32_t seq_num = htonl(info->tcp_hdr->sequence_num);
+
+    //add new app block
+    packet *new_block = ecalloc(1, sizeof(packet));
+    new_block->seq_num = htonl(info->tcp_hdr->sequence_num);
+    new_block->data = ecalloc(1, info->app_data_len);
+    memcpy(new_block->data, info->app_data, info->app_data_len);
+    new_block->len = info->app_data_len;
+    new_block->next = NULL;
+    new_block->expiration = 0;
+
+    packet *saved_data = (incoming)? f->downstream_app_data->first_packet :
+        f->upstream_app_data->first_packet;
+
+    //put app data block in queue
+    if(saved_data == NULL){
+        if(incoming){
+            f->downstream_app_data->first_packet = new_block;
+            if(new_block->seq_num ==
+                    f->downstream_seq_num){
+                f->downstream_seq_num += new_block->len;
 #ifdef DEBUG
-						printf("Updated downstream expected seqnum to %u\n",
-								observed->downstream_seq_num );
+                printf("Updated downstream expected seqnum to %u\n",
+                        f->downstream_seq_num );
 #endif
-					}
-				} else {
-					observed->upstream_app_data->first_packet = new_block;
-					if(new_block->seq_num ==
-							observed->upstream_seq_num){
-						observed->upstream_seq_num += new_block->len;
+            }
+        } else {
+            f->upstream_app_data->first_packet = new_block;
+            if(new_block->seq_num ==
+                    f->upstream_seq_num){
+                f->upstream_seq_num += new_block->len;
 #ifdef DEBUG
-						printf("Updated upstream expected seqnum to %u\n",
-								observed->upstream_seq_num );
+                printf("Updated upstream expected seqnum to %u\n",
+                        f->upstream_seq_num );
 #endif
-					}
-				}
-
-			}
-			else{
-				uint8_t saved = 0;
-				while(saved_data->next != NULL){
-					if(!saved && (saved_data->next->seq_num > seq_num)){
-						new_block->next = saved_data->next;
-						saved_data->next = new_block;
-						saved = 1;
-					}
-
-					//update expected sequence number
-					if(incoming){
-						if(saved_data->next->seq_num ==
-								observed->downstream_seq_num){
-							observed->downstream_seq_num += saved_data->next->len;
+            }
+        }
+
+    } else {
+        uint8_t saved = 0;
+        while(saved_data->next != NULL){
+            if(!saved && (saved_data->next->seq_num > seq_num)){
+                new_block->next = saved_data->next;
+                saved_data->next = new_block;
+                saved = 1;
+            }
+
+            //update expected sequence number
+            if(incoming){
+                if(saved_data->next->seq_num ==
+                        f->downstream_seq_num){
+                    f->downstream_seq_num += saved_data->next->len;
 #ifdef DEBUG
-							printf("Updated downstream expected seqnum to %u\n",
-									observed->downstream_seq_num );
+                    printf("Updated downstream expected seqnum to %u\n",
+                            f->downstream_seq_num );
 #endif
-						}
-					} else {//outgoing
-						if(saved_data->next->seq_num ==
-								observed->upstream_seq_num){
-							observed->upstream_seq_num += saved_data->next->len;
+                }
+            } else {//outgoing
+                if(saved_data->next->seq_num ==
+                        f->upstream_seq_num){
+                    f->upstream_seq_num += saved_data->next->len;
 #ifdef DEBUG
-							printf("Updated upstream expected seqnum to %u\n",
-									observed->upstream_seq_num );
+                    printf("Updated upstream expected seqnum to %u\n",
+                            f->upstream_seq_num );
 #endif
-						}
-					}
-						
-					saved_data = saved_data->next;
-
-				}
-				if(!saved){
-					saved_data->next = new_block;
-					//update expected sequence number
-					if(incoming){
-						if(saved_data->next->seq_num ==
-								observed->downstream_seq_num){
-							observed->downstream_seq_num += saved_data->next->len;
+                }
+            }
+                
+            saved_data = saved_data->next;
+
+        }
+        if(!saved){
+            saved_data->next = new_block;
+            //update expected sequence number
+            if(incoming){
+                if(saved_data->next->seq_num ==
+                        f->downstream_seq_num){
+                    f->downstream_seq_num += saved_data->next->len;
 #ifdef DEBUG
-							printf("Updated downstream expected seqnum to %u\n",
-									observed->downstream_seq_num );
+                    printf("Updated downstream expected seqnum to %u\n",
+                            f->downstream_seq_num );
 #endif
-						}
-					} else {//outgoing
-						if(saved_data->next->seq_num ==
-								observed->upstream_seq_num){
-							observed->upstream_seq_num += saved_data->next->len;
+                }
+            } else {//outgoing
+                if(saved_data->next->seq_num ==
+                        f->upstream_seq_num){
+                    f->upstream_seq_num += saved_data->next->len;
 #ifdef DEBUG
-							printf("Updated upstream expected seqnum to %u\n",
-									observed->upstream_seq_num );
+                    printf("Updated upstream expected seqnum to %u\n",
+                            f->upstream_seq_num );
 #endif
-						}
-					}
+                }
+            }
 
-				}
-			}
-		}
+        }
+    }
+}
 
-		observed->ref_ctr--;
-	}
+/**
+ * This function cleans up data that has been acked, after the TCP window of the recipient has been
+ * exceeded. This ensures that a retransmisson of the data will no longer occur.
+ *
+ * Sets the expiration for recent data base on the TCP window
+ */
+void update_window_expiration(flow *f, struct packet_info *info){
+
+    uint8_t incoming = (info->ip_hdr->src.s_addr != f->src_ip.s_addr)? 1 : 0;
+    uint32_t ack_num = htonl(info->tcp_hdr->ack_num);
+    uint32_t end_seq = htonl(info->tcp_hdr->sequence_num) + info->app_data_len - 1;
+    uint32_t window = ack_num + htons(info->tcp_hdr->win_size);
+
+//#ifdef DEBUG
+    printf("Received sequence number %u\n", htonl(info->tcp_hdr->sequence_num));
+    printf("Acknowledged up to %u with window expiring at %u\n", ack_num, window);
+    printf("Removing all packets up to %u\n", end_seq);
+//#endif
+
+    packet *saved_data = (incoming)? f->downstream_app_data->first_packet :
+        f->upstream_app_data->first_packet;
+    while((saved_data != NULL) && (saved_data->expiration != 0) && (end_seq > saved_data->expiration)){
+        //remove entire block
+        if(incoming){
+            f->downstream_app_data->first_packet = saved_data->next;
+        } else {
+            f->upstream_app_data->first_packet = saved_data->next;
+        }
+
+        free(saved_data->data);
+        free(saved_data);
+        saved_data = (incoming)? f->downstream_app_data->first_packet :
+            f->upstream_app_data->first_packet;
+
+#ifdef DEBUG
+        if(saved_data != NULL){
+            printf("Currently saved seq_num is now %u\n", saved_data->seq_num);
+        } else {
+            printf("Acked all data, queue is empty\n");
+        }
+#endif
 
+    }
 
+    /* Update expiration for packets based on TCP window size */
+    saved_data = (incoming)? f->upstream_app_data->first_packet :
+        f->downstream_app_data->first_packet;
+    while((saved_data != NULL) && (ack_num > saved_data->seq_num)){
+        //update window
+        if(ack_num >= saved_data->seq_num + saved_data->len){
+            //remove entire block
+            saved_data->expiration = window;
+        }
+        saved_data = saved_data->next;
+    }
+
+}
+
+/**
+ * This function retransmits previously sent (and possibly modified) data
+ *
+ */
+void retransmit(flow *f, struct packet_info *info, uint32_t data_to_fill){
+
+    uint8_t *p = info->app_data;
+    uint32_t seq_num = htonl(info->tcp_hdr->sequence_num);
+    uint8_t incoming = (info->ip_hdr->src.s_addr != f->src_ip.s_addr)? 1 : 0;
+
+    packet *saved_data = (incoming)? f->downstream_app_data->first_packet :
+        f->upstream_app_data->first_packet;
+
+    printf("Filling with %d retransmitted bytes\n", data_to_fill);
+
+    while(data_to_fill > 0){
+        if(saved_data == NULL){
+            //have already acked all data
+            p += data_to_fill;
+            seq_num += data_to_fill;
+            data_to_fill -= data_to_fill;
+            continue;
+        }
+
+        if(seq_num < saved_data->seq_num){
+            //we are missing a block. Use what was given
+            if(saved_data->seq_num - seq_num > data_to_fill){
+                //skip the rest
+                p += data_to_fill;
+                seq_num += data_to_fill;
+                data_to_fill -= data_to_fill;
+            } else {
+                p += saved_data->seq_num - seq_num;
+                data_to_fill -= saved_data->seq_num - seq_num;
+                seq_num += saved_data->seq_num - seq_num;
+            }
+        } else if ( seq_num == saved_data->seq_num) {
+
+            if(data_to_fill >= saved_data->len){
+                //exhaust this block and move onto next one
+                memcpy(p, saved_data->data, saved_data->len);
+                p += saved_data->len;
+                seq_num += saved_data->len;
+                data_to_fill -= saved_data->len;
+                saved_data = saved_data->next;
+            } else {
+                //fill with partial block
+                memcpy(p, saved_data->data, data_to_fill);
+                p += data_to_fill;
+                seq_num += data_to_fill;
+                data_to_fill -= data_to_fill;
+            }
+        } else { //seq_num > saved_data->seq_num
+            uint32_t offset = seq_num - saved_data->seq_num;
+            
+            if(offset > saved_data->len){
+                saved_data = saved_data->next;
+                offset -= saved_data->len;
+            } else {
+                if(data_to_fill > saved_data->len - offset){
+                    memcpy(p, saved_data->data + offset, saved_data->len - offset);
+                    p += saved_data->len - offset;
+                    seq_num += saved_data->len - offset;
+                    data_to_fill -= saved_data->len - offset;
+                    saved_data = saved_data->next;
+                } else {
+                    memcpy(p, saved_data->data + offset, data_to_fill);
+                    p += data_to_fill;
+                    seq_num += data_to_fill;
+                    data_to_fill -= data_to_fill;
+                }
+            }
+        }
+    }
+    tcp_checksum(info);//update checksum
 }
 
 /** This function extracts the ip, tcp, and tls record headers
@@ -528,5 +579,3 @@ struct packet_info *copy_packet_info(struct packet_info *src_info){
 
 	return dst_info;
 }
-
-