Pārlūkot izejas kodu

Rewriting files.

Steven Engler 5 gadi atpakaļ
vecāks
revīzija
242eaa8a5d

+ 7 - 6
Makefile

@@ -1,6 +1,6 @@
 CC=gcc
-CFLAGS=-O3 -std=c99 -D_BSD_SOURCE
-PYTHON_INC=/usr/include/python3.4
+CFLAGS=-O3 -std=c99 -D_DEFAULT_SOURCE
+PYTHON_INC=/usr/include/python3.6
 
 PY_BIN_FILES:=$(patsubst src/%.py,bin/%.py,$(wildcard src/*.py))
 PY_DEV_FILES:=$(patsubst src/%.py,dev/%.py,$(wildcard src/*.py))
@@ -9,7 +9,8 @@ all: bin_dir $(PY_BIN_FILES) bin/accelerated_functions.so
 dev: dev_dir $(PY_DEV_FILES) dev/accelerated_functions.so
 
 clean:
-	@rm -r bin
+	@if [ -d bin ]; then rm -r bin; fi
+	@if [ -d dev ]; then rm -r dev; fi
 
 bin/accelerated_functions.so: src/accelerated_functions.c
 dev/accelerated_functions.so: src/accelerated_functions.c
@@ -27,12 +28,12 @@ bin_dir:
 
 #######
 
+dev/%.so: src/%.c
+	$(CC) $(CFLAGS) -I $(PYTHON_INC) -shared -fPIC $^ -o $@
+
 dev/%.py: src/%.py
 	rm -f $@
 	ln $< $@
 
-dev/%.so: src/%.c
-	$(CC) $(CFLAGS) -I $(PYTHON_INC) -shared -fPIC $^ -o $@
-
 dev_dir:
 	@mkdir -p dev

+ 6 - 2
src/basic_protocols.py

@@ -490,12 +490,12 @@ class ReceiveDataProtocol(Protocol):
 	#
 #
 class ServerListener():
-	def __init__(self, endpoint, accept_callback):
+	def __init__(self, bind_endpoint, accept_callback):
 		self.callback = accept_callback
 		#
 		self.s = socket.socket()
 		self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-		self.s.bind(endpoint)
+		self.s.bind(bind_endpoint)
 		self.s.listen(0)
 	#
 	def accept(self):
@@ -504,6 +504,10 @@ class ServerListener():
 					  endpoint[0], endpoint[1], newsock.fileno())
 		self.callback(newsock)
 	#
+	def stop(self):
+		self.s.shutdown(socket.SHUT_RDWR)
+		# use 'shutdown' rather than 'close' since 'close' won't stop a blocking 'accept' call
+	#
 #
 class SimpleClientConnectionProtocol(Protocol):
 	def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None):

+ 182 - 0
src/chutney_manager.py

@@ -0,0 +1,182 @@
+#!/usr/bin/env python3
+#
+import subprocess
+import logging
+import os
+import sys
+#
+def start_chutney_network(chutney_path, tor_path, network_file, controlling_pid=None):
+	args = [os.path.join(chutney_path, 'tools/test-network.sh'), '--chutney-path', chutney_path,
+	        '--tor-path', tor_path, '--stop-time', '-1', '--network', network_file]
+	if controlling_pid is not None:
+		args.extend(['--controlling-pid', str(controlling_pid)])
+	#
+	try:
+		subprocess.check_output(args, stderr=subprocess.STDOUT)
+	except subprocess.CalledProcessError as e:
+		logging.error('Chutney error:\n' + e.output.decode(sys.stdout.encoding))
+		raise
+	#
+#
+def stop_chutney_network(chutney_path, network_file):
+	args = [os.path.join(chutney_path, 'chutney'), 'stop', network_file]
+	try:
+		subprocess.check_output(args, stderr=subprocess.STDOUT)
+	except subprocess.CalledProcessError as e:
+		logging.error('Chutney error:\n' + e.output.decode(sys.stdout.encoding))
+		raise
+	#
+#
+class ChutneyNetwork:
+	def __init__(self, chutney_path, tor_path, network_file, controlling_pid=None):
+		self.chutney_path = chutney_path
+		self.network_file = network_file
+		#
+		start_chutney_network(chutney_path, tor_path, network_file, controlling_pid=controlling_pid)
+	#
+	def stop(self):
+		stop_chutney_network(self.chutney_path, self.network_file)
+	#
+	def __enter__(self):
+		return self
+	#
+	def __exit__(self, exc_type, exc_val, exc_tb):
+		self.stop()
+	#
+#
+class Node:
+	def __init__(self, **kwargs):
+		self.options = kwargs
+	#
+	def guess_nickname(self, index):
+		"""
+		This guesses the nickname based on the format Chutney uses. There is
+		no good way to get the actual value.
+		"""
+		#
+		return '{:03}{}'.format(index, self.options['tag'])
+	#
+	def _value_formatter(self, value):
+		if type(value) == str:
+			return "'{}'".format(value)
+		#
+		return value
+	#
+	def __str__(self):
+		arg_value_pairs = ['{}={}'.format(x, self._value_formatter(self.options[x])) for x in self.options]
+		return 'Node({})'.format(', '.join(arg_value_pairs))
+	#
+#
+def create_compact_chutney_config(nodes):
+	if len(nodes) == 0:
+		return None
+	#
+	config = ''
+	for (name, count, options) in nodes:
+		config += '{} = {}\n'.format(name, str(options))
+	#
+	config += '\n'
+	config += 'NODES = {}\n'.format(' + '.join(['{}.getN({})'.format(name, count) for (name, count, options) in nodes]))
+	config += '\n'
+	config += 'ConfigureNodes(NODES)'
+	#
+	return config
+#
+def create_chutney_config(nodes):
+	if len(nodes) == 0:
+		return None
+	#
+	config = ''
+	config += 'NODES = [{}]\n'.format(', \n'.join([str(node) for node in nodes]))
+	config += '\n'
+	config += 'ConfigureNodes(NODES)'
+	#
+	return config
+#
+def read_fingerprint(nickname, chutney_path):
+	try:
+		with open(os.path.join(chutney_path, 'net', 'nodes', nickname, 'fingerprint'), 'r') as f:
+			return f.read().strip().split(' ')[1]
+		#
+	except IOError as e:
+		return None
+	#
+#
+def numa_scheduler(num_processors_needed, numa_nodes):
+	"""
+	Finds the numa node with the most physical cores remaining and
+	assigns physical cores (typically 2 virtual processors) until
+	the process has enough processors.
+	"""
+	#
+	chosen_processors = []
+	num_physical_cores = {x:len(numa_nodes[x]['physical_cores']) for x in numa_nodes}
+	node_with_most_physical_cores = max(num_physical_cores, key=lambda x: (num_physical_cores.get(x), -x))
+	while len(chosen_processors) < num_processors_needed:
+		chosen_processors.extend(numa_nodes[node_with_most_physical_cores]['physical_cores'][0])
+		# note: this may assign more processors than requested
+		numa_nodes[node_with_most_physical_cores]['physical_cores'] = numa_nodes[node_with_most_physical_cores]['physical_cores'][1:]
+	#
+	return (node_with_most_physical_cores, chosen_processors)
+#
+if __name__ == '__main__':
+	import time
+	import tempfile
+	import numa
+	#
+	logging.basicConfig(level=logging.DEBUG)
+	#
+	chutney_path = '/home/sengler/code/measureme/chutney'
+	tor_path = '/home/sengler/code/measureme/tor'
+	#
+	#nodes = [('authority', 2, Node(tag='a', relay=1, authority=1, torrc='authority.tmpl')),
+	#         ('other_relay', 14, Node(tag='r', relay=1, torrc='relay-non-exit.tmpl')),
+	#         ('exit_relay', 1, Node(tag='r', exit=1, torrc='relay.tmpl')),
+	#         ('client', 16, Node(tag='c', client=1, torrc='client.tmpl'))]
+	#nodes = [('authority', 2, Node(tag='a', relay=1, num_cpus=2, authority=1, torrc='authority.tmpl')),
+	#         ('other_relay', 2, Node(tag='r', relay=1, num_cpus=2, torrc='relay-non-exit.tmpl')),
+	#         ('exit_relay', 1, Node(tag='r', exit=1, num_cpus=2, torrc='relay.tmpl')),
+	#         ('client', 2, Node(tag='c', client=1, num_cpus=1, torrc='client.tmpl'))]
+	#
+	nodes = [Node(tag='a', relay=1, num_cpus=2, authority=1, torrc='authority.tmpl') for _ in range(2)] + \
+	        [Node(tag='r', relay=1, num_cpus=2, torrc='relay-non-exit.tmpl') for _ in range(2)] + \
+	        [Node(tag='e', exit=1, num_cpus=2, torrc='relay.tmpl') for _ in range(1)] + \
+	        [Node(tag='c', client=1, num_cpus=1, torrc='client.tmpl') for _ in range(2)]
+	#
+	numa_remaining = numa.get_numa_overview()
+	numa_sets = []
+	for node in nodes:
+		num_cpus = node.options['num_cpus']
+		if num_cpus%2 != 0:
+			num_cpus += 1
+		#
+		(numa_node, processors) = numa_scheduler(num_cpus, numa_remaining)
+		node.options['numa_settings'] = (numa_node, processors)
+		numa_sets.append((numa_node, processors))
+	#
+	print(numa_sets)
+	unused_processors = numa.generate_range_list([z for node in numa_remaining for y in numa_remaining[node]['physical_cores'] for z in y])
+	print(unused_processors)
+	#
+	nicknames = [nodes[x].guess_nickname(x) for x in range(len(nodes))]
+	print(nicknames)
+	#
+	(fd, tmp_network_file) = tempfile.mkstemp(prefix='chutney-network-')
+	try:
+		with os.fdopen(fd, mode='w') as f:
+			#f.write(create_compact_chutney_config(nodes))
+			f.write(create_chutney_config(nodes))
+		#
+		with ChutneyNetwork(chutney_path, tor_path, tmp_network_file) as net:
+			# do stuff here
+			fingerprints = []
+			for nick in nicknames:
+				fingerprints.append(read_fingerprint(nick, chutney_path))
+			#
+			print(fingerprints)
+			time.sleep(5)
+		#
+	finally:
+		os.remove(tmp_network_file)
+	#
+#

+ 2 - 0
src/cpu_graph.py

@@ -1,3 +1,5 @@
+#!/usr/bin/env python3
+#
 import sys
 import json
 import gzip

+ 2 - 0
src/experiment.py

@@ -0,0 +1,2 @@
+#!/usr/bin/python3
+#

+ 94 - 0
src/numa.py

@@ -0,0 +1,94 @@
+import os
+#
+def parse_range_list(range_list_str):
+	"""
+	Take an input like '1-3,5,7-10' and return a list like [1,2,3,5,7,8,9,10].
+	"""
+	#
+	range_strings = range_list_str.split(',')
+	all_items = []
+	for range_str in range_strings:
+		if '-' in range_str:
+			# in the form '12-34'
+			range_ends = [int(x) for x in range_str.split('-')]
+			assert(len(range_ends) == 2)
+			all_items.extend(range(range_ends[0], range_ends[1]+1))
+		else:
+			# just a number
+			all_items.append(int(range_str))
+		#
+	#
+	return all_items
+#
+def generate_range_list(l):
+	"""
+	Take a list like [1,2,3,5,7,8,9,10] and return a string like '1-3,5,7-10'.
+	"""
+	#
+	l = list(set(sorted(l)))
+	ranges = []
+	current_range_start = None
+	#
+	for index in range(len(l)):
+		if current_range_start is None:
+			current_range_start = l[index]
+		else:
+			if l[index] != l[index-1]+1:
+				ranges.append((current_range_start, l[index-1]))
+				current_range_start = l[index]
+			#
+		#
+	#
+	ranges.append((current_range_start, l[-1]))
+	#
+	return ','.join(['-'.join([str(y) for y in x]) if x[0] != x[1] else str(x[0]) for x in ranges])
+#
+def check_path_traversal(path, base_path):
+	# this is not guarenteed to be secure
+	if os.path.commonprefix([os.path.realpath(path), base_path]) != base_path:
+		raise Exception('The path \'{}\' is not in the base path \'{}.\''.format(os.path.realpath(path), base_path))
+	#
+#
+def get_online_numa_nodes():
+	with open('/sys/devices/system/node/online', 'r') as f:
+		online_nodes = f.read().strip()
+		return parse_range_list(online_nodes)
+	#
+#
+def get_processors_in_numa_node(node_id):
+	path = '/sys/devices/system/node/node{}/cpulist'.format(node_id)
+	check_path_traversal(path, '/sys/devices/system/node')
+	with open(path, 'r') as f:
+		return parse_range_list(f.read().strip())
+	#
+#
+def get_thread_siblings(processor_id):
+	path = '/sys/devices/system/cpu/cpu{}/topology/thread_siblings_list'.format(processor_id)
+	check_path_traversal(path, '/sys/devices/system/cpu')
+	with open(path, 'r') as f:
+		return parse_range_list(f.read().strip())
+	#
+#
+def get_numa_overview():
+	numa_nodes = {}
+	#
+	for node_id in get_online_numa_nodes():
+		numa_nodes[node_id] = {}
+	#
+	for node_id in numa_nodes:
+		processors = get_processors_in_numa_node(node_id)
+		#
+		numa_nodes[node_id]['physical_cores'] = []
+		for processor_id in processors:
+			thread_siblings = sorted(get_thread_siblings(processor_id))
+			#
+			if thread_siblings not in numa_nodes[node_id]['physical_cores']:
+				numa_nodes[node_id]['physical_cores'].append(thread_siblings)
+			#
+		#
+	#
+	return numa_nodes
+#
+if __name__ == '__main__':
+	print(get_numa_overview())
+#

+ 2 - 0
src/plot_streams.py

@@ -1,3 +1,5 @@
+#!/usr/bin/env python3
+#
 import gzip
 import pickle
 import matplotlib.pylab as plt

+ 10 - 4
src/throughput_client.py

@@ -7,6 +7,7 @@ import os
 import argparse
 import logging
 import socket
+import time
 #
 if __name__ == '__main__':
 	logging.basicConfig(level=logging.DEBUG)
@@ -60,12 +61,17 @@ if __name__ == '__main__':
 		proxy_protocol = basic_protocols.FakeProxyProtocol(client_socket, endpoint)
 		protocols.append(proxy_protocol)
 	#
-	group_id = int(args.wait*1000) if args.wait is not None else None
+	group_id_bytes = args.wait.to_bytes(8, byteorder='big') if args.wait is not None else b''
+	if args.wait is not None:
+		push_start_cb = lambda: time.sleep(args.wait-time.time())
+	else:
+		push_start_cb = None
+	#
 	throughput_protocol = throughput_protocols.ClientProtocol(client_socket, args.num_bytes,
-	                                             wait_until=args.wait,
-	                                             group_id=group_id,
+	                                             custom_data=group_id_bytes,
 	                                             send_buffer_len=args.buffer_len,
-	                                             use_acceleration=(not args.no_accel))
+	                                             use_acceleration=(not args.no_accel),
+	                                             push_start_cb=push_start_cb)
 	protocols.append(throughput_protocol)
 	#
 	combined_protocol = basic_protocols.ChainedProtocol(protocols)

+ 18 - 51
src/throughput_protocols.py

@@ -7,52 +7,27 @@ import time
 import socket
 #
 class ClientProtocol(basic_protocols.Protocol):
-	def __init__(self, socket, total_bytes, group_id=None, send_buffer_len=None, use_acceleration=None, custom_data=b'', push_start_cb=None, push_done_cb=None): #wait_until=None
+	def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None, custom_data=b'', push_start_cb=None, push_done_cb=None):
 		self.socket = socket
 		self.total_bytes = total_bytes
-		#self.wait_until = wait_until
 		self.send_buffer_len = send_buffer_len
 		self.use_acceleration = use_acceleration
 		self.custom_data = custom_data
 		self.push_start_cb = push_start_cb
 		self.push_done_cb = push_done_cb
-		self.group_id = group_id if group_id is not None else 0
-		# a group id of 0 means no group
 		#
-		self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN SEND_GROUP_ID SEND_CUSTOM_DATA PUSH_DATA DONE') #WAIT
+		self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN SEND_CUSTOM_DATA PUSH_DATA DONE') #WAIT
 		self.state = self.states.READY_TO_BEGIN
 		#
 		self.sub_protocol = None
 	#
 	def _run_iteration(self):
 		if self.state is self.states.READY_TO_BEGIN:
-			group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
-			self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
-			self.state = self.states.SEND_GROUP_ID
-		#
-		if self.state is self.states.SEND_GROUP_ID:
-			if self.sub_protocol.run():
-				self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, self.custom_data)
-				self.state = self.states.SEND_CUSTOM_DATA
-			#
+			self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, self.custom_data)
+			self.state = self.states.SEND_CUSTOM_DATA
 		#
 		if self.state is self.states.SEND_CUSTOM_DATA:
 			if self.sub_protocol.run():
-				#self.state = self.states.WAIT
-				self.sub_protocol = basic_protocols.PushDataProtocol(self.socket, self.total_bytes,
-				                                                     send_buffer_len=self.send_buffer_len,
-				                                                     use_acceleration=self.use_acceleration,
-				                                                     push_start_cb=self.push_start_cb,
-				                                                     push_done_cb=self.push_done_cb)
-				self.state = self.states.PUSH_DATA
-			#
-		#
-		'''
-		if self.state is self.states.WAIT:
-			if self.wait_until is not None:
-				time.sleep(self.wait_until-time.time())
-			#
-			if self.wait_until is None or time.time() >= self.wait_until:
 				self.sub_protocol = basic_protocols.PushDataProtocol(self.socket, self.total_bytes,
 				                                                     send_buffer_len=self.send_buffer_len,
 				                                                     use_acceleration=self.use_acceleration,
@@ -61,7 +36,6 @@ class ClientProtocol(basic_protocols.Protocol):
 				self.state = self.states.PUSH_DATA
 			#
 		#
-		'''
 		if self.state is self.states.PUSH_DATA:
 			if self.sub_protocol.run():
 				self.state = self.states.DONE
@@ -74,14 +48,12 @@ class ClientProtocol(basic_protocols.Protocol):
 	#
 #
 class ServerProtocol(basic_protocols.Protocol):
-	def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None, use_acceleration=None):
+	def __init__(self, socket, results_callback=None, use_acceleration=None):
 		self.socket = socket
-		self.conn_id = conn_id
-		self.group_id_callback = group_id_callback
-		self.bandwidth_callback = bandwidth_callback
+		self.results_callback = results_callback
 		self.use_acceleration = use_acceleration
 		#
-		self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_GROUP_ID RECV_CUSTOM_DATA PULL_DATA DONE')
+		self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_CUSTOM_DATA PULL_DATA DONE')
 		self.state = self.states.READY_TO_BEGIN
 		#
 		self.sub_protocol = None
@@ -90,19 +62,7 @@ class ServerProtocol(basic_protocols.Protocol):
 	def _run_iteration(self):
 		if self.state is self.states.READY_TO_BEGIN:
 			self.sub_protocol = basic_protocols.ReceiveDataProtocol(self.socket)
-			self.state = self.states.RECV_GROUP_ID
-		#
-		if self.state is self.states.RECV_GROUP_ID:
-			if self.sub_protocol.run():
-				group_id = int.from_bytes(self.sub_protocol.received_data, byteorder='big', signed=False)
-				if group_id == 0:
-					# a group of 0 means no group
-					group_id = None
-				#
-				self.group_id_callback(self.conn_id, group_id)
-				self.sub_protocol = basic_protocols.ReceiveDataProtocol(self.socket)
-				self.state = self.states.RECV_CUSTOM_DATA
-			#
+			self.state = self.states.RECV_CUSTOM_DATA
 		#
 		if self.state is self.states.RECV_CUSTOM_DATA:
 			if self.sub_protocol.run():
@@ -114,9 +74,16 @@ class ServerProtocol(basic_protocols.Protocol):
 		#
 		if self.state is self.states.PULL_DATA:
 			if self.sub_protocol.run():
-				if self.bandwidth_callback:
-					#self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.time_of_first_byte, self.sub_protocol.time_of_last_byte, self.sub_protocol.calc_transfer_rate(), self.sub_protocol.byte_counter, self.sub_protocol.byte_counter_start_time)
-					self.bandwidth_callback(self.conn_id, self.custom_data, self.sub_protocol.data_size, self.sub_protocol.time_of_first_byte, self.sub_protocol.time_of_last_byte, self.sub_protocol.calc_transfer_rate(), self.sub_protocol.deltas)
+				if self.results_callback:
+					results = {}
+					results['custom_data'] = self.custom_data
+					results['data_size'] = self.sub_protocol.data_size
+					results['time_of_first_byte'] = self.sub_protocol.time_of_first_byte
+					results['time_of_last_byte'] = self.sub_protocol.time_of_last_byte
+					results['transfer_rate'] = self.sub_protocol.calc_transfer_rate()
+					results['deltas'] = self.sub_protocol.deltas
+					#
+					self.results_callback(results)
 				#
 				self.state = self.states.DONE
 			#

+ 174 - 0
src/throughput_server.new.py

@@ -0,0 +1,174 @@
+#!/usr/bin/python3
+#
+import throughput_protocols
+import basic_protocols
+import os
+import multiprocessing
+import threading
+import queue
+import logging
+import argparse
+#
+class QueueGetter:
+	def __init__(self, queue, unqueued):
+		self.queue = queue
+		self.unqueued = unqueued
+		#
+		self.t = threading.Thread(target=self._unqueue)
+		self.t.start()
+	#
+	def _unqueue(self):
+		while True:
+			val = self.queue.get()
+			if val is None:
+				break
+			#
+			self.unqueued.append(val)
+		#
+	#
+	def stop(self):
+		self.queue.put(None)
+	#
+	def join(self):
+		self.t.join()
+	#
+#
+class ThroughputServer:
+	def __init__(self, bind_endpoint, stop_event=None):
+		self.bind = bind_endpoint
+		self.stop_event = stop_event
+		#
+		self.server_listener = basic_protocols.ServerListener(bind_endpoint, self._accept_callback)
+		#
+		self.processes = []
+		self.process_counter = 0
+		#
+		self.results_queue = multiprocessing.Queue()
+		self.results = []
+		self.results_getter = QueueGetter(self.results_queue, self.results)
+		#
+		if self.stop_event is not None:
+			self.event_thread = threading.Thread(target=self._wait_for_event, args=(self.stop_event, self._stop))
+			self.event_thread.start()
+		else:
+			self.event_thread = None
+		#
+	#
+	def _accept_callback(self, socket):
+		conn_id = self.process_counter
+		self.process_counter += 1
+		#
+		p = multiprocessing.Process(target=self._start_server_conn, args=(socket, conn_id))
+		self.processes.append(p)
+		p.start()
+		#
+		# close this process' copy of the socket
+		socket.close()
+	#
+	def _start_server_conn(self, socket, conn_id):
+		results_callback = lambda results: self.results_queue.put({'conn_id':conn_id, 'results':results})
+		protocol = throughput_protocols.ServerProtocol(socket, results_callback=results_callback,
+		                                               use_acceleration=True)
+		try:
+			protocol.run()
+		finally:
+			socket.close()
+		#
+	#
+	def _wait_for_event(self, event, callback):
+		event.wait()
+		callback()
+	#
+	def _stop(self):
+		self.server_listener.stop()
+	#
+	def run(self):
+		try:
+			while True:
+				self.server_listener.accept()
+			#
+		except OSError as e:
+			if e.errno == 22 and self.stop_event is not None and self.stop_event.is_set():
+				# we closed the socket on purpose
+				pass
+			else:
+				raise
+			#
+		finally:
+			if self.stop_event is not None:
+				# set the event to stop the thread
+				self.stop_event.set()
+			#
+			if self.event_thread is not None:
+				# make sure the event thread is stopped
+				self.event_thread.join()
+			#
+			for p in self.processes:
+				# wait for all processes to finish
+				p.join()
+			#
+			# finish reading from the results queue
+			self.results_getter.stop()
+			self.results_getter.join()
+		#
+	#
+#
+if __name__ == '__main__':
+	logging.basicConfig(level=logging.DEBUG)
+	#
+	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
+	parser.add_argument('port', type=int, help='listen on port')
+	parser.add_argument('--localhost', action='store_true', help='bind to 127.0.0.1 instead of 0.0.0.0')
+	args = parser.parse_args()
+	#
+	if args.localhost:
+		bind_to = ('127.0.0.1', args.port)
+	else:
+		bind_to = ('0.0.0.0', args.port)
+	#
+	stop_event = multiprocessing.Event()
+	server = ThroughputServer(bind_to, None)
+	try:
+		server.run()
+	except KeyboardInterrupt:
+		print('')
+		logging.debug('Server stopped (KeyboardInterrupt).')
+	#
+
+	#t = threading.Thread(target=server.run)
+	#t.start()
+	#try:
+	#	t.join()
+	#except KeyboardInterrupt:
+	#	print('')
+	#	stop_event.set()
+	#
+	
+	#p = multiprocessing.Process(target=server.run)
+	#p.start()
+	#try:
+	#	p.join()
+	#except KeyboardInterrupt:
+	#	print('')
+	#	stop_event.set()
+	#
+
+	results = server.results
+	#
+	group_ids = list(set([x['results']['custom_data'] for x in results]))
+	groups = [(g, [r['results'] for r in results if r['results']['custom_data'] == g]) for g in group_ids]
+	#
+	for (group_id, group) in groups:
+		avg_data_size = sum([x['data_size'] for x in group])/len(group)
+		avg_transfer_rate = sum([x['transfer_rate'] for x in group])/len(group)
+		time_of_first_byte = min([x['time_of_first_byte'] for x in group])
+		time_of_last_byte = max([x['time_of_last_byte'] for x in group])
+		total_transfer_rate = sum([x['data_size'] for x in group])/(time_of_last_byte-time_of_first_byte)
+		#
+		logging.info('Group id: %s', int.from_bytes(group_id, byteorder='big') if len(group_id)!=0 else None)
+		logging.info('  Group size: %d', len(group))
+		logging.info('  Avg Transferred (MiB): %.4f', avg_data_size/(1024**2))
+		logging.info('  Avg Transfer rate (MiB/s): %.4f', avg_transfer_rate/(1024**2))
+		logging.info('  Total Transfer rate (MiB/s): %.4f', total_transfer_rate/(1024**2))
+	#
+#

+ 1 - 1
src/useful.py

@@ -1,5 +1,5 @@
 def parse_bytes(bytes_str):
-	conversions = {'B':1, 'KiB':1024, 'MiB':1024**2, 'GiB':1024**3}
+	conversions = {'B':1, 'KiB':1024, 'MiB':1024**2, 'GiB':1024**3, 'TiB':1024**4}
 	#
 	matching_conversions = [x for x in conversions if bytes_str.endswith(x)]
 	if len(matching_conversions) > 0: