Browse Source

Cleaned up code and added more program arguments.

Simplified some sections of the code and added command line arguments to set the send/receive buffer length and to disable the accelerated C functions.
Steven Engler 5 years ago
parent
commit
02cd89f34a
4 changed files with 51 additions and 35 deletions
  1. 33 26
      src/basic_protocols.py
  2. 7 1
      src/throughput_client.py
  3. 8 7
      src/throughput_protocols.py
  4. 3 1
      src/throughput_server.py

+ 33 - 26
src/basic_protocols.py

@@ -123,26 +123,28 @@ class Socks4Protocol(Protocol):
 	#
 #
 class PushDataProtocol(Protocol):
-	def __init__(self, socket, total_bytes, data_generator=None, send_max_bytes=1024*512, use_accelerated=True):
-		if data_generator is None:
-			data_generator = self._default_data_generator
+	def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None):
+		if send_buffer_len is None:
+			send_buffer_len = 1024*512
+		#
+		if use_acceleration is None:
+			use_acceleration = True
 		#
 		self.socket = socket
-		self.data_generator = data_generator
 		self.total_bytes = total_bytes
-		self.send_max_bytes = send_max_bytes
-		self.use_accelerated = use_accelerated
+		self.use_acceleration = use_acceleration
 		#
 		self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE')
 		self.state = self.states.READY_TO_BEGIN
 		#
+		self.byte_buffer = os.urandom(send_buffer_len)
 		self.bytes_written = 0
 		self.protocol_helper = None
 	#
 	def _run_iteration(self, block=True):
 		if self.state is self.states.READY_TO_BEGIN:
 			info = self.total_bytes.to_bytes(8, byteorder='big', signed=False)
-			info += self.send_max_bytes.to_bytes(8, byteorder='big', signed=False)
+			info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False)
 			self.protocol_helper = ProtocolHelper()
 			self.protocol_helper.set_buffer(info)
 			self.state = self.states.SEND_INFO
@@ -153,24 +155,28 @@ class PushDataProtocol(Protocol):
 			#
 		#
 		if self.state is self.states.PUSH_DATA:
-			max_block_size = self.send_max_bytes
-			block_size = min(max_block_size, self.total_bytes-self.bytes_written)
-			data = self.data_generator(self.bytes_written, block_size)
-			#
-			if self.use_accelerated:
+			if self.use_acceleration:
 				if not block:
 					logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
 				#
-				ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, data)
+				ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer)
 				if ret_val < 0:
 					raise ProtocolException('Error while pushing data.')
 				#
 				self.bytes_written = self.total_bytes
 			else:
+				bytes_remaining = self.total_bytes-self.bytes_written
+				data_size = min(len(self.byte_buffer), bytes_remaining)
+				if data_size != len(self.byte_buffer):
+					data = self.byte_buffer[:data_size]
+				else:
+					data = self.byte_buffer
+					# don't make a copy of the byte string each time if we don't need to
+				#
 				n = self.socket.send(data)
 				self.bytes_written += n
 			#
-			if self.bytes_written >= self.total_bytes:
+			if self.bytes_written == self.total_bytes:
 				# finished sending the data
 				logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
 				self.protocol_helper = ProtocolHelper()
@@ -190,20 +196,20 @@ class PushDataProtocol(Protocol):
 		#
 		return False
 	#
-	def _default_data_generator(self, index, bytes_needed):
-		return os.urandom(bytes_needed)
-	#
 #
 class PullDataProtocol(Protocol):
-	def __init__(self, socket, use_accelerated=True):
+	def __init__(self, socket, use_acceleration=None):
+		if use_acceleration is None:
+			use_acceleration = True
+		#
 		self.socket = socket
-		self.use_accelerated = use_accelerated
+		self.use_acceleration = use_acceleration
 		#
 		self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
 		self.state = self.states.READY_TO_BEGIN
 		#
 		self.data_size = None
-		self.recv_max_bytes = None
+		self.recv_buffer_len = None
 		self.bytes_read = 0
 		self.protocol_helper = None
 		self._time_of_first_byte = None
@@ -219,27 +225,28 @@ class PullDataProtocol(Protocol):
 			if self.protocol_helper.recv(self.socket, info_size):
 				response = self.protocol_helper.get_buffer()
 				self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False)
-				self.recv_max_bytes = int.from_bytes(response[8:16], byteorder='big', signed=False)
+				self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False)
 				self.state = self.states.PULL_DATA
 			#
 		#
 		if self.state is self.states.PULL_DATA:
-			max_block_size = self.recv_max_bytes
-			block_size = min(max_block_size, self.data_size-self.bytes_read)
-			#
-			if self.use_accelerated:
+			if self.use_acceleration:
 				if not block:
 					logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
 				#
-				(ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, block_size)
+				(ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len)
 				if ret_val < 0:
 					raise ProtocolException('Error while pulling data.')
 				#
 				self.bytes_read = self.data_size
 				self.elapsed_time = elapsed_time
 			else:
+				bytes_remaining = self.data_size-self.bytes_read
+				block_size = min(self.recv_buffer_len, bytes_remaining)
+				#
 				data = self.socket.recv(block_size)
 				self.bytes_read += len(data)
+				#
 				if self.bytes_read != 0 and self._time_of_first_byte is None:
 					self._time_of_first_byte = time.time()
 				#

+ 7 - 1
src/throughput_client.py

@@ -17,6 +17,9 @@ if __name__ == '__main__':
 	parser.add_argument('--proxy', type=str, help='proxy ip address and port', metavar=('ip','port'), nargs=2)
 	parser.add_argument('--wait', type=int,
 	                    help='wait until the given time before pushing data (time in seconds since epoch)', metavar='time')
+	parser.add_argument('--buffer-len', type=useful.parse_bytes,
+	                    help='size of the send and receive buffers (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='bytes')
+	parser.add_argument('--no-accel', action='store_true', help='don\'t use C acceleration (use pure Python)')
 	args = parser.parse_args()
 	#
 	endpoint = (args.ip, args.port)
@@ -28,6 +31,9 @@ if __name__ == '__main__':
 	username = bytes([x for x in os.urandom(12) if x != 0])
 	#username = None
 	#
-	client = throughput_protocols.ClientProtocol(endpoint, args.num_bytes, proxy=proxy, username=username, wait_until=args.wait)
+	client = throughput_protocols.ClientProtocol(endpoint, args.num_bytes, proxy=proxy,
+	                                             username=username, wait_until=args.wait,
+	                                             send_buffer_len=args.buffer_len,
+	                                             use_acceleration=(not args.no_accel))
 	client.run()
 #

+ 8 - 7
src/throughput_protocols.py

@@ -7,13 +7,14 @@ import time
 import socket
 #
 class ClientProtocol(basic_protocols.Protocol):
-	def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None, wait_until=None):
+	def __init__(self, endpoint, total_bytes, proxy=None, username=None, wait_until=None, send_buffer_len=None, use_acceleration=None):
 		self.endpoint = endpoint
-		self.data_generator = data_generator
 		self.total_bytes = total_bytes
 		self.proxy = proxy
 		self.username = username
 		self.wait_until = wait_until
+		self.send_buffer_len = send_buffer_len
+		self.use_acceleration = use_acceleration
 		#
 		self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY SEND_GROUP_ID PUSH_DATA DONE')
 		self.state = self.states.READY_TO_BEGIN
@@ -47,7 +48,6 @@ class ClientProtocol(basic_protocols.Protocol):
 				group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
 				self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
 				self.state = self.states.SEND_GROUP_ID
-				#logging.debug('Sent group ID.')
 			#
 		#
 		if self.state is self.states.SEND_GROUP_ID:
@@ -56,8 +56,8 @@ class ClientProtocol(basic_protocols.Protocol):
 			#
 			if (self.wait_until is None or time.time() >= self.wait_until) and self.sub_protocol.run(block=block):
 				self.sub_protocol = basic_protocols.PushDataProtocol(self.socket, self.total_bytes,
-				                                                     data_generator=self.data_generator,
-				                                                     send_max_bytes=1024*512)
+				                                                     send_buffer_len=self.send_buffer_len,
+				                                                     use_acceleration=self.use_acceleration)
 				self.state = self.states.PUSH_DATA
 			#
 		#
@@ -71,11 +71,12 @@ class ClientProtocol(basic_protocols.Protocol):
 	#
 #
 class ServerProtocol(basic_protocols.Protocol):
-	def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None):
+	def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None, use_acceleration=None):
 		self.socket = socket
 		self.conn_id = conn_id
 		self.group_id_callback = group_id_callback
 		self.bandwidth_callback = bandwidth_callback
+		self.use_acceleration = use_acceleration
 		#
 		self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_GROUP_ID PULL_DATA DONE')
 		self.state = self.states.READY_TO_BEGIN
@@ -95,7 +96,7 @@ class ServerProtocol(basic_protocols.Protocol):
 					group_id = None
 				#
 				self.group_id_callback(self.conn_id, group_id)
-				self.sub_protocol = basic_protocols.PullDataProtocol(self.socket)
+				self.sub_protocol = basic_protocols.PullDataProtocol(self.socket, use_acceleration=self.use_acceleration)
 				self.state = self.states.PULL_DATA
 			#
 		#

+ 3 - 1
src/throughput_server.py

@@ -13,6 +13,7 @@ if __name__ == '__main__':
 	#
 	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
 	parser.add_argument('port', type=int, help='listen on port')
+	parser.add_argument('--no-accel', action='store_true', help='don\'t use C acceleration (use pure Python)')
 	args = parser.parse_args()
 	#
 	endpoint = ('127.0.0.1', args.port)
@@ -34,7 +35,8 @@ if __name__ == '__main__':
 		bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'transfer_rate':transfer_rate})
 	#
 	def start_server_conn(socket, conn_id):
-		server = throughput_protocols.ServerProtocol(socket, conn_id, group_id_callback=group_id_callback, bandwidth_callback=bw_callback)
+		server = throughput_protocols.ServerProtocol(socket, conn_id, group_id_callback=group_id_callback,
+		                                             bandwidth_callback=bw_callback, use_acceleration=(not args.no_accel))
 		try:
 			server.run()
 		except KeyboardInterrupt: