experiment_client.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. #!/usr/bin/python3
  2. #
  3. import stem.control
  4. import stem.descriptor.remote
  5. import stem.process
  6. import socket
  7. import logging
  8. import multiprocessing
  9. import queue
  10. import random
  11. import time
  12. import json
  13. import os
  14. import datetime
  15. #
  16. import basic_protocols
  17. import throughput_protocols
  18. import useful
  19. #
  20. def get_socks_port(control_port):
  21. with stem.control.Controller.from_port(port=control_port) as controller:
  22. controller.authenticate()
  23. #
  24. socks_addresses = controller.get_listeners(stem.control.Listener.SOCKS)
  25. assert(len(socks_addresses) == 1)
  26. assert(socks_addresses[0][0] == '127.0.0.1')
  27. #
  28. return socks_addresses[0][1]
  29. #
  30. #
  31. def wait_then_sleep(event, duration):
  32. event.wait()
  33. time.sleep(duration)
  34. #
  35. def send_measureme(stem_controller, circuit_id, measureme_id, hop):
  36. response = stem_controller.msg('SENDMEASUREME %s ID=%s HOP=%s' % (circuit_id, measureme_id, hop))
  37. stem.response.convert('SINGLELINE', response)
  38. #
  39. if not response.is_ok():
  40. if response.code in ('512', '552'):
  41. if response.message.startswith('Unknown circuit '):
  42. raise stem.InvalidArguments(response.code, response.message, [circuit_id])
  43. #
  44. raise stem.InvalidRequest(response.code, response.message)
  45. else:
  46. raise stem.ProtocolError('MEASUREME returned unexpected response code: %s' % response.code)
  47. #
  48. #
  49. #
  50. def send_measureme_cells(control_address, circuit_id, measureme_id, hops):
  51. logging.debug('Sending measuremes to control address {}, then sleeping'.format(control_address))
  52. with stem.control.Controller.from_port(address=control_address[0], port=control_address[1]) as controller:
  53. controller.authenticate()
  54. for hop in hops:
  55. send_measureme(controller, circuit_id, measureme_id, hop)
  56. #
  57. #
  58. #
  59. def send_measureme_cells_and_wait(control_port, circuit_id, measureme_id, hops, wait_event, wait_offset):
  60. send_measureme_cells(control_port, circuit_id, measureme_id, hops)
  61. wait_then_sleep(wait_event, wait_offset)
  62. #
  63. def get_fingerprints(consensus):
  64. """
  65. Get the fingerprints of all relays.
  66. """
  67. #
  68. return [desc.fingerprint for desc in consensus]
  69. #
  70. def get_exit_fingerprints(consensus, endpoint):
  71. """
  72. Get the fingerprints of relays that can exit to the endpoint.
  73. """
  74. #
  75. return [desc.fingerprint for desc in consensus if desc.exit_policy.can_exit_to(*endpoint)]
  76. #
  77. class ExperimentController:
  78. def __init__(self, control_address):
  79. self.control_address = control_address
  80. self.connection = None
  81. self.circuits = {}
  82. self.unassigned_circuit_ids = []
  83. self.assigned_streams = {}
  84. #
  85. def connect(self):
  86. self.connection = stem.control.Controller.from_port(address=self.control_address[0], port=self.control_address[1])
  87. self.connection.authenticate()
  88. #
  89. self.connection.add_event_listener(self._attach_stream, stem.control.EventType.STREAM)
  90. self.connection.set_conf('__LeaveStreamsUnattached', '1')
  91. #
  92. def disconnect(self):
  93. #if len(self.unused_circuit_ids) > 0:
  94. # logging.warning('Closed stem controller before all circuits were used')
  95. #
  96. self.connection.close()
  97. #
  98. def assign_stream(self, from_address):
  99. """
  100. Should run this function before starting the protocol, and therefore before telling
  101. the SOCKS proxy where we're connecting to and before the stream is created.
  102. """
  103. circuit_id = self.unassigned_circuit_ids.pop(0)
  104. self.assigned_streams[from_address] = circuit_id
  105. return circuit_id
  106. #
  107. def _attach_stream(self, stream):
  108. try:
  109. if stream.status == 'NEW':
  110. # by default, let tor handle new streams
  111. circuit_id = 0
  112. #
  113. if stream.purpose == 'USER':
  114. # this is probably one of our streams (although not guaranteed)
  115. circuit_id = self.assigned_streams[(stream.source_address, stream.source_port)]
  116. #
  117. try:
  118. self.connection.attach_stream(stream.id, circuit_id)
  119. #logging.debug('Attaching to circuit {}'.format(circuit_id))
  120. except stem.InvalidRequest:
  121. if stream.purpose != 'USER':
  122. # could not attach a non-user stream, ignoring
  123. pass
  124. else:
  125. raise
  126. #
  127. except stem.UnsatisfiableRequest:
  128. if stream.purpose != 'USER':
  129. # could not attach a non-user stream, so probably raised:
  130. # stem.UnsatisfiableRequest: Connection is not managed by controller.
  131. # therefore we should ignore this exception
  132. pass
  133. else:
  134. raise
  135. #
  136. #
  137. #
  138. except:
  139. logging.exception('Error while attaching the stream.')
  140. raise
  141. #
  142. #
  143. def build_circuit(self, circuit_generator, gen_id):
  144. circuit_id = None
  145. #
  146. while circuit_id is None:
  147. try:
  148. circuit = circuit_generator(gen_id)
  149. circuit_id = self.connection.new_circuit(circuit, await_build=True)
  150. logging.debug('New circuit (id={}, controller={}): {}'.format(circuit_id, self.control_address, circuit))
  151. except stem.CircuitExtensionFailed as e:
  152. wait_seconds = 1
  153. logging.debug('Failed circuit: {}'.format(circuit))
  154. logging.warning('Circuit creation failed (CircuitExtensionFailed: {}). Retrying in {} second{}...'.format(str(e),
  155. wait_seconds,
  156. 's' if wait_seconds != 1 else ''))
  157. time.sleep(wait_seconds)
  158. except stem.InvalidRequest as e:
  159. wait_seconds = 15
  160. logging.debug('Failed circuit: {}'.format(circuit))
  161. logging.warning('Circuit creation failed (InvalidRequest: {}). Retrying in {} second{}...'.format(str(e),
  162. wait_seconds,
  163. 's' if wait_seconds != 1 else ''))
  164. time.sleep(wait_seconds)
  165. #
  166. #
  167. self.unassigned_circuit_ids.append(circuit_id)
  168. self.circuits[circuit_id] = circuit
  169. #
  170. #
  171. class ExperimentProtocol(basic_protocols.ChainedProtocol):
  172. def __init__(self, socket, endpoint, num_bytes, custom_data=None, send_buffer_len=None, push_start_cb=None):
  173. proxy_username = bytes([z for z in os.urandom(12) if z != 0])
  174. proxy_protocol = basic_protocols.Socks4Protocol(socket, endpoint, username=proxy_username)
  175. #
  176. throughput_protocol = throughput_protocols.ClientProtocol(socket, num_bytes,
  177. custom_data=custom_data,
  178. send_buffer_len=send_buffer_len,
  179. use_acceleration=True,
  180. push_start_cb=push_start_cb)
  181. #
  182. super().__init__([proxy_protocol, throughput_protocol])
  183. #
  184. #
  185. class ExperimentProtocolManager():
  186. def __init__(self):
  187. self.stopped = False
  188. self.process_counter = 0
  189. self.used_ids = []
  190. self.running_processes = {}
  191. self.global_finished_process_queue = multiprocessing.Queue()
  192. self.local_finished_process_queue = queue.Queue()
  193. self.queue_getter = useful.QueueGetter(self.global_finished_process_queue,
  194. self.local_finished_process_queue.put)
  195. #
  196. def _run_client(self, protocol, protocol_id):
  197. had_error = False
  198. try:
  199. logging.debug('Starting client protocol (id: {})'.format(protocol_id))
  200. protocol.run()
  201. logging.debug('Done client protocol (id: {})'.format(protocol_id))
  202. except KeyboardInterrupt:
  203. had_error = True
  204. logging.info('Client protocol id: {} stopped (KeyboardInterrupt)'.format(protocol_id))
  205. except:
  206. had_error = True
  207. logging.warning('Client protocol error')
  208. logging.exception('Client protocol id: {} had an error ({})'.format(protocol_id, datetime.datetime.now().time()))
  209. finally:
  210. self.global_finished_process_queue.put((protocol_id, had_error))
  211. if had_error:
  212. logging.warning('Client protocol with error successfully added self to global queue')
  213. #
  214. #
  215. #
  216. def start_experiment_protocol(self, protocol, protocol_id=None):
  217. if protocol_id is None:
  218. protocol_id = self.process_counter
  219. #
  220. assert not self.stopped
  221. assert protocol_id not in self.used_ids, 'Protocol ID already used'
  222. #
  223. p = multiprocessing.Process(target=self._run_client, args=(protocol, protocol_id))
  224. self.running_processes[protocol_id] = p
  225. self.used_ids.append(protocol_id)
  226. #
  227. p.start()
  228. self.process_counter += 1
  229. #
  230. #protocol.socket.close()
  231. #
  232. def wait(self, finished_protocol_cb=None, kill_timeout=None):
  233. timed_out = False
  234. #
  235. while len(self.running_processes) > 0:
  236. logging.debug('Waiting for processes ({} left)'.format(len(self.running_processes)))
  237. #
  238. if not timed_out:
  239. try:
  240. (protocol_id, had_error) = self.local_finished_process_queue.get(timeout=kill_timeout)
  241. p = self.running_processes[protocol_id]
  242. except queue.Empty:
  243. if kill_timeout is None:
  244. raise
  245. #
  246. logging.warning('Timed out waiting for processes to finish, will terminate remaining processes')
  247. timed_out = True
  248. #
  249. #
  250. if timed_out:
  251. (protocol_id, p) = next(iter(self.running_processes.items()))
  252. # just get any process and kill it
  253. had_error = True
  254. p.terminate()
  255. logging.debug('Terminated protocol {}'.format(protocol_id))
  256. #
  257. p.join()
  258. self.running_processes.pop(protocol_id)
  259. if finished_protocol_cb is not None:
  260. finished_protocol_cb(protocol_id, had_error)
  261. #
  262. #
  263. #
  264. def stop(self):
  265. self.wait(kill_timeout=1.5)
  266. self.queue_getter.stop()
  267. self.queue_getter.join()
  268. self.stopped = True
  269. #
  270. #
  271. def build_client_protocol(endpoint, socks_address, control_address, controller, start_event, wait_duration=0, measureme_id=None, num_bytes=None, buffer_len=None):
  272. client_socket = socket.socket()
  273. #
  274. logging.debug('Socket %d connecting to proxy %r...', client_socket.fileno(), socks_address)
  275. client_socket.connect(socks_address)
  276. logging.debug('Socket %d connected', client_socket.fileno())
  277. #
  278. custom_data = {}
  279. #
  280. circuit_id = controller.assign_stream(client_socket.getsockname())
  281. custom_data['circuit'] = (circuit_id, controller.circuits[circuit_id])
  282. #
  283. if measureme_id is not None:
  284. custom_data['measureme_id'] = measureme_id
  285. #
  286. hops = list(range(len(controller.circuits[circuit_id])+1))[::-1]
  287. # send the measureme cells to the last relay first
  288. start_cb = lambda control_address=control_address, circuit_id=circuit_id, measureme_id=measureme_id, \
  289. hops=hops, event=start_event, wait_duration=wait_duration: \
  290. send_measureme_cells_and_wait(control_address, circuit_id, measureme_id, hops, event, wait_duration)
  291. else:
  292. start_cb = lambda event=start_event, duration=wait_duration: wait_then_sleep(event, duration)
  293. #
  294. custom_data = json.dumps(custom_data).encode('utf-8')
  295. protocol = ExperimentProtocol(client_socket, endpoint, num_bytes,
  296. custom_data=custom_data,
  297. send_buffer_len=buffer_len,
  298. push_start_cb=start_cb)
  299. return protocol
  300. #
  301. if __name__ == '__main__':
  302. import argparse
  303. #
  304. logging.basicConfig(level=logging.DEBUG)
  305. logging.getLogger('stem').setLevel(logging.WARNING)
  306. #
  307. parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
  308. parser.add_argument('ip', type=str, help='destination ip address')
  309. parser.add_argument('port', type=int, help='destination port')
  310. parser.add_argument('num_bytes', type=useful.parse_bytes,
  311. help='number of bytes to send per connection (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='num-bytes')
  312. parser.add_argument('num_streams_per_client', type=int, help='number of streams per Tor client', metavar='num-streams-per-client')
  313. parser.add_argument('--buffer-len', type=useful.parse_bytes,
  314. help='size of the send and receive buffers (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='bytes')
  315. parser.add_argument('--wait-range', type=int, default=0,
  316. help='add a random wait time to each connection so that they don\'t all start at the same time (default is 0)', metavar='time')
  317. parser.add_argument('--proxy-control-ports', type=useful.parse_range_list, help='range of ports for the control ports', metavar='control-ports')
  318. parser.add_argument('--measureme', action='store_true', help='send measureme cells to the exit')
  319. args = parser.parse_args()
  320. #
  321. endpoint = (args.ip, args.port)
  322. #
  323. logging.debug('Getting consensus')
  324. try:
  325. consensus = stem.descriptor.remote.get_consensus(endpoints=(stem.DirPort('127.0.0.1', 7000),))
  326. except Exception as e:
  327. raise Exception('Unable to retrieve the consensus') from e
  328. #
  329. fingerprints = get_fingerprints(consensus)
  330. exit_fingerprints = get_exit_fingerprints(consensus, endpoint)
  331. non_exit_fingerprints = list(set(fingerprints)-set(exit_fingerprints))
  332. #
  333. assert len(exit_fingerprints) == 1, 'Need exactly one exit relay'
  334. assert len(non_exit_fingerprints) >= 1, 'Need at least one non-exit relay'
  335. #
  336. circuit_generator = lambda gen_id=None: [random.choice(non_exit_fingerprints), exit_fingerprints[0]]
  337. #
  338. proxy_addresses = []
  339. for control_port in args.proxy_control_ports:
  340. proxy = {}
  341. proxy['control'] = ('127.0.0.1', control_port)
  342. proxy['socks'] = ('127.0.0.1', get_socks_port(control_port))
  343. proxy_addresses.append(proxy)
  344. #
  345. controllers = []
  346. protocol_manager = ExperimentProtocolManager()
  347. #
  348. try:
  349. for proxy_address in proxy_addresses:
  350. controller = ExperimentController(proxy_address['control'])
  351. controller.connect()
  352. # the controller has to attach new streams to circuits, so the
  353. # connection has to stay open until we're done creating streams
  354. #
  355. for _ in range(args.num_streams_per_client):
  356. # make a circuit for each stream
  357. controller.build_circuit(circuit_generator)
  358. time.sleep(0.5)
  359. #
  360. controllers.append(controller)
  361. #
  362. start_event = multiprocessing.Event()
  363. #
  364. for stream_index in range(args.num_streams_per_client):
  365. for (controller_index, proxy_address, controller) in zip(range(len(controllers)), proxy_addresses, controllers):
  366. if args.measureme:
  367. measureme_id = stream_index*args.num_streams_per_client + controller_index + 1
  368. else:
  369. measureme_id = None
  370. #
  371. wait_duration = random.randint(0, args.wait_range)
  372. protocol = build_client_protocol(endpoint, proxy_address['socks'], proxy_address['control'], controller, start_event,
  373. wait_duration=wait_duration, measureme_id=measureme_id,
  374. num_bytes=args.num_bytes, buffer_len=args.buffer_len)
  375. protocol_manager.start_experiment_protocol(protocol, protocol_id=None)
  376. #
  377. #
  378. time.sleep(2)
  379. start_event.set()
  380. #
  381. protocol_manager.wait(finished_protocol_cb=lambda protocol_id,had_error: logging.info('Finished {} (had_error={})'.format(protocol_id,had_error)))
  382. finally:
  383. for controller in controllers:
  384. controller.disconnect()
  385. #
  386. protocol_manager.stop()
  387. #
  388. #