basic_protocols.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. #!/usr/bin/python3
  2. #
  3. import socket
  4. import struct
  5. import logging
  6. import time
  7. import enum
  8. import select
  9. import os
  10. #
  11. class ProtocolException(Exception):
  12. pass
  13. #
  14. class ProtocolHelper():
  15. def __init__(self):
  16. self._buffer = b''
  17. #
  18. def set_buffer(self, data):
  19. self._buffer = data
  20. #
  21. def get_buffer(self):
  22. return self._buffer
  23. #
  24. def recv(self, socket, num_bytes):
  25. data = socket.recv(num_bytes-len(self._buffer))
  26. self._buffer += data
  27. if len(self._buffer) == num_bytes:
  28. return True
  29. #
  30. return False
  31. #
  32. def send(self, socket):
  33. n = socket.send(self._buffer)
  34. self._buffer = self._buffer[n:]
  35. if len(self._buffer) == 0:
  36. return True
  37. #
  38. return False
  39. #
  40. #
  41. class Protocol():
  42. def _run_iteration(self, block=True):
  43. pass
  44. #
  45. def run(self, block=True):
  46. while True:
  47. finished = self._run_iteration(block=block)
  48. #
  49. if finished:
  50. # protocol is done
  51. return True
  52. elif not block:
  53. # not done the protocol yet, but don't block
  54. return False
  55. #
  56. #
  57. #
  58. #
  59. class Socks4Protocol(Protocol):
  60. def __init__(self, socket, addr_port, username=None):
  61. self.socket = socket
  62. self.addr_port = addr_port
  63. self.username = username
  64. #
  65. self.states = enum.Enum('SOCKS_4_STATES', 'READY_TO_BEGIN CONNECTING_TO_PROXY WAITING_FOR_PROXY DONE')
  66. self.state = self.states.READY_TO_BEGIN
  67. #
  68. self.protocol_helper = None
  69. #
  70. def _run_iteration(self, block=True):
  71. if self.state is self.states.READY_TO_BEGIN:
  72. self.protocol_helper = ProtocolHelper()
  73. self.protocol_helper.set_buffer(self.socks_cmd(self.addr_port, self.username))
  74. self.state = self.states.CONNECTING_TO_PROXY
  75. #
  76. if self.state is self.states.CONNECTING_TO_PROXY:
  77. if self.protocol_helper.send(self.socket):
  78. self.protocol_helper = ProtocolHelper()
  79. self.state = self.states.WAITING_FOR_PROXY
  80. #logging.debug('Waiting for reply from proxy')
  81. #
  82. #
  83. if self.state is self.states.WAITING_FOR_PROXY:
  84. response_size = 8
  85. if self.protocol_helper.recv(self.socket, response_size):
  86. response = self.protocol_helper.get_buffer()
  87. if response[1] != 0x5a:
  88. raise ProtocolException('Could not connect to SOCKS proxy, msg: %x'%(response[1],))
  89. #
  90. self.state = self.states.DONE
  91. return True
  92. #
  93. #
  94. return False
  95. #
  96. def socks_cmd(self, addr_port, username=None):
  97. socks_version = 4
  98. command = 1
  99. dnsname = b''
  100. host, port = addr_port
  101. #
  102. try:
  103. username = bytes(username, 'utf8')
  104. except TypeError:
  105. pass
  106. #
  107. if username is None:
  108. username = b''
  109. elif b'\x00' in username:
  110. raise ProtocolException('Username cannot contain a NUL character.')
  111. #
  112. username = username+b'\x00'
  113. #
  114. try:
  115. addr = socket.inet_aton(host)
  116. except socket.error:
  117. addr = b'\x00\x00\x00\x01'
  118. dnsname = bytes(host, 'utf8')+b'\x00'
  119. #
  120. return struct.pack('!BBH', socks_version, command, port) + addr + username + dnsname
  121. #
  122. #
  123. class PushDataProtocol(Protocol):
  124. def __init__(self, socket, total_bytes, data_generator=None, send_max_bytes=1024*512):
  125. if data_generator is None:
  126. data_generator = self._default_data_generator
  127. #
  128. self.socket = socket
  129. self.data_generator = data_generator
  130. self.total_bytes = total_bytes
  131. self.send_max_bytes = send_max_bytes
  132. #
  133. self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE')
  134. self.state = self.states.READY_TO_BEGIN
  135. #
  136. self.bytes_written = 0
  137. self.protocol_helper = None
  138. #
  139. def _run_iteration(self, block=True):
  140. if self.state is self.states.READY_TO_BEGIN:
  141. info = self.total_bytes.to_bytes(8, byteorder='big', signed=False)
  142. info += self.send_max_bytes.to_bytes(8, byteorder='big', signed=False)
  143. self.protocol_helper = ProtocolHelper()
  144. self.protocol_helper.set_buffer(info)
  145. self.state = self.states.SEND_INFO
  146. #
  147. if self.state is self.states.SEND_INFO:
  148. if self.protocol_helper.send(self.socket):
  149. self.state = self.states.PUSH_DATA
  150. #
  151. #
  152. if self.state is self.states.PUSH_DATA:
  153. max_block_size = self.send_max_bytes
  154. bytes_needed = min(max_block_size, self.total_bytes-self.bytes_written)
  155. data = self.data_generator(self.bytes_written, bytes_needed)
  156. n = self.socket.send(data)
  157. self.bytes_written += n
  158. if self.bytes_written >= self.total_bytes:
  159. # finished sending the data
  160. logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
  161. self.protocol_helper = ProtocolHelper()
  162. self.state = self.states.RECV_CONFIRMATION
  163. #
  164. #
  165. if self.state is self.states.RECV_CONFIRMATION:
  166. response_size = 8
  167. if self.protocol_helper.recv(self.socket, response_size):
  168. response = self.protocol_helper.get_buffer()
  169. if response != b'RECEIVED':
  170. raise ProtocolException('Did not receive the expected message: {}'.format(response))
  171. #
  172. self.state = self.states.DONE
  173. return True
  174. #
  175. #
  176. return False
  177. #
  178. def _default_data_generator(self, index, bytes_needed):
  179. return b'0'*bytes_needed
  180. #
  181. #
  182. class PullDataProtocol(Protocol):
  183. def __init__(self, socket):
  184. self.socket = socket
  185. #
  186. self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
  187. self.state = self.states.READY_TO_BEGIN
  188. #
  189. self.data_size = None
  190. self.recv_max_bytes = None
  191. self.bytes_read = 0
  192. self.protocol_helper = None
  193. #
  194. def _run_iteration(self, block=True):
  195. if self.state is self.states.READY_TO_BEGIN:
  196. self.protocol_helper = ProtocolHelper()
  197. self.state = self.states.RECV_INFO
  198. #
  199. if self.state is self.states.RECV_INFO:
  200. info_size = 16
  201. if self.protocol_helper.recv(self.socket, info_size):
  202. response = self.protocol_helper.get_buffer()
  203. self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False)
  204. self.recv_max_bytes = int.from_bytes(response[8:16], byteorder='big', signed=False)
  205. self.state = self.states.PULL_DATA
  206. #
  207. #
  208. if self.state is self.states.PULL_DATA:
  209. max_block_size = self.recv_max_bytes
  210. bytes_needed = min(max_block_size, self.data_size-self.bytes_read)
  211. data = self.socket.recv(bytes_needed)
  212. self.bytes_read += len(data)
  213. #logging.debug('Read %d bytes', self.bytes_read)
  214. if self.bytes_read == self.data_size:
  215. # finished receiving the data
  216. logging.debug('Finished receiving the data.')
  217. self.protocol_helper = ProtocolHelper()
  218. self.protocol_helper.set_buffer(b'RECEIVED')
  219. self.state = self.states.SEND_CONFIRMATION
  220. #
  221. #
  222. if self.state is self.states.SEND_CONFIRMATION:
  223. if self.protocol_helper.send(self.socket):
  224. self.state = self.states.DONE
  225. return True
  226. #
  227. #
  228. return False
  229. #
  230. #
  231. class PullDataProtocolWithMetrics(PullDataProtocol):
  232. def __init__(self, *args, **kwargs):
  233. super().__init__(*args, **kwargs)
  234. #
  235. self.time_of_first_byte = None
  236. self.time_of_last_byte = None
  237. #
  238. def _run_iteration(self, *args, **kwargs):
  239. data = super()._run_iteration(*args, **kwargs)
  240. #
  241. if self.bytes_read != 0 and self.time_of_first_byte is None:
  242. self.time_of_first_byte = time.time()
  243. #
  244. if self.bytes_read == self.data_size and self.time_of_last_byte is None:
  245. self.time_of_last_byte = time.time()
  246. #
  247. return data
  248. #
  249. def calc_transfer_rate(self):
  250. """ Returns bytes/s. """
  251. assert self.data_size is not None and self.time_of_first_byte is not None and self.time_of_last_byte is not None
  252. return self.data_size/(self.time_of_last_byte-self.time_of_first_byte)
  253. #
  254. #
  255. class SendDataProtocol(Protocol):
  256. def __init__(self, socket, data):
  257. self.socket = socket
  258. self.send_data = data
  259. #
  260. self.states = enum.Enum('SEND_DATA_STATES', 'READY_TO_BEGIN SEND_INFO SEND_DATA RECV_CONFIRMATION DONE')
  261. self.state = self.states.READY_TO_BEGIN
  262. #
  263. self.protocol_helper = None
  264. #
  265. def _run_iteration(self, block=True):
  266. if self.state is self.states.READY_TO_BEGIN:
  267. info_size = 20
  268. info = len(self.send_data).to_bytes(info_size, byteorder='big', signed=False)
  269. self.protocol_helper = ProtocolHelper()
  270. self.protocol_helper.set_buffer(info)
  271. self.state = self.states.SEND_INFO
  272. #
  273. if self.state is self.states.SEND_INFO:
  274. if self.protocol_helper.send(self.socket):
  275. self.protocol_helper = ProtocolHelper()
  276. self.protocol_helper.set_buffer(self.send_data)
  277. self.state = self.states.SEND_DATA
  278. #
  279. #
  280. if self.state is self.states.SEND_DATA:
  281. if self.protocol_helper.send(self.socket):
  282. self.protocol_helper = ProtocolHelper()
  283. self.state = self.states.RECV_CONFIRMATION
  284. #
  285. #
  286. if self.state is self.states.RECV_CONFIRMATION:
  287. response_size = 8
  288. if self.protocol_helper.recv(self.socket, response_size):
  289. response = self.protocol_helper.get_buffer()
  290. if response != b'RECEIVED':
  291. raise ProtocolException('Did not receive the expected message: {}'.format(response))
  292. #
  293. self.state = self.states.DONE
  294. return True
  295. #
  296. #
  297. return False
  298. #
  299. #
  300. class ReceiveDataProtocol(Protocol):
  301. def __init__(self, socket):
  302. self.socket = socket
  303. #
  304. self.states = enum.Enum('RECV_DATA_STATES', 'READY_TO_BEGIN RECV_INFO RECV_DATA SEND_CONFIRMATION DONE')
  305. self.state = self.states.READY_TO_BEGIN
  306. #
  307. self.protocol_helper = None
  308. self.data_size = None
  309. self.received_data = None
  310. #
  311. def _run_iteration(self, block=True):
  312. if self.state is self.states.READY_TO_BEGIN:
  313. self.protocol_helper = ProtocolHelper()
  314. self.state = self.states.RECV_INFO
  315. #
  316. if self.state is self.states.RECV_INFO:
  317. info_size = 20
  318. if self.protocol_helper.recv(self.socket, info_size):
  319. response = self.protocol_helper.get_buffer()
  320. self.data_size = int.from_bytes(response, byteorder='big', signed=False)
  321. self.protocol_helper = ProtocolHelper()
  322. self.state = self.states.RECV_DATA
  323. #
  324. #
  325. if self.state is self.states.RECV_DATA:
  326. if self.protocol_helper.recv(self.socket, self.data_size):
  327. response = self.protocol_helper.get_buffer()
  328. self.received_data = response
  329. self.protocol_helper = ProtocolHelper()
  330. self.protocol_helper.set_buffer(b'RECEIVED')
  331. self.state = self.states.SEND_CONFIRMATION
  332. #
  333. #
  334. if self.state is self.states.SEND_CONFIRMATION:
  335. if self.protocol_helper.send(self.socket):
  336. self.state = self.states.DONE
  337. return True
  338. #
  339. #
  340. return False
  341. #
  342. #
  343. class ServerListener():
  344. "A TCP listener, binding, listening and accepting new connections."
  345. def __init__(self, endpoint, accept_callback):
  346. self.callback = accept_callback
  347. #
  348. self.s = socket.socket()
  349. self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  350. self.s.bind(endpoint)
  351. self.s.listen(0)
  352. #
  353. def accept(self):
  354. newsock, endpoint = self.s.accept()
  355. logging.debug("New client from %s:%d (fd=%d)",
  356. endpoint[0], endpoint[1], newsock.fileno())
  357. self.callback(newsock)
  358. #
  359. #
  360. class SimpleClientConnectionProtocol(Protocol):
  361. def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None):
  362. self.endpoint = endpoint
  363. self.data_generator = data_generator
  364. self.total_bytes = total_bytes
  365. self.proxy = proxy
  366. self.username = username
  367. #
  368. self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY PUSH_DATA DONE')
  369. self.state = self.states.READY_TO_BEGIN
  370. #
  371. self.socket = socket.socket()
  372. self.sub_protocol = None
  373. #
  374. if self.proxy is None:
  375. logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint)
  376. self.socket.connect(self.endpoint)
  377. else:
  378. logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy)
  379. self.socket.connect(self.proxy)
  380. #
  381. #
  382. def _run_iteration(self, block=True):
  383. if self.state is self.states.READY_TO_BEGIN:
  384. if self.proxy is None:
  385. self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
  386. self.state = self.states.PUSH_DATA
  387. else:
  388. self.sub_protocol = Socks4Protocol(self.socket, self.endpoint, username=self.username)
  389. self.state = self.states.CONNECT_TO_PROXY
  390. #
  391. #
  392. if self.state is self.states.CONNECT_TO_PROXY:
  393. if self.sub_protocol.run(block=block):
  394. self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
  395. self.state = self.states.PUSH_DATA
  396. #
  397. #
  398. if self.state is self.states.PUSH_DATA:
  399. if self.sub_protocol.run(block=block):
  400. self.state = self.states.DONE
  401. return True
  402. #
  403. #
  404. return False
  405. #
  406. #
  407. class SimpleServerConnectionProtocol(Protocol):
  408. def __init__(self, socket, conn_id, bandwidth_callback=None):
  409. self.socket = socket
  410. self.conn_id = conn_id
  411. self.bandwidth_callback = bandwidth_callback
  412. #
  413. self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN PULL_DATA DONE')
  414. self.state = self.states.READY_TO_BEGIN
  415. #
  416. self.sub_protocol = None
  417. #
  418. def _run_iteration(self, block=True):
  419. if self.state is self.states.READY_TO_BEGIN:
  420. self.sub_protocol = PullDataProtocolWithMetrics(self.socket)
  421. self.state = self.states.PULL_DATA
  422. #
  423. if self.state is self.states.PULL_DATA:
  424. if self.sub_protocol.run(block=block):
  425. self.state = self.states.DONE
  426. if self.bandwidth_callback:
  427. self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate())
  428. #
  429. return True
  430. #
  431. #
  432. return False
  433. #
  434. #
  435. if __name__ == '__main__':
  436. import sys
  437. logging.basicConfig(level=logging.DEBUG)
  438. #
  439. if sys.argv[1] == 'client':
  440. endpoint = ('127.0.0.1', 4747)
  441. proxy = ('127.0.0.1', 9003)
  442. #proxy = None
  443. username = bytes([x for x in os.urandom(12) if x != 0])
  444. #username = None
  445. data_MB = 40
  446. #
  447. client = SimpleClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username)
  448. client.run()
  449. elif sys.argv[1] == 'server':
  450. import multiprocessing
  451. import queue
  452. #
  453. endpoint = ('127.0.0.1', 4747)
  454. processes = []
  455. conn_counter = [0]
  456. #
  457. def bw_callback(conn_id, data_size, transfer_rate):
  458. logging.info('Avg Transferred (MB): %.4f', data_size/(1024**2))
  459. logging.info('Avg Transfer rate (MB/s): %.4f', transfer_rate/(1024**2))
  460. #
  461. def start_server_conn(socket, conn_id):
  462. server = SimpleServerConnectionProtocol(socket, conn_id, bandwidth_callback=bw_callback)
  463. try:
  464. server.run()
  465. except KeyboardInterrupt:
  466. socket.close()
  467. #
  468. #
  469. def accept_callback(socket):
  470. conn_id = conn_counter[0]
  471. conn_counter[0] += 1
  472. #
  473. p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
  474. processes.append(p)
  475. p.start()
  476. #
  477. l = ServerListener(endpoint, accept_callback)
  478. #
  479. try:
  480. while True:
  481. l.accept()
  482. #
  483. except KeyboardInterrupt:
  484. print()
  485. #
  486. for p in processes:
  487. p.join()
  488. #
  489. #
  490. #