#!/usr/bin/python3 # import socket import struct import logging import time import enum import select import os # import accelerated_functions # class ProtocolException(Exception): pass # class ProtocolHelper(): def __init__(self): self._buffer = b'' # def set_buffer(self, data): self._buffer = data # def get_buffer(self): return self._buffer # def recv(self, socket, num_bytes): data = socket.recv(num_bytes-len(self._buffer)) self._buffer += data if len(self._buffer) == num_bytes: return True # return False # def send(self, socket): n = socket.send(self._buffer) self._buffer = self._buffer[n:] if len(self._buffer) == 0: return True # return False # # class Protocol(): def _run_iteration(self, block=True): pass # def run(self, block=True): while True: finished = self._run_iteration(block=block) # if finished: # protocol is done return True elif not block: # not done the protocol yet, but don't block return False # # # # class ChainedProtocol(Protocol): def __init__(self, protocols): self.protocols = protocols self.current_protocol = 0 # self.states = enum.Enum('CHAIN_STATES', 'READY_TO_BEGIN RUNNING DONE') self.state = self.states.READY_TO_BEGIN # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: self.state = self.states.RUNNING # if self.state is self.states.RUNNING: if self.protocols[self.current_protocol] is None or self.protocols[self.current_protocol].run(block=block): self.current_protocol += 1 # if self.current_protocol >= len(self.protocols): self.state = self.states.DONE return True # # return False # # class Socks4Protocol(Protocol): def __init__(self, socket, addr_port, username=None): self.socket = socket self.addr_port = addr_port self.username = username # self.states = enum.Enum('SOCKS_4_STATES', 'READY_TO_BEGIN CONNECTING_TO_PROXY WAITING_FOR_PROXY DONE') self.state = self.states.READY_TO_BEGIN # self.protocol_helper = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(self.socks_cmd(self.addr_port, self.username)) self.state = self.states.CONNECTING_TO_PROXY # if self.state is self.states.CONNECTING_TO_PROXY: if self.protocol_helper.send(self.socket): self.protocol_helper = ProtocolHelper() self.state = self.states.WAITING_FOR_PROXY #logging.debug('Waiting for reply from proxy') # # if self.state is self.states.WAITING_FOR_PROXY: response_size = 8 if self.protocol_helper.recv(self.socket, response_size): response = self.protocol_helper.get_buffer() if response[1] != 0x5a: raise ProtocolException('Could not connect to SOCKS proxy, msg: %x'%(response[1],)) # self.state = self.states.DONE return True # # return False # def socks_cmd(self, addr_port, username=None): socks_version = 4 command = 1 dnsname = b'' host, port = addr_port # try: username = bytes(username, 'utf8') except TypeError: pass # if username is None: username = b'' elif b'\x00' in username: raise ProtocolException('Username cannot contain a NUL character.') # username = username+b'\x00' # try: addr = socket.inet_aton(host) except socket.error: addr = b'\x00\x00\x00\x01' dnsname = bytes(host, 'utf8')+b'\x00' # return struct.pack('!BBH', socks_version, command, port) + addr + username + dnsname # # class PushDataProtocol(Protocol): def __init__(self, socket, total_bytes, send_buffer_len=None, use_acceleration=None): if send_buffer_len is None: send_buffer_len = 1024*512 # if use_acceleration is None: use_acceleration = True # self.socket = socket self.total_bytes = total_bytes self.use_acceleration = use_acceleration # self.states = enum.Enum('PUSH_DATA_STATES', 'READY_TO_BEGIN SEND_INFO PUSH_DATA RECV_CONFIRMATION DONE') self.state = self.states.READY_TO_BEGIN # self.byte_buffer = os.urandom(send_buffer_len) self.bytes_written = 0 self.protocol_helper = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: info = self.total_bytes.to_bytes(8, byteorder='big', signed=False) info += len(self.byte_buffer).to_bytes(8, byteorder='big', signed=False) self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(info) self.state = self.states.SEND_INFO # if self.state is self.states.SEND_INFO: if self.protocol_helper.send(self.socket): self.state = self.states.PUSH_DATA # # if self.state is self.states.PUSH_DATA: if self.use_acceleration: if not block: logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.') # ret_val = accelerated_functions.push_data(self.socket.fileno(), self.total_bytes, self.byte_buffer) if ret_val < 0: raise ProtocolException('Error while pushing data.') # self.bytes_written = self.total_bytes else: bytes_remaining = self.total_bytes-self.bytes_written data_size = min(len(self.byte_buffer), bytes_remaining) if data_size != len(self.byte_buffer): data = self.byte_buffer[:data_size] else: data = self.byte_buffer # don't make a copy of the byte string each time if we don't need to # n = self.socket.send(data) self.bytes_written += n # if self.bytes_written == self.total_bytes: # finished sending the data logging.debug('Finished sending the data (%d bytes).', self.bytes_written) self.protocol_helper = ProtocolHelper() self.state = self.states.RECV_CONFIRMATION # # if self.state is self.states.RECV_CONFIRMATION: response_size = 8 if self.protocol_helper.recv(self.socket, response_size): response = self.protocol_helper.get_buffer() if response != b'RECEIVED': raise ProtocolException('Did not receive the expected message: {}'.format(response)) # self.state = self.states.DONE return True # # return False # # class PullDataProtocol(Protocol): def __init__(self, socket, use_acceleration=None): if use_acceleration is None: use_acceleration = True # self.socket = socket self.use_acceleration = use_acceleration # self.states = enum.Enum('PULL_DATA_STATES', 'READY_TO_BEGIN RECV_INFO PULL_DATA SEND_CONFIRMATION DONE') self.state = self.states.READY_TO_BEGIN # self.data_size = None self.recv_buffer_len = None self.bytes_read = 0 self.protocol_helper = None self._time_of_first_byte = None self.elapsed_time = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: self.protocol_helper = ProtocolHelper() self.state = self.states.RECV_INFO # if self.state is self.states.RECV_INFO: info_size = 16 if self.protocol_helper.recv(self.socket, info_size): response = self.protocol_helper.get_buffer() self.data_size = int.from_bytes(response[0:8], byteorder='big', signed=False) self.recv_buffer_len = int.from_bytes(response[8:16], byteorder='big', signed=False) self.state = self.states.PULL_DATA # # if self.state is self.states.PULL_DATA: if self.use_acceleration: if not block: logging.warning('Protocol set to non-blocking, but using the blocking accelerated function.') # (ret_val, elapsed_time) = accelerated_functions.pull_data(self.socket.fileno(), self.data_size, self.recv_buffer_len) if ret_val < 0: raise ProtocolException('Error while pulling data.') # self.bytes_read = self.data_size self.elapsed_time = elapsed_time else: bytes_remaining = self.data_size-self.bytes_read block_size = min(self.recv_buffer_len, bytes_remaining) # data = self.socket.recv(block_size) self.bytes_read += len(data) # if self.bytes_read != 0 and self._time_of_first_byte is None: self._time_of_first_byte = time.time() # if self.bytes_read == self.data_size and self.elapsed_time is None: self.elapsed_time = time.time()-self._time_of_first_byte # # if self.bytes_read == self.data_size: # finished receiving the data logging.debug('Finished receiving the data.') self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(b'RECEIVED') self.state = self.states.SEND_CONFIRMATION # # if self.state is self.states.SEND_CONFIRMATION: if self.protocol_helper.send(self.socket): self.state = self.states.DONE return True # # return False # def calc_transfer_rate(self): """ Returns bytes/s. """ assert self.data_size is not None and self.elapsed_time is not None return self.data_size/self.elapsed_time # # class SendDataProtocol(Protocol): def __init__(self, socket, data): self.socket = socket self.send_data = data # self.states = enum.Enum('SEND_DATA_STATES', 'READY_TO_BEGIN SEND_INFO SEND_DATA RECV_CONFIRMATION DONE') self.state = self.states.READY_TO_BEGIN # self.protocol_helper = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: info_size = 20 info = len(self.send_data).to_bytes(info_size, byteorder='big', signed=False) self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(info) self.state = self.states.SEND_INFO # if self.state is self.states.SEND_INFO: if self.protocol_helper.send(self.socket): self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(self.send_data) self.state = self.states.SEND_DATA # # if self.state is self.states.SEND_DATA: if self.protocol_helper.send(self.socket): self.protocol_helper = ProtocolHelper() self.state = self.states.RECV_CONFIRMATION # # if self.state is self.states.RECV_CONFIRMATION: response_size = 8 if self.protocol_helper.recv(self.socket, response_size): response = self.protocol_helper.get_buffer() if response != b'RECEIVED': raise ProtocolException('Did not receive the expected message: {}'.format(response)) # self.state = self.states.DONE return True # # return False # # class ReceiveDataProtocol(Protocol): def __init__(self, socket): self.socket = socket # self.states = enum.Enum('RECV_DATA_STATES', 'READY_TO_BEGIN RECV_INFO RECV_DATA SEND_CONFIRMATION DONE') self.state = self.states.READY_TO_BEGIN # self.protocol_helper = None self.data_size = None self.received_data = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: self.protocol_helper = ProtocolHelper() self.state = self.states.RECV_INFO # if self.state is self.states.RECV_INFO: info_size = 20 if self.protocol_helper.recv(self.socket, info_size): response = self.protocol_helper.get_buffer() self.data_size = int.from_bytes(response, byteorder='big', signed=False) self.protocol_helper = ProtocolHelper() self.state = self.states.RECV_DATA # # if self.state is self.states.RECV_DATA: if self.protocol_helper.recv(self.socket, self.data_size): response = self.protocol_helper.get_buffer() self.received_data = response self.protocol_helper = ProtocolHelper() self.protocol_helper.set_buffer(b'RECEIVED') self.state = self.states.SEND_CONFIRMATION # # if self.state is self.states.SEND_CONFIRMATION: if self.protocol_helper.send(self.socket): self.state = self.states.DONE return True # # return False # # class ServerListener(): "A TCP listener, binding, listening and accepting new connections." def __init__(self, endpoint, accept_callback): self.callback = accept_callback # self.s = socket.socket() self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.s.bind(endpoint) self.s.listen(0) # def accept(self): newsock, endpoint = self.s.accept() logging.debug("New client from %s:%d (fd=%d)", endpoint[0], endpoint[1], newsock.fileno()) self.callback(newsock) # # class SimpleClientConnectionProtocol(Protocol): def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None): self.endpoint = endpoint self.data_generator = data_generator self.total_bytes = total_bytes self.proxy = proxy self.username = username # self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY PUSH_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.socket = socket.socket() self.sub_protocol = None # if self.proxy is None: logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint) self.socket.connect(self.endpoint) else: logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy) self.socket.connect(self.proxy) # # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: if self.proxy is None: self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator) self.state = self.states.PUSH_DATA else: self.sub_protocol = Socks4Protocol(self.socket, self.endpoint, username=self.username) self.state = self.states.CONNECT_TO_PROXY # # if self.state is self.states.CONNECT_TO_PROXY: if self.sub_protocol.run(block=block): self.sub_protocol = PushDataProtocol(self.socket, self.total_bytes, self.data_generator) self.state = self.states.PUSH_DATA # # if self.state is self.states.PUSH_DATA: if self.sub_protocol.run(block=block): self.state = self.states.DONE return True # # return False # # class SimpleServerConnectionProtocol(Protocol): def __init__(self, socket, conn_id, bandwidth_callback=None): self.socket = socket self.conn_id = conn_id self.bandwidth_callback = bandwidth_callback # self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN PULL_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.sub_protocol = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: self.sub_protocol = PullDataProtocol(self.socket) self.state = self.states.PULL_DATA # if self.state is self.states.PULL_DATA: if self.sub_protocol.run(block=block): self.state = self.states.DONE if self.bandwidth_callback: self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate()) # return True # # return False # # if __name__ == '__main__': import sys logging.basicConfig(level=logging.DEBUG) # if sys.argv[1] == 'client': endpoint = ('127.0.0.1', 4747) proxy = ('127.0.0.1', 9003) #proxy = None username = bytes([x for x in os.urandom(12) if x != 0]) #username = None data_MB = 40 # client = SimpleClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username) client.run() elif sys.argv[1] == 'server': import multiprocessing import queue # endpoint = ('127.0.0.1', 4747) processes = [] conn_counter = [0] # def bw_callback(conn_id, data_size, transfer_rate): logging.info('Avg Transferred (MB): %.4f', data_size/(1024**2)) logging.info('Avg Transfer rate (MB/s): %.4f', transfer_rate/(1024**2)) # def start_server_conn(socket, conn_id): server = SimpleServerConnectionProtocol(socket, conn_id, bandwidth_callback=bw_callback) try: server.run() except KeyboardInterrupt: socket.close() # # def accept_callback(socket): conn_id = conn_counter[0] conn_counter[0] += 1 # p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id)) processes.append(p) p.start() # l = ServerListener(endpoint, accept_callback) # try: while True: l.accept() # except KeyboardInterrupt: print() # for p in processes: p.join() # # #