correctness_tester.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #!/usr/bin/python3
  2. #
  3. import basic_protocols
  4. import logging
  5. import enum
  6. import time
  7. import socket
  8. #
  9. class ClientConnectionProtocol(basic_protocols.Protocol):
  10. def __init__(self, endpoint, data, proxy=None, username=None):
  11. self.endpoint = endpoint
  12. self.data = data
  13. self.proxy = proxy
  14. self.username = username
  15. #
  16. self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY SEND_DATA DONE')
  17. self.state = self.states.READY_TO_BEGIN
  18. #
  19. self.socket = socket.socket()
  20. self.sub_protocol = None
  21. #
  22. if self.proxy is None:
  23. logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint)
  24. self.socket.connect(self.endpoint)
  25. else:
  26. logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy)
  27. self.socket.connect(self.proxy)
  28. #
  29. #
  30. def _run_iteration(self, block=True):
  31. if self.state is self.states.READY_TO_BEGIN:
  32. if self.proxy is None:
  33. self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, self.data)
  34. self.state = self.states.SEND_DATA
  35. else:
  36. #self.sub_protocol = basic_protocols.Socks4Protocol(self.socket, self.endpoint, username=self.username)
  37. self.sub_protocol = basic_protocols.WeirdProtocol(self.socket, self.endpoint)
  38. self.state = self.states.CONNECT_TO_PROXY
  39. #
  40. #
  41. if self.state is self.states.CONNECT_TO_PROXY:
  42. if self.sub_protocol.run(block=block):
  43. self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, self.data)
  44. self.state = self.states.SEND_DATA
  45. #
  46. #
  47. if self.state is self.states.SEND_DATA:
  48. if self.sub_protocol.run(block=block):
  49. self.state = self.states.DONE
  50. return True
  51. #
  52. #
  53. return False
  54. #
  55. #
  56. class ServerConnectionProtocol(basic_protocols.Protocol):
  57. def __init__(self, socket, conn_id, data_callback=None):
  58. self.socket = socket
  59. self.conn_id = conn_id
  60. self.data_callback = data_callback
  61. #
  62. self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_DATA DONE')
  63. self.state = self.states.READY_TO_BEGIN
  64. #
  65. self.sub_protocol = None
  66. #
  67. def _run_iteration(self, block=True):
  68. if self.state is self.states.READY_TO_BEGIN:
  69. self.sub_protocol = basic_protocols.ReceiveDataProtocol(self.socket)
  70. self.state = self.states.RECV_DATA
  71. #
  72. if self.state is self.states.RECV_DATA:
  73. if self.sub_protocol.run(block=block):
  74. self.data_callback(self.conn_id, self.sub_protocol.received_data)
  75. self.state = self.states.DONE
  76. return True
  77. #
  78. #
  79. return False
  80. #
  81. #
  82. if __name__ == '__main__':
  83. import sys
  84. logging.basicConfig(level=logging.DEBUG)
  85. #
  86. import random
  87. random.seed(10)
  88. data_to_send = bytearray(random.getrandbits(8) for _ in range(1024*1024*100))
  89. #
  90. print('Generated bytes')
  91. #
  92. if sys.argv[1] == 'client':
  93. import os
  94. #
  95. endpoint = ('127.0.0.1', 4747)
  96. #endpoint = ('127.0.0.1', 8627)
  97. #proxy = ('127.0.0.1', 9003+int(sys.argv[3])-1)
  98. #proxy = ('127.0.0.1', 9003)
  99. proxy = ('127.0.0.1', 12849)
  100. #proxy = None
  101. username = bytes([x for x in os.urandom(12) if x != 0])
  102. #username = None
  103. #
  104. client = ClientConnectionProtocol(endpoint, data_to_send, proxy=proxy, username=username)
  105. client.run()
  106. #
  107. elif sys.argv[1] == 'server':
  108. import multiprocessing
  109. import queue
  110. #
  111. endpoint = ('127.0.0.1', 4747)
  112. processes = []
  113. processes_map = {}
  114. joinable_connections = multiprocessing.Queue()
  115. conn_counter = [0]
  116. group_queue = multiprocessing.Queue()
  117. bw_queue = multiprocessing.Queue()
  118. #
  119. def data_callback(conn_id, data):
  120. # check data here
  121. print('Received {} MB'.format(len(data)/(1024**2)))
  122. print('Data matches: {}'.format(data==data_to_send))
  123. #
  124. def start_server_conn(socket, conn_id):
  125. server = ServerConnectionProtocol(socket, conn_id, data_callback=data_callback)
  126. try:
  127. server.run()
  128. except KeyboardInterrupt:
  129. socket.close()
  130. finally:
  131. joinable_connections.put(conn_id)
  132. #
  133. #
  134. def accept_callback(socket):
  135. conn_id = conn_counter[0]
  136. conn_counter[0] += 1
  137. #logging.debug('Adding connection %d', conn_id)
  138. p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
  139. processes.append(p)
  140. processes_map[conn_id] = p
  141. p.start()
  142. socket.close()
  143. # close this process' copy of the socket
  144. #
  145. l = basic_protocols.ServerListener(endpoint, accept_callback)
  146. #
  147. try:
  148. while True:
  149. l.accept()
  150. try:
  151. while True:
  152. conn_id = joinable_connections.get(False)
  153. p = processes_map[conn_id]
  154. p.join()
  155. #
  156. except queue.Empty:
  157. pass
  158. #
  159. #
  160. except KeyboardInterrupt:
  161. print()
  162. #
  163. for p in processes:
  164. p.join()
  165. #
  166. #
  167. #