tmp.patch 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. diff --git a/src/basic_protocols.py b/src/basic_protocols.py
  2. index ba8b847..0dcecc8 100755
  3. --- a/src/basic_protocols.py
  4. +++ b/src/basic_protocols.py
  5. @@ -123,26 +123,28 @@ class Socks4Protocol(Protocol):
  6. #
  7. #
  8. class PushDataProtocol(Protocol):
  9. - def __init__(self, socket, total_bytes, data_generator=None, send_max_bytes=1024*512, use_accelerated=True):
  10. - if data_generator is None:
  11. - data_generator = self._default_data_generator
  12. + def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None):
  13. + if send_buffer_len is None:
  14. + send_buffer_len = 1024*512
  15. + #
  16. + if use_acceleration is None:
  17. + use_acceleration = True
  18. #
  19. self.socket = socket
  20. - self.data_generator = data_generator
  21. self.total_bytes = total_bytes
  22. - self.send_max_bytes = send_max_bytes
  23. - self.use_accelerated = use_accelerated
  24. + self.use_acceleration = use_acceleration
  25. #
  26. self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE')
  27. self.state = self.states.READY_TO_BEGIN
  28. #
  29. + self.byte_buffer = os.urandom(send_buffer_len)
  30. self.bytes_written = 0
  31. self.protocol_helper = None
  32. #
  33. def _run_iteration(self, block=True):
  34. if self.state is self.states.READY_TO_BEGIN:
  35. info = self.total_bytes.to_bytes(8, byteorder='big', signed=False)
  36. - info += self.send_max_bytes.to_bytes(8, byteorder='big', signed=False)
  37. + info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False)
  38. self.protocol_helper = ProtocolHelper()
  39. self.protocol_helper.set_buffer(info)
  40. self.state = self.states.SEND_INFO
  41. @@ -153,24 +155,28 @@ class PushDataProtocol(Protocol):
  42. #
  43. #
  44. if self.state is self.states.PUSH_DATA:
  45. - max_block_size = self.send_max_bytes
  46. - block_size = min(max_block_size, self.total_bytes-self.bytes_written)
  47. - data = self.data_generator(self.bytes_written, block_size)
  48. - #
  49. - if self.use_accelerated:
  50. + if self.use_acceleration:
  51. if not block:
  52. logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
  53. #
  54. - ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, data)
  55. + ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer)
  56. if ret_val < 0:
  57. raise ProtocolException('Error while pushing data.')
  58. #
  59. self.bytes_written = self.total_bytes
  60. else:
  61. + bytes_remaining = self.total_bytes-self.bytes_written
  62. + data_size = min(len(self.byte_buffer), bytes_remaining)
  63. + if data_size != len(self.byte_buffer):
  64. + data = self.byte_buffer[:data_size]
  65. + else:
  66. + data = self.byte_buffer
  67. + # don't make a copy of the byte string each time if we don't need to
  68. + #
  69. n = self.socket.send(data)
  70. self.bytes_written += n
  71. #
  72. - if self.bytes_written >= self.total_bytes:
  73. + if self.bytes_written == self.total_bytes:
  74. # finished sending the data
  75. logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
  76. self.protocol_helper = ProtocolHelper()
  77. @@ -190,20 +196,20 @@ class PushDataProtocol(Protocol):
  78. #
  79. return False
  80. #
  81. - def _default_data_generator(self, index, bytes_needed):
  82. - return os.urandom(bytes_needed)
  83. - #
  84. #
  85. class PullDataProtocol(Protocol):
  86. - def __init__(self, socket, use_accelerated=True):
  87. + def __init__(self, socket, use_acceleration=None):
  88. + if use_acceleration is None:
  89. + use_acceleration = True
  90. + #
  91. self.socket = socket
  92. - self.use_accelerated = use_accelerated
  93. + self.use_acceleration = use_acceleration
  94. #
  95. self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
  96. self.state = self.states.READY_TO_BEGIN
  97. #
  98. self.data_size = None
  99. - self.recv_max_bytes = None
  100. + self.recv_buffer_len = None
  101. self.bytes_read = 0
  102. self.protocol_helper = None
  103. self._time_of_first_byte = None
  104. @@ -219,27 +225,28 @@ class PullDataProtocol(Protocol):
  105. if self.protocol_helper.recv(self.socket, info_size):
  106. response = self.protocol_helper.get_buffer()
  107. self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False)
  108. - self.recv_max_bytes = int.from_bytes(response[8:16], byteorder='big', signed=False)
  109. + self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False)
  110. self.state = self.states.PULL_DATA
  111. #
  112. #
  113. if self.state is self.states.PULL_DATA:
  114. - max_block_size = self.recv_max_bytes
  115. - block_size = min(max_block_size, self.data_size-self.bytes_read)
  116. - #
  117. - if self.use_accelerated:
  118. + if self.use_acceleration:
  119. if not block:
  120. logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
  121. #
  122. - (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, block_size)
  123. + (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len)
  124. if ret_val < 0:
  125. raise ProtocolException('Error while pulling data.')
  126. #
  127. self.bytes_read = self.data_size
  128. self.elapsed_time = elapsed_time
  129. else:
  130. + bytes_remaining = self.data_size-self.bytes_read
  131. + block_size = min(self.recv_buffer_len, bytes_remaining)
  132. + #
  133. data = self.socket.recv(block_size)
  134. self.bytes_read += len(data)
  135. + #
  136. if self.bytes_read != 0 and self._time_of_first_byte is None:
  137. self._time_of_first_byte = time.time()
  138. #
  139. diff --git a/src/throughput_client.py b/src/throughput_client.py
  140. index d45dffe..0be8289 100644
  141. --- a/src/throughput_client.py
  142. +++ b/src/throughput_client.py
  143. @@ -17,6 +17,8 @@ if __name__ == '__main__':
  144. parser.add_argument('--proxy', type=str, help='proxy ip address and port', metavar=('ip','port'), nargs=2)
  145. parser.add_argument('--wait', type=int,
  146. help='wait until the given time before pushing data (time in seconds since epoch)', metavar='time')
  147. + parser.add_argument('--buffer-len', type=useful.parse_bytes, help='size of the send and receive buffers (can also end with \'B\', \'KiB\', \'MiB\', or \'GiB\')', metavar='bytes')
  148. + parser.add_argument('--no-accel', action='store_true', help='don\'t use C acceleration (use pure Python)')
  149. args = parser.parse_args()
  150. #
  151. endpoint = (args.ip, args.port)
  152. @@ -27,7 +29,20 @@ if __name__ == '__main__':
  153. #
  154. username = bytes([x for x in os.urandom(12) if x != 0])
  155. #username = None
  156. + '''
  157. + data_MB = 200 #20000
  158. + data_B = data_MB*2**20
  159. #
  160. - client = throughput_protocols.ClientProtocol(endpoint, args.num_bytes, proxy=proxy, username=username, wait_until=args.wait)
  161. + if len(sys.argv) > 2:
  162. + wait_until = int(sys.argv[2])
  163. + else:
  164. + wait_until = None
  165. + #
  166. + '''
  167. + #
  168. + client = throughput_protocols.ClientProtocol(endpoint, args.num_bytes, proxy=proxy,
  169. + username=username, wait_until=args.wait,
  170. + send_buffer_len=args.buffer_len,
  171. + use_acceleration=(not args.no_accel))
  172. client.run()
  173. #
  174. diff --git a/src/throughput_protocols.py b/src/throughput_protocols.py
  175. index 5dec4b6..3eb3d60 100755
  176. --- a/src/throughput_protocols.py
  177. +++ b/src/throughput_protocols.py
  178. @@ -7,13 +7,14 @@ import time
  179. import socket
  180. #
  181. class ClientProtocol(basic_protocols.Protocol):
  182. - def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None, wait_until=None):
  183. + def __init__(self, endpoint, total_bytes, proxy=None, username=None, wait_until=None, send_buffer_len=None, use_acceleration=None):
  184. self.endpoint = endpoint
  185. - self.data_generator = data_generator
  186. self.total_bytes = total_bytes
  187. self.proxy = proxy
  188. self.username = username
  189. self.wait_until = wait_until
  190. + self.send_buffer_len = send_buffer_len
  191. + self.use_acceleration = use_acceleration
  192. #
  193. self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY SEND_GROUP_ID PUSH_DATA DONE')
  194. self.state = self.states.READY_TO_BEGIN
  195. @@ -47,7 +48,6 @@ class ClientProtocol(basic_protocols.Protocol):
  196. group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
  197. self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
  198. self.state = self.states.SEND_GROUP_ID
  199. - #logging.debug('Sent group ID.')
  200. #
  201. #
  202. if self.state is self.states.SEND_GROUP_ID:
  203. @@ -56,8 +56,8 @@ class ClientProtocol(basic_protocols.Protocol):
  204. #
  205. if (self.wait_until is None or time.time() >= self.wait_until) and self.sub_protocol.run(block=block):
  206. self.sub_protocol = basic_protocols.PushDataProtocol(self.socket, self.total_bytes,
  207. - data_generator=self.data_generator,
  208. - send_max_bytes=1024*512)
  209. + send_buffer_len=self.send_buffer_len,
  210. + use_acceleration=self.use_acceleration)
  211. self.state = self.states.PUSH_DATA
  212. #
  213. #
  214. @@ -71,11 +71,12 @@ class ClientProtocol(basic_protocols.Protocol):
  215. #
  216. #
  217. class ServerProtocol(basic_protocols.Protocol):
  218. - def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None):
  219. + def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None, use_acceleration=None):
  220. self.socket = socket
  221. self.conn_id = conn_id
  222. self.group_id_callback = group_id_callback
  223. self.bandwidth_callback = bandwidth_callback
  224. + self.use_acceleration = use_acceleration
  225. #
  226. self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_GROUP_ID PULL_DATA DONE')
  227. self.state = self.states.READY_TO_BEGIN
  228. @@ -95,7 +96,7 @@ class ServerProtocol(basic_protocols.Protocol):
  229. group_id = None
  230. #
  231. self.group_id_callback(self.conn_id, group_id)
  232. - self.sub_protocol = basic_protocols.PullDataProtocol(self.socket)
  233. + self.sub_protocol = basic_protocols.PullDataProtocol(self.socket, use_acceleration=self.use_acceleration)
  234. self.state = self.states.PULL_DATA
  235. #
  236. #
  237. diff --git a/src/throughput_server.py b/src/throughput_server.py
  238. index a22ed8f..0217d14 100644
  239. --- a/src/throughput_server.py
  240. +++ b/src/throughput_server.py
  241. @@ -13,6 +13,7 @@ if __name__ == '__main__':
  242. #
  243. parser = argparse.ArgumentParser(description='Test the network throughput (optionally through a proxy).')
  244. parser.add_argument('port', type=int, help='listen on port')
  245. + parser.add_argument('--no-accel', action='store_true', help='don\'t use C acceleration (use pure Python)')
  246. args = parser.parse_args()
  247. #
  248. endpoint = ('127.0.0.1', args.port)
  249. @@ -34,7 +35,8 @@ if __name__ == '__main__':
  250. bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'transfer_rate':transfer_rate})
  251. #
  252. def start_server_conn(socket, conn_id):
  253. - server = throughput_protocols.ServerProtocol(socket, conn_id, group_id_callback=group_id_callback, bandwidth_callback=bw_callback)
  254. + server = throughput_protocols.ServerProtocol(socket, conn_id, group_id_callback=group_id_callback,
  255. + bandwidth_callback=bw_callback, use_acceleration=(not args.no_accel))
  256. try:
  257. server.run()
  258. except KeyboardInterrupt: