|
@@ -123,26 +123,28 @@ class Socks4Protocol(Protocol):
|
|
|
#
|
|
|
#
|
|
|
class PushDataProtocol(Protocol):
|
|
|
- def __init__(self, socket, total_bytes, data_generator=None, send_max_bytes=1024*512, use_accelerated=True):
|
|
|
- if data_generator is None:
|
|
|
- data_generator = self._default_data_generator
|
|
|
+ def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None):
|
|
|
+ if send_buffer_len is None:
|
|
|
+ send_buffer_len = 1024*512
|
|
|
+ #
|
|
|
+ if use_acceleration is None:
|
|
|
+ use_acceleration = True
|
|
|
#
|
|
|
self.socket = socket
|
|
|
- self.data_generator = data_generator
|
|
|
self.total_bytes = total_bytes
|
|
|
- self.send_max_bytes = send_max_bytes
|
|
|
- self.use_accelerated = use_accelerated
|
|
|
+ self.use_acceleration = use_acceleration
|
|
|
#
|
|
|
self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE')
|
|
|
self.state = self.states.READY_TO_BEGIN
|
|
|
#
|
|
|
+ self.byte_buffer = os.urandom(send_buffer_len)
|
|
|
self.bytes_written = 0
|
|
|
self.protocol_helper = None
|
|
|
#
|
|
|
def _run_iteration(self, block=True):
|
|
|
if self.state is self.states.READY_TO_BEGIN:
|
|
|
info = self.total_bytes.to_bytes(8, byteorder='big', signed=False)
|
|
|
- info += self.send_max_bytes.to_bytes(8, byteorder='big', signed=False)
|
|
|
+ info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False)
|
|
|
self.protocol_helper = ProtocolHelper()
|
|
|
self.protocol_helper.set_buffer(info)
|
|
|
self.state = self.states.SEND_INFO
|
|
@@ -153,24 +155,28 @@ class PushDataProtocol(Protocol):
|
|
|
#
|
|
|
#
|
|
|
if self.state is self.states.PUSH_DATA:
|
|
|
- max_block_size = self.send_max_bytes
|
|
|
- block_size = min(max_block_size, self.total_bytes-self.bytes_written)
|
|
|
- data = self.data_generator(self.bytes_written, block_size)
|
|
|
- #
|
|
|
- if self.use_accelerated:
|
|
|
+ if self.use_acceleration:
|
|
|
if not block:
|
|
|
logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
|
|
|
#
|
|
|
- ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, data)
|
|
|
+ ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer)
|
|
|
if ret_val < 0:
|
|
|
raise ProtocolException('Error while pushing data.')
|
|
|
#
|
|
|
self.bytes_written = self.total_bytes
|
|
|
else:
|
|
|
+ bytes_remaining = self.total_bytes-self.bytes_written
|
|
|
+ data_size = min(len(self.byte_buffer), bytes_remaining)
|
|
|
+ if data_size != len(self.byte_buffer):
|
|
|
+ data = self.byte_buffer[:data_size]
|
|
|
+ else:
|
|
|
+ data = self.byte_buffer
|
|
|
+ # don't make a copy of the byte string each time if we don't need to
|
|
|
+ #
|
|
|
n = self.socket.send(data)
|
|
|
self.bytes_written += n
|
|
|
#
|
|
|
- if self.bytes_written >= self.total_bytes:
|
|
|
+ if self.bytes_written == self.total_bytes:
|
|
|
# finished sending the data
|
|
|
logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
|
|
|
self.protocol_helper = ProtocolHelper()
|
|
@@ -190,20 +196,20 @@ class PushDataProtocol(Protocol):
|
|
|
#
|
|
|
return False
|
|
|
#
|
|
|
- def _default_data_generator(self, index, bytes_needed):
|
|
|
- return os.urandom(bytes_needed)
|
|
|
- #
|
|
|
#
|
|
|
class PullDataProtocol(Protocol):
|
|
|
- def __init__(self, socket, use_accelerated=True):
|
|
|
+ def __init__(self, socket, use_acceleration=None):
|
|
|
+ if use_acceleration is None:
|
|
|
+ use_acceleration = True
|
|
|
+ #
|
|
|
self.socket = socket
|
|
|
- self.use_accelerated = use_accelerated
|
|
|
+ self.use_acceleration = use_acceleration
|
|
|
#
|
|
|
self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
|
|
|
self.state = self.states.READY_TO_BEGIN
|
|
|
#
|
|
|
self.data_size = None
|
|
|
- self.recv_max_bytes = None
|
|
|
+ self.recv_buffer_len = None
|
|
|
self.bytes_read = 0
|
|
|
self.protocol_helper = None
|
|
|
self._time_of_first_byte = None
|
|
@@ -219,27 +225,28 @@ class PullDataProtocol(Protocol):
|
|
|
if self.protocol_helper.recv(self.socket, info_size):
|
|
|
response = self.protocol_helper.get_buffer()
|
|
|
self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False)
|
|
|
- self.recv_max_bytes = int.from_bytes(response[8:16], byteorder='big', signed=False)
|
|
|
+ self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False)
|
|
|
self.state = self.states.PULL_DATA
|
|
|
#
|
|
|
#
|
|
|
if self.state is self.states.PULL_DATA:
|
|
|
- max_block_size = self.recv_max_bytes
|
|
|
- block_size = min(max_block_size, self.data_size-self.bytes_read)
|
|
|
- #
|
|
|
- if self.use_accelerated:
|
|
|
+ if self.use_acceleration:
|
|
|
if not block:
|
|
|
logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
|
|
|
#
|
|
|
- (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, block_size)
|
|
|
+ (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len)
|
|
|
if ret_val < 0:
|
|
|
raise ProtocolException('Error while pulling data.')
|
|
|
#
|
|
|
self.bytes_read = self.data_size
|
|
|
self.elapsed_time = elapsed_time
|
|
|
else:
|
|
|
+ bytes_remaining = self.data_size-self.bytes_read
|
|
|
+ block_size = min(self.recv_buffer_len, bytes_remaining)
|
|
|
+ #
|
|
|
data = self.socket.recv(block_size)
|
|
|
self.bytes_read += len(data)
|
|
|
+ #
|
|
|
if self.bytes_read != 0 and self._time_of_first_byte is None:
|
|
|
self._time_of_first_byte = time.time()
|
|
|
#
|