basic_protocols.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. #!/usr/bin/python3
  2. #
  3. import socket
  4. import struct
  5. import logging
  6. import time
  7. import enum
  8. import select
  9. import os
  10. #
  11. import accelerated_functions
  12. #
  13. class ProtocolException(Exception):
  14. pass
  15. #
  16. class ProtocolHelper():
  17. def __init__(self):
  18. self._buffer = b''
  19. #
  20. def set_buffer(self, data):
  21. self._buffer = data
  22. #
  23. def get_buffer(self):
  24. return self._buffer
  25. #
  26. def recv(self, socket, num_bytes):
  27. data = socket.recv(num_bytes-len(self._buffer))
  28. self._buffer += data
  29. if len(self._buffer) == num_bytes:
  30. return True
  31. #
  32. return False
  33. #
  34. def send(self, socket):
  35. n = socket.send(self._buffer)
  36. self._buffer = self._buffer[n:]
  37. if len(self._buffer) == 0:
  38. return True
  39. #
  40. return False
  41. #
  42. #
  43. class Protocol():
  44. def _run_iteration(self, block=True):
  45. pass
  46. #
  47. def run(self, block=True):
  48. while True:
  49. finished = self._run_iteration(block=block)
  50. #
  51. if finished:
  52. # protocol is done
  53. return True
  54. elif not block:
  55. # not done the protocol yet, but don't block
  56. return False
  57. #
  58. #
  59. #
  60. #
  61. class ChainedProtocol(Protocol):
  62. def __init__(self, protocols):
  63. self.protocols = protocols
  64. self.current_protocol = 0
  65. #
  66. self.states = enum.Enum('CHAIN_STATES', 'READY_TO_BEGIN RUNNING DONE')
  67. self.state = self.states.READY_TO_BEGIN
  68. #
  69. def _run_iteration(self, block=True):
  70. if self.state is self.states.READY_TO_BEGIN:
  71. self.state = self.states.RUNNING
  72. #
  73. if self.state is self.states.RUNNING:
  74. if self.protocols[self.current_protocol] is None or self.protocols[self.current_protocol].run(block=block):
  75. self.current_protocol += 1
  76. #
  77. if self.current_protocol >= len(self.protocols):
  78. self.state = self.states.DONE
  79. #
  80. #
  81. if self.state is self.states.DONE:
  82. return True
  83. #
  84. return False
  85. #
  86. #
  87. class Socks4Protocol(Protocol):
  88. def __init__(self, socket, addr_port, username=None):
  89. self.socket = socket
  90. self.addr_port = addr_port
  91. self.username = username
  92. #
  93. self.states = enum.Enum('SOCKS_4_STATES', 'READY_TO_BEGIN CONNECTING_TO_PROXY WAITING_FOR_PROXY DONE')
  94. self.state = self.states.READY_TO_BEGIN
  95. #
  96. self.protocol_helper = None
  97. #
  98. def _run_iteration(self, block=True):
  99. if self.state is self.states.READY_TO_BEGIN:
  100. self.protocol_helper = ProtocolHelper()
  101. self.protocol_helper.set_buffer(self.socks_cmd(self.addr_port, self.username))
  102. self.state = self.states.CONNECTING_TO_PROXY
  103. #
  104. if self.state is self.states.CONNECTING_TO_PROXY:
  105. if self.protocol_helper.send(self.socket):
  106. self.protocol_helper = ProtocolHelper()
  107. self.state = self.states.WAITING_FOR_PROXY
  108. #logging.debug('Waiting for reply from proxy')
  109. #
  110. #
  111. if self.state is self.states.WAITING_FOR_PROXY:
  112. response_size = 8
  113. if self.protocol_helper.recv(self.socket, response_size):
  114. response = self.protocol_helper.get_buffer()
  115. if response[1] != 0x5a:
  116. raise ProtocolException('Could not connect to SOCKS proxy, msg: %x'%(response[1],))
  117. #
  118. self.state = self.states.DONE
  119. #
  120. #
  121. if self.state is self.states.DONE:
  122. return True
  123. #
  124. return False
  125. #
  126. def socks_cmd(self, addr_port, username=None):
  127. socks_version = 4
  128. command = 1
  129. dnsname = b''
  130. host, port = addr_port
  131. #
  132. try:
  133. username = bytes(username, 'utf8')
  134. except TypeError:
  135. pass
  136. #
  137. if username is None:
  138. username = b''
  139. elif b'\x00' in username:
  140. raise ProtocolException('Username cannot contain a NUL character.')
  141. #
  142. username = username+b'\x00'
  143. #
  144. try:
  145. addr = socket.inet_aton(host)
  146. except socket.error:
  147. addr = b'\x00\x00\x00\x01'
  148. dnsname = bytes(host, 'utf8')+b'\x00'
  149. #
  150. return struct.pack('!BBH', socks_version, command, port) + addr + username + dnsname
  151. #
  152. #
  153. class PushDataProtocol(Protocol):
  154. def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None):
  155. if send_buffer_len is None:
  156. send_buffer_len = 1024*512
  157. #
  158. if use_acceleration is None:
  159. use_acceleration = True
  160. #
  161. self.socket = socket
  162. self.total_bytes = total_bytes
  163. self.use_acceleration = use_acceleration
  164. #
  165. self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE')
  166. self.state = self.states.READY_TO_BEGIN
  167. #
  168. self.byte_buffer = os.urandom(send_buffer_len)
  169. self.bytes_written = 0
  170. self.protocol_helper = None
  171. #
  172. def _run_iteration(self, block=True):
  173. if self.state is self.states.READY_TO_BEGIN:
  174. info = self.total_bytes.to_bytes(8, byteorder='big', signed=False)
  175. info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False)
  176. self.protocol_helper = ProtocolHelper()
  177. self.protocol_helper.set_buffer(info)
  178. self.state = self.states.SEND_INFO
  179. #
  180. if self.state is self.states.SEND_INFO:
  181. if self.protocol_helper.send(self.socket):
  182. self.state = self.states.PUSH_DATA
  183. #
  184. #
  185. if self.state is self.states.PUSH_DATA:
  186. if self.use_acceleration:
  187. if not block:
  188. logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
  189. #
  190. ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer)
  191. if ret_val < 0:
  192. raise ProtocolException('Error while pushing data.')
  193. #
  194. self.bytes_written = self.total_bytes
  195. else:
  196. bytes_remaining = self.total_bytes-self.bytes_written
  197. data_size = min(len(self.byte_buffer), bytes_remaining)
  198. if data_size != len(self.byte_buffer):
  199. data = self.byte_buffer[:data_size]
  200. else:
  201. data = self.byte_buffer
  202. # don't make a copy of the byte string each time if we don't need to
  203. #
  204. n = self.socket.send(data)
  205. self.bytes_written += n
  206. #
  207. if self.bytes_written == self.total_bytes:
  208. # finished sending the data
  209. logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
  210. self.protocol_helper = ProtocolHelper()
  211. self.state = self.states.RECV_CONFIRMATION
  212. #
  213. #
  214. if self.state is self.states.RECV_CONFIRMATION:
  215. response_size = 8
  216. if self.protocol_helper.recv(self.socket, response_size):
  217. response = self.protocol_helper.get_buffer()
  218. if response != b'RECEIVED':
  219. raise ProtocolException('Did not receive the expected message: {}'.format(response))
  220. #
  221. self.state = self.states.DONE
  222. #
  223. #
  224. if self.state is self.states.DONE:
  225. return True
  226. #
  227. return False
  228. #
  229. #
  230. class PullDataProtocol(Protocol):
  231. def __init__(self, socket, use_acceleration=None):
  232. if use_acceleration is None:
  233. use_acceleration = True
  234. #
  235. self.socket = socket
  236. self.use_acceleration = use_acceleration
  237. #
  238. self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
  239. self.state = self.states.READY_TO_BEGIN
  240. #
  241. self.data_size = None
  242. self.recv_buffer_len = None
  243. self.bytes_read = 0
  244. self.protocol_helper = None
  245. self._time_of_first_byte = None
  246. self.elapsed_time = None
  247. #
  248. def _run_iteration(self, block=True):
  249. if self.state is self.states.READY_TO_BEGIN:
  250. self.protocol_helper = ProtocolHelper()
  251. self.state = self.states.RECV_INFO
  252. #
  253. if self.state is self.states.RECV_INFO:
  254. info_size = 16
  255. if self.protocol_helper.recv(self.socket, info_size):
  256. response = self.protocol_helper.get_buffer()
  257. self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False)
  258. self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False)
  259. self.state = self.states.PULL_DATA
  260. #
  261. #
  262. if self.state is self.states.PULL_DATA:
  263. if self.use_acceleration:
  264. if not block:
  265. logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
  266. #
  267. (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len)
  268. if ret_val < 0:
  269. raise ProtocolException('Error while pulling data.')
  270. #
  271. self.bytes_read = self.data_size
  272. self.elapsed_time = elapsed_time
  273. else:
  274. bytes_remaining = self.data_size-self.bytes_read
  275. block_size = min(self.recv_buffer_len, bytes_remaining)
  276. #
  277. data = self.socket.recv(block_size)
  278. self.bytes_read += len(data)
  279. #
  280. if self.bytes_read != 0 and self._time_of_first_byte is None:
  281. self._time_of_first_byte = time.time()
  282. #
  283. if self.bytes_read == self.data_size and self.elapsed_time is None:
  284. self.elapsed_time = time.time()-self._time_of_first_byte
  285. #
  286. #
  287. if self.bytes_read == self.data_size:
  288. # finished receiving the data
  289. logging.debug('Finished receiving the data.')
  290. self.protocol_helper = ProtocolHelper()
  291. self.protocol_helper.set_buffer(b'RECEIVED')
  292. self.state = self.states.SEND_CONFIRMATION
  293. #
  294. #
  295. if self.state is self.states.SEND_CONFIRMATION:
  296. if self.protocol_helper.send(self.socket):
  297. self.state = self.states.DONE
  298. #
  299. #
  300. if self.state is self.states.DONE:
  301. return True
  302. #
  303. return False
  304. #
  305. def calc_transfer_rate(self):
  306. """ Returns bytes/s. """
  307. assert self.data_size is not None and self.elapsed_time is not None
  308. return self.data_size/self.elapsed_time
  309. #
  310. #
  311. class SendDataProtocol(Protocol):
  312. def __init__(self, socket, data):
  313. self.socket = socket
  314. self.send_data = data
  315. #
  316. self.states = enum.Enum('SEND_DATA_STATES', 'READY_TO_BEGIN SEND_INFO SEND_DATA RECV_CONFIRMATION DONE')
  317. self.state = self.states.READY_TO_BEGIN
  318. #
  319. self.protocol_helper = None
  320. #
  321. def _run_iteration(self, block=True):
  322. if self.state is self.states.READY_TO_BEGIN:
  323. info_size = 20
  324. info = len(self.send_data).to_bytes(info_size, byteorder='big', signed=False)
  325. self.protocol_helper = ProtocolHelper()
  326. self.protocol_helper.set_buffer(info)
  327. self.state = self.states.SEND_INFO
  328. #
  329. if self.state is self.states.SEND_INFO:
  330. if self.protocol_helper.send(self.socket):
  331. self.protocol_helper = ProtocolHelper()
  332. self.protocol_helper.set_buffer(self.send_data)
  333. self.state = self.states.SEND_DATA
  334. #
  335. #
  336. if self.state is self.states.SEND_DATA:
  337. if self.protocol_helper.send(self.socket):
  338. self.protocol_helper = ProtocolHelper()
  339. self.state = self.states.RECV_CONFIRMATION
  340. #
  341. #
  342. if self.state is self.states.RECV_CONFIRMATION:
  343. response_size = 8
  344. if self.protocol_helper.recv(self.socket, response_size):
  345. response = self.protocol_helper.get_buffer()
  346. if response != b'RECEIVED':
  347. raise ProtocolException('Did not receive the expected message: {}'.format(response))
  348. #
  349. self.state = self.states.DONE
  350. #
  351. #
  352. if self.state is self.states.DONE:
  353. return True
  354. #
  355. return False
  356. #
  357. #
  358. class ReceiveDataProtocol(Protocol):
  359. def __init__(self, socket):
  360. self.socket = socket
  361. #
  362. self.states = enum.Enum('RECV_DATA_STATES', 'READY_TO_BEGIN RECV_INFO RECV_DATA SEND_CONFIRMATION DONE')
  363. self.state = self.states.READY_TO_BEGIN
  364. #
  365. self.protocol_helper = None
  366. self.data_size = None
  367. self.received_data = None
  368. #
  369. def _run_iteration(self, block=True):
  370. if self.state is self.states.READY_TO_BEGIN:
  371. self.protocol_helper = ProtocolHelper()
  372. self.state = self.states.RECV_INFO
  373. #
  374. if self.state is self.states.RECV_INFO:
  375. info_size = 20
  376. if self.protocol_helper.recv(self.socket, info_size):
  377. response = self.protocol_helper.get_buffer()
  378. self.data_size = int.from_bytes(response, byteorder='big', signed=False)
  379. self.protocol_helper = ProtocolHelper()
  380. self.state = self.states.RECV_DATA
  381. #
  382. #
  383. if self.state is self.states.RECV_DATA:
  384. if self.protocol_helper.recv(self.socket, self.data_size):
  385. response = self.protocol_helper.get_buffer()
  386. self.received_data = response
  387. self.protocol_helper = ProtocolHelper()
  388. self.protocol_helper.set_buffer(b'RECEIVED')
  389. self.state = self.states.SEND_CONFIRMATION
  390. #
  391. #
  392. if self.state is self.states.SEND_CONFIRMATION:
  393. if self.protocol_helper.send(self.socket):
  394. self.state = self.states.DONE
  395. #
  396. #
  397. if self.state is self.states.DONE:
  398. return True
  399. #
  400. return False
  401. #
  402. #
  403. class ServerListener():
  404. "A TCP listener, binding, listening and accepting new connections."
  405. def __init__(self, endpoint, accept_callback):
  406. self.callback = accept_callback
  407. #
  408. self.s = socket.socket()
  409. self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  410. self.s.bind(endpoint)
  411. self.s.listen(0)
  412. #
  413. def accept(self):
  414. newsock, endpoint = self.s.accept()
  415. logging.debug("New client from %s:%d (fd=%d)",
  416. endpoint[0], endpoint[1], newsock.fileno())
  417. self.callback(newsock)
  418. #
  419. #
  420. class SimpleClientConnectionProtocol(Protocol):
  421. def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None):
  422. self.endpoint = endpoint
  423. self.data_generator = data_generator
  424. self.total_bytes = total_bytes
  425. self.proxy = proxy
  426. self.username = username
  427. #
  428. self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY PUSH_DATA DONE')
  429. self.state = self.states.READY_TO_BEGIN
  430. #
  431. self.socket = socket.socket()
  432. self.sub_protocol = None
  433. #
  434. if self.proxy is None:
  435. logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint)
  436. self.socket.connect(self.endpoint)
  437. else:
  438. logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy)
  439. self.socket.connect(self.proxy)
  440. #
  441. #
  442. def _run_iteration(self, block=True):
  443. if self.state is self.states.READY_TO_BEGIN:
  444. if self.proxy is None:
  445. self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
  446. self.state = self.states.PUSH_DATA
  447. else:
  448. self.sub_protocol = Socks4Protocol(self.socket, self.endpoint, username=self.username)
  449. self.state = self.states.CONNECT_TO_PROXY
  450. #
  451. #
  452. if self.state is self.states.CONNECT_TO_PROXY:
  453. if self.sub_protocol.run(block=block):
  454. self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
  455. self.state = self.states.PUSH_DATA
  456. #
  457. #
  458. if self.state is self.states.PUSH_DATA:
  459. if self.sub_protocol.run(block=block):
  460. self.state = self.states.DONE
  461. #
  462. #
  463. if self.state is self.states.DONE:
  464. return True
  465. #
  466. return False
  467. #
  468. #
  469. class SimpleServerConnectionProtocol(Protocol):
  470. def __init__(self, socket, conn_id, bandwidth_callback=None):
  471. self.socket = socket
  472. self.conn_id = conn_id
  473. self.bandwidth_callback = bandwidth_callback
  474. #
  475. self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN PULL_DATA DONE')
  476. self.state = self.states.READY_TO_BEGIN
  477. #
  478. self.sub_protocol = None
  479. #
  480. def _run_iteration(self, block=True):
  481. if self.state is self.states.READY_TO_BEGIN:
  482. self.sub_protocol = PullDataProtocol(self.socket)
  483. self.state = self.states.PULL_DATA
  484. #
  485. if self.state is self.states.PULL_DATA:
  486. if self.sub_protocol.run(block=block):
  487. if self.bandwidth_callback:
  488. self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate())
  489. #
  490. self.state = self.states.DONE
  491. #
  492. #
  493. if self.state is self.states.DONE:
  494. return True
  495. #
  496. return False
  497. #
  498. #
  499. if __name__ == '__main__':
  500. import sys
  501. logging.basicConfig(level=logging.DEBUG)
  502. #
  503. if sys.argv[1] == 'client':
  504. endpoint = ('127.0.0.1', 4747)
  505. proxy = ('127.0.0.1', 9003)
  506. #proxy = None
  507. username = bytes([x for x in os.urandom(12) if x != 0])
  508. #username = None
  509. data_MB = 40
  510. #
  511. client = SimpleClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username)
  512. client.run()
  513. elif sys.argv[1] == 'server':
  514. import multiprocessing
  515. import queue
  516. #
  517. endpoint = ('127.0.0.1', 4747)
  518. processes = []
  519. conn_counter = [0]
  520. #
  521. def bw_callback(conn_id, data_size, transfer_rate):
  522. logging.info('Avg Transferred (MB): %.4f', data_size/(1024**2))
  523. logging.info('Avg Transfer rate (MB/s): %.4f', transfer_rate/(1024**2))
  524. #
  525. def start_server_conn(socket, conn_id):
  526. server = SimpleServerConnectionProtocol(socket, conn_id, bandwidth_callback=bw_callback)
  527. try:
  528. server.run()
  529. except KeyboardInterrupt:
  530. socket.close()
  531. #
  532. #
  533. def accept_callback(socket):
  534. conn_id = conn_counter[0]
  535. conn_counter[0] += 1
  536. #
  537. p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
  538. processes.append(p)
  539. p.start()
  540. #
  541. l = ServerListener(endpoint, accept_callback)
  542. #
  543. try:
  544. while True:
  545. l.accept()
  546. #
  547. except KeyboardInterrupt:
  548. print()
  549. #
  550. for p in processes:
  551. p.join()
  552. #
  553. #
  554. #