123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516 |
- #!/usr/bin/python3
- #
- import stem.control
- import stem.descriptor.remote
- import stem.process
- import socket
- import logging
- import multiprocessing
- import queue
- import random
- import time
- import json
- import os
- import datetime
- #
- import basic_protocols
- import throughput_protocols
- import useful
- #
- def get_socks_port(control_port):
- with stem.control.Controller.from_port(port=control_port) as controller:
- controller.authenticate()
- #
- socks_addresses = controller.get_listeners(stem.control.Listener.SOCKS)
- assert(len(socks_addresses) == 1)
- assert(socks_addresses[0][0] == '')
- #
- return socks_addresses[0][1]
- #
- #
- def wait_then_sleep(event, duration):
- event.wait()
- time.sleep(duration)
- #
- def send_measureme(stem_controller, circuit_id, measureme_id, hop):
- response = stem_controller.msg('SENDMEASUREME %s ID=%s HOP=%s' % (circuit_id, measureme_id, hop))
- stem.response.convert('SINGLELINE', response)
- #
- if not response.is_ok():
- if response.code in ('512', '552'):
- if response.message.startswith('Unknown circuit '):
- raise stem.InvalidArguments(response.code, response.message, [circuit_id])
- #
- raise stem.InvalidRequest(response.code, response.message)
- else:
- raise stem.ProtocolError('MEASUREME returned unexpected response code: %s' % response.code)
- #
- #
- #
- def send_measureme_cells(control_address, circuit_id, measureme_id, hops):
- logging.debug('Sending measuremes to control address {}, then sleeping'.format(control_address))
- with stem.control.Controller.from_port(address=control_address[0], port=control_address[1]) as controller:
- controller.authenticate()
- for hop in hops:
- send_measureme(controller, circuit_id, measureme_id, hop)
- #
- #
- #
- def send_measureme_cells_and_wait(control_port, circuit_id, measureme_id, hops, wait_event, wait_offset):
- send_measureme_cells(control_port, circuit_id, measureme_id, hops)
- wait_then_sleep(wait_event, wait_offset)
- #
- def get_fingerprints(consensus):
- """
- Get the fingerprints of all relays.
- """
- #
- return [desc.fingerprint for desc in consensus]
- #
- def get_exit_fingerprints(consensus, endpoint):
- """
- Get the fingerprints of relays that can exit to the endpoint.
- """
- #
- return [desc.fingerprint for desc in consensus if desc.exit_policy.can_exit_to(*endpoint)]
- #
- class ExperimentController:
- def __init__(self, control_address):
- self.control_address = control_address
- self.connection = None
- self.circuits = {}
- self.unassigned_circuit_ids = []
- self.assigned_streams = {}
- #
- def connect(self):
- self.connection = stem.control.Controller.from_port(address=self.control_address[0], port=self.control_address[1])
- self.connection.authenticate()
- #
- self.connection.add_event_listener(self.stream_event, stem.control.EventType.STREAM)
- self.connection.add_event_listener(self.circuit_event, stem.control.EventType.CIRC)
- self.connection.set_conf('__LeaveStreamsUnattached', '1')
- #self.connection.set_conf('__DisablePredictedCircuits', '1')
- # we still need to generate circuits for things like directory fetches
- #
- def disconnect(self):
- #if len(self.unused_circuit_ids) > 0:
- # logging.warning('Closed stem controller before all circuits were used')
- #
- self.connection.close()
- #
- def assign_stream(self, from_address):
- """
- Should run this function before starting the protocol, and therefore before telling
- the SOCKS proxy where we're connecting to and before the stream is created.
- """
- circuit_id = self.unassigned_circuit_ids.pop(0)
- self.assigned_streams[from_address] = circuit_id
- return circuit_id
- #
- def stream_event(self, stream):
- try:
- if stream.status == 'NEW':
- # by default, let tor handle new streams
- circuit_id = 0
- #
- if stream.purpose == 'USER':
- # NOTE: we used to try to attach all streams (including non-user streams,
- # which we attached to circuit 0, but Stem was found to hang sometimes
- # when attaching DIR_FETCH streams, so now we only attach user streams
- # and let tor take care of other streams
- #
- # this is probably one of our streams (although not guaranteed)
- circuit_id = self.assigned_streams[(stream.source_address, stream.source_port)]
- #
- try:
- logging.debug('Attaching to circuit {}'.format(circuit_id))
- self.connection.attach_stream(stream.id, circuit_id)
- logging.debug('Attached to circuit {}'.format(circuit_id))
- except stem.InvalidRequest:
- if stream.purpose != 'USER':
- # could not attach a non-user stream, ignoring
- pass
- else:
- raise
- #
- except stem.UnsatisfiableRequest:
- if stream.purpose != 'USER':
- # could not attach a non-user stream, so probably raised:
- # stem.UnsatisfiableRequest: Connection is not managed by controller.
- # therefore we should ignore this exception
- pass
- else:
- raise
- #
- except stem.SocketClosed:
- logging.debug('Stream {} ({}, controller={}) {}: socket closed while attaching'.format(stream.id,
- stream.purpose, self.control_address, stream.status))
- raise
- #
- #
- #
- if stream.status == 'DETACHED' or stream.status == 'FAILED':
- logging.debug('Stream {} ({}, controller={}) {}: {}; {}'.format(stream.id, stream.purpose, self.control_address,
- stream.status, stream.reason, stream.remote_reason))
- #
- except:
- logging.exception('Error while attaching the stream.')
- raise
- #
- #
- def circuit_event(self, circuit):
- if circuit.purpose == 'CONTROLLER' and (circuit.status == 'FAILED' or circuit.status == 'CLOSED'):
- logging.debug('Circuit {} ({}, controller={}) {}: {}; {}'.format(circuit.id, circuit.purpose, self.control_address,
- circuit.status, circuit.reason, circuit.remote_reason))
- #
- #
- def build_circuit(self, circuit_generator, gen_id):
- circuit_id = None
- tries_remaining = 5
- #
- while circuit_id is None and tries_remaining > 0:
- try:
- circuit = circuit_generator(gen_id)
- tries_remaining -= 1
- circuit_id = self.connection.new_circuit(circuit, await_build=True, purpose='controller', timeout=10)
- logging.debug('New circuit (circ_id={}, controller={}): {}'.format(circuit_id, self.control_address, circuit))
- except stem.CircuitExtensionFailed as e:
- wait_seconds = 1
- logging.debug('Failed circuit: {}'.format(circuit))
- if tries_remaining == 0:
- logging.warning('Tried too many times')
- raise
- #
- logging.warning('Circuit creation failed (CircuitExtensionFailed: {}). Retrying in {} second{}...'.format(str(e),
- wait_seconds,
- 's' if wait_seconds != 1 else ''))
- time.sleep(wait_seconds)
- except stem.InvalidRequest as e:
- wait_seconds = 15
- logging.debug('Failed circuit: {}'.format(circuit))
- if tries_remaining == 0:
- logging.warning('Tried too many times')
- raise
- #
- logging.warning('Circuit creation failed (InvalidRequest: {}). Retrying in {} second{}...'.format(str(e),
- wait_seconds,
- 's' if wait_seconds != 1 else ''))
- time.sleep(wait_seconds)
- except stem.Timeout as e:
- wait_seconds = 5
- logging.debug('Failed circuit: {}'.format(circuit))
- if tries_remaining == 0:
- logging.warning('Tried too many times')
- raise
- #
- logging.warning('Circuit creation timed out (Timeout: {}). Retrying in {} second{}...'.format(str(e),
- wait_seconds,
- 's' if wait_seconds != 1 else ''))
- time.sleep(wait_seconds)
- #
- #
- self.unassigned_circuit_ids.append(circuit_id)
- self.circuits[circuit_id] = circuit
- #
- #
- class ExperimentProtocol(basic_protocols.ChainedProtocol):
- def __init__(self, socket, endpoint, num_bytes, circuit_info, custom_data=None, send_buffer_len=None, push_start_cb=None):
- proxy_username = bytes([z for z in os.urandom(12) if z != 0])
- proxy_protocol = basic_protocols.Socks4Protocol(socket, endpoint, username=proxy_username)
- #
- self.proxy_info = socket.getpeername()
- self.circuit_info = circuit_info
- #
- throughput_protocol = throughput_protocols.ClientProtocol(socket, num_bytes,
- custom_data=custom_data,
- send_buffer_len=send_buffer_len,
- use_acceleration=True,
- push_start_cb=push_start_cb)
- #
- super().__init__([proxy_protocol, throughput_protocol])
- #
- def get_desc(self):
- super_desc = super().get_desc()
- if super_desc is not None:
- return '{} -> {} - {}'.format(self.proxy_info, self.circuit_info, super_desc)
- else:
- return '{} -> {}'.format(self.proxy_info, self.circuit_info)
- #
- #
- class ExperimentProtocolManager():
- def __init__(self):
- self.stopped = False
- self.process_counter = 0
- self.used_ids = []
- self.running_processes = {}
- self.checked_in = multiprocessing.Manager().dict()
- self.global_finished_process_queue = multiprocessing.Queue()
- self.local_finished_process_queue = queue.Queue()
- self.queue_getter = useful.QueueGetter(self.global_finished_process_queue,
- self.local_finished_process_queue.put)
- #
- def _run_client(self, protocol, protocol_id):
- had_error = False
- try:
- logging.debug('Starting client protocol (id: {}, desc: {})'.format(protocol_id, protocol.get_desc()))
- self.checked_in[protocol_id] = True
- protocol.run()
- logging.debug('Done client protocol (id: {})'.format(protocol_id))
- except KeyboardInterrupt:
- had_error = True
- logging.info('Client protocol id: {} stopped (KeyboardInterrupt)'.format(protocol_id))
- except:
- had_error = True
- logging.warning('Client protocol error')
- logging.exception('Client protocol id: {} had an error ({})'.format(protocol_id, datetime.datetime.now().time()))
- finally:
- self.global_finished_process_queue.put((protocol_id, had_error))
- if had_error:
- logging.warning('Client protocol with error successfully added self to global queue')
- #
- #
- #
- def start_experiment_protocol(self, protocol, protocol_id=None):
- if protocol_id is None:
- protocol_id = self.process_counter
- #
- assert not self.stopped
- assert protocol_id not in self.used_ids, 'Protocol ID already used'
- #
- #logging.debug('Launching client protocol (id: {})'.format(protocol_id))
- p = multiprocessing.Process(target=self._run_client, args=(protocol, protocol_id))
- self.running_processes[protocol_id] = p
- self.checked_in[protocol_id] = False
- # because of Python multiprocessing bugs, the process may deadlock when it starts
- self.used_ids.append(protocol_id)
- #
- p.start()
- self.process_counter += 1
- #
- #protocol.socket.close()
- #
- def _get_not_checked_in(self):
- temp = self.checked_in.copy()
- return [x for x in temp if not temp[x]]
- #
- #def _count_checked_in(self):
- # temp = self.checked_in.copy()
- # only_checked_in = [True for x in temp if temp is True]
- # return (len(only_checked_in), len(self.checked_in))
- #
- def _get_dead_processes(self):
- dead_processes = []
- for (protocol_id, p) in self.running_processes.items():
- if not p.is_alive():
- dead_processes.append((protocol_id, p))
- #
- #
- return dead_processes
- #
- def _cleanup_process(self, p, protocol_id, had_error, finished_protocol_cb):
- p.join()
- self.running_processes.pop(protocol_id)
- if finished_protocol_cb is not None:
- finished_protocol_cb(protocol_id, had_error)
- #
- #
- def _wait(self, timeout=None, finished_protocol_cb=None):
- return_on_timeout = True if timeout is not None else False
- timeout = timeout if timeout is not None else 10
- last_waiting_message = None
- #
- while len(self.running_processes) > 0:
- dead_processes = self._get_dead_processes()
- #
- while len(self.running_processes) > 0:
- #checked_in_count = self._count_checked_in()
- #not_checked_in = checked_in_count[1]-checked_in_count[0]
- not_checked_in = self._get_not_checked_in()
- #
- if last_waiting_message is None or last_waiting_message != len(self.running_processes):
- logging.debug('Waiting for processes ({} left, {} not checked in)'.format(len(self.running_processes),
- len(not_checked_in)))
- last_waiting_message = len(self.running_processes)
- #
- if len(self.running_processes) <= len(not_checked_in):
- running_not_checked_in = [protocol_id for protocol_id in self.running_processes if protocol_id in not_checked_in]
- if len(self.running_processes) == len(running_not_checked_in):
- logging.debug('The remaining processes have not checked in, so stopping the wait')
- return
- #
- #
- try:
- (protocol_id, had_error) = self.local_finished_process_queue.get(timeout=timeout)
- p = self.running_processes[protocol_id]
- self._cleanup_process(p, protocol_id, had_error, finished_protocol_cb)
- if (protocol_id, p) in dead_processes:
- dead_processes.remove((protocol_id, p))
- #
- logging.debug('Completed protocol (id: {}, checked_in={})'.format(protocol_id,
- self.checked_in[protocol_id]))
- except queue.Empty:
- if return_on_timeout:
- return
- else:
- break
- #
- #if kill_timeout is not None:
- # logging.warning('Timed out waiting for processes to finish, will terminate remaining processes')
- # kill_remaining = True
- #
- #
- #
- for (protocol_id, p) in dead_processes:
- # these processes were dead but didn't add themselves to the finished queue
- logging.debug('Found a dead process (id: {})'.format(protocol_id))
- self._cleanup_process(p, protocol_id, True, finished_protocol_cb)
- #
- #
- #
- def wait(self, finished_protocol_cb=None, kill_timeout=None):
- self._wait(kill_timeout, finished_protocol_cb)
- #
- if len(self.running_processes) > 0:
- logging.warning('Timed out ({} seconds) waiting for processes to finish, will terminate remaining processes'.format(kill_timeout))
- #
- while len(self.running_processes) > 0:
- (protocol_id, p) = next(iter(self.running_processes.items()))
- # just get any process and kill it
- was_alive = p.is_alive()
- p.terminate()
- logging.debug('Terminated protocol (id: {}, was_dead={}, checked_in={})'.format(protocol_id,
- (not was_alive),
- self.checked_in[protocol_id]))
- #
- self._cleanup_process(p, protocol_id, True, finished_protocol_cb)
- #
- #
- def stop(self):
- self.wait(kill_timeout=1.5)
- self.queue_getter.stop()
- self.queue_getter.join(timeout=10)
- self.stopped = True
- #
- #
- def build_client_protocol(endpoint, socks_address, control_address, controller, start_event, send_measureme, wait_duration=0, measureme_id=None, num_bytes=None, buffer_len=None):
- client_socket = socket.socket()
- #
- logging.debug('Socket %d connecting to proxy %r...', client_socket.fileno(), socks_address)
- client_socket.connect(socks_address)
- logging.debug('Socket %d connected', client_socket.fileno())
- #
- custom_data = {}
- #
- circuit_id = controller.assign_stream(client_socket.getsockname())
- custom_data['circuit'] = (circuit_id, controller.circuits[circuit_id])
- #
- if measureme_id is not None:
- custom_data['measureme_id'] = measureme_id
- #
- if send_measureme:
- assert measureme_id != None
- hops = list(range(len(controller.circuits[circuit_id])+1))[::-1]
- # send the measureme cells to the last relay first
- start_cb = lambda control_address=control_address, circuit_id=circuit_id, measureme_id=measureme_id, \
- hops=hops, event=start_event, wait_duration=wait_duration: \
- send_measureme_cells_and_wait(control_address, circuit_id, measureme_id, hops, event, wait_duration)
- else:
- start_cb = lambda event=start_event, duration=wait_duration: wait_then_sleep(event, duration)
- #
- custom_data = json.dumps(custom_data).encode('utf-8')
- protocol = ExperimentProtocol(client_socket, endpoint, num_bytes,
- '{}: {}'.format(circuit_id, controller.circuits[circuit_id]),
- custom_data=custom_data,
- send_buffer_len=buffer_len,
- push_start_cb=start_cb)
- return protocol
- #
- if __name__ == '__main__':
- import argparse
- #
- logging.basicConfig(level=logging.DEBUG)
- logging.getLogger('stem').setLevel(logging.WARNING)
- #
- parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
- parser.add_argument('ip', type=str, help='destination ip address')
- parser.add_argument('port', type=int, help='destination port')
- parser.add_argument('num_bytes', type=useful.parse_bytes,
- help='number of bytes to send per connection (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='num-bytes')
- parser.add_argument('num_streams_per_client', type=int, help='number of streams per Tor client', metavar='num-streams-per-client')
- parser.add_argument('--buffer-len', type=useful.parse_bytes,
- help='size of the send and receive buffers (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='bytes')
- parser.add_argument('--wait-range', type=int, default=0,
- help='add a random wait time to each connection so that they don\'t all start at the same time (default is 0)', metavar='time')
- parser.add_argument('--proxy-control-ports', type=useful.parse_range_list, help='range of ports for the control ports', metavar='control-ports')
- parser.add_argument('--measureme', action='store_true', help='send measureme cells to the exit')
- args = parser.parse_args()
- #
- endpoint = (args.ip, args.port)
- #
- logging.debug('Getting consensus')
- try:
- consensus = stem.descriptor.remote.get_consensus(endpoints=(stem.DirPort('', 7000),))
- except Exception as e:
- raise Exception('Unable to retrieve the consensus') from e
- #
- fingerprints = get_fingerprints(consensus)
- exit_fingerprints = get_exit_fingerprints(consensus, endpoint)
- non_exit_fingerprints = list(set(fingerprints)-set(exit_fingerprints))
- #
- assert len(exit_fingerprints) == 1, 'Need exactly one exit relay'
- assert len(non_exit_fingerprints) >= 1, 'Need at least one non-exit relay'
- #
- circuit_generator = lambda gen_id=None: [random.choice(non_exit_fingerprints), exit_fingerprints[0]]
- #
- proxy_addresses = []
- for control_port in args.proxy_control_ports:
- proxy = {}
- proxy['control'] = ('', control_port)
- proxy['socks'] = ('', get_socks_port(control_port))
- proxy_addresses.append(proxy)
- #
- controllers = []
- protocol_manager = ExperimentProtocolManager()
- #
- try:
- for proxy_address in proxy_addresses:
- controller = ExperimentController(proxy_address['control'])
- controller.connect()
- # the controller has to attach new streams to circuits, so the
- # connection has to stay open until we're done creating streams
- #
- for _ in range(args.num_streams_per_client):
- # make a circuit for each stream
- controller.build_circuit(circuit_generator)
- time.sleep(0.5)
- #
- controllers.append(controller)
- #
- start_event = multiprocessing.Event()
- #
- for stream_index in range(args.num_streams_per_client):
- for (controller_index, proxy_address, controller) in zip(range(len(controllers)), proxy_addresses, controllers):
- if args.measureme:
- measureme_id = stream_index*args.num_streams_per_client + controller_index + 1
- else:
- measureme_id = None
- #
- wait_duration = random.randint(0, args.wait_range)
- protocol = build_client_protocol(endpoint, proxy_address['socks'], proxy_address['control'],
- controller, start_event, args.measureme,
- wait_duration=wait_duration, measureme_id=measureme_id,
- num_bytes=args.num_bytes, buffer_len=args.buffer_len)
- protocol_manager.start_experiment_protocol(protocol, protocol_id=None)
- #
- #
- time.sleep(2)
- start_event.set()
- #
- protocol_manager.wait(finished_protocol_cb=lambda protocol_id,had_error: logging.info('Finished {} (had_error={})'.format(protocol_id,had_error)))
- finally:
- for controller in controllers:
- controller.disconnect()
- #
- protocol_manager.stop()
- #
- #