experiment_client.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. #!/usr/bin/python3
  2. #
  3. import stem.control
  4. import stem.descriptor.remote
  5. import stem.process
  6. import socket
  7. import logging
  8. import multiprocessing
  9. import queue
  10. import random
  11. import time
  12. import json
  13. import os
  14. import datetime
  15. #
  16. import basic_protocols
  17. import throughput_protocols
  18. import useful
  19. #
  20. def get_socks_port(control_port):
  21. with stem.control.Controller.from_port(port=control_port) as controller:
  22. controller.authenticate()
  23. #
  24. socks_addresses = controller.get_listeners(stem.control.Listener.SOCKS)
  25. assert(len(socks_addresses) == 1)
  26. assert(socks_addresses[0][0] == '127.0.0.1')
  27. #
  28. return socks_addresses[0][1]
  29. #
  30. #
  31. def wait_then_sleep(event, duration):
  32. event.wait()
  33. time.sleep(duration)
  34. #
  35. def send_measureme(stem_controller, circuit_id, measureme_id, hop):
  36. response = stem_controller.msg('SENDMEASUREME %s ID=%s HOP=%s' % (circuit_id, measureme_id, hop))
  37. stem.response.convert('SINGLELINE', response)
  38. #
  39. if not response.is_ok():
  40. if response.code in ('512', '552'):
  41. if response.message.startswith('Unknown circuit '):
  42. raise stem.InvalidArguments(response.code, response.message, [circuit_id])
  43. #
  44. raise stem.InvalidRequest(response.code, response.message)
  45. else:
  46. raise stem.ProtocolError('MEASUREME returned unexpected response code: %s' % response.code)
  47. #
  48. #
  49. #
  50. def send_measureme_cells(control_address, circuit_id, measureme_id, hops):
  51. logging.debug('Sending measuremes to control address {}, then sleeping'.format(control_address))
  52. with stem.control.Controller.from_port(address=control_address[0], port=control_address[1]) as controller:
  53. controller.authenticate()
  54. for hop in hops:
  55. send_measureme(controller, circuit_id, measureme_id, hop)
  56. #
  57. #
  58. #
  59. def send_measureme_cells_and_wait(control_port, circuit_id, measureme_id, hops, wait_event, wait_offset):
  60. send_measureme_cells(control_port, circuit_id, measureme_id, hops)
  61. wait_then_sleep(wait_event, wait_offset)
  62. #
  63. def get_fingerprints(consensus):
  64. """
  65. Get the fingerprints of all relays.
  66. """
  67. #
  68. return [desc.fingerprint for desc in consensus]
  69. #
  70. def get_exit_fingerprints(consensus, endpoint):
  71. """
  72. Get the fingerprints of relays that can exit to the endpoint.
  73. """
  74. #
  75. return [desc.fingerprint for desc in consensus if desc.exit_policy.can_exit_to(*endpoint)]
  76. #
  77. class ExperimentController:
  78. def __init__(self, control_address):
  79. self.control_address = control_address
  80. self.connection = None
  81. self.circuits = {}
  82. self.unassigned_circuit_ids = []
  83. self.assigned_streams = {}
  84. #
  85. def connect(self):
  86. self.connection = stem.control.Controller.from_port(address=self.control_address[0], port=self.control_address[1])
  87. self.connection.authenticate()
  88. #
  89. self.connection.add_event_listener(self.stream_event, stem.control.EventType.STREAM)
  90. self.connection.add_event_listener(self.circuit_event, stem.control.EventType.CIRC)
  91. self.connection.set_conf('__LeaveStreamsUnattached', '1')
  92. #self.connection.set_conf('__DisablePredictedCircuits', '1')
  93. # we still need to generate circuits for things like directory fetches
  94. #
  95. def disconnect(self):
  96. #if len(self.unused_circuit_ids) > 0:
  97. # logging.warning('Closed stem controller before all circuits were used')
  98. #
  99. self.connection.close()
  100. #
  101. def assign_stream(self, from_address):
  102. """
  103. Should run this function before starting the protocol, and therefore before telling
  104. the SOCKS proxy where we're connecting to and before the stream is created.
  105. """
  106. circuit_id = self.unassigned_circuit_ids.pop(0)
  107. self.assigned_streams[from_address] = circuit_id
  108. return circuit_id
  109. #
  110. def stream_event(self, stream):
  111. try:
  112. if stream.status == 'NEW':
  113. # by default, let tor handle new streams
  114. circuit_id = 0
  115. #
  116. if stream.purpose == 'USER':
  117. # NOTE: we used to try to attach all streams (including non-user streams,
  118. # which we attached to circuit 0, but Stem was found to hang sometimes
  119. # when attaching DIR_FETCH streams, so now we only attach user streams
  120. # and let tor take care of other streams
  121. #
  122. # this is probably one of our streams (although not guaranteed)
  123. circuit_id = self.assigned_streams[(stream.source_address, stream.source_port)]
  124. #
  125. try:
  126. logging.debug('Attaching to circuit {}'.format(circuit_id))
  127. self.connection.attach_stream(stream.id, circuit_id)
  128. logging.debug('Attached to circuit {}'.format(circuit_id))
  129. except stem.InvalidRequest:
  130. if stream.purpose != 'USER':
  131. # could not attach a non-user stream, ignoring
  132. pass
  133. else:
  134. raise
  135. #
  136. except stem.UnsatisfiableRequest:
  137. if stream.purpose != 'USER':
  138. # could not attach a non-user stream, so probably raised:
  139. # stem.UnsatisfiableRequest: Connection is not managed by controller.
  140. # therefore we should ignore this exception
  141. pass
  142. else:
  143. raise
  144. #
  145. except stem.SocketClosed:
  146. logging.debug('Stream {} ({}, controller={}) {}: socket closed while attaching'.format(stream.id,
  147. stream.purpose, self.control_address, stream.status))
  148. raise
  149. #
  150. #
  151. #
  152. if stream.status == 'DETACHED' or stream.status == 'FAILED':
  153. logging.debug('Stream {} ({}, controller={}) {}: {}; {}'.format(stream.id, stream.purpose, self.control_address,
  154. stream.status, stream.reason, stream.remote_reason))
  155. #
  156. except:
  157. logging.exception('Error while attaching the stream.')
  158. raise
  159. #
  160. #
  161. def circuit_event(self, circuit):
  162. if circuit.purpose == 'CONTROLLER' and (circuit.status == 'FAILED' or circuit.status == 'CLOSED'):
  163. logging.debug('Circuit {} ({}, controller={}) {}: {}; {}'.format(circuit.id, circuit.purpose, self.control_address,
  164. circuit.status, circuit.reason, circuit.remote_reason))
  165. #
  166. #
  167. def build_circuit(self, circuit_generator, gen_id):
  168. circuit_id = None
  169. tries_remaining = 5
  170. #
  171. while circuit_id is None and tries_remaining > 0:
  172. try:
  173. circuit = circuit_generator(gen_id)
  174. tries_remaining -= 1
  175. circuit_id = self.connection.new_circuit(circuit, await_build=True, purpose='controller', timeout=10)
  176. logging.debug('New circuit (circ_id={}, controller={}): {}'.format(circuit_id, self.control_address, circuit))
  177. except stem.CircuitExtensionFailed as e:
  178. wait_seconds = 1
  179. logging.debug('Failed circuit: {}'.format(circuit))
  180. if tries_remaining == 0:
  181. logging.warning('Tried too many times')
  182. raise
  183. #
  184. logging.warning('Circuit creation failed (CircuitExtensionFailed: {}). Retrying in {} second{}...'.format(str(e),
  185. wait_seconds,
  186. 's' if wait_seconds != 1 else ''))
  187. time.sleep(wait_seconds)
  188. except stem.InvalidRequest as e:
  189. wait_seconds = 15
  190. logging.debug('Failed circuit: {}'.format(circuit))
  191. if tries_remaining == 0:
  192. logging.warning('Tried too many times')
  193. raise
  194. #
  195. logging.warning('Circuit creation failed (InvalidRequest: {}). Retrying in {} second{}...'.format(str(e),
  196. wait_seconds,
  197. 's' if wait_seconds != 1 else ''))
  198. time.sleep(wait_seconds)
  199. except stem.Timeout as e:
  200. wait_seconds = 5
  201. logging.debug('Failed circuit: {}'.format(circuit))
  202. if tries_remaining == 0:
  203. logging.warning('Tried too many times')
  204. raise
  205. #
  206. logging.warning('Circuit creation timed out (Timeout: {}). Retrying in {} second{}...'.format(str(e),
  207. wait_seconds,
  208. 's' if wait_seconds != 1 else ''))
  209. time.sleep(wait_seconds)
  210. #
  211. #
  212. self.unassigned_circuit_ids.append(circuit_id)
  213. self.circuits[circuit_id] = circuit
  214. #
  215. #
  216. class ExperimentProtocol(basic_protocols.ChainedProtocol):
  217. def __init__(self, socket, endpoint, num_bytes, circuit_info, custom_data=None, send_buffer_len=None, push_start_cb=None):
  218. proxy_username = bytes([z for z in os.urandom(12) if z != 0])
  219. proxy_protocol = basic_protocols.Socks4Protocol(socket, endpoint, username=proxy_username)
  220. #
  221. self.proxy_info = socket.getpeername()
  222. self.circuit_info = circuit_info
  223. #
  224. throughput_protocol = throughput_protocols.ClientProtocol(socket, num_bytes,
  225. custom_data=custom_data,
  226. send_buffer_len=send_buffer_len,
  227. use_acceleration=True,
  228. push_start_cb=push_start_cb)
  229. #
  230. super().__init__([proxy_protocol, throughput_protocol])
  231. #
  232. def get_desc(self):
  233. super_desc = super().get_desc()
  234. if super_desc is not None:
  235. return '{} -> {} - {}'.format(self.proxy_info, self.circuit_info, super_desc)
  236. else:
  237. return '{} -> {}'.format(self.proxy_info, self.circuit_info)
  238. #
  239. #
  240. class ExperimentProtocolManager():
  241. def __init__(self):
  242. self.stopped = False
  243. self.process_counter = 0
  244. self.used_ids = []
  245. self.running_processes = {}
  246. self.checked_in = multiprocessing.Manager().dict()
  247. self.global_finished_process_queue = multiprocessing.Queue()
  248. self.local_finished_process_queue = queue.Queue()
  249. self.queue_getter = useful.QueueGetter(self.global_finished_process_queue,
  250. self.local_finished_process_queue.put)
  251. #
  252. def _run_client(self, protocol, protocol_id):
  253. had_error = False
  254. try:
  255. logging.debug('Starting client protocol (id: {}, desc: {})'.format(protocol_id, protocol.get_desc()))
  256. self.checked_in[protocol_id] = True
  257. protocol.run()
  258. logging.debug('Done client protocol (id: {})'.format(protocol_id))
  259. except KeyboardInterrupt:
  260. had_error = True
  261. logging.info('Client protocol id: {} stopped (KeyboardInterrupt)'.format(protocol_id))
  262. except:
  263. had_error = True
  264. logging.warning('Client protocol error')
  265. logging.exception('Client protocol id: {} had an error ({})'.format(protocol_id, datetime.datetime.now().time()))
  266. finally:
  267. self.global_finished_process_queue.put((protocol_id, had_error))
  268. if had_error:
  269. logging.warning('Client protocol with error successfully added self to global queue')
  270. #
  271. #
  272. #
  273. def start_experiment_protocol(self, protocol, protocol_id=None):
  274. if protocol_id is None:
  275. protocol_id = self.process_counter
  276. #
  277. assert not self.stopped
  278. assert protocol_id not in self.used_ids, 'Protocol ID already used'
  279. #
  280. #logging.debug('Launching client protocol (id: {})'.format(protocol_id))
  281. p = multiprocessing.Process(target=self._run_client, args=(protocol, protocol_id))
  282. self.running_processes[protocol_id] = p
  283. self.checked_in[protocol_id] = False
  284. # because of Python multiprocessing bugs, the process may deadlock when it starts
  285. self.used_ids.append(protocol_id)
  286. #
  287. p.start()
  288. self.process_counter += 1
  289. #
  290. #protocol.socket.close()
  291. #
  292. def _get_not_checked_in(self):
  293. temp = self.checked_in.copy()
  294. return [x for x in temp if not temp[x]]
  295. #
  296. #def _count_checked_in(self):
  297. # temp = self.checked_in.copy()
  298. # only_checked_in = [True for x in temp if temp is True]
  299. # return (len(only_checked_in), len(self.checked_in))
  300. #
  301. def _get_dead_processes(self):
  302. dead_processes = []
  303. for (protocol_id, p) in self.running_processes.items():
  304. if not p.is_alive():
  305. dead_processes.append((protocol_id, p))
  306. #
  307. #
  308. return dead_processes
  309. #
  310. def _cleanup_process(self, p, protocol_id, had_error, finished_protocol_cb):
  311. p.join()
  312. self.running_processes.pop(protocol_id)
  313. if finished_protocol_cb is not None:
  314. finished_protocol_cb(protocol_id, had_error)
  315. #
  316. #
  317. def _wait(self, timeout=None, finished_protocol_cb=None):
  318. return_on_timeout = True if timeout is not None else False
  319. timeout = timeout if timeout is not None else 10
  320. last_waiting_message = None
  321. #
  322. while len(self.running_processes) > 0:
  323. dead_processes = self._get_dead_processes()
  324. #
  325. while len(self.running_processes) > 0:
  326. #checked_in_count = self._count_checked_in()
  327. #not_checked_in = checked_in_count[1]-checked_in_count[0]
  328. not_checked_in = self._get_not_checked_in()
  329. #
  330. if last_waiting_message is None or last_waiting_message != len(self.running_processes):
  331. logging.debug('Waiting for processes ({} left, {} not checked in)'.format(len(self.running_processes),
  332. len(not_checked_in)))
  333. last_waiting_message = len(self.running_processes)
  334. #
  335. if len(self.running_processes) <= len(not_checked_in):
  336. running_not_checked_in = [protocol_id for protocol_id in self.running_processes if protocol_id in not_checked_in]
  337. if len(self.running_processes) == len(running_not_checked_in):
  338. logging.debug('The remaining processes have not checked in, so stopping the wait')
  339. return
  340. #
  341. #
  342. try:
  343. (protocol_id, had_error) = self.local_finished_process_queue.get(timeout=timeout)
  344. p = self.running_processes[protocol_id]
  345. self._cleanup_process(p, protocol_id, had_error, finished_protocol_cb)
  346. if (protocol_id, p) in dead_processes:
  347. dead_processes.remove((protocol_id, p))
  348. #
  349. logging.debug('Completed protocol (id: {}, checked_in={})'.format(protocol_id,
  350. self.checked_in[protocol_id]))
  351. except queue.Empty:
  352. if return_on_timeout:
  353. return
  354. else:
  355. break
  356. #
  357. #if kill_timeout is not None:
  358. # logging.warning('Timed out waiting for processes to finish, will terminate remaining processes')
  359. # kill_remaining = True
  360. #
  361. #
  362. #
  363. for (protocol_id, p) in dead_processes:
  364. # these processes were dead but didn't add themselves to the finished queue
  365. logging.debug('Found a dead process (id: {})'.format(protocol_id))
  366. self._cleanup_process(p, protocol_id, True, finished_protocol_cb)
  367. #
  368. #
  369. #
  370. def wait(self, finished_protocol_cb=None, kill_timeout=None):
  371. self._wait(kill_timeout, finished_protocol_cb)
  372. #
  373. if len(self.running_processes) > 0:
  374. logging.warning('Timed out ({} seconds) waiting for processes to finish, will terminate remaining processes'.format(kill_timeout))
  375. #
  376. while len(self.running_processes) > 0:
  377. (protocol_id, p) = next(iter(self.running_processes.items()))
  378. # just get any process and kill it
  379. was_alive = p.is_alive()
  380. p.terminate()
  381. logging.debug('Terminated protocol (id: {}, was_dead={}, checked_in={})'.format(protocol_id,
  382. (not was_alive),
  383. self.checked_in[protocol_id]))
  384. #
  385. self._cleanup_process(p, protocol_id, True, finished_protocol_cb)
  386. #
  387. #
  388. def stop(self):
  389. self.wait(kill_timeout=1.5)
  390. self.queue_getter.stop()
  391. self.queue_getter.join(timeout=10)
  392. self.stopped = True
  393. #
  394. #
  395. def build_client_protocol(endpoint, socks_address, control_address, controller, start_event, send_measureme, wait_duration=0, measureme_id=None, num_bytes=None, buffer_len=None):
  396. client_socket = socket.socket()
  397. #
  398. logging.debug('Socket %d connecting to proxy %r...', client_socket.fileno(), socks_address)
  399. client_socket.connect(socks_address)
  400. logging.debug('Socket %d connected', client_socket.fileno())
  401. #
  402. custom_data = {}
  403. #
  404. circuit_id = controller.assign_stream(client_socket.getsockname())
  405. custom_data['circuit'] = (circuit_id, controller.circuits[circuit_id])
  406. #
  407. if measureme_id is not None:
  408. custom_data['measureme_id'] = measureme_id
  409. #
  410. if send_measureme:
  411. assert measureme_id != None
  412. hops = list(range(len(controller.circuits[circuit_id])+1))[::-1]
  413. # send the measureme cells to the last relay first
  414. start_cb = lambda control_address=control_address, circuit_id=circuit_id, measureme_id=measureme_id, \
  415. hops=hops, event=start_event, wait_duration=wait_duration: \
  416. send_measureme_cells_and_wait(control_address, circuit_id, measureme_id, hops, event, wait_duration)
  417. else:
  418. start_cb = lambda event=start_event, duration=wait_duration: wait_then_sleep(event, duration)
  419. #
  420. custom_data = json.dumps(custom_data).encode('utf-8')
  421. protocol = ExperimentProtocol(client_socket, endpoint, num_bytes,
  422. '{}: {}'.format(circuit_id, controller.circuits[circuit_id]),
  423. custom_data=custom_data,
  424. send_buffer_len=buffer_len,
  425. push_start_cb=start_cb)
  426. return protocol
  427. #
  428. if __name__ == '__main__':
  429. import argparse
  430. #
  431. logging.basicConfig(level=logging.DEBUG)
  432. logging.getLogger('stem').setLevel(logging.WARNING)
  433. #
  434. parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
  435. parser.add_argument('ip', type=str, help='destination ip address')
  436. parser.add_argument('port', type=int, help='destination port')
  437. parser.add_argument('num_bytes', type=useful.parse_bytes,
  438. help='number of bytes to send per connection (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='num-bytes')
  439. parser.add_argument('num_streams_per_client', type=int, help='number of streams per Tor client', metavar='num-streams-per-client')
  440. parser.add_argument('--buffer-len', type=useful.parse_bytes,
  441. help='size of the send and receive buffers (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='bytes')
  442. parser.add_argument('--wait-range', type=int, default=0,
  443. help='add a random wait time to each connection so that they don\'t all start at the same time (default is 0)', metavar='time')
  444. parser.add_argument('--proxy-control-ports', type=useful.parse_range_list, help='range of ports for the control ports', metavar='control-ports')
  445. parser.add_argument('--measureme', action='store_true', help='send measureme cells to the exit')
  446. args = parser.parse_args()
  447. #
  448. endpoint = (args.ip, args.port)
  449. #
  450. logging.debug('Getting consensus')
  451. try:
  452. consensus = stem.descriptor.remote.get_consensus(endpoints=(stem.DirPort('127.0.0.1', 7000),))
  453. except Exception as e:
  454. raise Exception('Unable to retrieve the consensus') from e
  455. #
  456. fingerprints = get_fingerprints(consensus)
  457. exit_fingerprints = get_exit_fingerprints(consensus, endpoint)
  458. non_exit_fingerprints = list(set(fingerprints)-set(exit_fingerprints))
  459. #
  460. assert len(exit_fingerprints) == 1, 'Need exactly one exit relay'
  461. assert len(non_exit_fingerprints) >= 1, 'Need at least one non-exit relay'
  462. #
  463. circuit_generator = lambda gen_id=None: [random.choice(non_exit_fingerprints), exit_fingerprints[0]]
  464. #
  465. proxy_addresses = []
  466. for control_port in args.proxy_control_ports:
  467. proxy = {}
  468. proxy['control'] = ('127.0.0.1', control_port)
  469. proxy['socks'] = ('127.0.0.1', get_socks_port(control_port))
  470. proxy_addresses.append(proxy)
  471. #
  472. controllers = []
  473. protocol_manager = ExperimentProtocolManager()
  474. #
  475. try:
  476. for proxy_address in proxy_addresses:
  477. controller = ExperimentController(proxy_address['control'])
  478. controller.connect()
  479. # the controller has to attach new streams to circuits, so the
  480. # connection has to stay open until we're done creating streams
  481. #
  482. for _ in range(args.num_streams_per_client):
  483. # make a circuit for each stream
  484. controller.build_circuit(circuit_generator)
  485. time.sleep(0.5)
  486. #
  487. controllers.append(controller)
  488. #
  489. start_event = multiprocessing.Event()
  490. #
  491. for stream_index in range(args.num_streams_per_client):
  492. for (controller_index, proxy_address, controller) in zip(range(len(controllers)), proxy_addresses, controllers):
  493. if args.measureme:
  494. measureme_id = stream_index*args.num_streams_per_client + controller_index + 1
  495. else:
  496. measureme_id = None
  497. #
  498. wait_duration = random.randint(0, args.wait_range)
  499. protocol = build_client_protocol(endpoint, proxy_address['socks'], proxy_address['control'],
  500. controller, start_event, args.measureme,
  501. wait_duration=wait_duration, measureme_id=measureme_id,
  502. num_bytes=args.num_bytes, buffer_len=args.buffer_len)
  503. protocol_manager.start_experiment_protocol(protocol, protocol_id=None)
  504. #
  505. #
  506. time.sleep(2)
  507. start_event.set()
  508. #
  509. protocol_manager.wait(finished_protocol_cb=lambda protocol_id,had_error: logging.info('Finished {} (had_error={})'.format(protocol_id,had_error)))
  510. finally:
  511. for controller in controllers:
  512. controller.disconnect()
  513. #
  514. protocol_manager.stop()
  515. #
  516. #