Browse Source

Added C acceleration functions.

The push/pull data is now done in C rather than Python in order to maximize performance.
Steven Engler 5 years ago
parent
commit
8f5679aa0f
4 changed files with 259 additions and 33 deletions
  1. 37 0
      Makefile
  2. 172 0
      src/accelerated_functions.c
  3. 1 1
      src/bandwidth_tester.py
  4. 49 32
      src/basic_protocols.py

+ 37 - 0
Makefile

@@ -0,0 +1,37 @@
+CC=gcc
+CFLAGS=-O3
+PYTHON_INC=/usr/include/python3.6
+
+PY_BIN_FILES:=$(patsubst src/%.py,bin/%.py,$(wildcard src/*.py))
+PY_DEV_FILES:=$(patsubst src/%.py,dev/%.py,$(wildcard src/*.py))
+
+all: bin_dir $(PY_BIN_FILES) bin/accelerated_functions.so
+dev: dev_dir $(PY_DEV_FILES) dev/accelerated_functions.so
+
+clean:
+	@rm -r bin
+
+bin/accelerated_functions.so: src/accelerated_functions.c
+dev/accelerated_functions.so: src/accelerated_functions.c
+
+#######
+
+bin/%.so: src/%.c
+	$(CC) $(CFLAGS) -I $(PYTHON_INC) -shared -fPIC $^ -o $@
+
+bin/%.py: src/%.py
+	@cp $< $@
+
+bin_dir:
+	@mkdir -p bin
+
+#######
+
+dev/%.py: src/%.py
+	ln $< $@
+
+dev/%.so: src/%.c
+	$(CC) $(CFLAGS) -I $(PYTHON_INC) -shared -fPIC $^ -o $@
+
+dev_dir:
+	@mkdir -p dev

+ 172 - 0
src/accelerated_functions.c

@@ -0,0 +1,172 @@
+#include <string.h>
+#include <sys/poll.h>
+#include <sys/socket.h>
+#include <Python.h>
+//
+static PyObject *py_push_data(PyObject *self, PyObject *args);
+static char push_data_docstring[] =
+    "Send data as quickly as possible into a socket.";
+//
+static PyObject *py_pull_data(PyObject *self, PyObject *args);
+static char pull_data_docstring[] =
+    "Receive data as quickly as possible from a socket.";
+//
+static char module_docstring[] =
+    "This module provides accelerated functions which would perform slower in pure Python.";
+//
+static PyMethodDef module_methods[] = {
+	{"push_data", py_push_data, METH_VARARGS, push_data_docstring},
+	{"pull_data", py_pull_data, METH_VARARGS, pull_data_docstring},
+	{NULL, NULL, 0, NULL}
+};
+//
+static struct PyModuleDef _coremodule = {
+	PyModuleDef_HEAD_INIT,
+	"accelerated_functions", // name of module
+	module_docstring, // module documentation, may be NULL
+	-1, /* size of per-interpreter state of the module,
+	       or -1 if the module keeps state in global variables. */
+	module_methods,
+};
+//
+PyMODINIT_FUNC PyInit_accelerated_functions(void){
+	return PyModule_Create(&_coremodule);
+}
+//
+long min(long num1, long num2){
+	return (num1 > num2) ? num2 : num1;
+}
+//
+int push_data(int socket, long bytes_total, char* buffer, int buffer_len){
+	long bytes_written = 0;
+	//
+	struct pollfd poll_fds[1];
+	int num_poll_fds = 0;
+	//
+	memset(poll_fds, 0, sizeof(poll_fds));
+	poll_fds[0].fd = socket;
+	poll_fds[0].events = POLLOUT;
+	num_poll_fds++;
+	//
+	while(bytes_written < bytes_total){
+		int rc = poll(poll_fds, num_poll_fds, 1*60*1000);
+		//
+		if(rc < 0){
+			return -1;
+		}else if(rc == 0){
+			return -1;
+		}
+		//
+		if(poll_fds[0].revents == 0){
+			continue;
+		}else if(poll_fds[0].revents != POLLOUT){
+			return -1;
+		}
+		//
+		long bytes_to_send = min(buffer_len, bytes_total-bytes_written);
+		int n = send(poll_fds[0].fd, buffer, bytes_to_send, 0);
+		//
+		if(n < 0){
+			return -1;
+		}
+		//
+		bytes_written += n;
+	}
+	//
+	return 0;
+}
+//
+int pull_data(int socket, long bytes_total, int buffer_len, double* time_ptr){
+	long bytes_read = 0;
+	char* buffer = malloc(buffer_len);
+	struct timeval time_of_first_byte, time_of_last_byte;
+	//
+	struct pollfd poll_fds[1];
+	int num_poll_fds = 0;
+	//
+	if(buffer == NULL){
+		return -1;
+	}
+	//
+	memset(poll_fds, 0, sizeof(poll_fds));
+	poll_fds[0].fd = socket;
+	poll_fds[0].events = POLLIN;
+	num_poll_fds++;
+	//
+	while(bytes_read < bytes_total){
+		int rc = poll(poll_fds, num_poll_fds, 1*60*1000);
+		//
+		if(rc < 0){
+			printf("Here1\n");
+			free(buffer);
+			return -1;
+		}else if(rc == 0){
+			printf("Here2\n");
+			free(buffer);
+			return -1;
+		}
+		//
+		if(poll_fds[0].revents == 0){
+			continue;
+		}else if(poll_fds[0].revents != POLLIN){
+			printf("Here3\n");
+			free(buffer);
+			return -1;
+		}
+		//
+		long bytes_to_recv = min(buffer_len, bytes_total-bytes_read);
+		int n = recv(poll_fds[0].fd, buffer, bytes_to_recv, 0);
+		//
+		if(n < 0){
+			printf("Here4\n");
+			free(buffer);
+			return -1;
+		}
+		//
+		if(n > 0 && bytes_read == 0){
+			gettimeofday(&time_of_first_byte, NULL);
+		}
+		//
+		bytes_read += n;
+	}
+	//
+	gettimeofday(&time_of_last_byte, NULL);
+	*time_ptr = (time_of_last_byte.tv_sec-time_of_first_byte.tv_sec) + (time_of_last_byte.tv_usec-time_of_first_byte.tv_usec)/(1000.0*1000.0);
+	//
+	free(buffer);
+	return 0;
+}
+//
+static PyObject *py_push_data(PyObject *self, PyObject *args){
+	PyObject *yerr_obj;
+	int socket;
+	long bytes_total;
+	char* buffer = NULL;
+	int buffer_len;
+	//
+	if(!PyArg_ParseTuple(args, "ily#", &socket, &bytes_total, &buffer, &buffer_len, &yerr_obj)){
+		return NULL;
+	}
+	//
+	int ret_val = push_data(socket, bytes_total, buffer, buffer_len);
+	PyObject* py_ret_val = PyLong_FromLong(ret_val);
+	//
+	return py_ret_val;
+}
+//
+static PyObject *py_pull_data(PyObject *self, PyObject *args){
+	PyObject *yerr_obj;
+	int socket;
+	long bytes_total;
+	int buffer_len;
+	//
+	if(!PyArg_ParseTuple(args, "ili", &socket, &bytes_total, &buffer_len, &yerr_obj)){
+		return NULL;
+	}
+	//
+	double elapsed_time = 0;
+	int ret_val = pull_data(socket, bytes_total, buffer_len, &elapsed_time);
+	PyObject* py_ret_val = Py_BuildValue("(id)", ret_val, elapsed_time);
+	//
+	return py_ret_val;
+}

+ 1 - 1
src/bandwidth_tester.py

@@ -95,7 +95,7 @@ class ServerConnectionProtocol(basic_protocols.Protocol):
 					group_id = None
 				#
 				self.group_id_callback(self.conn_id, group_id)
-				self.sub_protocol = basic_protocols.PullDataProtocolWithMetrics(self.socket)
+				self.sub_protocol = basic_protocols.PullDataProtocol(self.socket)
 				self.state = self.states.PULL_DATA
 			#
 		#

+ 49 - 32
src/basic_protocols.py

@@ -8,6 +8,8 @@ import enum
 import select
 import os
 #
+import accelerated_functions
+#
 class ProtocolException(Exception):
     pass
 #
@@ -121,7 +123,7 @@ class Socks4Protocol(Protocol):
 	#
 #
 class PushDataProtocol(Protocol):
-	def __init__(self, socket, total_bytes, data_generator=None, send_max_bytes=1024*512):
+	def __init__(self, socket, total_bytes, data_generator=None, send_max_bytes=1024*512, use_accelerated=True):
 		if data_generator is None:
 			data_generator = self._default_data_generator
 		#
@@ -129,6 +131,7 @@ class PushDataProtocol(Protocol):
 		self.data_generator = data_generator
 		self.total_bytes = total_bytes
 		self.send_max_bytes = send_max_bytes
+		self.use_accelerated = use_accelerated
 		#
 		self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE')
 		self.state = self.states.READY_TO_BEGIN
@@ -151,10 +154,22 @@ class PushDataProtocol(Protocol):
 		#
 		if self.state is self.states.PUSH_DATA:
 			max_block_size = self.send_max_bytes
-			bytes_needed = min(max_block_size, self.total_bytes-self.bytes_written)
-			data = self.data_generator(self.bytes_written, bytes_needed)
-			n = self.socket.send(data)
-			self.bytes_written += n
+			block_size = min(max_block_size, self.total_bytes-self.bytes_written)
+			data = self.data_generator(self.bytes_written, block_size)
+			#
+			if self.use_accelerated:
+				if not block:
+					logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
+				#
+				ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, data)
+				if ret_val < 0:
+					raise ProtocolException('Error while pushing data.')
+				#
+				self.bytes_written = self.total_bytes
+			else:
+				n = self.socket.send(data)
+				self.bytes_written += n
+			#
 			if self.bytes_written >= self.total_bytes:
 				# finished sending the data
 				logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
@@ -180,8 +195,9 @@ class PushDataProtocol(Protocol):
 	#
 #
 class PullDataProtocol(Protocol):
-	def __init__(self, socket):
+	def __init__(self, socket, use_accelerated=True):
 		self.socket = socket
+		self.use_accelerated = use_accelerated
 		#
 		self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
 		self.state = self.states.READY_TO_BEGIN
@@ -190,6 +206,8 @@ class PullDataProtocol(Protocol):
 		self.recv_max_bytes = None
 		self.bytes_read = 0
 		self.protocol_helper = None
+		self._time_of_first_byte = None
+		self.elapsed_time = None
 	#
 	def _run_iteration(self, block=True):
 		if self.state is self.states.READY_TO_BEGIN:
@@ -207,10 +225,28 @@ class PullDataProtocol(Protocol):
 		#
 		if self.state is self.states.PULL_DATA:
 			max_block_size = self.recv_max_bytes
-			bytes_needed = min(max_block_size, self.data_size-self.bytes_read)	
-			data = self.socket.recv(bytes_needed)
-			self.bytes_read += len(data)
-			#logging.debug('Read %d bytes', self.bytes_read)
+			block_size = min(max_block_size, self.data_size-self.bytes_read)
+			#
+			if self.use_accelerated:
+				if not block:
+					logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
+				#
+				(ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, block_size)
+				if ret_val < 0:
+					raise ProtocolException('Error while pulling data.')
+				#
+				self.bytes_read = self.data_size
+				self.elapsed_time = elapsed_time
+			else:
+				data = self.socket.recv(block_size)
+				self.bytes_read += len(data)
+				if self.bytes_read != 0 and self._time_of_first_byte is None:
+					self._time_of_first_byte = time.time()
+				#
+				if self.bytes_read == self.data_size and self.elapsed_time is None:
+					self.elapsed_time = time.time()-self._time_of_first_byte
+				#
+			#
 			if self.bytes_read == self.data_size:
 				# finished receiving the data
 				logging.debug('Finished receiving the data.')
@@ -227,29 +263,10 @@ class PullDataProtocol(Protocol):
 		#
 		return False
 	#
-#
-class PullDataProtocolWithMetrics(PullDataProtocol):
-	def __init__(self, *args, **kwargs):
-		super().__init__(*args, **kwargs)
-		#
-		self.time_of_first_byte = None
-		self.time_of_last_byte = None
-	#
-	def _run_iteration(self, *args, **kwargs):
-		data = super()._run_iteration(*args, **kwargs)
-		#
-		if self.bytes_read != 0 and self.time_of_first_byte is None:
-			self.time_of_first_byte = time.time()
-		#
-		if self.bytes_read == self.data_size and self.time_of_last_byte is None:
-			self.time_of_last_byte = time.time()
-		#
-		return data
-	#
 	def calc_transfer_rate(self):
 		""" Returns bytes/s. """
-		assert self.data_size is not None and self.time_of_first_byte is not None and self.time_of_last_byte is not None
-		return self.data_size/(self.time_of_last_byte-self.time_of_first_byte)
+		assert self.data_size is not None and self.elapsed_time is not None
+		return self.data_size/self.elapsed_time
 	#
 #
 class SendDataProtocol(Protocol):
@@ -417,7 +434,7 @@ class SimpleServerConnectionProtocol(Protocol):
 	#
 	def _run_iteration(self, block=True):
 		if self.state is self.states.READY_TO_BEGIN:
-			self.sub_protocol = PullDataProtocolWithMetrics(self.socket)
+			self.sub_protocol = PullDataProtocol(self.socket)
 			self.state = self.states.PULL_DATA
 		#
 		if self.state is self.states.PULL_DATA: