|
@@ -8,6 +8,8 @@ import enum
|
|
|
import select
|
|
|
import os
|
|
|
#
|
|
|
+import accelerated_functions
|
|
|
+#
|
|
|
class ProtocolException(Exception):
|
|
|
pass
|
|
|
#
|
|
@@ -121,7 +123,7 @@ class Socks4Protocol(Protocol):
|
|
|
#
|
|
|
#
|
|
|
class PushDataProtocol(Protocol):
|
|
|
- def __init__(self, socket, total_bytes, data_generator=None, send_max_bytes=1024*512):
|
|
|
+ def __init__(self, socket, total_bytes, data_generator=None, send_max_bytes=1024*512, use_accelerated=True):
|
|
|
if data_generator is None:
|
|
|
data_generator = self._default_data_generator
|
|
|
#
|
|
@@ -129,6 +131,7 @@ class PushDataProtocol(Protocol):
|
|
|
self.data_generator = data_generator
|
|
|
self.total_bytes = total_bytes
|
|
|
self.send_max_bytes = send_max_bytes
|
|
|
+ self.use_accelerated = use_accelerated
|
|
|
#
|
|
|
self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE')
|
|
|
self.state = self.states.READY_TO_BEGIN
|
|
@@ -151,10 +154,22 @@ class PushDataProtocol(Protocol):
|
|
|
#
|
|
|
if self.state is self.states.PUSH_DATA:
|
|
|
max_block_size = self.send_max_bytes
|
|
|
- bytes_needed = min(max_block_size, self.total_bytes-self.bytes_written)
|
|
|
- data = self.data_generator(self.bytes_written, bytes_needed)
|
|
|
- n = self.socket.send(data)
|
|
|
- self.bytes_written += n
|
|
|
+ block_size = min(max_block_size, self.total_bytes-self.bytes_written)
|
|
|
+ data = self.data_generator(self.bytes_written, block_size)
|
|
|
+ #
|
|
|
+ if self.use_accelerated:
|
|
|
+ if not block:
|
|
|
+ logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
|
|
|
+ #
|
|
|
+ ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, data)
|
|
|
+ if ret_val < 0:
|
|
|
+ raise ProtocolException('Error while pushing data.')
|
|
|
+ #
|
|
|
+ self.bytes_written = self.total_bytes
|
|
|
+ else:
|
|
|
+ n = self.socket.send(data)
|
|
|
+ self.bytes_written += n
|
|
|
+ #
|
|
|
if self.bytes_written >= self.total_bytes:
|
|
|
# finished sending the data
|
|
|
logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
|
|
@@ -180,8 +195,9 @@ class PushDataProtocol(Protocol):
|
|
|
#
|
|
|
#
|
|
|
class PullDataProtocol(Protocol):
|
|
|
- def __init__(self, socket):
|
|
|
+ def __init__(self, socket, use_accelerated=True):
|
|
|
self.socket = socket
|
|
|
+ self.use_accelerated = use_accelerated
|
|
|
#
|
|
|
self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
|
|
|
self.state = self.states.READY_TO_BEGIN
|
|
@@ -190,6 +206,8 @@ class PullDataProtocol(Protocol):
|
|
|
self.recv_max_bytes = None
|
|
|
self.bytes_read = 0
|
|
|
self.protocol_helper = None
|
|
|
+ self._time_of_first_byte = None
|
|
|
+ self.elapsed_time = None
|
|
|
#
|
|
|
def _run_iteration(self, block=True):
|
|
|
if self.state is self.states.READY_TO_BEGIN:
|
|
@@ -207,10 +225,28 @@ class PullDataProtocol(Protocol):
|
|
|
#
|
|
|
if self.state is self.states.PULL_DATA:
|
|
|
max_block_size = self.recv_max_bytes
|
|
|
- bytes_needed = min(max_block_size, self.data_size-self.bytes_read)
|
|
|
- data = self.socket.recv(bytes_needed)
|
|
|
- self.bytes_read += len(data)
|
|
|
- #logging.debug('Read %d bytes', self.bytes_read)
|
|
|
+ block_size = min(max_block_size, self.data_size-self.bytes_read)
|
|
|
+ #
|
|
|
+ if self.use_accelerated:
|
|
|
+ if not block:
|
|
|
+ logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
|
|
|
+ #
|
|
|
+ (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, block_size)
|
|
|
+ if ret_val < 0:
|
|
|
+ raise ProtocolException('Error while pulling data.')
|
|
|
+ #
|
|
|
+ self.bytes_read = self.data_size
|
|
|
+ self.elapsed_time = elapsed_time
|
|
|
+ else:
|
|
|
+ data = self.socket.recv(block_size)
|
|
|
+ self.bytes_read += len(data)
|
|
|
+ if self.bytes_read != 0 and self._time_of_first_byte is None:
|
|
|
+ self._time_of_first_byte = time.time()
|
|
|
+ #
|
|
|
+ if self.bytes_read == self.data_size and self.elapsed_time is None:
|
|
|
+ self.elapsed_time = time.time()-self._time_of_first_byte
|
|
|
+ #
|
|
|
+ #
|
|
|
if self.bytes_read == self.data_size:
|
|
|
# finished receiving the data
|
|
|
logging.debug('Finished receiving the data.')
|
|
@@ -227,29 +263,10 @@ class PullDataProtocol(Protocol):
|
|
|
#
|
|
|
return False
|
|
|
#
|
|
|
-#
|
|
|
-class PullDataProtocolWithMetrics(PullDataProtocol):
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
- super().__init__(*args, **kwargs)
|
|
|
- #
|
|
|
- self.time_of_first_byte = None
|
|
|
- self.time_of_last_byte = None
|
|
|
- #
|
|
|
- def _run_iteration(self, *args, **kwargs):
|
|
|
- data = super()._run_iteration(*args, **kwargs)
|
|
|
- #
|
|
|
- if self.bytes_read != 0 and self.time_of_first_byte is None:
|
|
|
- self.time_of_first_byte = time.time()
|
|
|
- #
|
|
|
- if self.bytes_read == self.data_size and self.time_of_last_byte is None:
|
|
|
- self.time_of_last_byte = time.time()
|
|
|
- #
|
|
|
- return data
|
|
|
- #
|
|
|
def calc_transfer_rate(self):
|
|
|
""" Returns bytes/s. """
|
|
|
- assert self.data_size is not None and self.time_of_first_byte is not None and self.time_of_last_byte is not None
|
|
|
- return self.data_size/(self.time_of_last_byte-self.time_of_first_byte)
|
|
|
+ assert self.data_size is not None and self.elapsed_time is not None
|
|
|
+ return self.data_size/self.elapsed_time
|
|
|
#
|
|
|
#
|
|
|
class SendDataProtocol(Protocol):
|
|
@@ -417,7 +434,7 @@ class SimpleServerConnectionProtocol(Protocol):
|
|
|
#
|
|
|
def _run_iteration(self, block=True):
|
|
|
if self.state is self.states.READY_TO_BEGIN:
|
|
|
- self.sub_protocol = PullDataProtocolWithMetrics(self.socket)
|
|
|
+ self.sub_protocol = PullDataProtocol(self.socket)
|
|
|
self.state = self.states.PULL_DATA
|
|
|
#
|
|
|
if self.state is self.states.PULL_DATA:
|