#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <pthread.h>
#include <errno.h>
#include <semaphore.h>

#include "flow.h"
#include "crypto.h"
#include "slitheen.h"
#include "relay.h"

static flow_table *table;
static session_cache *sessions;
data_queue *downstream_queue;
stream_table *streams;

sem_t flow_table_lock;

/* Initialize the table of tagged flows */
int init_tables(void) {

	table = calloc(1, sizeof(flow_table));
	table->first_entry = NULL;
	table->len = 0;

	sem_init(&flow_table_lock, 0, 1);

	downstream_queue = calloc(1, sizeof(data_queue));
	downstream_queue->first_block = NULL;
	
	streams = calloc(1, sizeof(stream_table));
	streams->first = NULL;
	printf("initialized downstream queue\n");

	return 0;
}


/* Add a new flow to the tagged flow table */
flow *add_flow(flow newFlow) {
	flow_entry *entry = calloc(1, sizeof(flow_entry));
	flow *ptr = calloc(1, sizeof(flow));
	entry->f = ptr;
	entry->next = NULL;

	printf("there are %d flows in the table\n", table->len);

	sem_init(&(newFlow.flow_lock), 0, 1);
	newFlow.state = TLS_CLNT_HELLO;
	newFlow.in_encrypted = 0;
	newFlow.out_encrypted = 0;
	newFlow.application = 0;
	newFlow.resume_session = 0;
	newFlow.current_session = NULL;
	newFlow.packet_chain = NULL;
	sem_init(&(newFlow.packet_chain_lock), 0, 1);
	newFlow.upstream_queue = NULL;
	newFlow.upstream_remaining = 0;
	newFlow.outbox = NULL;
	newFlow.outbox_len = 0;
	newFlow.outbox_offset = 0;
	newFlow.partial_record_header = NULL;
	newFlow.partial_record_header_len = 0;
	newFlow.remaining_record_len = 0;
	newFlow.remaining_response_len = 0;
	newFlow.httpstate = PARSE_HEADER;
	newFlow.replace_response = 0;


	newFlow.finish_md_ctx = EVP_MD_CTX_create();
	const EVP_MD *md = EVP_sha384();
	EVP_DigestInit_ex(newFlow.finish_md_ctx, md, NULL);

	newFlow.clnt_read_ctx = NULL;
	newFlow.clnt_write_ctx = NULL;
	newFlow.srvr_read_ctx = NULL;
	newFlow.srvr_write_ctx = NULL;

	memset(newFlow.read_seq, 0, 8);
	memset(newFlow.write_seq, 0, 8);

	*ptr = newFlow;

	sem_wait(&flow_table_lock);
	flow_entry *last = table->first_entry;
	if(last == NULL){
		table->first_entry = entry;
	} else {
		for(int i=0; i< table->len-1; i++){
			last = last->next;
		}
		last->next = entry;
	}
	table->len ++;
	sem_post(&flow_table_lock);

	return ptr;
}

/** Observes TLS handshake messages and updates the state of
 *  the flow
 *
 *  Inputs:
 *  	f: the tagged flow
 *
 *  Output:
 *  	0 on success, 1 on failure
 */
int update_flow(flow *f) {
	uint8_t *record;
	const struct record_header *record_hdr;
	const struct handshake_header *handshake_hdr;

	sem_wait(&(f->packet_chain_lock));

	if(f->packet_chain == NULL){
		sem_post(&(f->packet_chain_lock));
		return 0;
	}
	uint8_t *p = f->packet_chain->data;
	record_hdr = (struct record_header*) p;
	int record_len;
	int data_len;

	record_len = RECORD_LEN(record_hdr)+RECORD_HEADER_LEN;
	data_len = f->packet_chain->data_len;

	packet *current = f->packet_chain;
	int incoming = current->incoming;
	record = calloc(1, record_len);
	
	for(int i=0; (i<data_len) && (i<record_len); i++){
		record[i] = p[i];
	}

	while(record_len > data_len) {
		if(current->next == NULL){
			goto err;
		}
		if(current->next->seq_num != current->seq_num + current->len){
			printf("Missing packet: seq_num= %d, datalen= %d, nextseq= %d\n", current->seq_num, current->len, current->next->seq_num);
			goto err;
		}

		current = current->next;
		p = current->data;
		int i;
		for(i=0; (i<current->data_len) && (i+data_len < record_len); i++){

			record[data_len+i] = p[i];
		}
		data_len += current->data_len;
	}

	switch(record_hdr->type){
		case HS:
			p = record;
			p += RECORD_HEADER_LEN;

			if((incoming && f->in_encrypted) || (!incoming && f->out_encrypted)){
				printf("Decrypting finished (%d bytes) (%x:%d -> %x:%d)\n", record_len - RECORD_HEADER_LEN, f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
				int32_t n = encrypt(f, p, p, record_len - RECORD_HEADER_LEN, incoming, 0x16, 0);
				if(n<=0){
					printf("Error decrypting finished  (%x:%d -> %x:%d)\n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
				}
				printf("Finished decrypted: (%x:%d -> %x:%d)\n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
				p += EVP_GCM_TLS_EXPLICIT_IV_LEN;
				
				printf("record:\n");
				for(int i=0; i< n; i++){
					printf("%02x ", p[i]);
				}
				printf("\n");

				update_context(f, p, n, incoming, 0x16, 0);
				if(incoming) f->in_encrypted = 2;
				else f->out_encrypted = 2;
			}
			handshake_hdr = (struct handshake_header*) p;
			f->state = handshake_hdr->type;

			switch(f->state){
				case TLS_CLNT_HELLO: 
					printf("Received tagged client hello (%x:%d -> %x:%d)\n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
					update_finish_hash(f, p);
					check_session(f, p, HANDSHAKE_MESSAGE_LEN(handshake_hdr));
					break;
				case TLS_SERV_HELLO:
					printf("Received server hello (%x:%d -> %x:%d)\n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
					if(f->resume_session){
						verify_session_id(f,p);
					} else {
						save_session_id(f,p);
					}
					extract_server_random(f, p);
					update_finish_hash(f, p);
					break;
				case TLS_NEW_SESS:
					printf("Received new session\n");
					save_session_ticket(f, p, HANDSHAKE_MESSAGE_LEN(handshake_hdr));
					update_finish_hash(f, p);
					break;
				case TLS_CERT:
					printf("Received cert\n");
					update_finish_hash(f, p);
					break;
				case TLS_SRVR_KEYEX:
					printf("Received server keyex\n");
					update_finish_hash(f, p);

					if(extract_parameters(f, p)){
						printf("Error extracting params\n");
					}
					if(compute_master_secret(f)){
						printf("Error computing master secret\n");
					}
					break;
				case TLS_CERT_REQ:
					update_finish_hash(f, p);
					break;
				case TLS_SRVR_HELLO_DONE:
					printf("Received server hello done\n");
					update_finish_hash(f, p);
					break;
				case TLS_CERT_VERIFY:
					printf("received cert verify\n");
					update_finish_hash(f, p);
					break;
				case TLS_CLNT_KEYEX:
					printf("Received client key exchange\n");
					update_finish_hash(f, p);
					break;
				case TLS_FINISHED:
					printf("Received finished (%x:%d -> %x:%d)\n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
					verify_finish_hash(f,p, incoming);
					update_finish_hash(f, p);
					if((f->in_encrypted == 2) && (f->out_encrypted == 2)){
						printf("Handshake complete!\n");
						f->application = 1;
						if(current->incoming)
							f->seq_num = current->seq_num + current->len;
						while(current->next != NULL){
							current = current->next;
							if(current->incoming)
								f->seq_num = current->seq_num+ current->len;
						}
					}
					break;
				default:
					printf("Error? (%x:%d -> %x:%d)...\n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
					break;
			}
			break;
		case APP:
			printf("Application Data\n");
			break;
		case CCS:
			printf("CCS (%x:%d -> %x:%d) \n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
			if(incoming){
				f->in_encrypted = 1;
			} else {
				f->out_encrypted = 1;
			}
			
			/*Initialize ciphers */
			init_ciphers(f);
			break;
		case ALERT:
			p = record;
			p += RECORD_HEADER_LEN;
			if(((incoming) && (f->in_encrypted > 0)) || ((!incoming) && (f->out_encrypted > 0))){
				encrypt(f, p, p, record_len - RECORD_HEADER_LEN, incoming, 0x16, 0);
				p += EVP_GCM_TLS_EXPLICIT_IV_LEN;
			}
			printf("Alert (%x:%d -> %x:%d) %02x %02x \n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port, p[0], p[1]);
			fflush(stdout);
			break;
		case HB:
			printf("Heartbeat\n");
			break;
		default:
			printf("Error: Not a Record (%x:%d -> %x:%d)\n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
			fflush(stdout);
			//TODO: later figure this out, for now delete
			packet *tmp = f->packet_chain;
			f->packet_chain = f->packet_chain->next;
			free(tmp->data);
			free(tmp);
			
			if( f->packet_chain != NULL){
				sem_post(&(f->packet_chain_lock));
				free(record);
				update_flow(f);
				return 0;
			}
			goto err;
	}

	//if(!f->application){
		f->seq_num = current->seq_num;

		if(record_len == data_len){
			/* record ended on packet boundary */
			current = current->next;
			packet *tmp = f->packet_chain;
			while(tmp != current){
				f->packet_chain = tmp->next;
				free(tmp->data);
				free(tmp);
				tmp = f->packet_chain;
			}
		} else {
			/* need to update data */
			packet *tmp = f->packet_chain;
			while(tmp != current){
				f->packet_chain = tmp->next;
				free(tmp->data);
				free(tmp);
				tmp = f->packet_chain;
			}
			memmove(current->data, current->data + (current->data_len - (data_len - record_len)), data_len - record_len);
			current->data_len = data_len - record_len;
			sem_post(&(f->packet_chain_lock));
			free(record);
			update_flow(f);
			return 0;
		}
	//}

err:
	sem_post(&(f->packet_chain_lock));
	free(record);
	return 0;
}

/** Removes the tagged flow from the flow table: happens when
 *  the station receives a TCP RST or FIN packet
 *
 *  Input:
 *  	index: the index into the flow table of the tagged flow
 *
 *  Output:
 *  	0 on success, 1 on failure
 */
int remove_flow(flow *f) {

	EVP_MD_CTX_destroy(f->finish_md_ctx);
	//Clean up cipher ctxs
	if(f->clnt_read_ctx != NULL){
		EVP_CIPHER_CTX_free(f->clnt_read_ctx);
	}
	if(f->clnt_write_ctx != NULL){
		EVP_CIPHER_CTX_free(f->clnt_write_ctx);
	}
	if(f->srvr_read_ctx != NULL){
		EVP_CIPHER_CTX_free(f->srvr_read_ctx);
	}
	if(f->srvr_write_ctx != NULL){
		EVP_CIPHER_CTX_free(f->srvr_write_ctx);
	}

	sem_wait(&flow_table_lock);
	flow_entry *entry = table->first_entry;
	if(entry->f == f){
		table->first_entry = entry->next;
		free(entry->f);
		free(entry);
		printf("flow removed!\n");
		fflush(stdout);
		table->len --;
	} else {

		flow_entry *next;
		for(int i=0; i< table->len; i++){
			if(entry->next != NULL){
				next = entry->next;
			} else {
				printf("Flow not in table\n");
				break;
			}

			if(next->f == f){
				entry->next = next->next;
				free(next->f);
				free(next);
				printf("flow removed!\n");
				table->len --;
				break;
			}

			entry = next;
		}
	}
	sem_post(&flow_table_lock);

	return 1;
}

/** Expands the flow table when we run out of space
 *  TODO: implement and test
 */
int grow_table() {
	return 0;
}

/** Returns the index of a flow in the flow table if
 *  it exists, returns 0 if it is not present.
 *
 *  Inputs:
 *  	observed: details for the observed flow
 *
 *  Output:
 *  	index of flow in table or -1 if it doesn't exist
 */
flow *check_flow(flow observed){
	/* Loop through flows in table and see if it exists */
	int i;
	flow_entry *entry = table->first_entry;
	flow *candidate;
	flow *found = NULL;
	if(entry == NULL)
		return NULL;

	sem_wait(&flow_table_lock);
	/* Check first in this direction */
	for(i=0; i<table->len; i++){
		if(entry == NULL){
			printf("Error: entry is null\n");
			break;
		}
		candidate = entry->f;
		if(candidate->src_ip.s_addr == observed.src_ip.s_addr){
			if(candidate->dst_ip.s_addr == observed.dst_ip.s_addr){
				if(candidate->src_port == observed.src_port){
					if(candidate->dst_port == observed.dst_port){
						found = candidate;
					}
				}
			}
		}
		entry = entry->next;
	}


	entry = table->first_entry;
	/* Then in the other direction */
	for(i=0; i<table->len; i++){
		if(entry == NULL){
			printf("Error: entry is null\n");
			break;
		}
		candidate = entry->f;
		if(candidate->src_ip.s_addr == observed.dst_ip.s_addr){
			if(candidate->dst_ip.s_addr == observed.src_ip.s_addr){
				if(candidate->src_port == observed.dst_port){
					if(candidate->dst_port == observed.src_port){
						found = candidate;
					}
				}
			}
		}
		entry = entry->next;
	}
	sem_post(&flow_table_lock);

	return found;
}

int init_session_cache(void){
	sessions = malloc(sizeof(session_cache));
	sessions->length = 0;
	sessions->first_session = NULL;

	return 0;
}

/** Called from ServerHello, verifies that the session id returned matches
 *  the session id requested from the client hello
 *
 *  Input:
 *  	f: the tagged flow
 *  	hs: a pointer to the ServerHello message
 *
 *  Output:
 *  	0 if success, 1 if failed
 */
int verify_session_id(flow *f, uint8_t *hs){
	
	//increment pointer to point to sessionid
	uint8_t *p = hs + HANDSHAKE_HEADER_LEN;
	p += 2; //skip version
	p += SSL3_RANDOM_SIZE; //skip random

	uint8_t id_len = (uint8_t) p[0];
	p ++;
	
	//check to see if it matches flow's session id set by ClientHello
	if(f->current_session != NULL && f->current_session->session_id_len > 0 && !memcmp(f->current_session->session_id, p, id_len)){
		//if it matched, update flow with master secret :D
		session *last = sessions->first_session;
		int found = 0;
		for(int i=0; ((i<sessions->length) && (!found)); i++){
			if(!memcmp(last->session_id, f->current_session->session_id, id_len)){
				memcpy(f->master_secret, last->master_secret, SSL3_MASTER_SECRET_SIZE);
				found = 1;
			}
			last = last->next;
		}
		if((!found) && (f->current_session->session_ticket_len > 0)){
			last = sessions->first_session;
			for(int i=0; ((i<sessions->length) && (!found)); i++){
				if(!memcmp(last->session_ticket, f->current_session->session_ticket, f->current_session->session_ticket_len)){
					memcpy(f->master_secret, last->master_secret, SSL3_MASTER_SECRET_SIZE);
					found = 1;
					printf("Found new session ticket (%x:%d -> %x:%d)\n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);
					for(int i=0; i< last->session_ticket_len; i++){
						printf("%02x ", last->session_ticket[i]);
					}
					printf("\n");
				}
				last = last->next;
			}
		}

	} else if (f->current_session != NULL &&  f->current_session->session_id_len > 0){
		//check to see if server's hello extension matches the ticket
		save_session_id(f, p);
	}

	//now check 

	return 0;

}

/* Called from ClientHello. Checks to see if the session id len is > 0. If so,
 * saves sessionid for later verification. Also checks to see if a session
 * ticket is included as an extension.
 *
 *  Input:
 *  	f: the tagged flow
 *  	hs: a pointer to the ServerHello message
 *
 *  Output:
 *  	0 if success, 1 if failed
 */
int check_session(flow *f, uint8_t *hs, uint32_t len){

	uint8_t *p = hs + HANDSHAKE_HEADER_LEN;
	p += 2; //skip version
	p += SSL3_RANDOM_SIZE; //skip random

	session *new_session = calloc(1, sizeof(session));
	new_session->session_id_len = (uint8_t) p[0];
	new_session->session_ticket_len = 0;
	p  ++;

	if(new_session->session_id_len > 0){
		f->resume_session = 1;
		memcpy(new_session->session_id, p, new_session->session_id_len);
		new_session->next = NULL;
		printf("Requested new session (%x:%d -> %x:%d)\n", f->src_ip.s_addr, f->src_port, f->dst_ip.s_addr, f->dst_port);

		f->current_session = new_session;
	}

	p += new_session->session_id_len;
	
	//check to see if there is a session ticket included

	//skip to extensions
	uint16_t ciphersuite_len = (p[0] << 8) + p[1];
	p += 2 + ciphersuite_len;
	uint8_t compress_meth_len = p[0];
	p += 1 + compress_meth_len;
	
	//search for SessionTicket TLS extension
	if(2 + SSL3_RANDOM_SIZE + new_session->session_id_len + 1 + 2 + ciphersuite_len + 1 + compress_meth_len > len){
		//no extension
		if(f->current_session == NULL)
			free(new_session);
		return 0;
	}
	uint16_t extensions_len = (p[0] << 8) + p[1];
	p += 2;
	while(extensions_len > 0){
		uint16_t type = (p[0] << 8) + p[1];
		p += 2;
		uint16_t ext_len = (p[0] << 8) + p[1];
		p += 2;
		if(type == 0x23){
			if(ext_len > 0){
				f->resume_session = 1;
				new_session->session_ticket_len = ext_len;
				new_session->session_ticket = calloc(1, ext_len);
				memcpy(new_session->session_ticket, p, ext_len);
				f->current_session = new_session;
				
			}
		}
		p += ext_len;
		extensions_len -= (4 + ext_len);
	}

	if(!f->resume_session){
		//see if a ticket is incuded
		free(new_session);
	}

	return 0;
}
	

/* Called from ServerHello during full handshake. Adds the session id to the
 * cache for later resumptions
 *
 *  Input:
 *  	f: the tagged flow
 *  	hs: a pointer to the ServerHello message
 *
 *  Output:
 *  	0 if success, 1 if failed
 */
int save_session_id(flow *f, uint8_t *hs){
	printf("saving session id\n");

	//increment pointer to point to sessionid
	uint8_t *p = hs + HANDSHAKE_HEADER_LEN;
	p += 2; //skip version
	p += SSL3_RANDOM_SIZE; //skip random
	
	session *new_session = calloc(1, sizeof(session));
	new_session->session_id_len = (uint8_t) p[0];
	if(new_session->session_id_len <= 0){
		//if this value is zero, the session is non-resumable or the
		//server will issue a NewSessionTicket handshake message
		free(new_session);
		return 0;
	}
	p++;
	memcpy(new_session->session_id, p, new_session->session_id_len);
	new_session->next = NULL;

	f->current_session = new_session;

	if(sessions->first_session == NULL){
		sessions->first_session = new_session;
	} else {
		session *last = sessions->first_session;

		for(int i=0; i< sessions->length; i++){
			if(last == NULL)
				printf("UH OH: last is null?\n");
			last = last->next;
		}
		last->next = new_session;
	}

	sessions->length ++;

	printf("Saved session id:");
	for(int i=0; i< new_session->session_id_len; i++){
		printf(" %02x", p[i]);
	}
	printf("\n");

	printf("THERE ARE NOW %d saved sessions\n", sessions->length);

	return 0;

}

/* Called from NewSessionTicket. Adds the session ticket to the
 * cache for later resumptions
 *
 *  Input:
 *  	f: the tagged flow
 *  	hs: a pointer to the ServerHello message
 *
 *  Output:
 *  	0 if success, 1 if failed
 */
int save_session_ticket(flow *f, uint8_t *hs, uint32_t len){
	uint8_t *p = hs + HANDSHAKE_HEADER_LEN;
	p += 4; //skip lifetime TODO: add to session struct
	session *new_session = calloc(1,sizeof(session));
	new_session->session_id_len = 0;
	
	new_session->session_ticket_len = (p[0] << 8) + p[1];
	printf("saving ticket of size %d\n", new_session->session_ticket_len);
	fflush(stdout);
	p += 2;

	uint8_t *ticket = calloc(1, new_session->session_ticket_len);
	memcpy(ticket, p, new_session->session_ticket_len);
	new_session->session_ticket = ticket;
	memcpy(new_session->master_secret, f->master_secret, SSL3_MASTER_SECRET_SIZE);

	if(sessions->first_session == NULL){
		sessions->first_session = new_session;
	} else {
		session *last = sessions->first_session;

		for(int i=0; i< (sessions->length-1); i++){
			if(last == NULL){
				printf("UH OH: last is null?\n");
				fflush(stdout);
			}
			last = last->next;
		}
		last->next = new_session;
	}

	sessions->length ++;

	printf("Saved session ticket:");
	for(int i=0; i< new_session->session_ticket_len; i++){
		printf(" %02x", p[i]);
	}
	printf("\n");
	fflush(stdout);

	printf("Saved session master secret:");
	for(int i=0; i< SSL3_MASTER_SECRET_SIZE; i++){
		printf(" %02x", new_session->master_secret[i]);
	}
	printf("\n");
	fflush(stdout);

	printf("THERE ARE NOW %d saved sessions\n", sessions->length);
	fflush(stdout);

	return 0;
}

/* Adds a packet the flow's packet chain */
int add_packet(flow *f, struct packet_info *info){
	if (info->tcp_hdr == NULL){
		return 0;
	}

	packet *new_packet = calloc(1, sizeof(packet));
	new_packet->seq_num = htonl(info->tcp_hdr->sequence_num);
	new_packet->len = info->app_data_len;

	uint8_t *packet_data = calloc(1, new_packet->len);
	memcpy(packet_data, info->app_data, new_packet->len);

	new_packet->data = packet_data;
	new_packet->data_len = new_packet->len;
	new_packet->next = NULL;
	new_packet->incoming = 
		(info->ip_hdr->src.s_addr == f->src_ip.s_addr) ? 0 : 1;

	/* Find appropriate place in chain */
	if(new_packet->data_len > 0){
		packet *previous = NULL;
		packet *next = f->packet_chain;
		while(next != NULL && (next->seq_num <= new_packet->seq_num)){
			previous = next;
			next = next->next;
		}

		//place packet after current
		if(previous == NULL){
			//goes at the beginning of chain
			new_packet->next = f->packet_chain;
			f->packet_chain = new_packet;
		} else {
			new_packet->next = next;
			previous->next = new_packet;
		}

	}
	
	return 0;

}