basic_protocols.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. #!/usr/bin/python3
  2. #
  3. import socket
  4. import struct
  5. import logging
  6. import time
  7. import enum
  8. import select
  9. import os
  10. #
  11. import accelerated_functions
  12. #
  13. class ProtocolException(Exception):
  14. pass
  15. #
  16. class ProtocolHelper():
  17. def __init__(self):
  18. self._buffer = b''
  19. #
  20. def set_buffer(self, data):
  21. """
  22. Set the buffer contents to the data that you wish to send.
  23. """
  24. #
  25. self._buffer = data
  26. #
  27. def get_buffer(self):
  28. return self._buffer
  29. #
  30. def recv(self, socket, num_bytes):
  31. """
  32. Try to fill up the buffer to a max of 'num_bytes'. If the buffer is filled,
  33. return True, otherwise return False.
  34. """
  35. #
  36. data = socket.recv(num_bytes-len(self._buffer))
  37. self._buffer += data
  38. if len(self._buffer) == num_bytes:
  39. return True
  40. #
  41. return False
  42. #
  43. def send(self, socket):
  44. """
  45. Try to send the remainder of the buffer. If the entire buffer has been sent,
  46. return True, otherwise return False.
  47. """
  48. #
  49. n = socket.send(self._buffer)
  50. self._buffer = self._buffer[n:]
  51. if len(self._buffer) == 0:
  52. return True
  53. #
  54. return False
  55. #
  56. #
  57. class Protocol():
  58. def _run_iteration(self, block=True):
  59. """
  60. This function should be overridden. It runs a single iteration of the protocol.
  61. """
  62. #
  63. pass
  64. #
  65. def run(self, block=True):
  66. while True:
  67. finished = self._run_iteration(block=block)
  68. #
  69. if finished:
  70. # protocol is done
  71. return True
  72. elif not block:
  73. # not done the protocol yet, but don't block
  74. return False
  75. #
  76. #
  77. #
  78. #
  79. class ChainedProtocol(Protocol):
  80. def __init__(self, protocols):
  81. self.protocols = protocols
  82. self.current_protocol = 0
  83. #
  84. self.states = enum.Enum('CHAIN_STATES', 'READY_TO_BEGIN RUNNING DONE')
  85. self.state = self.states.READY_TO_BEGIN
  86. #
  87. def _run_iteration(self, block=True):
  88. if self.state is self.states.READY_TO_BEGIN:
  89. self.state = self.states.RUNNING
  90. #
  91. if self.state is self.states.RUNNING:
  92. if self.protocols[self.current_protocol] is None or self.protocols[self.current_protocol].run(block=block):
  93. self.current_protocol += 1
  94. #
  95. if self.current_protocol >= len(self.protocols):
  96. self.state = self.states.DONE
  97. #
  98. #
  99. if self.state is self.states.DONE:
  100. return True
  101. #
  102. return False
  103. #
  104. #
  105. class Socks4Protocol(Protocol):
  106. def __init__(self, socket, addr_port, username=None):
  107. self.socket = socket
  108. self.addr_port = addr_port
  109. self.username = username
  110. #
  111. self.states = enum.Enum('SOCKS_4_STATES', 'READY_TO_BEGIN CONNECTING_TO_PROXY WAITING_FOR_PROXY DONE')
  112. self.state = self.states.READY_TO_BEGIN
  113. #
  114. self.protocol_helper = None
  115. #
  116. def _run_iteration(self, block=True):
  117. if self.state is self.states.READY_TO_BEGIN:
  118. self.protocol_helper = ProtocolHelper()
  119. self.protocol_helper.set_buffer(self.socks_cmd(self.addr_port, self.username))
  120. self.state = self.states.CONNECTING_TO_PROXY
  121. #
  122. if self.state is self.states.CONNECTING_TO_PROXY:
  123. if self.protocol_helper.send(self.socket):
  124. self.protocol_helper = ProtocolHelper()
  125. self.state = self.states.WAITING_FOR_PROXY
  126. #logging.debug('Waiting for reply from proxy')
  127. #
  128. #
  129. if self.state is self.states.WAITING_FOR_PROXY:
  130. response_size = 8
  131. if self.protocol_helper.recv(self.socket, response_size):
  132. response = self.protocol_helper.get_buffer()
  133. if response[1] != 0x5a:
  134. raise ProtocolException('Could not connect to SOCKS proxy, msg: %x'%(response[1],))
  135. #
  136. self.state = self.states.DONE
  137. #
  138. #
  139. if self.state is self.states.DONE:
  140. return True
  141. #
  142. return False
  143. #
  144. def socks_cmd(self, addr_port, username=None):
  145. socks_version = 4
  146. command = 1
  147. dnsname = b''
  148. host, port = addr_port
  149. #
  150. try:
  151. username = bytes(username, 'utf8')
  152. except TypeError:
  153. pass
  154. #
  155. if username is None:
  156. username = b''
  157. elif b'\x00' in username:
  158. raise ProtocolException('Username cannot contain a NUL character.')
  159. #
  160. username = username+b'\x00'
  161. #
  162. try:
  163. addr = socket.inet_aton(host)
  164. except socket.error:
  165. addr = b'\x00\x00\x00\x01'
  166. dnsname = bytes(host, 'utf8')+b'\x00'
  167. #
  168. return struct.pack('!BBH', socks_version, command, port) + addr + username + dnsname
  169. #
  170. #
  171. class PushDataProtocol(Protocol):
  172. def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None):
  173. if send_buffer_len is None:
  174. send_buffer_len = 1024*512
  175. #
  176. if use_acceleration is None:
  177. use_acceleration = True
  178. #
  179. self.socket = socket
  180. self.total_bytes = total_bytes
  181. self.use_acceleration = use_acceleration
  182. #
  183. self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE')
  184. self.state = self.states.READY_TO_BEGIN
  185. #
  186. self.byte_buffer = os.urandom(send_buffer_len)
  187. self.bytes_written = 0
  188. self.protocol_helper = None
  189. #
  190. def _run_iteration(self, block=True):
  191. if self.state is self.states.READY_TO_BEGIN:
  192. info = self.total_bytes.to_bytes(8, byteorder='big', signed=False)
  193. info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False)
  194. self.protocol_helper = ProtocolHelper()
  195. self.protocol_helper.set_buffer(info)
  196. self.state = self.states.SEND_INFO
  197. #
  198. if self.state is self.states.SEND_INFO:
  199. if self.protocol_helper.send(self.socket):
  200. self.state = self.states.PUSH_DATA
  201. #
  202. #
  203. if self.state is self.states.PUSH_DATA:
  204. if self.use_acceleration:
  205. if not block:
  206. logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
  207. #
  208. ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer)
  209. if ret_val < 0:
  210. raise ProtocolException('Error while pushing data.')
  211. #
  212. self.bytes_written = self.total_bytes
  213. else:
  214. bytes_remaining = self.total_bytes-self.bytes_written
  215. data_size = min(len(self.byte_buffer), bytes_remaining)
  216. if data_size != len(self.byte_buffer):
  217. data = self.byte_buffer[:data_size]
  218. else:
  219. data = self.byte_buffer
  220. # don't make a copy of the byte string each time if we don't need to
  221. #
  222. n = self.socket.send(data)
  223. self.bytes_written += n
  224. #
  225. if self.bytes_written == self.total_bytes:
  226. # finished sending the data
  227. logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
  228. self.protocol_helper = ProtocolHelper()
  229. self.state = self.states.RECV_CONFIRMATION
  230. #
  231. #
  232. if self.state is self.states.RECV_CONFIRMATION:
  233. response_size = 8
  234. if self.protocol_helper.recv(self.socket, response_size):
  235. response = self.protocol_helper.get_buffer()
  236. if response != b'RECEIVED':
  237. raise ProtocolException('Did not receive the expected message: {}'.format(response))
  238. #
  239. self.state = self.states.DONE
  240. #
  241. #
  242. if self.state is self.states.DONE:
  243. return True
  244. #
  245. return False
  246. #
  247. #
  248. class PullDataProtocol(Protocol):
  249. def __init__(self, socket, use_acceleration=None):
  250. if use_acceleration is None:
  251. use_acceleration = True
  252. #
  253. self.socket = socket
  254. self.use_acceleration = use_acceleration
  255. #
  256. self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
  257. self.state = self.states.READY_TO_BEGIN
  258. #
  259. self.data_size = None
  260. self.recv_buffer_len = None
  261. self.bytes_read = 0
  262. self.protocol_helper = None
  263. self._time_of_first_byte = None
  264. self.elapsed_time = None
  265. #
  266. def _run_iteration(self, block=True):
  267. if self.state is self.states.READY_TO_BEGIN:
  268. self.protocol_helper = ProtocolHelper()
  269. self.state = self.states.RECV_INFO
  270. #
  271. if self.state is self.states.RECV_INFO:
  272. info_size = 16
  273. if self.protocol_helper.recv(self.socket, info_size):
  274. response = self.protocol_helper.get_buffer()
  275. self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False)
  276. self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False)
  277. self.state = self.states.PULL_DATA
  278. #
  279. #
  280. if self.state is self.states.PULL_DATA:
  281. if self.use_acceleration:
  282. if not block:
  283. logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
  284. #
  285. (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len)
  286. if ret_val < 0:
  287. raise ProtocolException('Error while pulling data.')
  288. #
  289. self.bytes_read = self.data_size
  290. self.elapsed_time = elapsed_time
  291. else:
  292. bytes_remaining = self.data_size-self.bytes_read
  293. block_size = min(self.recv_buffer_len, bytes_remaining)
  294. #
  295. data = self.socket.recv(block_size)
  296. self.bytes_read += len(data)
  297. #
  298. if self.bytes_read != 0 and self._time_of_first_byte is None:
  299. self._time_of_first_byte = time.time()
  300. #
  301. if self.bytes_read == self.data_size and self.elapsed_time is None:
  302. self.elapsed_time = time.time()-self._time_of_first_byte
  303. #
  304. #
  305. if self.bytes_read == self.data_size:
  306. # finished receiving the data
  307. logging.debug('Finished receiving the data.')
  308. self.protocol_helper = ProtocolHelper()
  309. self.protocol_helper.set_buffer(b'RECEIVED')
  310. self.state = self.states.SEND_CONFIRMATION
  311. #
  312. #
  313. if self.state is self.states.SEND_CONFIRMATION:
  314. if self.protocol_helper.send(self.socket):
  315. self.state = self.states.DONE
  316. #
  317. #
  318. if self.state is self.states.DONE:
  319. return True
  320. #
  321. return False
  322. #
  323. def calc_transfer_rate(self):
  324. """ Returns bytes/s. """
  325. assert self.data_size is not None and self.elapsed_time is not None
  326. return self.data_size/self.elapsed_time
  327. #
  328. #
  329. class SendDataProtocol(Protocol):
  330. def __init__(self, socket, data):
  331. self.socket = socket
  332. self.send_data = data
  333. #
  334. self.states = enum.Enum('SEND_DATA_STATES', 'READY_TO_BEGIN SEND_INFO SEND_DATA RECV_CONFIRMATION DONE')
  335. self.state = self.states.READY_TO_BEGIN
  336. #
  337. self.protocol_helper = None
  338. #
  339. def _run_iteration(self, block=True):
  340. if self.state is self.states.READY_TO_BEGIN:
  341. info_size = 20
  342. info = len(self.send_data).to_bytes(info_size, byteorder='big', signed=False)
  343. self.protocol_helper = ProtocolHelper()
  344. self.protocol_helper.set_buffer(info)
  345. self.state = self.states.SEND_INFO
  346. #
  347. if self.state is self.states.SEND_INFO:
  348. if self.protocol_helper.send(self.socket):
  349. self.protocol_helper = ProtocolHelper()
  350. self.protocol_helper.set_buffer(self.send_data)
  351. self.state = self.states.SEND_DATA
  352. #
  353. #
  354. if self.state is self.states.SEND_DATA:
  355. if self.protocol_helper.send(self.socket):
  356. self.protocol_helper = ProtocolHelper()
  357. self.state = self.states.RECV_CONFIRMATION
  358. #
  359. #
  360. if self.state is self.states.RECV_CONFIRMATION:
  361. response_size = 8
  362. if self.protocol_helper.recv(self.socket, response_size):
  363. response = self.protocol_helper.get_buffer()
  364. if response != b'RECEIVED':
  365. raise ProtocolException('Did not receive the expected message: {}'.format(response))
  366. #
  367. self.state = self.states.DONE
  368. #
  369. #
  370. if self.state is self.states.DONE:
  371. return True
  372. #
  373. return False
  374. #
  375. #
  376. class ReceiveDataProtocol(Protocol):
  377. def __init__(self, socket):
  378. self.socket = socket
  379. #
  380. self.states = enum.Enum('RECV_DATA_STATES', 'READY_TO_BEGIN RECV_INFO RECV_DATA SEND_CONFIRMATION DONE')
  381. self.state = self.states.READY_TO_BEGIN
  382. #
  383. self.protocol_helper = None
  384. self.data_size = None
  385. self.received_data = None
  386. #
  387. def _run_iteration(self, block=True):
  388. if self.state is self.states.READY_TO_BEGIN:
  389. self.protocol_helper = ProtocolHelper()
  390. self.state = self.states.RECV_INFO
  391. #
  392. if self.state is self.states.RECV_INFO:
  393. info_size = 20
  394. if self.protocol_helper.recv(self.socket, info_size):
  395. response = self.protocol_helper.get_buffer()
  396. self.data_size = int.from_bytes(response, byteorder='big', signed=False)
  397. self.protocol_helper = ProtocolHelper()
  398. self.state = self.states.RECV_DATA
  399. #
  400. #
  401. if self.state is self.states.RECV_DATA:
  402. if self.protocol_helper.recv(self.socket, self.data_size):
  403. response = self.protocol_helper.get_buffer()
  404. self.received_data = response
  405. self.protocol_helper = ProtocolHelper()
  406. self.protocol_helper.set_buffer(b'RECEIVED')
  407. self.state = self.states.SEND_CONFIRMATION
  408. #
  409. #
  410. if self.state is self.states.SEND_CONFIRMATION:
  411. if self.protocol_helper.send(self.socket):
  412. self.state = self.states.DONE
  413. #
  414. #
  415. if self.state is self.states.DONE:
  416. return True
  417. #
  418. return False
  419. #
  420. #
  421. class ServerListener():
  422. def __init__(self, endpoint, accept_callback):
  423. self.callback = accept_callback
  424. #
  425. self.s = socket.socket()
  426. self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  427. self.s.bind(endpoint)
  428. self.s.listen(0)
  429. #
  430. def accept(self):
  431. newsock, endpoint = self.s.accept()
  432. logging.debug("New client from %s:%d (fd=%d)",
  433. endpoint[0], endpoint[1], newsock.fileno())
  434. self.callback(newsock)
  435. #
  436. #
  437. class SimpleClientConnectionProtocol(Protocol):
  438. def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None):
  439. self.endpoint = endpoint
  440. self.data_generator = data_generator
  441. self.total_bytes = total_bytes
  442. self.proxy = proxy
  443. self.username = username
  444. #
  445. self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY PUSH_DATA DONE')
  446. self.state = self.states.READY_TO_BEGIN
  447. #
  448. self.socket = socket.socket()
  449. self.sub_protocol = None
  450. #
  451. if self.proxy is None:
  452. logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint)
  453. self.socket.connect(self.endpoint)
  454. else:
  455. logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy)
  456. self.socket.connect(self.proxy)
  457. #
  458. #
  459. def _run_iteration(self, block=True):
  460. if self.state is self.states.READY_TO_BEGIN:
  461. if self.proxy is None:
  462. self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
  463. self.state = self.states.PUSH_DATA
  464. else:
  465. self.sub_protocol = Socks4Protocol(self.socket, self.endpoint, username=self.username)
  466. self.state = self.states.CONNECT_TO_PROXY
  467. #
  468. #
  469. if self.state is self.states.CONNECT_TO_PROXY:
  470. if self.sub_protocol.run(block=block):
  471. self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
  472. self.state = self.states.PUSH_DATA
  473. #
  474. #
  475. if self.state is self.states.PUSH_DATA:
  476. if self.sub_protocol.run(block=block):
  477. self.state = self.states.DONE
  478. #
  479. #
  480. if self.state is self.states.DONE:
  481. return True
  482. #
  483. return False
  484. #
  485. #
  486. class SimpleServerConnectionProtocol(Protocol):
  487. def __init__(self, socket, conn_id, bandwidth_callback=None):
  488. self.socket = socket
  489. self.conn_id = conn_id
  490. self.bandwidth_callback = bandwidth_callback
  491. #
  492. self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN PULL_DATA DONE')
  493. self.state = self.states.READY_TO_BEGIN
  494. #
  495. self.sub_protocol = None
  496. #
  497. def _run_iteration(self, block=True):
  498. if self.state is self.states.READY_TO_BEGIN:
  499. self.sub_protocol = PullDataProtocol(self.socket)
  500. self.state = self.states.PULL_DATA
  501. #
  502. if self.state is self.states.PULL_DATA:
  503. if self.sub_protocol.run(block=block):
  504. if self.bandwidth_callback:
  505. self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate())
  506. #
  507. self.state = self.states.DONE
  508. #
  509. #
  510. if self.state is self.states.DONE:
  511. return True
  512. #
  513. return False
  514. #
  515. #
  516. if __name__ == '__main__':
  517. import sys
  518. logging.basicConfig(level=logging.DEBUG)
  519. #
  520. if sys.argv[1] == 'client':
  521. endpoint = ('127.0.0.1', 4747)
  522. proxy = ('127.0.0.1', 9003)
  523. #proxy = None
  524. username = bytes([x for x in os.urandom(12) if x != 0])
  525. #username = None
  526. data_MB = 40
  527. #
  528. client = SimpleClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username)
  529. client.run()
  530. elif sys.argv[1] == 'server':
  531. import multiprocessing
  532. import queue
  533. #
  534. endpoint = ('127.0.0.1', 4747)
  535. processes = []
  536. conn_counter = [0]
  537. #
  538. def bw_callback(conn_id, data_size, transfer_rate):
  539. logging.info('Avg Transferred (MB): %.4f', data_size/(1024**2))
  540. logging.info('Avg Transfer rate (MB/s): %.4f', transfer_rate/(1024**2))
  541. #
  542. def start_server_conn(socket, conn_id):
  543. server = SimpleServerConnectionProtocol(socket, conn_id, bandwidth_callback=bw_callback)
  544. try:
  545. server.run()
  546. except KeyboardInterrupt:
  547. socket.close()
  548. #
  549. #
  550. def accept_callback(socket):
  551. conn_id = conn_counter[0]
  552. conn_counter[0] += 1
  553. #
  554. p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
  555. processes.append(p)
  556. p.start()
  557. #
  558. l = ServerListener(endpoint, accept_callback)
  559. #
  560. try:
  561. while True:
  562. l.accept()
  563. #
  564. except KeyboardInterrupt:
  565. print()
  566. #
  567. for p in processes:
  568. p.join()
  569. #
  570. #
  571. #