basic_protocols.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. #!/usr/bin/python3
  2. #
  3. import socket
  4. import struct
  5. import logging
  6. import time
  7. import enum
  8. import select
  9. import os
  10. #
  11. import accelerated_functions
  12. #
  13. class ProtocolException(Exception):
  14. pass
  15. #
  16. class ProtocolHelper():
  17. def __init__(self):
  18. self._buffer = b''
  19. #
  20. def set_buffer(self, data):
  21. """
  22. Set the buffer contents to the data that you wish to send.
  23. """
  24. #
  25. self._buffer = data
  26. #
  27. def get_buffer(self):
  28. return self._buffer
  29. #
  30. def recv(self, socket, num_bytes):
  31. """
  32. Try to fill up the buffer to a max of 'num_bytes'. If the buffer is filled,
  33. return True, otherwise return False.
  34. """
  35. #
  36. data = socket.recv(num_bytes-len(self._buffer))
  37. #
  38. if len(data) == 0:
  39. raise ProtocolException('The socket was closed.')
  40. #
  41. self._buffer += data
  42. if len(self._buffer) == num_bytes:
  43. return True
  44. #
  45. return False
  46. #
  47. def send(self, socket):
  48. """
  49. Try to send the remainder of the buffer. If the entire buffer has been sent,
  50. return True, otherwise return False.
  51. """
  52. #
  53. n = socket.send(self._buffer)
  54. self._buffer = self._buffer[n:]
  55. if len(self._buffer) == 0:
  56. return True
  57. #
  58. return False
  59. #
  60. #
  61. class Protocol():
  62. def _run_iteration(self, block=True):
  63. """
  64. This function should be overridden. It runs a single iteration of the protocol.
  65. """
  66. #
  67. pass
  68. #
  69. def run(self, block=True):
  70. while True:
  71. finished = self._run_iteration(block=block)
  72. #
  73. if finished:
  74. # protocol is done
  75. return True
  76. elif not block:
  77. # not done the protocol yet, but don't block
  78. return False
  79. #
  80. #
  81. #
  82. #
  83. class ChainedProtocol(Protocol):
  84. def __init__(self, protocols):
  85. self.protocols = protocols
  86. self.current_protocol = 0
  87. #
  88. self.states = enum.Enum('CHAIN_STATES', 'READY_TO_BEGIN RUNNING DONE')
  89. self.state = self.states.READY_TO_BEGIN
  90. #
  91. def _run_iteration(self, block=True):
  92. if self.state is self.states.READY_TO_BEGIN:
  93. self.state = self.states.RUNNING
  94. #
  95. if self.state is self.states.RUNNING:
  96. if self.protocols[self.current_protocol] is None or self.protocols[self.current_protocol].run(block=block):
  97. self.current_protocol += 1
  98. #
  99. if self.current_protocol >= len(self.protocols):
  100. self.state = self.states.DONE
  101. #
  102. #
  103. if self.state is self.states.DONE:
  104. return True
  105. #
  106. return False
  107. #
  108. #
  109. class Socks4Protocol(Protocol):
  110. def __init__(self, socket, addr_port, username=None):
  111. self.socket = socket
  112. self.addr_port = addr_port
  113. self.username = username
  114. #
  115. self.states = enum.Enum('SOCKS_4_STATES', 'READY_TO_BEGIN CONNECTING_TO_PROXY WAITING_FOR_PROXY DONE')
  116. self.state = self.states.READY_TO_BEGIN
  117. #
  118. self.protocol_helper = None
  119. #
  120. def _run_iteration(self, block=True):
  121. if self.state is self.states.READY_TO_BEGIN:
  122. self.protocol_helper = ProtocolHelper()
  123. self.protocol_helper.set_buffer(self.socks_cmd(self.addr_port, self.username))
  124. self.state = self.states.CONNECTING_TO_PROXY
  125. #
  126. if self.state is self.states.CONNECTING_TO_PROXY:
  127. if self.protocol_helper.send(self.socket):
  128. self.protocol_helper = ProtocolHelper()
  129. self.state = self.states.WAITING_FOR_PROXY
  130. #logging.debug('Waiting for reply from proxy')
  131. #
  132. #
  133. if self.state is self.states.WAITING_FOR_PROXY:
  134. response_size = 8
  135. if self.protocol_helper.recv(self.socket, response_size):
  136. response = self.protocol_helper.get_buffer()
  137. if response[1] != 0x5a:
  138. raise ProtocolException('Could not connect to SOCKS proxy, msg: %x'%(response[1],))
  139. #
  140. self.state = self.states.DONE
  141. #
  142. #
  143. if self.state is self.states.DONE:
  144. return True
  145. #
  146. return False
  147. #
  148. def socks_cmd(self, addr_port, username=None):
  149. socks_version = 4
  150. command = 1
  151. dnsname = b''
  152. host, port = addr_port
  153. #
  154. try:
  155. username = bytes(username, 'utf8')
  156. except TypeError:
  157. pass
  158. #
  159. if username is None:
  160. username = b''
  161. elif b'\x00' in username:
  162. raise ProtocolException('Username cannot contain a NUL character.')
  163. #
  164. username = username+b'\x00'
  165. #
  166. try:
  167. addr = socket.inet_aton(host)
  168. except socket.error:
  169. addr = b'\x00\x00\x00\x01'
  170. dnsname = bytes(host, 'utf8')+b'\x00'
  171. #
  172. return struct.pack('!BBH', socks_version, command, port) + addr + username + dnsname
  173. #
  174. #
  175. class PushDataProtocol(Protocol):
  176. def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None):
  177. if send_buffer_len is None:
  178. send_buffer_len = 1024*512
  179. #
  180. if use_acceleration is None:
  181. use_acceleration = True
  182. #
  183. self.socket = socket
  184. self.total_bytes = total_bytes
  185. self.use_acceleration = use_acceleration
  186. #
  187. self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE')
  188. self.state = self.states.READY_TO_BEGIN
  189. #
  190. self.byte_buffer = os.urandom(send_buffer_len)
  191. self.bytes_written = 0
  192. self.protocol_helper = None
  193. #
  194. def _run_iteration(self, block=True):
  195. if self.state is self.states.READY_TO_BEGIN:
  196. info = self.total_bytes.to_bytes(8, byteorder='big', signed=False)
  197. info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False)
  198. self.protocol_helper = ProtocolHelper()
  199. self.protocol_helper.set_buffer(info)
  200. self.state = self.states.SEND_INFO
  201. #
  202. if self.state is self.states.SEND_INFO:
  203. if self.protocol_helper.send(self.socket):
  204. self.state = self.states.PUSH_DATA
  205. #
  206. #
  207. if self.state is self.states.PUSH_DATA:
  208. if self.use_acceleration:
  209. if not block:
  210. logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
  211. #
  212. ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer)
  213. if ret_val < 0:
  214. raise ProtocolException('Error while pushing data.')
  215. #
  216. self.bytes_written = self.total_bytes
  217. else:
  218. bytes_remaining = self.total_bytes-self.bytes_written
  219. data_size = min(len(self.byte_buffer), bytes_remaining)
  220. if data_size != len(self.byte_buffer):
  221. data = self.byte_buffer[:data_size]
  222. else:
  223. data = self.byte_buffer
  224. # don't make a copy of the byte string each time if we don't need to
  225. #
  226. n = self.socket.send(data)
  227. self.bytes_written += n
  228. #
  229. if self.bytes_written == self.total_bytes:
  230. # finished sending the data
  231. logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
  232. self.protocol_helper = ProtocolHelper()
  233. self.state = self.states.RECV_CONFIRMATION
  234. #
  235. #
  236. if self.state is self.states.RECV_CONFIRMATION:
  237. response_size = 8
  238. if self.protocol_helper.recv(self.socket, response_size):
  239. response = self.protocol_helper.get_buffer()
  240. if response != b'RECEIVED':
  241. raise ProtocolException('Did not receive the expected message: {}'.format(response))
  242. #
  243. self.state = self.states.DONE
  244. #
  245. #
  246. if self.state is self.states.DONE:
  247. return True
  248. #
  249. return False
  250. #
  251. #
  252. class PullDataProtocol(Protocol):
  253. def __init__(self, socket, use_acceleration=None):
  254. if use_acceleration is None:
  255. use_acceleration = True
  256. #
  257. self.socket = socket
  258. self.use_acceleration = use_acceleration
  259. #
  260. self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
  261. self.state = self.states.READY_TO_BEGIN
  262. #
  263. self.data_size = None
  264. self.recv_buffer_len = None
  265. self.bytes_read = 0
  266. self.protocol_helper = None
  267. self._time_of_first_byte = None
  268. self.elapsed_time = None
  269. #
  270. def _run_iteration(self, block=True):
  271. if self.state is self.states.READY_TO_BEGIN:
  272. self.protocol_helper = ProtocolHelper()
  273. self.state = self.states.RECV_INFO
  274. #
  275. if self.state is self.states.RECV_INFO:
  276. info_size = 16
  277. if self.protocol_helper.recv(self.socket, info_size):
  278. response = self.protocol_helper.get_buffer()
  279. self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False)
  280. self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False)
  281. self.state = self.states.PULL_DATA
  282. #
  283. #
  284. if self.state is self.states.PULL_DATA:
  285. if self.use_acceleration:
  286. if not block:
  287. logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.')
  288. #
  289. (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len)
  290. if ret_val < 0:
  291. raise ProtocolException('Error while pulling data.')
  292. #
  293. self.bytes_read = self.data_size
  294. self.elapsed_time = elapsed_time
  295. else:
  296. bytes_remaining = self.data_size-self.bytes_read
  297. block_size = min(self.recv_buffer_len, bytes_remaining)
  298. #
  299. data = self.socket.recv(block_size)
  300. #
  301. if len(data) == 0:
  302. raise ProtocolException('The socket was closed.')
  303. #
  304. self.bytes_read += len(data)
  305. #
  306. if self.bytes_read != 0 and self._time_of_first_byte is None:
  307. self._time_of_first_byte = time.time()
  308. #
  309. if self.bytes_read == self.data_size and self.elapsed_time is None:
  310. self.elapsed_time = time.time()-self._time_of_first_byte
  311. #
  312. #
  313. if self.bytes_read == self.data_size:
  314. # finished receiving the data
  315. logging.debug('Finished receiving the data.')
  316. self.protocol_helper = ProtocolHelper()
  317. self.protocol_helper.set_buffer(b'RECEIVED')
  318. self.state = self.states.SEND_CONFIRMATION
  319. #
  320. #
  321. if self.state is self.states.SEND_CONFIRMATION:
  322. if self.protocol_helper.send(self.socket):
  323. self.state = self.states.DONE
  324. #
  325. #
  326. if self.state is self.states.DONE:
  327. return True
  328. #
  329. return False
  330. #
  331. def calc_transfer_rate(self):
  332. """ Returns bytes/s. """
  333. assert self.data_size is not None and self.elapsed_time is not None
  334. return self.data_size/self.elapsed_time
  335. #
  336. #
  337. class SendDataProtocol(Protocol):
  338. def __init__(self, socket, data):
  339. self.socket = socket
  340. self.send_data = data
  341. #
  342. self.states = enum.Enum('SEND_DATA_STATES', 'READY_TO_BEGIN SEND_INFO SEND_DATA RECV_CONFIRMATION DONE')
  343. self.state = self.states.READY_TO_BEGIN
  344. #
  345. self.protocol_helper = None
  346. #
  347. def _run_iteration(self, block=True):
  348. if self.state is self.states.READY_TO_BEGIN:
  349. info_size = 20
  350. info = len(self.send_data).to_bytes(info_size, byteorder='big', signed=False)
  351. self.protocol_helper = ProtocolHelper()
  352. self.protocol_helper.set_buffer(info)
  353. self.state = self.states.SEND_INFO
  354. #
  355. if self.state is self.states.SEND_INFO:
  356. if self.protocol_helper.send(self.socket):
  357. self.protocol_helper = ProtocolHelper()
  358. self.protocol_helper.set_buffer(self.send_data)
  359. self.state = self.states.SEND_DATA
  360. #
  361. #
  362. if self.state is self.states.SEND_DATA:
  363. if self.protocol_helper.send(self.socket):
  364. self.protocol_helper = ProtocolHelper()
  365. self.state = self.states.RECV_CONFIRMATION
  366. #
  367. #
  368. if self.state is self.states.RECV_CONFIRMATION:
  369. response_size = 8
  370. if self.protocol_helper.recv(self.socket, response_size):
  371. response = self.protocol_helper.get_buffer()
  372. if response != b'RECEIVED':
  373. raise ProtocolException('Did not receive the expected message: {}'.format(response))
  374. #
  375. self.state = self.states.DONE
  376. #
  377. #
  378. if self.state is self.states.DONE:
  379. return True
  380. #
  381. return False
  382. #
  383. #
  384. class ReceiveDataProtocol(Protocol):
  385. def __init__(self, socket):
  386. self.socket = socket
  387. #
  388. self.states = enum.Enum('RECV_DATA_STATES', 'READY_TO_BEGIN RECV_INFO RECV_DATA SEND_CONFIRMATION DONE')
  389. self.state = self.states.READY_TO_BEGIN
  390. #
  391. self.protocol_helper = None
  392. self.data_size = None
  393. self.received_data = None
  394. #
  395. def _run_iteration(self, block=True):
  396. if self.state is self.states.READY_TO_BEGIN:
  397. self.protocol_helper = ProtocolHelper()
  398. self.state = self.states.RECV_INFO
  399. #
  400. if self.state is self.states.RECV_INFO:
  401. info_size = 20
  402. if self.protocol_helper.recv(self.socket, info_size):
  403. response = self.protocol_helper.get_buffer()
  404. self.data_size = int.from_bytes(response, byteorder='big', signed=False)
  405. self.protocol_helper = ProtocolHelper()
  406. self.state = self.states.RECV_DATA
  407. #
  408. #
  409. if self.state is self.states.RECV_DATA:
  410. if self.protocol_helper.recv(self.socket, self.data_size):
  411. response = self.protocol_helper.get_buffer()
  412. self.received_data = response
  413. self.protocol_helper = ProtocolHelper()
  414. self.protocol_helper.set_buffer(b'RECEIVED')
  415. self.state = self.states.SEND_CONFIRMATION
  416. #
  417. #
  418. if self.state is self.states.SEND_CONFIRMATION:
  419. if self.protocol_helper.send(self.socket):
  420. self.state = self.states.DONE
  421. #
  422. #
  423. if self.state is self.states.DONE:
  424. return True
  425. #
  426. return False
  427. #
  428. #
  429. class ServerListener():
  430. def __init__(self, endpoint, accept_callback):
  431. self.callback = accept_callback
  432. #
  433. self.s = socket.socket()
  434. self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  435. self.s.bind(endpoint)
  436. self.s.listen(0)
  437. #
  438. def accept(self):
  439. newsock, endpoint = self.s.accept()
  440. logging.debug("New client from %s:%d (fd=%d)",
  441. endpoint[0], endpoint[1], newsock.fileno())
  442. self.callback(newsock)
  443. #
  444. #
  445. class SimpleClientConnectionProtocol(Protocol):
  446. def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None):
  447. self.endpoint = endpoint
  448. self.data_generator = data_generator
  449. self.total_bytes = total_bytes
  450. self.proxy = proxy
  451. self.username = username
  452. #
  453. self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY PUSH_DATA DONE')
  454. self.state = self.states.READY_TO_BEGIN
  455. #
  456. self.socket = socket.socket()
  457. self.sub_protocol = None
  458. #
  459. if self.proxy is None:
  460. logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint)
  461. self.socket.connect(self.endpoint)
  462. else:
  463. logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy)
  464. self.socket.connect(self.proxy)
  465. #
  466. #
  467. def _run_iteration(self, block=True):
  468. if self.state is self.states.READY_TO_BEGIN:
  469. if self.proxy is None:
  470. self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
  471. self.state = self.states.PUSH_DATA
  472. else:
  473. self.sub_protocol = Socks4Protocol(self.socket, self.endpoint, username=self.username)
  474. self.state = self.states.CONNECT_TO_PROXY
  475. #
  476. #
  477. if self.state is self.states.CONNECT_TO_PROXY:
  478. if self.sub_protocol.run(block=block):
  479. self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
  480. self.state = self.states.PUSH_DATA
  481. #
  482. #
  483. if self.state is self.states.PUSH_DATA:
  484. if self.sub_protocol.run(block=block):
  485. self.state = self.states.DONE
  486. #
  487. #
  488. if self.state is self.states.DONE:
  489. return True
  490. #
  491. return False
  492. #
  493. #
  494. class SimpleServerConnectionProtocol(Protocol):
  495. def __init__(self, socket, conn_id, bandwidth_callback=None):
  496. self.socket = socket
  497. self.conn_id = conn_id
  498. self.bandwidth_callback = bandwidth_callback
  499. #
  500. self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN PULL_DATA DONE')
  501. self.state = self.states.READY_TO_BEGIN
  502. #
  503. self.sub_protocol = None
  504. #
  505. def _run_iteration(self, block=True):
  506. if self.state is self.states.READY_TO_BEGIN:
  507. self.sub_protocol = PullDataProtocol(self.socket)
  508. self.state = self.states.PULL_DATA
  509. #
  510. if self.state is self.states.PULL_DATA:
  511. if self.sub_protocol.run(block=block):
  512. if self.bandwidth_callback:
  513. self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate())
  514. #
  515. self.state = self.states.DONE
  516. #
  517. #
  518. if self.state is self.states.DONE:
  519. return True
  520. #
  521. return False
  522. #
  523. #
  524. if __name__ == '__main__':
  525. import sys
  526. logging.basicConfig(level=logging.DEBUG)
  527. #
  528. if sys.argv[1] == 'client':
  529. endpoint = ('127.0.0.1', 4747)
  530. proxy = ('127.0.0.1', 9003)
  531. #proxy = None
  532. username = bytes([x for x in os.urandom(12) if x != 0])
  533. #username = None
  534. data_MB = 40
  535. #
  536. client = SimpleClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username)
  537. client.run()
  538. elif sys.argv[1] == 'server':
  539. import multiprocessing
  540. import queue
  541. #
  542. endpoint = ('127.0.0.1', 4747)
  543. processes = []
  544. conn_counter = [0]
  545. #
  546. def bw_callback(conn_id, data_size, transfer_rate):
  547. logging.info('Avg Transferred (MB): %.4f', data_size/(1024**2))
  548. logging.info('Avg Transfer rate (MB/s): %.4f', transfer_rate/(1024**2))
  549. #
  550. def start_server_conn(socket, conn_id):
  551. server = SimpleServerConnectionProtocol(socket, conn_id, bandwidth_callback=bw_callback)
  552. try:
  553. server.run()
  554. except KeyboardInterrupt:
  555. socket.close()
  556. #
  557. #
  558. def accept_callback(socket):
  559. conn_id = conn_counter[0]
  560. conn_counter[0] += 1
  561. #
  562. p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
  563. processes.append(p)
  564. p.start()
  565. #
  566. l = ServerListener(endpoint, accept_callback)
  567. #
  568. try:
  569. while True:
  570. l.accept()
  571. #
  572. except KeyboardInterrupt:
  573. print()
  574. #
  575. for p in processes:
  576. p.join()
  577. #
  578. #
  579. #