123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666 |
- #!/usr/bin/python3
- #
- import socket
- import struct
- import logging
- import time
- import enum
- import select
- import os
- #
- import accelerated_functions
- #
- class ProtocolException(Exception):
- pass
- #
- class ProtocolHelper():
- def __init__(self):
- self._buffer = b''
- #
- def set_buffer(self, data):
- """
- Set the buffer contents to the data that you wish to send.
- """
- #
- self._buffer = data
- #
- def get_buffer(self):
- return self._buffer
- #
- def recv(self, socket, num_bytes):
- """
- Try to fill up the buffer to a max of 'num_bytes'. If the buffer is filled,
- return True, otherwise return False.
- """
- #
- data = socket.recv(num_bytes-len(self._buffer))
- #
- if len(data) == 0:
- raise ProtocolException('The socket was closed.')
- #
- self._buffer += data
- if len(self._buffer) == num_bytes:
- return True
- #
- return False
- #
- def send(self, socket):
- """
- Try to send the remainder of the buffer. If the entire buffer has been sent,
- return True, otherwise return False.
- """
- #
- n = socket.send(self._buffer)
- self._buffer = self._buffer[n:]
- if len(self._buffer) == 0:
- return True
- #
- return False
- #
- #
- class Protocol():
- def _run_iteration(self):
- """
- This function should be overridden. It runs a single iteration of the protocol.
- """
- #
- pass
- #
- def run(self):
- while True:
- finished = self._run_iteration()
- #
- if finished:
- # protocol is done
- return True
- #
- #
- #
- def get_desc(self):
- """
- This function can be overridden.
- """
- #
- return None
- #
- #
- class FakeProxyProtocol(Protocol):
- def __init__(self, socket, addr_port):
- self.socket = socket
- self.addr_port = addr_port
- #
- self.states = enum.Enum('PROXY_STATES', 'READY_TO_BEGIN CONNECTING_TO_PROXY DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.protocol_helper = None
- #
- def _run_iteration(self):
- if self.state is self.states.READY_TO_BEGIN:
- self.protocol_helper = ProtocolHelper()
- host, port = self.addr_port
- addr = socket.inet_aton(host)[::-1]
- self.protocol_helper.set_buffer(addr+struct.pack('!H', port))
- self.state = self.states.CONNECTING_TO_PROXY
- #
- if self.state is self.states.CONNECTING_TO_PROXY:
- if self.protocol_helper.send(self.socket):
- self.protocol_helper = ProtocolHelper()
- self.state = self.states.DONE
- #
- #
- if self.state is self.states.DONE:
- return True
- #
- return False
- #
- #
- class ChainedProtocol(Protocol):
- def __init__(self, protocols):
- self.protocols = protocols
- self.current_protocol = 0
- #
- self.states = enum.Enum('CHAIN_STATES', 'READY_TO_BEGIN RUNNING DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- def _run_iteration(self):
- if self.state is self.states.READY_TO_BEGIN:
- self.state = self.states.RUNNING
- #
- if self.state is self.states.RUNNING:
- if self.protocols[self.current_protocol] is None or self.protocols[self.current_protocol].run():
- self.current_protocol += 1
- #
- if self.current_protocol >= len(self.protocols):
- self.state = self.states.DONE
- #
- #
- if self.state is self.states.DONE:
- return True
- #
- return False
- #
- #
- class Socks4Protocol(Protocol):
- def __init__(self, socket, addr_port, username=None):
- self.socket = socket
- self.addr_port = addr_port
- self.username = username
- #
- self.states = enum.Enum('SOCKS_4_STATES', 'READY_TO_BEGIN CONNECTING_TO_PROXY WAITING_FOR_PROXY DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.protocol_helper = None
- #
- def _run_iteration(self):
- if self.state is self.states.READY_TO_BEGIN:
- self.protocol_helper = ProtocolHelper()
- self.protocol_helper.set_buffer(self.socks_cmd(self.addr_port, self.username))
- self.state = self.states.CONNECTING_TO_PROXY
- #
- if self.state is self.states.CONNECTING_TO_PROXY:
- if self.protocol_helper.send(self.socket):
- self.protocol_helper = ProtocolHelper()
- self.state = self.states.WAITING_FOR_PROXY
- #logging.debug('Waiting for reply from proxy')
- #
- #
- if self.state is self.states.WAITING_FOR_PROXY:
- response_size = 8
- if self.protocol_helper.recv(self.socket, response_size):
- response = self.protocol_helper.get_buffer()
- if response[1] != 0x5a:
- raise ProtocolException('Could not connect to SOCKS proxy, msg: %x'%(response[1],))
- #
- self.state = self.states.DONE
- #
- #
- if self.state is self.states.DONE:
- return True
- #
- return False
- #
- def socks_cmd(self, addr_port, username=None):
- socks_version = 4
- command = 1
- dnsname = b''
- host, port = addr_port
- #
- try:
- username = bytes(username, 'utf8')
- except TypeError:
- pass
- #
- if username is None:
- username = b''
- elif b'\x00' in username:
- raise ProtocolException('Username cannot contain a NUL character.')
- #
- username = username+b'\x00'
- #
- try:
- addr = socket.inet_aton(host)
- except socket.error:
- addr = b'\x00\x00\x00\x01'
- dnsname = bytes(host, 'utf8')+b'\x00'
- #
- return struct.pack('!BBH', socks_version, command, port) + addr + username + dnsname
- #
- #
- class PushDataProtocol(Protocol):
- def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None, push_start_cb=None, push_done_cb=None):
- if send_buffer_len is None:
- send_buffer_len = 1024*512
- #
- if use_acceleration is None:
- use_acceleration = True
- #
- self.socket = socket
- self.total_bytes = total_bytes
- self.use_acceleration = use_acceleration
- self.push_start_cb = push_start_cb
- self.push_done_cb = push_done_cb
- #
- self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO START_CALLBACK PUSH_DATA RECV_CONFIRMATION DONE_CALLBACK DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.byte_buffer = os.urandom(send_buffer_len)
- self.bytes_written = 0
- self.time_started_push = None
- self.protocol_helper = None
- #
- def _run_iteration(self):
- if self.state is self.states.READY_TO_BEGIN:
- info = self.total_bytes.to_bytes(8, byteorder='big', signed=False)
- info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False)
- self.protocol_helper = ProtocolHelper()
- self.protocol_helper.set_buffer(info)
- self.state = self.states.SEND_INFO
- #
- if self.state is self.states.SEND_INFO:
- if self.protocol_helper.send(self.socket):
- self.state = self.states.START_CALLBACK
- #
- #
- if self.state is self.states.START_CALLBACK:
- if self.push_start_cb is not None:
- self.push_start_cb()
- #
- self.state = self.states.PUSH_DATA
- self.time_started_push = time.time()
- #
- if self.state is self.states.PUSH_DATA:
- if self.use_acceleration:
- ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer)
- if ret_val < 0:
- raise ProtocolException('Error while pushing data.')
- #
- self.bytes_written = self.total_bytes
- else:
- bytes_remaining = self.total_bytes-self.bytes_written
- data_size = min(len(self.byte_buffer), bytes_remaining)
- if data_size != len(self.byte_buffer):
- data = self.byte_buffer[:data_size]
- else:
- data = self.byte_buffer
- # don't make a copy of the byte string each time if we don't need to
- #
- n = self.socket.send(data)
- self.bytes_written += n
- #
- if self.bytes_written == self.total_bytes:
- # finished sending the data
- logging.debug('Finished sending the data (%d bytes).', self.bytes_written)
- self.protocol_helper = ProtocolHelper()
- self.state = self.states.RECV_CONFIRMATION
- #
- #
- if self.state is self.states.RECV_CONFIRMATION:
- response_size = 8
- if self.protocol_helper.recv(self.socket, response_size):
- response = self.protocol_helper.get_buffer()
- if response != b'RECEIVED':
- raise ProtocolException('Did not receive the expected message: {}'.format(response))
- #
- self.state = self.states.DONE_CALLBACK
- #
- #
- if self.state is self.states.DONE_CALLBACK:
- if self.push_done_cb is not None:
- self.push_done_cb()
- #
- self.state = self.states.DONE
- #
- if self.state is self.states.DONE:
- return True
- #
- return False
- #
- #
- class PullDataProtocol(Protocol):
- def __init__(self, socket, use_acceleration=None):
- if use_acceleration is None:
- use_acceleration = True
- #
- self.socket = socket
- self.use_acceleration = use_acceleration
- #
- self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.data_size = None
- self.recv_buffer_len = None
- self.bytes_read = 0
- self.protocol_helper = None
- self.time_of_first_byte = None
- self.time_of_last_byte = None
- #self.byte_counter = None
- #self.byte_counter_start_time = None
- self.deltas = None
- #
- def _run_iteration(self):
- if self.state is self.states.READY_TO_BEGIN:
- self.protocol_helper = ProtocolHelper()
- self.state = self.states.RECV_INFO
- #
- if self.state is self.states.RECV_INFO:
- info_size = 16
- if self.protocol_helper.recv(self.socket, info_size):
- response = self.protocol_helper.get_buffer()
- self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False)
- self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False)
- assert(self.recv_buffer_len <= 10*1024*1024)
- # don't use a buffer size larget than 10 MiB to avoid using up all memory
- self.state = self.states.PULL_DATA
- #
- #
- if self.state is self.states.PULL_DATA:
- if self.use_acceleration:
- #(ret_val, time_of_first_byte, time_of_last_byte, byte_counter, byte_counter_start_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len)
- (ret_val, time_of_first_byte, time_of_last_byte, deltas) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len)
- if ret_val < 0:
- raise ProtocolException('Error while pulling data.')
- #
- #if sum(byte_counter) != self.data_size:
- if sum(deltas['bytes']) != self.data_size:
- logging.warning('Lost some history data ({} != {}).'.format(sum(deltas['bytes']), self.data_size))
- #
- self.bytes_read = self.data_size
- self.time_of_first_byte = time_of_first_byte
- self.time_of_last_byte = time_of_last_byte
- #self.byte_counter = byte_counter
- #self.byte_counter_start_time = byte_counter_start_time
- self.deltas = deltas
- else:
- bytes_remaining = self.data_size-self.bytes_read
- block_size = min(self.recv_buffer_len, bytes_remaining)
- #
- data = self.socket.recv(block_size)
- #
- if len(data) == 0:
- raise ProtocolException('The socket was closed.')
- #
- self.bytes_read += len(data)
- #
- if self.bytes_read != 0 and self.time_of_first_byte is None:
- self.time_of_first_byte = time.time()
- #
- if self.bytes_read == self.data_size and self.time_of_last_byte is None:
- self.time_of_last_byte = time.time()
- #
- #
- if self.bytes_read == self.data_size:
- # finished receiving the data
- logging.debug('Finished receiving the data.')
- self.protocol_helper = ProtocolHelper()
- self.protocol_helper.set_buffer(b'RECEIVED')
- self.state = self.states.SEND_CONFIRMATION
- #
- #
- if self.state is self.states.SEND_CONFIRMATION:
- if self.protocol_helper.send(self.socket):
- self.state = self.states.DONE
- #
- #
- if self.state is self.states.DONE:
- return True
- #
- return False
- #
- def calc_transfer_rate(self):
- """ Returns bytes/s. """
- assert self.data_size is not None and self.time_of_first_byte is not None and self.time_of_last_byte is not None
- try:
- return self.data_size/(self.time_of_last_byte-self.time_of_first_byte)
- except ZeroDivisionError:
- return float('nan')
- #
- #
- #
- class SendDataProtocol(Protocol):
- def __init__(self, socket, data):
- self.socket = socket
- self.send_data = data
- #
- self.states = enum.Enum('SEND_DATA_STATES', 'READY_TO_BEGIN SEND_INFO SEND_DATA RECV_CONFIRMATION DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.protocol_helper = None
- #
- def _run_iteration(self):
- if self.state is self.states.READY_TO_BEGIN:
- info_size = 20
- info = len(self.send_data).to_bytes(info_size, byteorder='big', signed=False)
- self.protocol_helper = ProtocolHelper()
- self.protocol_helper.set_buffer(info)
- self.state = self.states.SEND_INFO
- #
- if self.state is self.states.SEND_INFO:
- if self.protocol_helper.send(self.socket):
- self.protocol_helper = ProtocolHelper()
- if len(self.send_data) > 0:
- self.protocol_helper.set_buffer(self.send_data)
- self.state = self.states.SEND_DATA
- else:
- self.state = self.states.RECV_CONFIRMATION
- #
- #
- #
- if self.state is self.states.SEND_DATA:
- if self.protocol_helper.send(self.socket):
- self.protocol_helper = ProtocolHelper()
- self.state = self.states.RECV_CONFIRMATION
- #
- #
- if self.state is self.states.RECV_CONFIRMATION:
- response_size = 8
- if self.protocol_helper.recv(self.socket, response_size):
- response = self.protocol_helper.get_buffer()
- if response != b'RECEIVED':
- raise ProtocolException('Did not receive the expected message: {}'.format(response))
- #
- self.state = self.states.DONE
- #
- #
- if self.state is self.states.DONE:
- return True
- #
- return False
- #
- #
- class ReceiveDataProtocol(Protocol):
- def __init__(self, socket):
- self.socket = socket
- #
- self.states = enum.Enum('RECV_DATA_STATES', 'READY_TO_BEGIN RECV_INFO RECV_DATA SEND_CONFIRMATION DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.protocol_helper = None
- self.data_size = None
- self.received_data = None
- #
- def _run_iteration(self):
- if self.state is self.states.READY_TO_BEGIN:
- self.protocol_helper = ProtocolHelper()
- self.state = self.states.RECV_INFO
- #
- if self.state is self.states.RECV_INFO:
- info_size = 20
- if self.protocol_helper.recv(self.socket, info_size):
- response = self.protocol_helper.get_buffer()
- self.data_size = int.from_bytes(response, byteorder='big', signed=False)
- self.protocol_helper = ProtocolHelper()
- if self.data_size > 0:
- self.state = self.states.RECV_DATA
- else:
- self.received_data = b''
- self.protocol_helper.set_buffer(b'RECEIVED')
- self.state = self.states.SEND_CONFIRMATION
- #
- #
- #
- if self.state is self.states.RECV_DATA:
- if self.protocol_helper.recv(self.socket, self.data_size):
- response = self.protocol_helper.get_buffer()
- self.received_data = response
- self.protocol_helper = ProtocolHelper()
- self.protocol_helper.set_buffer(b'RECEIVED')
- self.state = self.states.SEND_CONFIRMATION
- #
- #
- if self.state is self.states.SEND_CONFIRMATION:
- if self.protocol_helper.send(self.socket):
- self.state = self.states.DONE
- #
- #
- if self.state is self.states.DONE:
- return True
- #
- return False
- #
- #
- class ServerListener():
- def __init__(self, bind_endpoint, accept_callback):
- self.callback = accept_callback
- #
- self.s = socket.socket()
- self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.s.setblocking(0)
- self.s.bind(bind_endpoint)
- self.s.listen(1000)
- #
- def accept(self, block=True):
- if block:
- (readable, _, _) = select.select([self.s], [], [])
- else:
- readable = [self.s]
- #
- try:
- (newsock, endpoint) = self.s.accept()
- logging.debug("New client from %s:%d (fd=%d)",
- endpoint[0], endpoint[1], newsock.fileno())
- self.callback(newsock)
- return True
- except BlockingIOError:
- return False
- #
- #
- def stop(self):
- self.s.shutdown(socket.SHUT_RDWR)
- # use 'shutdown' rather than 'close' since 'close' won't stop a blocking 'accept' call
- #
- #
- class SimpleClientConnectionProtocol(Protocol):
- def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None):
- self.endpoint = endpoint
- self.data_generator = data_generator
- self.total_bytes = total_bytes
- self.proxy = proxy
- self.username = username
- #
- self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY PUSH_DATA DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.socket = socket.socket()
- self.sub_protocol = None
- #
- if self.proxy is None:
- logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint)
- self.socket.connect(self.endpoint)
- else:
- logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy)
- self.socket.connect(self.proxy)
- #
- #
- def _run_iteration(self):
- if self.state is self.states.READY_TO_BEGIN:
- if self.proxy is None:
- self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
- self.state = self.states.PUSH_DATA
- else:
- self.sub_protocol = Socks4Protocol(self.socket, self.endpoint, username=self.username)
- self.state = self.states.CONNECT_TO_PROXY
- #
- #
- if self.state is self.states.CONNECT_TO_PROXY:
- if self.sub_protocol.run():
- self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator)
- self.state = self.states.PUSH_DATA
- #
- #
- if self.state is self.states.PUSH_DATA:
- if self.sub_protocol.run():
- self.state = self.states.DONE
- #
- #
- if self.state is self.states.DONE:
- return True
- #
- return False
- #
- #
- class SimpleServerConnectionProtocol(Protocol):
- def __init__(self, socket, conn_id, bandwidth_callback=None):
- self.socket = socket
- self.conn_id = conn_id
- self.bandwidth_callback = bandwidth_callback
- #
- self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN PULL_DATA DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.sub_protocol = None
- #
- def _run_iteration(self):
- if self.state is self.states.READY_TO_BEGIN:
- self.sub_protocol = PullDataProtocol(self.socket)
- self.state = self.states.PULL_DATA
- #
- if self.state is self.states.PULL_DATA:
- if self.sub_protocol.run():
- if self.bandwidth_callback:
- self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate())
- #
- self.state = self.states.DONE
- #
- #
- if self.state is self.states.DONE:
- return True
- #
- return False
- #
- #
- if __name__ == '__main__':
- import sys
- logging.basicConfig(level=logging.DEBUG)
- #
- if sys.argv[1] == 'client':
- endpoint = ('127.0.0.1', 4747)
- #proxy = ('127.0.0.1', 9003)
- proxy = None
- username = bytes([x for x in os.urandom(12) if x != 0])
- #username = None
- data_MB = 4000
- #
- client = SimpleClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username)
- client.run()
- elif sys.argv[1] == 'server':
- import multiprocessing
- import queue
- #
- endpoint = ('127.0.0.1', 4747)
- processes = []
- conn_counter = [0]
- #
- def bw_callback(conn_id, data_size, transfer_rate):
- logging.info('Avg Transferred (MB): %.4f', data_size/(1024**2))
- logging.info('Avg Transfer rate (MB/s): %.4f', transfer_rate/(1024**2))
- #
- def start_server_conn(socket, conn_id):
- server = SimpleServerConnectionProtocol(socket, conn_id, bandwidth_callback=bw_callback)
- try:
- server.run()
- except KeyboardInterrupt:
- socket.close()
- #
- #
- def accept_callback(socket):
- conn_id = conn_counter[0]
- conn_counter[0] += 1
- #
- p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
- processes.append(p)
- p.start()
- #
- l = ServerListener(endpoint, accept_callback)
- #
- try:
- while True:
- l.accept()
- #
- except KeyboardInterrupt:
- print()
- #
- for p in processes:
- p.join()
- #
- #
- #
|