bandwidth_tester.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #!/usr/bin/python3
  2. #
  3. import basic_protocols
  4. import logging
  5. import enum
  6. import time
  7. import socket
  8. #
  9. class ClientConnectionProtocol(basic_protocols.Protocol):
  10. def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None, wait_until=None):
  11. self.endpoint = endpoint
  12. self.data_generator = data_generator
  13. self.total_bytes = total_bytes
  14. self.proxy = proxy
  15. self.username = username
  16. self.wait_until = wait_until
  17. #
  18. self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY SEND_GROUP_ID PUSH_DATA DONE')
  19. self.state = self.states.READY_TO_BEGIN
  20. #
  21. self.socket = socket.socket()
  22. self.sub_protocol = None
  23. self.group_id = int(self.wait_until*1000) if self.wait_until is not None else 0
  24. # a group id of 0 means no group
  25. #
  26. if self.proxy is None:
  27. logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint)
  28. self.socket.connect(self.endpoint)
  29. else:
  30. logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy)
  31. self.socket.connect(self.proxy)
  32. #
  33. #
  34. def _run_iteration(self, block=True):
  35. if self.state is self.states.READY_TO_BEGIN:
  36. if self.proxy is None:
  37. group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
  38. self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
  39. self.state = self.states.SEND_GROUP_ID
  40. else:
  41. self.sub_protocol = basic_protocols.Socks4Protocol(self.socket, self.endpoint, username=self.username)
  42. self.state = self.states.CONNECT_TO_PROXY
  43. #
  44. #
  45. if self.state is self.states.CONNECT_TO_PROXY:
  46. if self.sub_protocol.run(block=block):
  47. group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
  48. self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
  49. self.state = self.states.SEND_GROUP_ID
  50. #logging.debug('Sent group ID.')
  51. #
  52. #
  53. if self.state is self.states.SEND_GROUP_ID:
  54. if block:
  55. time.sleep(self.wait_until-time.time())
  56. #
  57. if time.time() >= self.wait_until and self.sub_protocol.run(block=block):
  58. self.sub_protocol = basic_protocols.PushDataProtocol(self.socket, self.total_bytes,
  59. data_generator=self.data_generator,
  60. send_max_bytes=1024*512)
  61. self.state = self.states.PUSH_DATA
  62. #
  63. #
  64. if self.state is self.states.PUSH_DATA:
  65. if self.sub_protocol.run(block=block):
  66. self.state = self.states.DONE
  67. return True
  68. #
  69. #
  70. return False
  71. #
  72. #
  73. class ServerConnectionProtocol(basic_protocols.Protocol):
  74. def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None):
  75. self.socket = socket
  76. self.conn_id = conn_id
  77. self.group_id_callback = group_id_callback
  78. self.bandwidth_callback = bandwidth_callback
  79. #
  80. self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_GROUP_ID PULL_DATA DONE')
  81. self.state = self.states.READY_TO_BEGIN
  82. #
  83. self.sub_protocol = None
  84. #
  85. def _run_iteration(self, block=True):
  86. if self.state is self.states.READY_TO_BEGIN:
  87. self.sub_protocol = basic_protocols.ReceiveDataProtocol(self.socket)
  88. self.state = self.states.RECV_GROUP_ID
  89. #
  90. if self.state is self.states.RECV_GROUP_ID:
  91. if self.sub_protocol.run(block=block):
  92. group_id = int.from_bytes(self.sub_protocol.received_data, byteorder='big', signed=False)
  93. if group_id == 0:
  94. # a group of 0 means no group
  95. group_id = None
  96. #
  97. self.group_id_callback(self.conn_id, group_id)
  98. self.sub_protocol = basic_protocols.PullDataProtocolWithMetrics(self.socket)
  99. self.state = self.states.PULL_DATA
  100. #
  101. #
  102. if self.state is self.states.PULL_DATA:
  103. if self.sub_protocol.run(block=block):
  104. self.state = self.states.DONE
  105. if self.bandwidth_callback:
  106. self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate())
  107. #
  108. return True
  109. #
  110. #
  111. return False
  112. #
  113. #
  114. if __name__ == '__main__':
  115. import sys
  116. logging.basicConfig(level=logging.DEBUG)
  117. #
  118. if sys.argv[1] == 'client':
  119. import os
  120. #
  121. endpoint = ('127.0.0.1', 4747)
  122. #endpoint = ('127.0.0.1', 8627)
  123. #proxy = ('127.0.0.1', 9003+int(sys.argv[3])-1)
  124. #proxy = ('127.0.0.1', 9003)
  125. proxy = None
  126. username = bytes([x for x in os.urandom(12) if x != 0])
  127. #username = None
  128. data_MB = 500000
  129. #
  130. if len(sys.argv) > 2:
  131. wait_until = int(sys.argv[2])
  132. else:
  133. wait_until = None
  134. #
  135. client = ClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username, wait_until=wait_until)
  136. client.run()
  137. #
  138. elif sys.argv[1] == 'server':
  139. import multiprocessing
  140. import queue
  141. #
  142. endpoint = ('127.0.0.1', 4747)
  143. processes = []
  144. processes_map = {}
  145. joinable_connections = multiprocessing.Queue()
  146. conn_counter = [0]
  147. group_queue = multiprocessing.Queue()
  148. bw_queue = multiprocessing.Queue()
  149. #
  150. def group_id_callback(conn_id, group_id):
  151. # put them in a queue to display later
  152. #logging.debug('For conn %d Received group id: %d', conn_id, group_id)
  153. group_queue.put({'conn_id':conn_id, 'group_id':group_id})
  154. #
  155. def bw_callback(conn_id, data_size, transfer_rate):
  156. # put them in a queue to display later
  157. bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'transfer_rate':transfer_rate})
  158. #
  159. def start_server_conn(socket, conn_id):
  160. server = ServerConnectionProtocol(socket, conn_id, group_id_callback=group_id_callback, bandwidth_callback=bw_callback)
  161. try:
  162. server.run()
  163. except KeyboardInterrupt:
  164. socket.close()
  165. finally:
  166. joinable_connections.put(conn_id)
  167. #
  168. #
  169. def accept_callback(socket):
  170. conn_id = conn_counter[0]
  171. conn_counter[0] += 1
  172. #logging.debug('Adding connection %d', conn_id)
  173. p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
  174. processes.append(p)
  175. processes_map[conn_id] = p
  176. p.start()
  177. #
  178. l = basic_protocols.ServerListener(endpoint, accept_callback)
  179. #
  180. try:
  181. while True:
  182. l.accept()
  183. try:
  184. while True:
  185. conn_id = joinable_connections.get(False)
  186. p = processes_map[conn_id]
  187. p.join()
  188. #
  189. except queue.Empty:
  190. pass
  191. #
  192. #
  193. except KeyboardInterrupt:
  194. print()
  195. #
  196. bw_values = {}
  197. group_values = {}
  198. #
  199. try:
  200. while True:
  201. bw_val = bw_queue.get(False)
  202. bw_values[bw_val['conn_id']] = bw_val
  203. #
  204. except queue.Empty:
  205. pass
  206. #
  207. try:
  208. while True:
  209. group_val = group_queue.get(False)
  210. group_values[group_val['conn_id']] = group_val
  211. #
  212. except queue.Empty:
  213. pass
  214. #
  215. group_set = set([x['group_id'] for x in group_values.values()])
  216. for group in group_set:
  217. # doesn't handle group == None
  218. conns_in_group = [x[0] for x in group_values.items() if x[1]['group_id'] == group]
  219. in_group = [x for x in bw_values.values() if x['conn_id'] in conns_in_group]
  220. if len(in_group) > 0:
  221. avg_data_size = sum([x['data_size'] for x in in_group])/len(in_group)
  222. avg_transfer_rate = sum([x['transfer_rate'] for x in in_group])/len(in_group)
  223. #
  224. logging.info('Group size: %d', len(in_group))
  225. logging.info('Avg Transferred (MB): %.4f', avg_data_size/(1024**2))
  226. logging.info('Avg Transfer rate (MB/s): %.4f', avg_transfer_rate/(1024**2))
  227. #
  228. #
  229. #
  230. for p in processes:
  231. p.join()
  232. #
  233. #
  234. #