#!/usr/bin/python3 # import socket import struct import logging import time import enum import select import os # import accelerated_functions # class ProtocolException(Exception): pass # class ProtocolHelper(): def __init__(self): self._buffer = b'' # def set_buffer(self, data): """ Set the buffer contents to the data that you wish to send. """ # self._buffer = data # def get_buffer(self): return self._buffer # def recv(self, socket, num_bytes): """ Try to fill up the buffer to a max of 'num_bytes'. If the buffer is filled, return True, otherwise return False. """ # data = socket.recv(num_bytes-len(self._buffer)) # if len(data) == 0: raise ProtocolException('The socket was closed.') # self._buffer += data if len(self._buffer) == num_bytes: return True # return False # def send(self, socket): """ Try to send the remainder of the buffer. If the entire buffer has been sent, return True, otherwise return False. """ # n = socket.send(self._buffer) self._buffer = self._buffer[n:] if len(self._buffer) == 0: return True # return False # # class Protocol(): def _run_iteration(self): """ This function should be overridden. It runs a single iteration of the protocol. """ # pass # def run(self): while True: finished = self._run_iteration() # if finished: # protocol is done return True # # # def get_desc(self): """ This function can be overridden. """ # return None # # class FakeProxyProtocol(Protocol): def __init__(self, socket, addr_port): self.socket = socket self.addr_port = addr_port # self.states = enum.Enum('PROXY_STATES', 'READY_TO_BEGIN CONNECTING_TO_PROXY DONE') self.state = self.states.READY_TO_BEGIN # self.protocol_helper = None # def _run_iteration(self): if self.state is self.states.READY_TO_BEGIN: self.protocol_helper = ProtocolHelper() host, port = self.addr_port addr = socket.inet_aton(host)[::-1] self.protocol_helper.set_buffer(addr+struct.pack('!H', port)) self.state = self.states.CONNECTING_TO_PROXY # if self.state is self.states.CONNECTING_TO_PROXY: if self.protocol_helper.send(self.socket): self.protocol_helper = ProtocolHelper() self.state = self.states.DONE # # if self.state is self.states.DONE: return True # return False # # class ChainedProtocol(Protocol): def __init__(self, protocols): self.protocols = protocols self.current_protocol = 0 # self.states = enum.Enum('CHAIN_STATES', 'READY_TO_BEGIN RUNNING DONE') self.state = self.states.READY_TO_BEGIN # def _run_iteration(self): if self.state is self.states.READY_TO_BEGIN: self.state = self.states.RUNNING # if self.state is self.states.RUNNING: if self.protocols[self.current_protocol] is None or self.protocols[self.current_protocol].run(): self.current_protocol += 1 # if self.current_protocol >= len(self.protocols): self.state = self.states.DONE # # if self.state is self.states.DONE: return True # return False # # class Socks4Protocol(Protocol): def __init__(self, socket, addr_port, username=None): self.socket = socket self.addr_port = addr_port self.username = username # self.states = enum.Enum('SOCKS_4_STATES', 'READY_TO_BEGIN CONNECTING_TO_PROXY WAITING_FOR_PROXY DONE') self.state = self.states.READY_TO_BEGIN # self.protocol_helper = None # def _run_iteration(self): if self.state is self.states.READY_TO_BEGIN: self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(self.socks_cmd(self.addr_port, self.username)) self.state = self.states.CONNECTING_TO_PROXY # if self.state is self.states.CONNECTING_TO_PROXY: if self.protocol_helper.send(self.socket): self.protocol_helper = ProtocolHelper() self.state = self.states.WAITING_FOR_PROXY #logging.debug('Waiting for reply from proxy') # # if self.state is self.states.WAITING_FOR_PROXY: response_size = 8 if self.protocol_helper.recv(self.socket, response_size): response = self.protocol_helper.get_buffer() if response[1] != 0x5a: raise ProtocolException('Could not connect to SOCKS proxy, msg: %x'%(response[1],)) # self.state = self.states.DONE # # if self.state is self.states.DONE: return True # return False # def socks_cmd(self, addr_port, username=None): socks_version = 4 command = 1 dnsname = b'' host, port = addr_port # try: username = bytes(username, 'utf8') except TypeError: pass # if username is None: username = b'' elif b'\x00' in username: raise ProtocolException('Username cannot contain a NUL character.') # username = username+b'\x00' # try: addr = socket.inet_aton(host) except socket.error: addr = b'\x00\x00\x00\x01' dnsname = bytes(host, 'utf8')+b'\x00' # return struct.pack('!BBH', socks_version, command, port) + addr + username + dnsname # # class PushDataProtocol(Protocol): def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None, push_start_cb=None, push_done_cb=None): if send_buffer_len is None: send_buffer_len = 1024*512 # if use_acceleration is None: use_acceleration = True # self.socket = socket self.total_bytes = total_bytes self.use_acceleration = use_acceleration self.push_start_cb = push_start_cb self.push_done_cb = push_done_cb # self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO START_CALLBACK PUSH_DATA RECV_CONFIRMATION DONE_CALLBACK DONE') self.state = self.states.READY_TO_BEGIN # self.byte_buffer = os.urandom(send_buffer_len) self.bytes_written = 0 self.time_started_push = None self.protocol_helper = None # def _run_iteration(self): if self.state is self.states.READY_TO_BEGIN: info = self.total_bytes.to_bytes(8, byteorder='big', signed=False) info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False) self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(info) self.state = self.states.SEND_INFO # if self.state is self.states.SEND_INFO: if self.protocol_helper.send(self.socket): self.state = self.states.START_CALLBACK # # if self.state is self.states.START_CALLBACK: if self.push_start_cb is not None: self.push_start_cb() # self.state = self.states.PUSH_DATA self.time_started_push = time.time() # if self.state is self.states.PUSH_DATA: if self.use_acceleration: ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer) if ret_val < 0: raise ProtocolException('Error while pushing data.') # self.bytes_written = self.total_bytes else: bytes_remaining = self.total_bytes-self.bytes_written data_size = min(len(self.byte_buffer), bytes_remaining) if data_size != len(self.byte_buffer): data = self.byte_buffer[:data_size] else: data = self.byte_buffer # don't make a copy of the byte string each time if we don't need to # n = self.socket.send(data) self.bytes_written += n # if self.bytes_written == self.total_bytes: # finished sending the data logging.debug('Finished sending the data (%d bytes).', self.bytes_written) self.protocol_helper = ProtocolHelper() self.state = self.states.RECV_CONFIRMATION # # if self.state is self.states.RECV_CONFIRMATION: response_size = 8 if self.protocol_helper.recv(self.socket, response_size): response = self.protocol_helper.get_buffer() if response != b'RECEIVED': raise ProtocolException('Did not receive the expected message: {}'.format(response)) # self.state = self.states.DONE_CALLBACK # # if self.state is self.states.DONE_CALLBACK: if self.push_done_cb is not None: self.push_done_cb() # self.state = self.states.DONE # if self.state is self.states.DONE: return True # return False # # class PullDataProtocol(Protocol): def __init__(self, socket, use_acceleration=None): if use_acceleration is None: use_acceleration = True # self.socket = socket self.use_acceleration = use_acceleration # self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE') self.state = self.states.READY_TO_BEGIN # self.data_size = None self.recv_buffer_len = None self.bytes_read = 0 self.protocol_helper = None self.time_of_first_byte = None self.time_of_last_byte = None #self.byte_counter = None #self.byte_counter_start_time = None self.deltas = None # def _run_iteration(self): if self.state is self.states.READY_TO_BEGIN: self.protocol_helper = ProtocolHelper() self.state = self.states.RECV_INFO # if self.state is self.states.RECV_INFO: info_size = 16 if self.protocol_helper.recv(self.socket, info_size): response = self.protocol_helper.get_buffer() self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False) self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False) assert(self.recv_buffer_len <= 10*1024*1024) # don't use a buffer size larget than 10 MiB to avoid using up all memory self.state = self.states.PULL_DATA # # if self.state is self.states.PULL_DATA: if self.use_acceleration: #(ret_val, time_of_first_byte, time_of_last_byte, byte_counter, byte_counter_start_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len) (ret_val, time_of_first_byte, time_of_last_byte, deltas) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len) if ret_val < 0: raise ProtocolException('Error while pulling data.') # #if sum(byte_counter) != self.data_size: if sum(deltas['bytes']) != self.data_size: logging.warning('Lost some history data ({} != {}).'.format(sum(deltas['bytes']), self.data_size)) # self.bytes_read = self.data_size self.time_of_first_byte = time_of_first_byte self.time_of_last_byte = time_of_last_byte #self.byte_counter = byte_counter #self.byte_counter_start_time = byte_counter_start_time self.deltas = deltas else: bytes_remaining = self.data_size-self.bytes_read block_size = min(self.recv_buffer_len, bytes_remaining) # data = self.socket.recv(block_size) # if len(data) == 0: raise ProtocolException('The socket was closed.') # self.bytes_read += len(data) # if self.bytes_read != 0 and self.time_of_first_byte is None: self.time_of_first_byte = time.time() # if self.bytes_read == self.data_size and self.time_of_last_byte is None: self.time_of_last_byte = time.time() # # if self.bytes_read == self.data_size: # finished receiving the data logging.debug('Finished receiving the data.') self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(b'RECEIVED') self.state = self.states.SEND_CONFIRMATION # # if self.state is self.states.SEND_CONFIRMATION: if self.protocol_helper.send(self.socket): self.state = self.states.DONE # # if self.state is self.states.DONE: return True # return False # def calc_transfer_rate(self): """ Returns bytes/s. """ assert self.data_size is not None and self.time_of_first_byte is not None and self.time_of_last_byte is not None try: return self.data_size/(self.time_of_last_byte-self.time_of_first_byte) except ZeroDivisionError: return float('nan') # # # class SendDataProtocol(Protocol): def __init__(self, socket, data): self.socket = socket self.send_data = data # self.states = enum.Enum('SEND_DATA_STATES', 'READY_TO_BEGIN SEND_INFO SEND_DATA RECV_CONFIRMATION DONE') self.state = self.states.READY_TO_BEGIN # self.protocol_helper = None # def _run_iteration(self): if self.state is self.states.READY_TO_BEGIN: info_size = 20 info = len(self.send_data).to_bytes(info_size, byteorder='big', signed=False) self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(info) self.state = self.states.SEND_INFO # if self.state is self.states.SEND_INFO: if self.protocol_helper.send(self.socket): self.protocol_helper = ProtocolHelper() if len(self.send_data) > 0: self.protocol_helper.set_buffer(self.send_data) self.state = self.states.SEND_DATA else: self.state = self.states.RECV_CONFIRMATION # # # if self.state is self.states.SEND_DATA: if self.protocol_helper.send(self.socket): self.protocol_helper = ProtocolHelper() self.state = self.states.RECV_CONFIRMATION # # if self.state is self.states.RECV_CONFIRMATION: response_size = 8 if self.protocol_helper.recv(self.socket, response_size): response = self.protocol_helper.get_buffer() if response != b'RECEIVED': raise ProtocolException('Did not receive the expected message: {}'.format(response)) # self.state = self.states.DONE # # if self.state is self.states.DONE: return True # return False # # class ReceiveDataProtocol(Protocol): def __init__(self, socket): self.socket = socket # self.states = enum.Enum('RECV_DATA_STATES', 'READY_TO_BEGIN RECV_INFO RECV_DATA SEND_CONFIRMATION DONE') self.state = self.states.READY_TO_BEGIN # self.protocol_helper = None self.data_size = None self.received_data = None # def _run_iteration(self): if self.state is self.states.READY_TO_BEGIN: self.protocol_helper = ProtocolHelper() self.state = self.states.RECV_INFO # if self.state is self.states.RECV_INFO: info_size = 20 if self.protocol_helper.recv(self.socket, info_size): response = self.protocol_helper.get_buffer() self.data_size = int.from_bytes(response, byteorder='big', signed=False) self.protocol_helper = ProtocolHelper() if self.data_size > 0: self.state = self.states.RECV_DATA else: self.received_data = b'' self.protocol_helper.set_buffer(b'RECEIVED') self.state = self.states.SEND_CONFIRMATION # # # if self.state is self.states.RECV_DATA: if self.protocol_helper.recv(self.socket, self.data_size): response = self.protocol_helper.get_buffer() self.received_data = response self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(b'RECEIVED') self.state = self.states.SEND_CONFIRMATION # # if self.state is self.states.SEND_CONFIRMATION: if self.protocol_helper.send(self.socket): self.state = self.states.DONE # # if self.state is self.states.DONE: return True # return False # # class ServerListener(): def __init__(self, bind_endpoint, accept_callback): self.callback = accept_callback # self.s = socket.socket() self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.s.setblocking(0) self.s.bind(bind_endpoint) self.s.listen(1000) # def accept(self, block=True): if block: (readable, _, _) = select.select([self.s], [], []) else: readable = [self.s] # try: (newsock, endpoint) = self.s.accept() logging.debug("New client from %s:%d (fd=%d)", endpoint[0], endpoint[1], newsock.fileno()) self.callback(newsock) return True except BlockingIOError: return False # # def stop(self): self.s.shutdown(socket.SHUT_RDWR) # use 'shutdown' rather than 'close' since 'close' won't stop a blocking 'accept' call # # class SimpleClientConnectionProtocol(Protocol): def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None): self.endpoint = endpoint self.data_generator = data_generator self.total_bytes = total_bytes self.proxy = proxy self.username = username # self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY PUSH_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.socket = socket.socket() self.sub_protocol = None # if self.proxy is None: logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint) self.socket.connect(self.endpoint) else: logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy) self.socket.connect(self.proxy) # # def _run_iteration(self): if self.state is self.states.READY_TO_BEGIN: if self.proxy is None: self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator) self.state = self.states.PUSH_DATA else: self.sub_protocol = Socks4Protocol(self.socket, self.endpoint, username=self.username) self.state = self.states.CONNECT_TO_PROXY # # if self.state is self.states.CONNECT_TO_PROXY: if self.sub_protocol.run(): self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator) self.state = self.states.PUSH_DATA # # if self.state is self.states.PUSH_DATA: if self.sub_protocol.run(): self.state = self.states.DONE # # if self.state is self.states.DONE: return True # return False # # class SimpleServerConnectionProtocol(Protocol): def __init__(self, socket, conn_id, bandwidth_callback=None): self.socket = socket self.conn_id = conn_id self.bandwidth_callback = bandwidth_callback # self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN PULL_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.sub_protocol = None # def _run_iteration(self): if self.state is self.states.READY_TO_BEGIN: self.sub_protocol = PullDataProtocol(self.socket) self.state = self.states.PULL_DATA # if self.state is self.states.PULL_DATA: if self.sub_protocol.run(): if self.bandwidth_callback: self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate()) # self.state = self.states.DONE # # if self.state is self.states.DONE: return True # return False # # if __name__ == '__main__': import sys logging.basicConfig(level=logging.DEBUG) # if sys.argv[1] == 'client': endpoint = ('127.0.0.1', 4747) #proxy = ('127.0.0.1', 9003) proxy = None username = bytes([x for x in os.urandom(12) if x != 0]) #username = None data_MB = 4000 # client = SimpleClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username) client.run() elif sys.argv[1] == 'server': import multiprocessing import queue # endpoint = ('127.0.0.1', 4747) processes = [] conn_counter = [0] # def bw_callback(conn_id, data_size, transfer_rate): logging.info('Avg Transferred (MB): %.4f', data_size/(1024**2)) logging.info('Avg Transfer rate (MB/s): %.4f', transfer_rate/(1024**2)) # def start_server_conn(socket, conn_id): server = SimpleServerConnectionProtocol(socket, conn_id, bandwidth_callback=bw_callback) try: server.run() except KeyboardInterrupt: socket.close() # # def accept_callback(socket): conn_id = conn_counter[0] conn_counter[0] += 1 # p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id)) processes.append(p) p.start() # l = ServerListener(endpoint, accept_callback) # try: while True: l.accept() # except KeyboardInterrupt: print() # for p in processes: p.join() # # #