Browse Source

fixed bugs in socks proxy and webm state machine

cecylia 6 years ago
parent
commit
d6eb8ccbfa
2 changed files with 205 additions and 169 deletions
  1. 199 165
      client/socks5proxy.c
  2. 6 4
      relay_station/webm.c

+ 199 - 165
client/socks5proxy.c

@@ -234,6 +234,9 @@ void *ous_IO(void *args){
     uint8_t *buffer = emalloc(BUFSIZ);
     int32_t buffer_len = BUFSIZ;
 
+    uint8_t *input_buffer = NULL;
+    uint32_t input_buffer_len;
+
     int32_t bytes_read;
 
     /* Select on proxy pipes, demux thread, and ous to send and receive data*/
@@ -303,7 +306,7 @@ void *ous_IO(void *args){
 
         if(FD_ISSET(ous, &read_fds) && FD_ISSET(ous_out, &write_fds)){
 
-            bytes_read = recv(ous, buffer, 4, 0);
+            bytes_read = recv(ous, (uint8_t *) &input_buffer_len, 4, 0);
 #ifdef DEBUG_IO
             printf("Received %d bytes from OUS\n", bytes_read);
             for(int i=0; i< bytes_read; i++){
@@ -318,27 +321,29 @@ void *ous_IO(void *args){
                 break;
             }
 
-            uint32_t *chunk_len = (uint32_t*) buffer;
+            uint32_t chunk_len = input_buffer_len;
 
-            fprintf(stderr, "Length of this chunk: %u\n", *chunk_len);
+            bytes_sent = write(ous_out, (uint8_t *) &input_buffer_len, bytes_read);
+            //TODO: check return
 
+            input_buffer = malloc(input_buffer_len);
 
-            bytes_read = recv(ous, buffer, *chunk_len, 0);
+            bytes_read = recv(ous, input_buffer, chunk_len, 0);
 #ifdef DEBUG_IO
             printf("Received %d bytes from OUS\n", bytes_read);
             for(int i=0; i< bytes_read; i++){
-                printf("%02x ", buffer[i]);
+                printf("%02x ", input_buffer[i]);
             }
             printf("\n");
             fflush(stdout);
 #endif
 
             if(bytes_read > 0){
-                bytes_sent = write(ous_out, buffer, bytes_read);
+                bytes_sent = write(ous_out, input_buffer, bytes_read);
 #ifdef DEBUG_IO
                 printf("Sent %d bytes to demultiplexer\n", bytes_sent);
                 for(int i=0; i< bytes_sent; i++){
-                    printf("%02x ", buffer[i]);
+                    printf("%02x ", input_buffer[i]);
                 }
                 printf("\n");
                 fflush(stdout);
@@ -358,6 +363,9 @@ void *ous_IO(void *args){
                 fprintf(stderr, "Error reading from OUS\n");
                 break;
             }
+
+            free(input_buffer);
+            input_buffer = NULL;
         }
 
     }
@@ -612,8 +620,7 @@ void *multiplex_data(void *args){
 void *demultiplex_data(void *args){
     ous_pipes *pipes = (ous_pipes *) args;
 
-    int32_t buffer_len = BUFSIZ;
-    uint8_t *buffer = calloc(1, buffer_len);
+    uint8_t *buffer = NULL;
     uint8_t *p;
 
     uint8_t *partial_block = NULL;
@@ -622,208 +629,235 @@ void *demultiplex_data(void *args){
     data_block *saved_data = NULL;
 
     for(;;){
-        printf("Demux thread waiting to read\n");
-        int32_t bytes_read = read(pipes->out, buffer, buffer_len-partial_block_len);
-
-        if(bytes_read > 0){
-            int32_t chunk_remaining = bytes_read;
-            p = buffer;
-
-            //didn't read a full slitheen block last time
-            if(partial_block_len > 0){
-                //process first part of slitheen info
-                memmove(buffer+partial_block_len, buffer, bytes_read);
-                memcpy(buffer, partial_block, partial_block_len);
-                chunk_remaining += partial_block_len;
-                free(partial_block);
-                partial_block = NULL;
-                partial_block_len = 0;
-            }
+        uint32_t chunk_len;
 
-            while(chunk_remaining > 0){
+        int32_t bytes_read = read(pipes->out, (uint8_t *) &chunk_len, 4);
 
-#ifdef DEBUG_PARSE
-                printf("Received a new chunk of len %d bytes\n", chunk_remaining);
-#endif
+        fprintf(stdout, "Length of this chunk: %u\n", chunk_len);
+
+        buffer = calloc(1, chunk_len);
 
-                if(chunk_remaining < SLITHEEN_HEADER_LEN){
+        bytes_read = read(pipes->out, buffer, chunk_len);
+
+        if(bytes_read <= 0){
+            printf("Error: read %d bytes from OUS_out\n", bytes_read);
+            goto err;
+        }
+
+        if(bytes_read < chunk_len) {
+            printf("Error: read %d out of %d bytes\n", bytes_read, chunk_len);
+        }
+
+        int32_t chunk_remaining = bytes_read;
+        p = buffer;
 
 #ifdef DEBUG_PARSE
-                    printf("Partial header: ");
-                    int i;
-                    for(i = 0; i< chunk_remaining; i++){
-                        printf("%02x ", p[i]);
-                    }
-                    printf("\n");
+        printf("Received a new chunk of len %d bytes\n", chunk_remaining);
 #endif
+        //didn't read a full slitheen block last time
+        if(partial_block_len > 0){
+            //process first part of slitheen info
+            memmove(buffer+partial_block_len, buffer, bytes_read);
+            memcpy(buffer, partial_block, partial_block_len);
+            chunk_remaining += partial_block_len;
+            free(partial_block);
+            partial_block = NULL;
+            partial_block_len = 0;
+        }
 
-                    if(partial_block != NULL) printf("UH OH (PB)\n");
-                    partial_block = calloc(1, chunk_remaining);
-                    memcpy(partial_block, p, chunk_remaining);
-                    partial_block_len = chunk_remaining;
-                    chunk_remaining = 0;
-                    break;
-                }
+        while(chunk_remaining > 0){
+            /*TODO: investigate assumption that we only ever receive chunks
+             * that contain full slitheen blocks */
 
-                //decrypt header to see if we have entire block
-                uint8_t *tmp_header = malloc(SLITHEEN_HEADER_LEN);
-                memcpy(tmp_header, p, SLITHEEN_HEADER_LEN);
+#ifdef DEBUG_PARSE
+            printf("Chunk remaining: %d bytes\n", chunk_remaining);
+#endif
 
-                if(!peek_header(tmp_header)){
-                    printf("This chunk doesn't contain a Slitheen block\n");
-                    break;
-                }
+            if(chunk_remaining < SLITHEEN_HEADER_LEN){
 
-                struct slitheen_hdr *sl_hdr = (struct slitheen_hdr *) tmp_header;
-                //first see if sl_hdr corresponds to a valid stream. If not, ignore rest of read bytes
 #ifdef DEBUG_PARSE
-                printf("Slitheen header:\n");
+                printf("Incomplete Slitheen block: ");
                 int i;
-                for(i = 0; i< SLITHEEN_HEADER_LEN; i++){
-                    printf("%02x ", tmp_header[i]);
+                for(i = 0; i< chunk_remaining; i++){
+                    printf("%02x ", p[i]);
                 }
                 printf("\n");
 #endif
-                if(ntohs(sl_hdr->len) > chunk_remaining){
-                    printf("ERROR: slitheen block doesn't fit in resource remaining!\n");
-                    printf("Saving in partial block\n");
-
-                    if(partial_block != NULL) printf("UH OH (PB)\n");
-                    partial_block = calloc(1, ntohs(sl_hdr->len));
-                    memcpy(partial_block, p, chunk_remaining);
-                    partial_block_len = chunk_remaining;
-                    chunk_remaining = 0;
-                    free(tmp_header);
-                    break;
-                }
+                /*
+                if(partial_block != NULL) printf("UH OH (PB)\n");
+
+                partial_block = calloc(1, chunk_remaining);
+                memcpy(partial_block, p, chunk_remaining);
+                partial_block_len = chunk_remaining;
+                chunk_remaining = 0;
+                */
+
+                break;
+            }
+
+            //decrypt header to see if we have entire block
+            uint8_t *tmp_header = malloc(SLITHEEN_HEADER_LEN);
+            memcpy(tmp_header, p, SLITHEEN_HEADER_LEN);
 
-                super_decrypt(p);
+            if(!peek_header(tmp_header)){
+                printf("This chunk doesn't contain a Slitheen block\n");
+                break;
+            }
 
-                sl_hdr = (struct slitheen_hdr *) p;
+            struct slitheen_hdr *sl_hdr = (struct slitheen_hdr *) tmp_header;
+            //first see if sl_hdr corresponds to a valid stream. If not, ignore rest of read bytes
+#ifdef DEBUG
+            printf("Slitheen header:\n");
+            int i;
+            for(i = 0; i< SLITHEEN_HEADER_LEN; i++){
+                printf("%02x ", tmp_header[i]);
+            }
+            printf("\n");
+#endif
+            if(ntohs(sl_hdr->len) > chunk_remaining){
+                printf("ERROR: slitheen block doesn't fit in resource remaining!\n");
+                printf("Saving in partial block\n");
+
+                if(partial_block != NULL) printf("UH OH (PB)\n");
+                partial_block = calloc(1, ntohs(sl_hdr->len));
+                memcpy(partial_block, p, chunk_remaining);
+                partial_block_len = chunk_remaining;
+                chunk_remaining = 0;
                 free(tmp_header);
+                break;
+            }
+
+            super_decrypt(p);
+
+            sl_hdr = (struct slitheen_hdr *) p;
+            free(tmp_header);
 
-                p += SLITHEEN_HEADER_LEN;
-                chunk_remaining -= SLITHEEN_HEADER_LEN;
+            p += SLITHEEN_HEADER_LEN;
+            chunk_remaining -= SLITHEEN_HEADER_LEN;
 
-                if((!sl_hdr->len) && (sl_hdr->garbage)){
+            if((!sl_hdr->len) && (sl_hdr->garbage)){
 
 #ifdef DEBUG_PARSE
-                    printf("%d Garbage bytes\n", ntohs(sl_hdr->garbage));
+                printf("%d Garbage bytes\n", ntohs(sl_hdr->garbage));
 #endif
-                    p += ntohs(sl_hdr->garbage);
-                    chunk_remaining -= ntohs(sl_hdr->garbage);
-                    continue;
-                }
 
-                int32_t sock =-1;
-                if(connections->first == NULL){
-                    printf("Error: there are no connections\n");
-                } else {
-                    connection *last = connections->first;
+                //there might be more garbage bytes than we have chunk left
+                p += ntohs(sl_hdr->garbage);
+                chunk_remaining -= ntohs(sl_hdr->garbage);
+                continue;
+            }
+
+            int32_t sock =-1;
+            if(connections->first == NULL){
+                printf("Error: there are no connections\n");
+            } else {
+                connection *last = connections->first;
+                if (last->stream_id == sl_hdr->stream_id){
+                    sock = last->socket;
+                }
+                while(last->next != NULL){
+                    last = last->next;
                     if (last->stream_id == sl_hdr->stream_id){
                         sock = last->socket;
                     }
-                    while(last->next != NULL){
-                        last = last->next;
-                        if (last->stream_id == sl_hdr->stream_id){
-                            sock = last->socket;
-                        }
-                    }
                 }
+            }
 
-                if(sock == -1){
-                    printf("No stream id exists. Possibly invalid header\n");
-                    break;
-                }
+            if(sock == -1){
+                printf("No stream id exists. Possibly invalid header\n");
+                break;
+            }
 
 #ifdef DEBUG_PARSE
-                printf("Received information for stream id: %d of length: %u\n", sl_hdr->stream_id, ntohs(sl_hdr->len));
+            printf("Received information for stream id: %d of length: %u\n", sl_hdr->stream_id, ntohs(sl_hdr->len));
 #endif
 
-                //figure out how much to skip
-                int32_t padding = 0;
-                if(ntohs(sl_hdr->len) %16){
-                    padding = 16 - ntohs(sl_hdr->len)%16;
-                }
-                p += 16; //IV
-
-                //check counter to see if we are missing data
-                if(sl_hdr->counter > expected_next_count){
-                    //save any future data
-                    printf("Received header with count %lu. Expected count %lu.\n",
-                            sl_hdr->counter, expected_next_count);
-                    if((saved_data == NULL) || (saved_data->count > sl_hdr->counter)){
-                        data_block *new_block = malloc(sizeof(data_block));
-                        new_block->count = sl_hdr->counter;
-                        new_block->len = ntohs(sl_hdr->len);
-                        new_block->data = malloc(ntohs(sl_hdr->len));
-
-                        memcpy(new_block->data, p, ntohs(sl_hdr->len));
-
-                        new_block->socket = sock;
-                        new_block->next = saved_data;
-
-                        saved_data = new_block;
-
-                    } else {
-                        data_block *last = saved_data;
-                        while((last->next != NULL) && (last->next->count < sl_hdr->counter)){
-                            last = last->next;
-                        }
-                        data_block *new_block = malloc(sizeof(data_block));
-                        new_block->count = sl_hdr->counter;
-                        new_block->len = ntohs(sl_hdr->len);
-                        new_block->data = malloc(ntohs(sl_hdr->len));
-                        memcpy(new_block->data, p, ntohs(sl_hdr->len));
-                        new_block->socket = sock;
-                        new_block->next = last->next;
-
-                        last->next = new_block;
-                    }
+            //figure out how much to skip
+            int32_t padding = 0;
+            if(ntohs(sl_hdr->len) %16){
+                padding = 16 - ntohs(sl_hdr->len)%16;
+            }
+            p += 16; //IV
+
+            //check counter to see if we are missing data
+            if(sl_hdr->counter > expected_next_count){
+                //save any future data
+                printf("Received header with count %lu. Expected count %lu.\n",
+                        sl_hdr->counter, expected_next_count);
+                if((saved_data == NULL) || (saved_data->count > sl_hdr->counter)){
+                    data_block *new_block = malloc(sizeof(data_block));
+                    new_block->count = sl_hdr->counter;
+                    new_block->len = ntohs(sl_hdr->len);
+                    new_block->data = malloc(ntohs(sl_hdr->len));
+
+                    memcpy(new_block->data, p, ntohs(sl_hdr->len));
+
+                    new_block->socket = sock;
+                    new_block->next = saved_data;
+
+                    saved_data = new_block;
+
                 } else {
-                    int32_t bytes_sent = send(sock, p, ntohs(sl_hdr->len), 0);
-                    if(bytes_sent <= 0){
-                        printf("Error writing to socket for stream id %d\n", sl_hdr->stream_id);
+                    data_block *last = saved_data;
+                    while((last->next != NULL) && (last->next->count < sl_hdr->counter)){
+                        last = last->next;
                     }
-
-                    //increment expected counter
-                    expected_next_count++;
+                    data_block *new_block = malloc(sizeof(data_block));
+                    new_block->count = sl_hdr->counter;
+                    new_block->len = ntohs(sl_hdr->len);
+                    new_block->data = malloc(ntohs(sl_hdr->len));
+                    memcpy(new_block->data, p, ntohs(sl_hdr->len));
+                    new_block->socket = sock;
+                    new_block->next = last->next;
+
+                    last->next = new_block;
+                }
+            } else {
+                int32_t bytes_sent = send(sock, p, ntohs(sl_hdr->len), 0);
+                if(bytes_sent <= 0){
+                    printf("Error writing to socket for stream id %d\n", sl_hdr->stream_id);
                 }
 
-                //now check to see if there is saved data to write out
-                if(saved_data != NULL){
-                    data_block *current_block = saved_data;
-                    while((current_block != NULL) && (expected_next_count == current_block->count)){
-                        int32_t bytes_sent = send(current_block->socket, current_block->data,
-                                current_block->len, 0);
-                        if(bytes_sent <= 0){
-                            printf("Error writing to socket for stream id %d\n", sl_hdr->stream_id);
-                        }
-                        expected_next_count++;
-                        saved_data = current_block->next;
-                        free(current_block->data);
-                        free(current_block);
-                        current_block = saved_data;
+                //increment expected counter
+                expected_next_count++;
+            }
+
+            //now check to see if there is saved data to write out
+            if(saved_data != NULL){
+                data_block *current_block = saved_data;
+                while((current_block != NULL) && (expected_next_count == current_block->count)){
+                    int32_t bytes_sent = send(current_block->socket, current_block->data,
+                            current_block->len, 0);
+                    if(bytes_sent <= 0){
+                        printf("Error writing to socket for stream id %d\n", sl_hdr->stream_id);
                     }
+                    expected_next_count++;
+                    saved_data = current_block->next;
+                    free(current_block->data);
+                    free(current_block);
+                    current_block = saved_data;
                 }
+            }
 
-                p += ntohs(sl_hdr->len); //encrypted data
-                p += 16; //mac
-                p += padding;
-                p += ntohs(sl_hdr->garbage);
-
-                chunk_remaining -= ntohs(sl_hdr->len) + 16 + padding + 16 + ntohs(sl_hdr->garbage);
+            p += ntohs(sl_hdr->len); //encrypted data
+            p += 16; //mac
+            p += padding;
+            p += ntohs(sl_hdr->garbage);
 
-            }
+            chunk_remaining -= ntohs(sl_hdr->len) + 16 + padding + 16 + ntohs(sl_hdr->garbage);
 
-        } else {
-            printf("Error: read %d bytes from OUS_out\n", bytes_read);
-            goto err;
         }
 
+        free(buffer);
+        buffer = NULL;
+
     }
 err:
-    free(buffer);
+
+    if (buffer != NULL) {
+        free(buffer);
+    }
+
     close(pipes->out);
     pthread_exit(NULL);
 

+ 6 - 4
relay_station/webm.c

@@ -72,7 +72,8 @@ int32_t parse_webm(flow *f, uint8_t *ptr, uint32_t len) {
 
                 printf("Received header: %x\n", f->element_header);
 
-                if(f->element_header == 0xa3){
+                if((f->element_header == 0xa3) &&
+                        (remaining_len >= (SLITHEEN_HEADER_LEN + 9))){
                     //we want to replace this block
                     printf("Replaced simple block!\n");
                     p[0] = 0xef;
@@ -121,13 +122,14 @@ int32_t parse_webm(flow *f, uint8_t *ptr, uint32_t len) {
 
                 if (f->element_header == 0xa3) {
                     //replace content
-                    printf("Replaceable data (%d bytes):\n", parse_len);
+
+                    fill_with_downstream(f, p, parse_len);
+
+                    printf("Replaced data (%d bytes):\n", parse_len);
                     for(int i=0; i< parse_len; i++){
                         printf("%02x ", p[i]);
                     }
                     printf("\n");
-
-                    fill_with_downstream(f, p, parse_len);
                 }
 
                 p += parse_len;