diff --git a/src/basic_protocols.py b/src/basic_protocols.py index ba8b847..0dcecc8 100755 --- a/src/basic_protocols.py +++ b/src/basic_protocols.py @@ -123,26 +123,28 @@ class Socks4Protocol(Protocol): # # class PushDataProtocol(Protocol): - def __init__(self, socket, total_bytes, data_generator=None, send_max_bytes=1024*512, use_accelerated=True): - if data_generator is None: - data_generator = self._default_data_generator + def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None): + if send_buffer_len is None: + send_buffer_len = 1024*512 + # + if use_acceleration is None: + use_acceleration = True # self.socket = socket - self.data_generator = data_generator self.total_bytes = total_bytes - self.send_max_bytes = send_max_bytes - self.use_accelerated = use_accelerated + self.use_acceleration = use_acceleration # self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE') self.state = self.states.READY_TO_BEGIN # + self.byte_buffer = os.urandom(send_buffer_len) self.bytes_written = 0 self.protocol_helper = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: info = self.total_bytes.to_bytes(8, byteorder='big', signed=False) - info += self.send_max_bytes.to_bytes(8, byteorder='big', signed=False) + info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False) self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(info) self.state = self.states.SEND_INFO @@ -153,24 +155,28 @@ class PushDataProtocol(Protocol): # # if self.state is self.states.PUSH_DATA: - max_block_size = self.send_max_bytes - block_size = min(max_block_size, self.total_bytes-self.bytes_written) - data = self.data_generator(self.bytes_written, block_size) - # - if self.use_accelerated: + if self.use_acceleration: if not block: logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.') # - ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, data) + ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer) if ret_val < 0: raise ProtocolException('Error while pushing data.') # self.bytes_written = self.total_bytes else: + bytes_remaining = self.total_bytes-self.bytes_written + data_size = min(len(self.byte_buffer), bytes_remaining) + if data_size != len(self.byte_buffer): + data = self.byte_buffer[:data_size] + else: + data = self.byte_buffer + # don't make a copy of the byte string each time if we don't need to + # n = self.socket.send(data) self.bytes_written += n # - if self.bytes_written >= self.total_bytes: + if self.bytes_written == self.total_bytes: # finished sending the data logging.debug('Finished sending the data (%d bytes).', self.bytes_written) self.protocol_helper = ProtocolHelper() @@ -190,20 +196,20 @@ class PushDataProtocol(Protocol): # return False # - def _default_data_generator(self, index, bytes_needed): - return os.urandom(bytes_needed) - # # class PullDataProtocol(Protocol): - def __init__(self, socket, use_accelerated=True): + def __init__(self, socket, use_acceleration=None): + if use_acceleration is None: + use_acceleration = True + # self.socket = socket - self.use_accelerated = use_accelerated + self.use_acceleration = use_acceleration # self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE') self.state = self.states.READY_TO_BEGIN # self.data_size = None - self.recv_max_bytes = None + self.recv_buffer_len = None self.bytes_read = 0 self.protocol_helper = None self._time_of_first_byte = None @@ -219,27 +225,28 @@ class PullDataProtocol(Protocol): if self.protocol_helper.recv(self.socket, info_size): response = self.protocol_helper.get_buffer() self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False) - self.recv_max_bytes = int.from_bytes(response[8:16], byteorder='big', signed=False) + self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False) self.state = self.states.PULL_DATA # # if self.state is self.states.PULL_DATA: - max_block_size = self.recv_max_bytes - block_size = min(max_block_size, self.data_size-self.bytes_read) - # - if self.use_accelerated: + if self.use_acceleration: if not block: logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.') # - (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, block_size) + (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len) if ret_val < 0: raise ProtocolException('Error while pulling data.') # self.bytes_read = self.data_size self.elapsed_time = elapsed_time else: + bytes_remaining = self.data_size-self.bytes_read + block_size = min(self.recv_buffer_len, bytes_remaining) + # data = self.socket.recv(block_size) self.bytes_read += len(data) + # if self.bytes_read != 0 and self._time_of_first_byte is None: self._time_of_first_byte = time.time() # diff --git a/src/throughput_client.py b/src/throughput_client.py index d45dffe..0be8289 100644 --- a/src/throughput_client.py +++ b/src/throughput_client.py @@ -17,6 +17,8 @@ if __name__ == '__main__': parser.add_argument('--proxy', type=str, help='proxy ip address and port', metavar=('ip','port'), nargs=2) parser.add_argument('--wait', type=int, help='wait until the given time before pushing data (time in seconds since epoch)', metavar='time') + parser.add_argument('--buffer-len', type=useful.parse_bytes, help='size of the send and receive buffers (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='bytes') + parser.add_argument('--no-accel', action='store_true', help='don\'t use C acceleration (use pure Python)') args = parser.parse_args() # endpoint = (args.ip, args.port) @@ -27,7 +29,20 @@ if __name__ == '__main__': # username = bytes([x for x in os.urandom(12) if x != 0]) #username = None + ''' + data_MB = 200 #20000 + data_B = data_MB*2**20 # - client = throughput_protocols.ClientProtocol(endpoint, args.num_bytes, proxy=proxy, username=username, wait_until=args.wait) + if len(sys.argv) > 2: + wait_until = int(sys.argv[2]) + else: + wait_until = None + # + ''' + # + client = throughput_protocols.ClientProtocol(endpoint, args.num_bytes, proxy=proxy, + username=username, wait_until=args.wait, + send_buffer_len=args.buffer_len, + use_acceleration=(not args.no_accel)) client.run() # diff --git a/src/throughput_protocols.py b/src/throughput_protocols.py index 5dec4b6..3eb3d60 100755 --- a/src/throughput_protocols.py +++ b/src/throughput_protocols.py @@ -7,13 +7,14 @@ import time import socket # class ClientProtocol(basic_protocols.Protocol): - def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None, wait_until=None): + def __init__(self, endpoint, total_bytes, proxy=None, username=None, wait_until=None, send_buffer_len=None, use_acceleration=None): self.endpoint = endpoint - self.data_generator = data_generator self.total_bytes = total_bytes self.proxy = proxy self.username = username self.wait_until = wait_until + self.send_buffer_len = send_buffer_len + self.use_acceleration = use_acceleration # self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY SEND_GROUP_ID PUSH_DATA DONE') self.state = self.states.READY_TO_BEGIN @@ -47,7 +48,6 @@ class ClientProtocol(basic_protocols.Protocol): group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False) self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes) self.state = self.states.SEND_GROUP_ID - #logging.debug('Sent group ID.') # # if self.state is self.states.SEND_GROUP_ID: @@ -56,8 +56,8 @@ class ClientProtocol(basic_protocols.Protocol): # if (self.wait_until is None or time.time() >= self.wait_until) and self.sub_protocol.run(block=block): self.sub_protocol = basic_protocols.PushDataProtocol(self.socket, self.total_bytes, - data_generator=self.data_generator, - send_max_bytes=1024*512) + send_buffer_len=self.send_buffer_len, + use_acceleration=self.use_acceleration) self.state = self.states.PUSH_DATA # # @@ -71,11 +71,12 @@ class ClientProtocol(basic_protocols.Protocol): # # class ServerProtocol(basic_protocols.Protocol): - def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None): + def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None, use_acceleration=None): self.socket = socket self.conn_id = conn_id self.group_id_callback = group_id_callback self.bandwidth_callback = bandwidth_callback + self.use_acceleration = use_acceleration # self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_GROUP_ID PULL_DATA DONE') self.state = self.states.READY_TO_BEGIN @@ -95,7 +96,7 @@ class ServerProtocol(basic_protocols.Protocol): group_id = None # self.group_id_callback(self.conn_id, group_id) - self.sub_protocol = basic_protocols.PullDataProtocol(self.socket) + self.sub_protocol = basic_protocols.PullDataProtocol(self.socket, use_acceleration=self.use_acceleration) self.state = self.states.PULL_DATA # # diff --git a/src/throughput_server.py b/src/throughput_server.py index a22ed8f..0217d14 100644 --- a/src/throughput_server.py +++ b/src/throughput_server.py @@ -13,6 +13,7 @@ if __name__ == '__main__': # parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).') parser.add_argument('port', type=int, help='listen on port') + parser.add_argument('--no-accel', action='store_true', help='don\'t use C acceleration (use pure Python)') args = parser.parse_args() # endpoint = ('127.0.0.1', args.port) @@ -34,7 +35,8 @@ if __name__ == '__main__': bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'transfer_rate':transfer_rate}) # def start_server_conn(socket, conn_id): - server = throughput_protocols.ServerProtocol(socket, conn_id, group_id_callback=group_id_callback, bandwidth_callback=bw_callback) + server = throughput_protocols.ServerProtocol(socket, conn_id, group_id_callback=group_id_callback, + bandwidth_callback=bw_callback, use_acceleration=(not args.no_accel)) try: server.run() except KeyboardInterrupt: