Преглед на файлове

Experiment code progress.

Steven Engler преди 5 години
родител
ревизия
89d85b03d0
променени са 9 файла, в които са добавени 883 реда и са изтрити 468 реда
  1. 14 6
      src/chutney_manager.py
  2. 368 0
      src/experiment.py
  3. 4 46
      src/numa.py
  4. 35 0
      src/temp_server.py
  5. 1 1
      src/throughput_protocols.py
  6. 0 174
      src/throughput_server.new.py
  7. 274 0
      src/throughput_server.old.py
  8. 118 241
      src/throughput_server.py
  9. 69 0
      src/useful.py

+ 14 - 6
src/chutney_manager.py

@@ -123,6 +123,7 @@ if __name__ == '__main__':
 	import time
 	import tempfile
 	import numa
+	import useful
 	#
 	logging.basicConfig(level=logging.DEBUG)
 	#
@@ -154,12 +155,12 @@ if __name__ == '__main__':
 		node.options['numa_settings'] = (numa_node, processors)
 		numa_sets.append((numa_node, processors))
 	#
-	print(numa_sets)
-	unused_processors = numa.generate_range_list([z for node in numa_remaining for y in numa_remaining[node]['physical_cores'] for z in y])
-	print(unused_processors)
+	print('Used processors: {}'.format(numa_sets))
+	unused_processors = useful.generate_range_list([z for node in numa_remaining for y in numa_remaining[node]['physical_cores'] for z in y])
+	print('Unused processors: {}'.format(unused_processors))
 	#
 	nicknames = [nodes[x].guess_nickname(x) for x in range(len(nodes))]
-	print(nicknames)
+	print('Nicknames: {}'.format(nicknames))
 	#
 	(fd, tmp_network_file) = tempfile.mkstemp(prefix='chutney-network-')
 	try:
@@ -173,8 +174,15 @@ if __name__ == '__main__':
 			for nick in nicknames:
 				fingerprints.append(read_fingerprint(nick, chutney_path))
 			#
-			print(fingerprints)
-			time.sleep(5)
+			print('Fingerprints: {}'.format(fingerprints))
+			print('Press Ctrl-C to stop.')
+			try:
+				while True:
+					time.sleep(60)
+				#
+			except KeyboardInterrupt:
+				print()
+			#
 		#
 	finally:
 		os.remove(tmp_network_file)

+ 368 - 0
src/experiment.py

@@ -1,2 +1,370 @@
 #!/usr/bin/python3
 #
+import stem.control
+import stem.descriptor.remote
+import stem.process
+import socket
+import logging
+import multiprocessing
+import queue
+import random
+import time
+import json
+import os
+#
+import basic_protocols
+import throughput_protocols
+import useful
+#
+def get_socks_port(control_port):
+	with stem.control.Controller.from_port(port=control_port) as controller:
+		controller.authenticate()
+		#
+		socks_addresses = controller.get_listeners(stem.control.Listener.SOCKS)
+		assert(len(socks_addresses) == 1)
+		assert(socks_addresses[0][0] == '127.0.0.1')
+		#
+		return socks_addresses[0][1]
+	#
+#
+def wait_then_sleep(event, duration):
+	event.wait()
+	time.sleep(duration)
+#
+def send_measureme(stem_controller, circuit_id, measureme_id, hop):
+	response = stem_controller.msg('SENDMEASUREME %s ID=%s HOP=%s' % (circuit_id, measureme_id, hop))
+	stem.response.convert('SINGLELINE', response)
+	#
+	if not response.is_ok():
+		if response.code in ('512', '552'):
+			if response.message.startswith('Unknown circuit '):
+				raise stem.InvalidArguments(response.code, response.message, [circuit_id])
+			#
+			raise stem.InvalidRequest(response.code, response.message)
+		else:
+			raise stem.ProtocolError('MEASUREME returned unexpected response code: %s' % response.code)
+		#
+	#
+#
+def send_measureme_cells(control_address, circuit_id, measureme_id, hops):
+	logging.debug('Sending measuremes to control address {}, then sleeping'.format(control_address))
+	with stem.control.Controller.from_port(address=control_address[0], port=control_address[1]) as controller:
+		controller.authenticate()
+		for hop in hops:
+			send_measureme(controller, circuit_id, measureme_id, hop)
+		#
+	#
+#
+def send_measureme_cells_and_wait(control_port, circuit_id, measureme_id, hops, wait_event, wait_offset):
+	send_measureme_cells(control_port, circuit_id, measureme_id, hops)
+	wait_then_sleep(wait_event, wait_offset)
+#
+def get_fingerprints(consensus):
+	"""
+	Get the fingerprints of all relays.
+	"""
+	#
+	return [desc.fingerprint for desc in consensus]
+#
+def get_exit_fingerprints(consensus, endpoint):
+	"""
+	Get the fingerprints of relays that can exit to the endpoint.
+	"""
+	#
+	return [desc.fingerprint for desc in consensus if desc.exit_policy.can_exit_to(*endpoint)]
+#
+class ExperimentController:
+	def __init__(self, control_address):
+		self.control_address = control_address
+		self.connection = None
+		self.circuits = {}
+		self.unassigned_circuit_ids = []
+		self.assigned_streams = {}
+	#
+	def connect(self):
+		self.connection = stem.control.Controller.from_port(address=self.control_address[0], port=self.control_address[1])
+		self.connection.authenticate()
+		#
+		self.connection.add_event_listener(self._attach_stream, stem.control.EventType.STREAM)
+		self.connection.set_conf('__LeaveStreamsUnattached', '1')
+	#
+	def disconnect(self):
+		#if len(self.unused_circuit_ids) > 0:
+		#	logging.warning('Closed stem controller before all circuits were used')
+		#
+		self.connection.close()
+	#
+	def assign_stream(self, from_address):
+		circuit_id = self.unassigned_circuit_ids.pop(0)
+		self.assigned_streams[from_address] = circuit_id
+		return circuit_id
+	#
+	def _attach_stream(self, stream):
+		try:
+			if stream.status == 'NEW':
+				# by default, let tor handle new streams
+				circuit_id = 0
+				#
+				if stream.purpose == 'USER':
+					# this is probably one of our streams (although not guaranteed)
+					circuit_id = self.assigned_streams[(stream.source_address, stream.source_port)]
+				#
+				try:
+					self.connection.attach_stream(stream.id, circuit_id)
+					#logging.debug('Attaching to circuit {}'.format(circuit_id))
+				except stem.UnsatisfiableRequest:
+					if stream.purpose != 'USER':
+						# could not attach a non-user stream, so probably raised:
+						# stem.UnsatisfiableRequest: Connection is not managed by controller.
+						# therefore we should ignore this exception
+						pass
+					else:
+						raise
+					#
+				#
+			#
+		except:
+			logging.exception('Error while attaching the stream (control_port={}, circuit_id={}).'.format(self.control_port, circuit_id))
+			raise
+		#
+	#
+	def build_circuit(self, circuit_generator):
+		circuit_id = None
+		#
+		while circuit_id is None:
+			try:
+				circuit = circuit_generator()
+				circuit_id = self.connection.new_circuit(circuit, await_build=True)
+				logging.debug('New circuit (id={}): {}'.format(circuit_id, circuit))
+			except stem.CircuitExtensionFailed:
+				logging.debug('Failed circuit: {}'.format(circuit))
+				logging.warning('Circuit creation failed. Retrying...')
+			#
+		#
+		self.unassigned_circuit_ids.append(circuit_id)
+		self.circuits[circuit_id] = circuit
+	#
+#
+class ExperimentProtocol(basic_protocols.ChainedProtocol):
+	def __init__(self, socket, endpoint, num_bytes, custom_data=None, send_buffer_len=None, push_start_cb=None):
+		proxy_username = bytes([z for z in os.urandom(12) if z != 0])
+		proxy_protocol = basic_protocols.Socks4Protocol(socket, endpoint, username=proxy_username)
+		#
+		throughput_protocol = throughput_protocols.ClientProtocol(socket, num_bytes,
+		                                                          custom_data=custom_data,
+		                                                          send_buffer_len=send_buffer_len,
+		                                                          use_acceleration=True,
+		                                                          push_start_cb=push_start_cb)
+		#
+		super().__init__([proxy_protocol, throughput_protocol])
+	#
+#
+class ExperimentProtocolManager():
+	def __init__(self):
+		self.stopped = False
+		self.process_counter = 0
+		self.used_ids = []
+		self.running_processes = {}
+		self.global_finished_process_queue = multiprocessing.Queue()
+		self.local_finished_process_queue = queue.Queue()
+		self.queue_getter = useful.QueueGetter(self.global_finished_process_queue,
+		                                       self.local_finished_process_queue.put)
+	#
+	def _run_client(self, protocol, protocol_id):
+		had_error = False
+		try:
+			logging.debug('Starting protocol (id: {})'.format(protocol_id))
+			protocol.run()
+			logging.debug('Done protocol (id: {})'.format(protocol_id))
+		except:
+			had_error = True
+			logging.warning('Protocol error')
+			logging.exception('Protocol id: {} had an error'.format(protocol_id))
+		finally:
+			self.global_finished_process_queue.put((protocol_id, had_error))
+		#
+	#
+	def start_experiment_protocol(self, protocol, protocol_id=None):
+		if protocol_id is None:
+			protocol_id = self.process_counter
+		#
+		assert not self.stopped
+		assert protocol_id not in self.used_ids, 'Protocol ID already used'
+		#
+		p = multiprocessing.Process(target=self._run_client, args=(protocol, protocol_id))
+		self.running_processes[protocol_id] = p
+		self.used_ids.append(protocol_id)
+		#
+		p.start()
+		self.process_counter += 1
+		#
+		#protocol.socket.close()
+	#
+	def wait(self, finished_protocol_cb=None):
+		while len(self.running_processes) > 0:
+			logging.debug('Waiting for processes ({} left)'.format(len(self.running_processes)))
+			#
+			(protocol_id, had_error) = self.local_finished_process_queue.get()
+			p = self.running_processes[protocol_id]
+			p.join()
+			self.running_processes.pop(protocol_id)
+			finished_protocol_cb(protocol_id, had_error)
+		#
+	#
+	def stop(self):
+		self.wait()
+		self.queue_getter.stop()
+		self.queue_getter.join()
+		self.stopped = True
+	#
+#
+def build_client_protocol(endpoint, socks_address, control_address, controller, wait_duration=0, measureme_id=None, num_bytes=None, buffer_len=None):
+	client_socket = socket.socket()
+	#
+	logging.debug('Socket %d connecting to proxy %r...', client_socket.fileno(), socks_address)
+	client_socket.connect(socks_address)
+	logging.debug('Socket %d connected', client_socket.fileno())
+	#
+	custom_data = {}
+	#
+	circuit_id = controller.assign_stream(client_socket.getsockname())
+	custom_data['circuit'] = (circuit_id, controller.circuits[circuit_id])
+	#
+	if measureme_id is not None:
+		custom_data['measureme_id'] = measureme_id
+		#
+		hops = list(range(len(controller.circuits[circuit_id])+1))[::-1]
+		# send the measureme cells to the last relay first
+		start_cb = lambda control_address=control_address, circuit_id=circuit_id, measureme_id=measureme_id, \
+						  hops=hops, event=start_event, wait_duration=wait_duration: \
+						  send_measureme_cells_and_wait(control_address, circuit_id, measureme_id, hops, event, wait_duration)
+	else:
+		start_cb = lambda event=start_event, duration=wait_duration: wait_then_sleep(event, duration)
+	#
+	custom_data = json.dumps(custom_data).encode('utf-8')
+	protocol = ExperimentProtocol(client_socket, endpoint, args.num_bytes,
+								  custom_data=custom_data,
+								  send_buffer_len=args.buffer_len,
+								  push_start_cb=start_cb)
+	return protocol
+#
+if __name__ == '__main__':
+	import argparse
+	#
+	logging.basicConfig(level=logging.DEBUG)
+	logging.getLogger('stem').setLevel(logging.WARNING)
+	#
+	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
+	parser.add_argument('ip', type=str, help='destination ip address')
+	parser.add_argument('port', type=int, help='destination port')
+	parser.add_argument('num_bytes', type=useful.parse_bytes,
+	                    help='number of bytes to send per connection (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='num-bytes')
+	parser.add_argument('num_streams_per_client', type=int, help='number of streams per Tor client', metavar='num-streams-per-client')
+	parser.add_argument('--buffer-len', type=useful.parse_bytes,
+	                    help='size of the send and receive buffers (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='bytes')
+	parser.add_argument('--wait-range', type=int, default=0,
+	                    help='add a random wait time to each connection so that they don\'t all start at the same time (default is 0)', metavar='time')
+	parser.add_argument('--proxy-control-ports', type=useful.parse_range_list, help='range of ports for the control ports', metavar='control-ports')
+	parser.add_argument('--measureme', action='store_true', help='send measureme cells to the exit')
+	args = parser.parse_args()
+	#
+	endpoint = (args.ip, args.port)
+	#
+	logging.debug('Getting consensus')
+	try:
+		consensus = stem.descriptor.remote.get_consensus(endpoints=(stem.DirPort('127.0.0.1', 7000),))
+	except Exception as e:
+		raise Exception('Unable to retrieve the consensus') from e
+	#
+	fingerprints = get_fingerprints(consensus)
+	exit_fingerprints = get_exit_fingerprints(consensus, endpoint)
+	non_exit_fingerprints = list(set(fingerprints)-set(exit_fingerprints))
+	#
+	assert len(exit_fingerprints) == 1, 'Need exactly one exit relay'
+	assert len(non_exit_fingerprints) >= 1, 'Need at least one non-exit relay'
+	#
+	circuit_generator = lambda: [random.choice(non_exit_fingerprints), exit_fingerprints[0]]
+	#
+	proxy_addresses = []
+	for control_port in args.proxy_control_ports:
+		proxy = {}
+		proxy['control'] = ('127.0.0.1', control_port)
+		proxy['socks'] = ('127.0.0.1', get_socks_port(control_port))
+		proxy_addresses.append(proxy)
+	#
+	controllers = []
+	protocol_manager = ExperimentProtocolManager()
+	#
+	try:
+		for proxy_address in proxy_addresses:
+			controller = ExperimentController(proxy_address['control'])
+			controller.connect()
+			# the controller has to attach new streams to circuits, so the
+			# connection has to stay open until we're done creating streams
+			#
+			for _ in range(args.num_streams_per_client):
+				# make a circuit for each stream
+				controller.build_circuit(circuit_generator)
+				time.sleep(0.5)
+			#
+			controllers.append(controller)
+		#
+		start_event = multiprocessing.Event()
+		#
+		for stream_index in range(args.num_streams_per_client):
+			for (controller_index, proxy_address, controller) in zip(range(len(controllers)), proxy_addresses, controllers):
+				'''
+				client_socket = socket.socket()
+				#
+				logging.debug('Socket %d connecting to proxy %r...', client_socket.fileno(), proxy_address['socks'])
+				client_socket.connect(proxy_address['socks'])
+				logging.debug('Socket %d connected', client_socket.fileno())
+				#
+				wait_offset = random.randint(0, args.wait_range)
+				custom_data = {}
+				#
+				circuit_id = controller.assign_stream(client_socket.getsockname())
+				custom_data['circuit'] = (circuit_id, controller.circuits[circuit_id])
+				#
+				if args.measureme:
+					measureme_id = stream_index*args.num_streams_per_client + controllers.index(controller) + 1
+					custom_data['measureme_id'] = measureme_id
+					#
+					hops = list(range(len(controller.circuits[circuit_id])+1))[::-1]
+					# send the measureme cells to the last relay first
+					start_cb = lambda control_address=proxy_address['control'], circuit_id=circuit_id, measureme_id=measureme_id, \
+				                      hops=hops, event=start_event, wait_offset=wait_offset: \
+				                      send_measureme_cells_and_wait(control_address, circuit_id, measureme_id, hops, event, wait_offset)
+				else:
+					start_cb = lambda event=start_event, duration=wait_offset: wait_then_sleep(event, duration)
+				#
+				custom_data = json.dumps(custom_data).encode('utf-8')
+				protocol = ExperimentProtocol(client_socket, endpoint, args.num_bytes,
+				                              custom_data=custom_data,
+				                              send_buffer_len=args.buffer_len,
+				                              push_start_cb=start_cb)
+				#
+				'''
+				if args.measureme:
+					measureme_id = stream_index*args.num_streams_per_client + controller_index + 1
+				else:
+					measureme_id = None
+				#
+				wait_duration = random.randint(0, args.wait_range)
+				protocol = build_client_protocol(endpoint, proxy_address['socks'], proxy_address['control'], controller,
+				                                 wait_duration=wait_duration, measureme_id=measureme_id,
+				                                 num_bytes=args.num_bytes, buffer_len=args.buffer_len)
+				protocol_manager.start_experiment_protocol(protocol, protocol_id=None)
+			#
+		#
+		time.sleep(2)
+		start_event.set()
+		#
+		protocol_manager.wait(finished_protocol_cb=lambda protocol_id,had_error: logging.info('Finished {} (had_error={})'.format(protocol_id,had_error)))
+	finally:
+		for controller in controllers:
+			controller.disconnect()
+		#
+		protocol_manager.stop()
+	#
+#

+ 4 - 46
src/numa.py

@@ -1,47 +1,5 @@
 import os
-#
-def parse_range_list(range_list_str):
-	"""
-	Take an input like '1-3,5,7-10' and return a list like [1,2,3,5,7,8,9,10].
-	"""
-	#
-	range_strings = range_list_str.split(',')
-	all_items = []
-	for range_str in range_strings:
-		if '-' in range_str:
-			# in the form '12-34'
-			range_ends = [int(x) for x in range_str.split('-')]
-			assert(len(range_ends) == 2)
-			all_items.extend(range(range_ends[0], range_ends[1]+1))
-		else:
-			# just a number
-			all_items.append(int(range_str))
-		#
-	#
-	return all_items
-#
-def generate_range_list(l):
-	"""
-	Take a list like [1,2,3,5,7,8,9,10] and return a string like '1-3,5,7-10'.
-	"""
-	#
-	l = list(set(sorted(l)))
-	ranges = []
-	current_range_start = None
-	#
-	for index in range(len(l)):
-		if current_range_start is None:
-			current_range_start = l[index]
-		else:
-			if l[index] != l[index-1]+1:
-				ranges.append((current_range_start, l[index-1]))
-				current_range_start = l[index]
-			#
-		#
-	#
-	ranges.append((current_range_start, l[-1]))
-	#
-	return ','.join(['-'.join([str(y) for y in x]) if x[0] != x[1] else str(x[0]) for x in ranges])
+import useful
 #
 def check_path_traversal(path, base_path):
 	# this is not guarenteed to be secure
@@ -52,21 +10,21 @@ def check_path_traversal(path, base_path):
 def get_online_numa_nodes():
 	with open('/sys/devices/system/node/online', 'r') as f:
 		online_nodes = f.read().strip()
-		return parse_range_list(online_nodes)
+		return useful.parse_range_list(online_nodes)
 	#
 #
 def get_processors_in_numa_node(node_id):
 	path = '/sys/devices/system/node/node{}/cpulist'.format(node_id)
 	check_path_traversal(path, '/sys/devices/system/node')
 	with open(path, 'r') as f:
-		return parse_range_list(f.read().strip())
+		return useful.parse_range_list(f.read().strip())
 	#
 #
 def get_thread_siblings(processor_id):
 	path = '/sys/devices/system/cpu/cpu{}/topology/thread_siblings_list'.format(processor_id)
 	check_path_traversal(path, '/sys/devices/system/cpu')
 	with open(path, 'r') as f:
-		return parse_range_list(f.read().strip())
+		return useful.parse_range_list(f.read().strip())
 	#
 #
 def get_numa_overview():

+ 35 - 0
src/temp_server.py

@@ -0,0 +1,35 @@
+#!/usr/bin/env python3
+#
+import logging
+import argparse
+import multiprocessing
+#
+import throughput_server
+#
+if __name__ == '__main__':
+	logging.basicConfig(level=logging.DEBUG)
+	#
+	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
+	parser.add_argument('port', type=int, help='listen on port')
+	parser.add_argument('--localhost', action='store_true', help='bind to 127.0.0.1 instead of 0.0.0.0')
+	args = parser.parse_args()
+	#
+	if args.localhost:
+		bind_to = ('127.0.0.1', args.port)
+	else:
+		bind_to = ('0.0.0.0', args.port)
+	#
+	stop_event = multiprocessing.Event()
+	server = throughput_server.ThroughputServer(bind_to, None)
+	try:
+		server.run()
+	except KeyboardInterrupt:
+		print('')
+		logging.debug('Server stopped (KeyboardInterrupt).')
+	#
+	results = server.results
+	#
+	for x in results:
+		logging.info('{}'.format(x['results']['custom_data']))
+	#
+#

+ 1 - 1
src/throughput_protocols.py

@@ -16,7 +16,7 @@ class ClientProtocol(basic_protocols.Protocol):
 		self.push_start_cb = push_start_cb
 		self.push_done_cb = push_done_cb
 		#
-		self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN SEND_CUSTOM_DATA PUSH_DATA DONE') #WAIT
+		self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN SEND_CUSTOM_DATA PUSH_DATA DONE')
 		self.state = self.states.READY_TO_BEGIN
 		#
 		self.sub_protocol = None

+ 0 - 174
src/throughput_server.new.py

@@ -1,174 +0,0 @@
-#!/usr/bin/python3
-#
-import throughput_protocols
-import basic_protocols
-import os
-import multiprocessing
-import threading
-import queue
-import logging
-import argparse
-#
-class QueueGetter:
-	def __init__(self, queue, unqueued):
-		self.queue = queue
-		self.unqueued = unqueued
-		#
-		self.t = threading.Thread(target=self._unqueue)
-		self.t.start()
-	#
-	def _unqueue(self):
-		while True:
-			val = self.queue.get()
-			if val is None:
-				break
-			#
-			self.unqueued.append(val)
-		#
-	#
-	def stop(self):
-		self.queue.put(None)
-	#
-	def join(self):
-		self.t.join()
-	#
-#
-class ThroughputServer:
-	def __init__(self, bind_endpoint, stop_event=None):
-		self.bind = bind_endpoint
-		self.stop_event = stop_event
-		#
-		self.server_listener = basic_protocols.ServerListener(bind_endpoint, self._accept_callback)
-		#
-		self.processes = []
-		self.process_counter = 0
-		#
-		self.results_queue = multiprocessing.Queue()
-		self.results = []
-		self.results_getter = QueueGetter(self.results_queue, self.results)
-		#
-		if self.stop_event is not None:
-			self.event_thread = threading.Thread(target=self._wait_for_event, args=(self.stop_event, self._stop))
-			self.event_thread.start()
-		else:
-			self.event_thread = None
-		#
-	#
-	def _accept_callback(self, socket):
-		conn_id = self.process_counter
-		self.process_counter += 1
-		#
-		p = multiprocessing.Process(target=self._start_server_conn, args=(socket, conn_id))
-		self.processes.append(p)
-		p.start()
-		#
-		# close this process' copy of the socket
-		socket.close()
-	#
-	def _start_server_conn(self, socket, conn_id):
-		results_callback = lambda results: self.results_queue.put({'conn_id':conn_id, 'results':results})
-		protocol = throughput_protocols.ServerProtocol(socket, results_callback=results_callback,
-		                                               use_acceleration=True)
-		try:
-			protocol.run()
-		finally:
-			socket.close()
-		#
-	#
-	def _wait_for_event(self, event, callback):
-		event.wait()
-		callback()
-	#
-	def _stop(self):
-		self.server_listener.stop()
-	#
-	def run(self):
-		try:
-			while True:
-				self.server_listener.accept()
-			#
-		except OSError as e:
-			if e.errno == 22 and self.stop_event is not None and self.stop_event.is_set():
-				# we closed the socket on purpose
-				pass
-			else:
-				raise
-			#
-		finally:
-			if self.stop_event is not None:
-				# set the event to stop the thread
-				self.stop_event.set()
-			#
-			if self.event_thread is not None:
-				# make sure the event thread is stopped
-				self.event_thread.join()
-			#
-			for p in self.processes:
-				# wait for all processes to finish
-				p.join()
-			#
-			# finish reading from the results queue
-			self.results_getter.stop()
-			self.results_getter.join()
-		#
-	#
-#
-if __name__ == '__main__':
-	logging.basicConfig(level=logging.DEBUG)
-	#
-	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
-	parser.add_argument('port', type=int, help='listen on port')
-	parser.add_argument('--localhost', action='store_true', help='bind to 127.0.0.1 instead of 0.0.0.0')
-	args = parser.parse_args()
-	#
-	if args.localhost:
-		bind_to = ('127.0.0.1', args.port)
-	else:
-		bind_to = ('0.0.0.0', args.port)
-	#
-	stop_event = multiprocessing.Event()
-	server = ThroughputServer(bind_to, None)
-	try:
-		server.run()
-	except KeyboardInterrupt:
-		print('')
-		logging.debug('Server stopped (KeyboardInterrupt).')
-	#
-
-	#t = threading.Thread(target=server.run)
-	#t.start()
-	#try:
-	#	t.join()
-	#except KeyboardInterrupt:
-	#	print('')
-	#	stop_event.set()
-	#
-	
-	#p = multiprocessing.Process(target=server.run)
-	#p.start()
-	#try:
-	#	p.join()
-	#except KeyboardInterrupt:
-	#	print('')
-	#	stop_event.set()
-	#
-
-	results = server.results
-	#
-	group_ids = list(set([x['results']['custom_data'] for x in results]))
-	groups = [(g, [r['results'] for r in results if r['results']['custom_data'] == g]) for g in group_ids]
-	#
-	for (group_id, group) in groups:
-		avg_data_size = sum([x['data_size'] for x in group])/len(group)
-		avg_transfer_rate = sum([x['transfer_rate'] for x in group])/len(group)
-		time_of_first_byte = min([x['time_of_first_byte'] for x in group])
-		time_of_last_byte = max([x['time_of_last_byte'] for x in group])
-		total_transfer_rate = sum([x['data_size'] for x in group])/(time_of_last_byte-time_of_first_byte)
-		#
-		logging.info('Group id: %s', int.from_bytes(group_id, byteorder='big') if len(group_id)!=0 else None)
-		logging.info('  Group size: %d', len(group))
-		logging.info('  Avg Transferred (MiB): %.4f', avg_data_size/(1024**2))
-		logging.info('  Avg Transfer rate (MiB/s): %.4f', avg_transfer_rate/(1024**2))
-		logging.info('  Total Transfer rate (MiB/s): %.4f', total_transfer_rate/(1024**2))
-	#
-#

+ 274 - 0
src/throughput_server.old.py

@@ -0,0 +1,274 @@
+#!/usr/bin/python3
+#
+import throughput_protocols
+import basic_protocols
+import os
+import multiprocessing
+import threading
+import queue
+import logging
+import argparse
+#
+def overlap_byte_counters(byte_counters):
+	start_time = None
+	finish_time = None
+	for x in byte_counters:
+		if start_time is None or x['start_time'] < start_time:
+			start_time = x['start_time']
+		#
+		if finish_time is None or x['start_time']+len(x['history']) > finish_time:
+			finish_time = x['start_time']+len(x['history'])
+		#
+	#
+	total_history = [0]*(finish_time-start_time)
+	#
+	for x in byte_counters:
+		for y in range(len(x['history'])):
+			total_history[(x['start_time']-start_time)+y] += x['history'][y]
+		#
+	#
+	return total_history
+#
+if __name__ == '__main__':
+	logging.basicConfig(level=logging.DEBUG)
+	#
+	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
+	parser.add_argument('port', type=int, help='listen on port')
+	parser.add_argument('--no-accel', action='store_true', help='don\'t use C acceleration (use pure Python)')
+	parser.add_argument('--localhost', action='store_true', help='bind to 127.0.0.1 instead of 0.0.0.0')
+	args = parser.parse_args()
+	#
+	if args.localhost:
+		endpoint = ('127.0.0.1', args.port)
+	else:
+		endpoint = ('0.0.0.0', args.port)
+	#
+	processes = []
+	processes_map = {}
+	joinable_connections = multiprocessing.Queue()
+	joinable_connections_list = []
+	conn_counter = [0]
+	group_queue = multiprocessing.Queue()
+	group_queue_list = []
+	bw_queue = multiprocessing.Queue()
+	bw_queue_list = []
+	#
+	def group_id_callback(conn_id, group_id):
+		# put them in a queue to display later
+		#logging.debug('For conn %d Received group id: %d', conn_id, group_id)
+		group_queue.put({'conn_id':conn_id, 'group_id':group_id})
+	#
+	#def bw_callback(conn_id, data_size, time_first_byte, time_last_byte, transfer_rate, byte_counter, byte_counter_start_time):
+	def bw_callback(conn_id, custom_data, data_size, time_first_byte, time_last_byte, transfer_rate, deltas):
+		# put them in a queue to display later
+		#bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'time_of_first_byte':time_first_byte, 'time_of_last_byte':time_last_byte, 'transfer_rate':transfer_rate, 'byte_counter':byte_counter, 'byte_counter_start_time':byte_counter_start_time})
+		bw_queue.put({'conn_id':conn_id, 'custom_data':custom_data, 'data_size':data_size, 'time_of_first_byte':time_first_byte, 'time_of_last_byte':time_last_byte, 'transfer_rate':transfer_rate, 'deltas':deltas})
+	#
+	def start_server_conn(socket, conn_id):
+		server = throughput_protocols.ServerProtocol(socket, conn_id, group_id_callback=group_id_callback,
+		                                             bandwidth_callback=bw_callback, use_acceleration=(not args.no_accel))
+		try:
+			server.run()
+		except KeyboardInterrupt:
+			socket.close()
+		finally:
+			joinable_connections.put(conn_id)
+			'''
+			while True:
+				# while we're waiting to join, we might get a KeyboardInterrupt,
+				# in which case we cannot let the process end since it will kill
+				# the queue threads, which may be waiting to push data to the pipe
+				try:
+					joinable_connections.close()
+					group_queue.close()
+					bw_queue.close()
+					#
+					group_queue.join_thread()
+					bw_queue.join_thread()
+					joinable_connections.join_thread()
+					#
+					break
+				except KeyboardInterrupt:
+					pass
+				#
+			#
+			'''
+		#
+	#
+	def accept_callback(socket):
+		conn_id = conn_counter[0]
+		conn_counter[0] += 1
+		#logging.debug('Adding connection %d', conn_id)
+		p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
+		processes.append(p)
+		processes_map[conn_id] = p
+		p.start()
+		socket.close()
+		# close this process' copy of the socket
+	#
+	def unqueue(q, l, print_len=False):
+		while True:
+			val = q.get()
+			if val is None:
+				break
+			#
+			l.append(val)
+			if print_len:
+				print('Queue length: {}'.format(len(l)), end='\r')
+			#
+		#
+	#
+	l = basic_protocols.ServerListener(endpoint, accept_callback)
+	#
+	t_joinable_connections = threading.Thread(target=unqueue, args=(joinable_connections, joinable_connections_list))
+	t_group_queue = threading.Thread(target=unqueue, args=(group_queue, group_queue_list))
+	t_bw_queue = threading.Thread(target=unqueue, args=(bw_queue, bw_queue_list, True))
+	#
+	t_joinable_connections.start()
+	t_group_queue.start()
+	t_bw_queue.start()
+	#
+	try:
+		while True:
+			l.accept()
+			'''
+			try:
+				while True:
+					conn_id = joinable_connections.get(False)
+					p = processes_map[conn_id]
+					p.join()
+				#
+			except queue.Empty:
+				pass
+			#
+			'''
+		#
+	except KeyboardInterrupt:
+		print()
+		#
+		try:
+			for p in processes:
+				p.join()
+			#
+		except KeyboardInterrupt:
+			pass
+		#
+		joinable_connections.put(None)
+		group_queue.put(None)
+		bw_queue.put(None)
+		t_joinable_connections.join()
+		t_group_queue.join()
+		t_bw_queue.join()
+		#
+		bw_values = {}
+		group_values = {}
+		#
+		'''
+		logging.info('BW queue length: {}'.format(bw_queue.qsize()))
+		logging.info('Group queue length: {}'.format(group_queue.qsize()))
+		#
+		temp_counter = 0
+		try:
+			while True:
+				bw_val = bw_queue.get(False)
+				bw_values[bw_val['conn_id']] = bw_val
+				temp_counter += 1
+			#
+		except queue.Empty:
+			pass
+		#
+		logging.info('temp counter: {}'.format(temp_counter))
+		import time
+		time.sleep(2)
+		try:
+			while True:
+				bw_val = bw_queue.get(False)
+				bw_values[bw_val['conn_id']] = bw_val
+				temp_counter += 1
+			#
+		except queue.Empty:
+			pass
+		#
+		logging.info('temp counter: {}'.format(temp_counter))
+		
+		#
+		try:
+			while True:
+				group_val = group_queue.get(False)
+				group_values[group_val['conn_id']] = group_val
+			#
+		except queue.Empty:
+			pass
+		#
+		logging.info('bw_values length: {}'.format(len(bw_values)))
+		logging.info('group_values length: {}'.format(len(group_values)))
+		logging.info('group_values set: {}'.format(list(set([x['group_id'] for x in group_values.values()]))))
+		#
+		'''
+		#
+		#logging.info('BW list length: {}'.format(len(bw_queue_list)))
+		#logging.info('Group list length: {}'.format(len(group_queue_list)))
+		#
+		for x in bw_queue_list:
+			bw_values[x['conn_id']] = x
+		#
+		for x in group_queue_list:
+			group_values[x['conn_id']] = x
+		#
+		group_set = set([x['group_id'] for x in group_values.values()])
+		for group in group_set:
+			# doesn't handle group == None
+			conns_in_group = [x[0] for x in group_values.items() if x[1]['group_id'] == group]
+			in_group = [x for x in bw_values.values() if x['conn_id'] in conns_in_group]
+			if len(in_group) > 0:
+				avg_data_size = sum([x['data_size'] for x in in_group])/len(in_group)
+				avg_transfer_rate = sum([x['transfer_rate'] for x in in_group])/len(in_group)
+				total_transfer_rate = sum([x['data_size'] for x in in_group])/(max([x['time_of_last_byte'] for x in in_group])-min([x['time_of_first_byte'] for x in in_group]))
+				#
+				logging.info('Group size: %d', len(in_group))
+				logging.info('Avg Transferred (MiB): %.4f', avg_data_size/(1024**2))
+				logging.info('Avg Transfer rate (MiB/s): %.4f', avg_transfer_rate/(1024**2))
+				logging.info('Total Transfer rate (MiB/s): %.4f', total_transfer_rate/(1024**2))
+				#
+				'''
+				import math
+				histories = [{'start_time':x['byte_counter_start_time'], 'history':x['byte_counter']} for x in in_group]
+				total_history = overlap_byte_counters(histories)
+				#
+				logging.info('Max Transfer rate (MiB/s): %.4f', max(total_history)/(1024**2))
+				if sum(total_history) != sum([x['data_size'] for x in in_group]):
+					logging.warning('History doesn\'t add up ({} != {}).'.format(sum(total_history), sum([x['data_size'] for x in in_group])))
+				#
+				import json
+				with open('/tmp/group-{}.json'.format(group), 'w') as f:
+					json.dump({'id':group, 'history':total_history, 'individual_histories':histories, 'size':len(in_group), 'avg_transferred':avg_data_size,
+					           'avg_transfer_rate':avg_transfer_rate, 'total_transfer_rate':total_transfer_rate}, f)
+				#
+				'''
+				custom_data = [x['custom_data'].decode('utf-8') for x in in_group]
+				#
+				histories = [x['deltas'] for x in in_group]
+				combined_timestamps, combined_bytes = zip(*sorted(zip([x for y in histories for x in y['timestamps']],
+				                                                      [x for y in histories for x in y['bytes']])))
+				combined_history = {'bytes':combined_bytes, 'timestamps':combined_timestamps}
+				#combined_history = sorted([item for sublist in histories for item in sublist['deltas']], key=lambda x: x['timestamp'])
+				#
+				sum_history_bytes = sum(combined_history['bytes'])
+				sum_data_bytes = sum([x['data_size'] for x in in_group])
+				if sum_history_bytes != sum_data_bytes:
+					logging.warning('History doesn\'t add up ({} != {}).'.format(sum_history_bytes, sum_data_bytes))
+				#
+				import json
+				import gzip
+				with gzip.GzipFile('/tmp/group-{}.json.gz'.format(group), 'w') as f:
+					f.write(json.dumps({'id':group, 'history':combined_history, 'individual_histories':histories, 'size':len(in_group),
+					                    'avg_transferred':avg_data_size, 'avg_transfer_rate':avg_transfer_rate,
+					                    'total_transfer_rate':total_transfer_rate, 'custom_data':custom_data}, f).encode('utf-8'))
+				#
+			#
+		#
+	#
+	for p in processes:
+		p.join()
+	#
+#

+ 118 - 241
src/throughput_server.py

@@ -8,267 +8,144 @@ import threading
 import queue
 import logging
 import argparse
+import useful
 #
-def overlap_byte_counters(byte_counters):
-	start_time = None
-	finish_time = None
-	for x in byte_counters:
-		if start_time is None or x['start_time'] < start_time:
-			start_time = x['start_time']
+class ThroughputServer:
+	def __init__(self, bind_endpoint, stop_event=None):
+		self.bind = bind_endpoint
+		self.stop_event = stop_event
 		#
-		if finish_time is None or x['start_time']+len(x['history']) > finish_time:
-			finish_time = x['start_time']+len(x['history'])
+		self.server_listener = basic_protocols.ServerListener(bind_endpoint, self._accept_callback)
 		#
-	#
-	total_history = [0]*(finish_time-start_time)
-	#
-	for x in byte_counters:
-		for y in range(len(x['history'])):
-			total_history[(x['start_time']-start_time)+y] += x['history'][y]
+		self.processes = []
+		self.process_counter = 0
 		#
-	#
-	return total_history
-#
-if __name__ == '__main__':
-	logging.basicConfig(level=logging.DEBUG)
-	#
-	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
-	parser.add_argument('port', type=int, help='listen on port')
-	parser.add_argument('--no-accel', action='store_true', help='don\'t use C acceleration (use pure Python)')
-	parser.add_argument('--localhost', action='store_true', help='bind to 127.0.0.1 instead of 0.0.0.0')
-	args = parser.parse_args()
-	#
-	if args.localhost:
-		endpoint = ('127.0.0.1', args.port)
-	else:
-		endpoint = ('0.0.0.0', args.port)
-	#
-	processes = []
-	processes_map = {}
-	joinable_connections = multiprocessing.Queue()
-	joinable_connections_list = []
-	conn_counter = [0]
-	group_queue = multiprocessing.Queue()
-	group_queue_list = []
-	bw_queue = multiprocessing.Queue()
-	bw_queue_list = []
-	#
-	def group_id_callback(conn_id, group_id):
-		# put them in a queue to display later
-		#logging.debug('For conn %d Received group id: %d', conn_id, group_id)
-		group_queue.put({'conn_id':conn_id, 'group_id':group_id})
-	#
-	#def bw_callback(conn_id, data_size, time_first_byte, time_last_byte, transfer_rate, byte_counter, byte_counter_start_time):
-	def bw_callback(conn_id, custom_data, data_size, time_first_byte, time_last_byte, transfer_rate, deltas):
-		# put them in a queue to display later
-		#bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'time_of_first_byte':time_first_byte, 'time_of_last_byte':time_last_byte, 'transfer_rate':transfer_rate, 'byte_counter':byte_counter, 'byte_counter_start_time':byte_counter_start_time})
-		bw_queue.put({'conn_id':conn_id, 'custom_data':custom_data, 'data_size':data_size, 'time_of_first_byte':time_first_byte, 'time_of_last_byte':time_last_byte, 'transfer_rate':transfer_rate, 'deltas':deltas})
-	#
-	def start_server_conn(socket, conn_id):
-		server = throughput_protocols.ServerProtocol(socket, conn_id, group_id_callback=group_id_callback,
-		                                             bandwidth_callback=bw_callback, use_acceleration=(not args.no_accel))
-		try:
-			server.run()
-		except KeyboardInterrupt:
-			socket.close()
-		finally:
-			joinable_connections.put(conn_id)
-			'''
-			while True:
-				# while we're waiting to join, we might get a KeyboardInterrupt,
-				# in which case we cannot let the process end since it will kill
-				# the queue threads, which may be waiting to push data to the pipe
-				try:
-					joinable_connections.close()
-					group_queue.close()
-					bw_queue.close()
-					#
-					group_queue.join_thread()
-					bw_queue.join_thread()
-					joinable_connections.join_thread()
-					#
-					break
-				except KeyboardInterrupt:
-					pass
-				#
-			#
-			'''
+		self.results_queue = multiprocessing.Queue()
+		self.results = []
+		self.results_getter = useful.QueueGetter(self.results_queue, self.results.append)
+		#
+		if self.stop_event is not None:
+			self.event_thread = threading.Thread(target=self._wait_for_event, args=(self.stop_event, self._stop))
+			self.event_thread.start()
+		else:
+			self.event_thread = None
 		#
 	#
-	def accept_callback(socket):
-		conn_id = conn_counter[0]
-		conn_counter[0] += 1
-		#logging.debug('Adding connection %d', conn_id)
-		p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
-		processes.append(p)
-		processes_map[conn_id] = p
+	def _accept_callback(self, socket):
+		conn_id = self.process_counter
+		self.process_counter += 1
+		#
+		p = multiprocessing.Process(target=self._start_server_conn, args=(socket, conn_id))
+		self.processes.append(p)
 		p.start()
-		socket.close()
+		#
 		# close this process' copy of the socket
+		socket.close()
 	#
-	def unqueue(q, l, print_len=False):
-		while True:
-			val = q.get()
-			if val is None:
-				break
-			#
-			l.append(val)
-			if print_len:
-				print('Queue length: {}'.format(len(l)), end='\r')
-			#
+	def _start_server_conn(self, socket, conn_id):
+		results_callback = lambda results: self.results_queue.put({'conn_id':conn_id, 'results':results})
+		protocol = throughput_protocols.ServerProtocol(socket, results_callback=results_callback,
+		                                               use_acceleration=True)
+		try:
+			protocol.run()
+		finally:
+			socket.close()
 		#
 	#
-	l = basic_protocols.ServerListener(endpoint, accept_callback)
+	def _wait_for_event(self, event, callback):
+		event.wait()
+		callback()
 	#
-	t_joinable_connections = threading.Thread(target=unqueue, args=(joinable_connections, joinable_connections_list))
-	t_group_queue = threading.Thread(target=unqueue, args=(group_queue, group_queue_list))
-	t_bw_queue = threading.Thread(target=unqueue, args=(bw_queue, bw_queue_list, True))
+	def _stop(self):
+		self.server_listener.stop()
 	#
-	t_joinable_connections.start()
-	t_group_queue.start()
-	t_bw_queue.start()
-	#
-	try:
-		while True:
-			l.accept()
-			'''
-			try:
-				while True:
-					conn_id = joinable_connections.get(False)
-					p = processes_map[conn_id]
-					p.join()
-				#
-			except queue.Empty:
-				pass
-			#
-			'''
-		#
-	except KeyboardInterrupt:
-		print()
-		#
-		try:
-			for p in processes:
-				p.join()
-			#
-		except KeyboardInterrupt:
-			pass
-		#
-		joinable_connections.put(None)
-		group_queue.put(None)
-		bw_queue.put(None)
-		t_joinable_connections.join()
-		t_group_queue.join()
-		t_bw_queue.join()
-		#
-		bw_values = {}
-		group_values = {}
-		#
-		'''
-		logging.info('BW queue length: {}'.format(bw_queue.qsize()))
-		logging.info('Group queue length: {}'.format(group_queue.qsize()))
-		#
-		temp_counter = 0
+	def run(self):
 		try:
 			while True:
-				bw_val = bw_queue.get(False)
-				bw_values[bw_val['conn_id']] = bw_val
-				temp_counter += 1
+				self.server_listener.accept()
 			#
-		except queue.Empty:
-			pass
-		#
-		logging.info('temp counter: {}'.format(temp_counter))
-		import time
-		time.sleep(2)
-		try:
-			while True:
-				bw_val = bw_queue.get(False)
-				bw_values[bw_val['conn_id']] = bw_val
-				temp_counter += 1
+		except OSError as e:
+			if e.errno == 22 and self.stop_event is not None and self.stop_event.is_set():
+				# we closed the socket on purpose
+				pass
+			else:
+				raise
 			#
-		except queue.Empty:
-			pass
-		#
-		logging.info('temp counter: {}'.format(temp_counter))
-		
-		#
-		try:
-			while True:
-				group_val = group_queue.get(False)
-				group_values[group_val['conn_id']] = group_val
+		finally:
+			if self.stop_event is not None:
+				# set the event to stop the thread
+				self.stop_event.set()
 			#
-		except queue.Empty:
-			pass
-		#
-		logging.info('bw_values length: {}'.format(len(bw_values)))
-		logging.info('group_values length: {}'.format(len(group_values)))
-		logging.info('group_values set: {}'.format(list(set([x['group_id'] for x in group_values.values()]))))
-		#
-		'''
-		#
-		#logging.info('BW list length: {}'.format(len(bw_queue_list)))
-		#logging.info('Group list length: {}'.format(len(group_queue_list)))
-		#
-		for x in bw_queue_list:
-			bw_values[x['conn_id']] = x
-		#
-		for x in group_queue_list:
-			group_values[x['conn_id']] = x
-		#
-		group_set = set([x['group_id'] for x in group_values.values()])
-		for group in group_set:
-			# doesn't handle group == None
-			conns_in_group = [x[0] for x in group_values.items() if x[1]['group_id'] == group]
-			in_group = [x for x in bw_values.values() if x['conn_id'] in conns_in_group]
-			if len(in_group) > 0:
-				avg_data_size = sum([x['data_size'] for x in in_group])/len(in_group)
-				avg_transfer_rate = sum([x['transfer_rate'] for x in in_group])/len(in_group)
-				total_transfer_rate = sum([x['data_size'] for x in in_group])/(max([x['time_of_last_byte'] for x in in_group])-min([x['time_of_first_byte'] for x in in_group]))
-				#
-				logging.info('Group size: %d', len(in_group))
-				logging.info('Avg Transferred (MiB): %.4f', avg_data_size/(1024**2))
-				logging.info('Avg Transfer rate (MiB/s): %.4f', avg_transfer_rate/(1024**2))
-				logging.info('Total Transfer rate (MiB/s): %.4f', total_transfer_rate/(1024**2))
-				#
-				'''
-				import math
-				histories = [{'start_time':x['byte_counter_start_time'], 'history':x['byte_counter']} for x in in_group]
-				total_history = overlap_byte_counters(histories)
-				#
-				logging.info('Max Transfer rate (MiB/s): %.4f', max(total_history)/(1024**2))
-				if sum(total_history) != sum([x['data_size'] for x in in_group]):
-					logging.warning('History doesn\'t add up ({} != {}).'.format(sum(total_history), sum([x['data_size'] for x in in_group])))
-				#
-				import json
-				with open('/tmp/group-{}.json'.format(group), 'w') as f:
-					json.dump({'id':group, 'history':total_history, 'individual_histories':histories, 'size':len(in_group), 'avg_transferred':avg_data_size,
-					           'avg_transfer_rate':avg_transfer_rate, 'total_transfer_rate':total_transfer_rate}, f)
-				#
-				'''
-				custom_data = [x['custom_data'].decode('utf-8') for x in in_group]
-				#
-				histories = [x['deltas'] for x in in_group]
-				combined_timestamps, combined_bytes = zip(*sorted(zip([x for y in histories for x in y['timestamps']],
-				                                                      [x for y in histories for x in y['bytes']])))
-				combined_history = {'bytes':combined_bytes, 'timestamps':combined_timestamps}
-				#combined_history = sorted([item for sublist in histories for item in sublist['deltas']], key=lambda x: x['timestamp'])
-				#
-				sum_history_bytes = sum(combined_history['bytes'])
-				sum_data_bytes = sum([x['data_size'] for x in in_group])
-				if sum_history_bytes != sum_data_bytes:
-					logging.warning('History doesn\'t add up ({} != {}).'.format(sum_history_bytes, sum_data_bytes))
-				#
-				import json
-				import gzip
-				with gzip.GzipFile('/tmp/group-{}.json.gz'.format(group), 'w') as f:
-					f.write(json.dumps({'id':group, 'history':combined_history, 'individual_histories':histories, 'size':len(in_group),
-					                    'avg_transferred':avg_data_size, 'avg_transfer_rate':avg_transfer_rate,
-					                    'total_transfer_rate':total_transfer_rate, 'custom_data':custom_data}, f).encode('utf-8'))
-				#
+			if self.event_thread is not None:
+				# make sure the event thread is stopped
+				self.event_thread.join()
+			#
+			for p in self.processes:
+				# wait for all processes to finish
+				p.join()
 			#
+			# finish reading from the results queue
+			self.results_getter.stop()
+			self.results_getter.join()
 		#
 	#
-	for p in processes:
-		p.join()
+#
+if __name__ == '__main__':
+	logging.basicConfig(level=logging.DEBUG)
+	#
+	parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
+	parser.add_argument('port', type=int, help='listen on port')
+	parser.add_argument('--localhost', action='store_true', help='bind to 127.0.0.1 instead of 0.0.0.0')
+	args = parser.parse_args()
+	#
+	if args.localhost:
+		bind_to = ('127.0.0.1', args.port)
+	else:
+		bind_to = ('0.0.0.0', args.port)
+	#
+	stop_event = multiprocessing.Event()
+	server = ThroughputServer(bind_to, None)
+	try:
+		server.run()
+	except KeyboardInterrupt:
+		print('')
+		logging.debug('Server stopped (KeyboardInterrupt).')
+	#
+
+	#t = threading.Thread(target=server.run)
+	#t.start()
+	#try:
+	#	t.join()
+	#except KeyboardInterrupt:
+	#	print('')
+	#	stop_event.set()
+	#
+	
+	#p = multiprocessing.Process(target=server.run)
+	#p.start()
+	#try:
+	#	p.join()
+	#except KeyboardInterrupt:
+	#	print('')
+	#	stop_event.set()
+	#
+
+	results = server.results
+	#
+	group_ids = list(set([x['results']['custom_data'] for x in results]))
+	groups = [(g, [r['results'] for r in results if r['results']['custom_data'] == g]) for g in group_ids]
+	#
+	for (group_id, group) in groups:
+		avg_data_size = sum([x['data_size'] for x in group])/len(group)
+		avg_transfer_rate = sum([x['transfer_rate'] for x in group])/len(group)
+		time_of_first_byte = min([x['time_of_first_byte'] for x in group])
+		time_of_last_byte = max([x['time_of_last_byte'] for x in group])
+		total_transfer_rate = sum([x['data_size'] for x in group])/(time_of_last_byte-time_of_first_byte)
+		#
+		logging.info('Group id: %s', int.from_bytes(group_id, byteorder='big') if len(group_id)!=0 else None)
+		logging.info('  Group size: %d', len(group))
+		logging.info('  Avg Transferred (MiB): %.4f', avg_data_size/(1024**2))
+		logging.info('  Avg Transfer rate (MiB/s): %.4f', avg_transfer_rate/(1024**2))
+		logging.info('  Total Transfer rate (MiB/s): %.4f', total_transfer_rate/(1024**2))
 	#
 #

+ 69 - 0
src/useful.py

@@ -1,3 +1,5 @@
+import threading
+#
 def parse_bytes(bytes_str):
 	conversions = {'B':1, 'KiB':1024, 'MiB':1024**2, 'GiB':1024**3, 'TiB':1024**4}
 	#
@@ -10,3 +12,70 @@ def parse_bytes(bytes_str):
 	#
 	return int(bytes_str)
 #
+def parse_range_list(range_list_str):
+	"""
+	Take an input like '1-3,5,7-10' and return a list like [1,2,3,5,7,8,9,10].
+	"""
+	#
+	range_strings = range_list_str.split(',')
+	all_items = []
+	for range_str in range_strings:
+		if '-' in range_str:
+			# in the form '12-34'
+			range_ends = [int(x) for x in range_str.split('-')]
+			assert(len(range_ends) == 2)
+			all_items.extend(range(range_ends[0], range_ends[1]+1))
+		else:
+			# just a number
+			all_items.append(int(range_str))
+		#
+	#
+	return all_items
+#
+def generate_range_list(l):
+	"""
+	Take a list like [1,2,3,5,7,8,9,10] and return a string like '1-3,5,7-10'.
+	"""
+	#
+	l = list(set(sorted(l)))
+	ranges = []
+	current_range_start = None
+	#
+	for index in range(len(l)):
+		if current_range_start is None:
+			current_range_start = l[index]
+		else:
+			if l[index] != l[index-1]+1:
+				ranges.append((current_range_start, l[index-1]))
+				current_range_start = l[index]
+			#
+		#
+	#
+	ranges.append((current_range_start, l[-1]))
+	#
+	return ','.join(['-'.join([str(y) for y in x]) if x[0] != x[1] else str(x[0]) for x in ranges])
+#
+class QueueGetter:
+	def __init__(self, queue, callback):
+		self.queue = queue
+		self.callback = callback
+		#
+		self.t = threading.Thread(target=self._unqueue)
+		self.t.start()
+	#
+	def _unqueue(self):
+		while True:
+			val = self.queue.get()
+			if val is None:
+				break
+			#
+			self.callback(val)
+		#
+	#
+	def stop(self):
+		self.queue.put(None)
+	#
+	def join(self):
+		self.t.join()
+	#
+#