Browse Source

Reorganized bandwidth_tester code.

Split the code into separate client/server/protocols files, and added argument parsing instead of hard-coded values.
Steven Engler 5 years ago
parent
commit
58322789e9
4 changed files with 159 additions and 125 deletions
  1. 33 0
      src/throughput_client.py
  2. 2 125
      src/throughput_protocols.py
  3. 112 0
      src/throughput_server.py
  4. 12 0
      src/useful.py

+ 33 - 0
src/throughput_client.py

@@ -0,0 +1,33 @@
+#!/usr/bin/python3
+#
+import throughput_protocols
+import useful
+import os
+import argparse
+import logging
+#
+if __name__ == '__main__':
+	logging.basicConfig(level=logging.DEBUG)
+	#
+	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
+	parser.add_argument('ip', type=str, help='destination ip address')
+	parser.add_argument('port', type=int, help='destination port')
+	parser.add_argument('num_bytes', type=useful.parse_bytes,
+	                    help='number of bytes to send (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='num-bytes')
+	parser.add_argument('--proxy', type=str, help='proxy ip address and port', metavar=('ip','port'), nargs=2)
+	parser.add_argument('--wait', type=int,
+	                    help='wait until the given time before pushing data (time in seconds since epoch)', metavar='time')
+	args = parser.parse_args()
+	#
+	endpoint = (args.ip, args.port)
+	proxy = None
+	#
+	if args.proxy is not None:
+		proxy = (args.proxy[0], int(args.proxy[1]))
+	#
+	username = bytes([x for x in os.urandom(12) if x != 0])
+	#username = None
+	#
+	client = throughput_protocols.ClientProtocol(endpoint, args.num_bytes, proxy=proxy, username=username, wait_until=args.wait)
+	client.run()
+#

+ 2 - 125
src/bandwidth_tester.py → src/throughput_protocols.py

@@ -6,7 +6,7 @@ import enum
 import time
 import socket
 #
-class ClientConnectionProtocol(basic_protocols.Protocol):
+class ClientProtocol(basic_protocols.Protocol):
 	def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None, wait_until=None):
 		self.endpoint = endpoint
 		self.data_generator = data_generator
@@ -70,7 +70,7 @@ class ClientConnectionProtocol(basic_protocols.Protocol):
 		return False
 	#
 #
-class ServerConnectionProtocol(basic_protocols.Protocol):
+class ServerProtocol(basic_protocols.Protocol):
 	def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None):
 		self.socket = socket
 		self.conn_id = conn_id
@@ -111,126 +111,3 @@ class ServerConnectionProtocol(basic_protocols.Protocol):
 		return False
 	#
 #
-if __name__ == '__main__':
-	import sys
-	logging.basicConfig(level=logging.DEBUG)
-	#
-	if sys.argv[1] == 'client':
-		import os
-		#
-		endpoint = ('127.0.0.1', 4747)
-		#endpoint = ('127.0.0.1', 8627)
-		#proxy = ('127.0.0.1', 9003+int(sys.argv[3])-1)
-		#proxy = ('127.0.0.1', 9003)
-		proxy = None
-		username = bytes([x for x in os.urandom(12) if x != 0])
-		#username = None
-		data_MB = 500000
-		#
-		if len(sys.argv) > 2:
-			wait_until = int(sys.argv[2])
-		else:
-			wait_until = None
-		#
-		client = ClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username, wait_until=wait_until)
-		client.run()
-		#
-	elif sys.argv[1] == 'server':
-		import multiprocessing
-		import queue
-		#
-		endpoint = ('127.0.0.1', 4747)
-		processes = []
-		processes_map = {}
-		joinable_connections = multiprocessing.Queue()
-		conn_counter = [0]
-		group_queue = multiprocessing.Queue()
-		bw_queue = multiprocessing.Queue()
-		#
-		def group_id_callback(conn_id, group_id):
-			# put them in a queue to display later
-			#logging.debug('For conn %d Received group id: %d', conn_id, group_id)
-			group_queue.put({'conn_id':conn_id, 'group_id':group_id})
-		#
-		def bw_callback(conn_id, data_size, transfer_rate):
-			# put them in a queue to display later
-			bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'transfer_rate':transfer_rate})
-		#
-		def start_server_conn(socket, conn_id):
-			server = ServerConnectionProtocol(socket, conn_id, group_id_callback=group_id_callback, bandwidth_callback=bw_callback)
-			try:
-				server.run()
-			except KeyboardInterrupt:
-				socket.close()
-			finally:
-				joinable_connections.put(conn_id)
-			#
-		#
-		def accept_callback(socket):
-			conn_id = conn_counter[0]
-			conn_counter[0] += 1
-			#logging.debug('Adding connection %d', conn_id)
-			p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
-			processes.append(p)
-			processes_map[conn_id] = p
-			p.start()
-			socket.close()
-			# close this process' copy of the socket
-		#
-		l = basic_protocols.ServerListener(endpoint, accept_callback)
-		#
-		try:
-			while True:
-				l.accept()
-				try:
-					while True:
-						conn_id = joinable_connections.get(False)
-						p = processes_map[conn_id]
-						p.join()
-					#
-				except queue.Empty:
-					pass
-				#
-			#
-		except KeyboardInterrupt:
-			print()
-			#
-			bw_values = {}
-			group_values = {}
-			#
-			try:
-				while True:
-					bw_val = bw_queue.get(False)
-					bw_values[bw_val['conn_id']] = bw_val
-				#
-			except queue.Empty:
-				pass
-			#
-			try:
-				while True:
-					group_val = group_queue.get(False)
-					group_values[group_val['conn_id']] = group_val
-				#
-			except queue.Empty:
-				pass
-			#
-			group_set = set([x['group_id'] for x in group_values.values()])
-			for group in group_set:
-				# doesn't handle group == None
-				conns_in_group = [x[0] for x in group_values.items() if x[1]['group_id'] == group]
-				in_group = [x for x in bw_values.values() if x['conn_id'] in conns_in_group]
-				if len(in_group) > 0:
-					avg_data_size = sum([x['data_size'] for x in in_group])/len(in_group)
-					avg_transfer_rate = sum([x['transfer_rate'] for x in in_group])/len(in_group)
-					#
-					logging.info('Group size: %d', len(in_group))
-					logging.info('Avg Transferred (MB): %.4f', avg_data_size/(1024**2))
-					logging.info('Avg Transfer rate (MB/s): %.4f', avg_transfer_rate/(1024**2))
-				#
-			#
-		#
-		for p in processes:
-			p.join()
-		#
-	#
-#

+ 112 - 0
src/throughput_server.py

@@ -0,0 +1,112 @@
+#!/usr/bin/python3
+#
+import throughput_protocols
+import basic_protocols
+import os
+import multiprocessing
+import queue
+import logging
+import argparse
+#
+if __name__ == '__main__':
+	logging.basicConfig(level=logging.DEBUG)
+	#
+	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
+	parser.add_argument('port', type=int, help='listen on port')
+	args = parser.parse_args()
+	#
+	endpoint = ('127.0.0.1', args.port)
+	#
+	processes = []
+	processes_map = {}
+	joinable_connections = multiprocessing.Queue()
+	conn_counter = [0]
+	group_queue = multiprocessing.Queue()
+	bw_queue = multiprocessing.Queue()
+	#
+	def group_id_callback(conn_id, group_id):
+		# put them in a queue to display later
+		#logging.debug('For conn %d Received group id: %d', conn_id, group_id)
+		group_queue.put({'conn_id':conn_id, 'group_id':group_id})
+	#
+	def bw_callback(conn_id, data_size, transfer_rate):
+		# put them in a queue to display later
+		bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'transfer_rate':transfer_rate})
+	#
+	def start_server_conn(socket, conn_id):
+		server = throughput_protocols.ServerProtocol(socket, conn_id, group_id_callback=group_id_callback, bandwidth_callback=bw_callback)
+		try:
+			server.run()
+		except KeyboardInterrupt:
+			socket.close()
+		finally:
+			joinable_connections.put(conn_id)
+		#
+	#
+	def accept_callback(socket):
+		conn_id = conn_counter[0]
+		conn_counter[0] += 1
+		#logging.debug('Adding connection %d', conn_id)
+		p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
+		processes.append(p)
+		processes_map[conn_id] = p
+		p.start()
+		socket.close()
+		# close this process' copy of the socket
+	#
+	l = basic_protocols.ServerListener(endpoint, accept_callback)
+	#
+	try:
+		while True:
+			l.accept()
+			try:
+				while True:
+					conn_id = joinable_connections.get(False)
+					p = processes_map[conn_id]
+					p.join()
+				#
+			except queue.Empty:
+				pass
+			#
+		#
+	except KeyboardInterrupt:
+		print()
+		#
+		bw_values = {}
+		group_values = {}
+		#
+		try:
+			while True:
+				bw_val = bw_queue.get(False)
+				bw_values[bw_val['conn_id']] = bw_val
+			#
+		except queue.Empty:
+			pass
+		#
+		try:
+			while True:
+				group_val = group_queue.get(False)
+				group_values[group_val['conn_id']] = group_val
+			#
+		except queue.Empty:
+			pass
+		#
+		group_set = set([x['group_id'] for x in group_values.values()])
+		for group in group_set:
+			# doesn't handle group == None
+			conns_in_group = [x[0] for x in group_values.items() if x[1]['group_id'] == group]
+			in_group = [x for x in bw_values.values() if x['conn_id'] in conns_in_group]
+			if len(in_group) > 0:
+				avg_data_size = sum([x['data_size'] for x in in_group])/len(in_group)
+				avg_transfer_rate = sum([x['transfer_rate'] for x in in_group])/len(in_group)
+				#
+				logging.info('Group size: %d', len(in_group))
+				logging.info('Avg Transferred (MB): %.4f', avg_data_size/(1024**2))
+				logging.info('Avg Transfer rate (MB/s): %.4f', avg_transfer_rate/(1024**2))
+			#
+		#
+	#
+	for p in processes:
+		p.join()
+	#
+#

+ 12 - 0
src/useful.py

@@ -0,0 +1,12 @@
+def parse_bytes(bytes_str):
+	conversions = {'B':1, 'KiB':1024, 'MiB':1024**2, 'GiB':1024**3}
+	#
+	matching_conversions = [x for x in conversions if bytes_str.endswith(x)]
+	if len(matching_conversions) > 0:
+		# if any conversion suffix matched
+		most_precise_match = max(matching_conversions, key=len)
+		number = int(bytes_str[:-len(most_precise_match)])
+		return number*conversions[most_precise_match]
+	#
+	return int(bytes_str)
+#