#!/usr/bin/python3 # import basic_protocols import logging import enum import time import socket # class ClientConnectionProtocol(basic_protocols.Protocol): def __init__(self, endpoint, data, proxy=None, username=None): self.endpoint = endpoint self.data = data self.proxy = proxy self.username = username # self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY SEND_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.socket = socket.socket() self.sub_protocol = None # if self.proxy is None: logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint) self.socket.connect(self.endpoint) else: logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy) self.socket.connect(self.proxy) # # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: if self.proxy is None: self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, self.data) self.state = self.states.SEND_DATA else: #self.sub_protocol = basic_protocols.Socks4Protocol(self.socket, self.endpoint, username=self.username) self.sub_protocol = basic_protocols.WeirdProtocol(self.socket, self.endpoint) self.state = self.states.CONNECT_TO_PROXY # # if self.state is self.states.CONNECT_TO_PROXY: if self.sub_protocol.run(block=block): self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, self.data) self.state = self.states.SEND_DATA # # if self.state is self.states.SEND_DATA: if self.sub_protocol.run(block=block): self.state = self.states.DONE return True # # return False # # class ServerConnectionProtocol(basic_protocols.Protocol): def __init__(self, socket, conn_id, data_callback=None): self.socket = socket self.conn_id = conn_id self.data_callback = data_callback # self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.sub_protocol = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: self.sub_protocol = basic_protocols.ReceiveDataProtocol(self.socket) self.state = self.states.RECV_DATA # if self.state is self.states.RECV_DATA: if self.sub_protocol.run(block=block): self.data_callback(self.conn_id, self.sub_protocol.received_data) self.state = self.states.DONE return True # # return False # # if __name__ == '__main__': import sys logging.basicConfig(level=logging.DEBUG) # import random random.seed(10) data_to_send = bytearray(random.getrandbits(8) for _ in range(1024*1024*100)) # print('Generated bytes') # if sys.argv[1] == 'client': import os # endpoint = ('127.0.0.1', 4747) #endpoint = ('127.0.0.1', 8627) #proxy = ('127.0.0.1', 9003+int(sys.argv[3])-1) #proxy = ('127.0.0.1', 9003) proxy = ('127.0.0.1', 12849) #proxy = None username = bytes([x for x in os.urandom(12) if x != 0]) #username = None # client = ClientConnectionProtocol(endpoint, data_to_send, proxy=proxy, username=username) client.run() # elif sys.argv[1] == 'server': import multiprocessing import queue # endpoint = ('127.0.0.1', 4747) processes = [] processes_map = {} joinable_connections = multiprocessing.Queue() conn_counter = [0] group_queue = multiprocessing.Queue() bw_queue = multiprocessing.Queue() # def data_callback(conn_id, data): # check data here print('Received {} MB'.format(len(data)/(1024**2))) print('Data matches: {}'.format(data==data_to_send)) # def start_server_conn(socket, conn_id): server = ServerConnectionProtocol(socket, conn_id, data_callback=data_callback) try: server.run() except KeyboardInterrupt: socket.close() finally: joinable_connections.put(conn_id) # # def accept_callback(socket): conn_id = conn_counter[0] conn_counter[0] += 1 #logging.debug('Adding connection %d', conn_id) p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id)) processes.append(p) processes_map[conn_id] = p p.start() socket.close() # close this process' copy of the socket # l = basic_protocols.ServerListener(endpoint, accept_callback) # try: while True: l.accept() try: while True: conn_id = joinable_connections.get(False) p = processes_map[conn_id] p.join() # except queue.Empty: pass # # except KeyboardInterrupt: print() # for p in processes: p.join() # # #