#!/usr/bin/python3 # import basic_protocols import logging import enum import time import socket # class ClientConnectionProtocol(basic_protocols.Protocol): def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None, wait_until=None): self.endpoint = endpoint self.data_generator = data_generator self.total_bytes = total_bytes self.proxy = proxy self.username = username self.wait_until = wait_until # self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY SEND_GROUP_ID PUSH_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.socket = socket.socket() self.sub_protocol = None self.group_id = int(self.wait_until*1000) if self.wait_until is not None else 0 # a group id of 0 means no group # if self.proxy is None: logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint) self.socket.connect(self.endpoint) else: logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy) self.socket.connect(self.proxy) # # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: if self.proxy is None: group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False) self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes) self.state = self.states.SEND_GROUP_ID else: self.sub_protocol = basic_protocols.Socks4Protocol(self.socket, self.endpoint, username=self.username) self.state = self.states.CONNECT_TO_PROXY # # if self.state is self.states.CONNECT_TO_PROXY: if self.sub_protocol.run(block=block): group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False) self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes) self.state = self.states.SEND_GROUP_ID #logging.debug('Sent group ID.') # # if self.state is self.states.SEND_GROUP_ID: if block: time.sleep(self.wait_until-time.time()) # if time.time() >= self.wait_until and self.sub_protocol.run(block=block): self.sub_protocol = basic_protocols.PushDataProtocol(self.socket, self.total_bytes, data_generator=self.data_generator, send_max_bytes=1024*512) self.state = self.states.PUSH_DATA # # if self.state is self.states.PUSH_DATA: if self.sub_protocol.run(block=block): self.state = self.states.DONE return True # # return False # # class ServerConnectionProtocol(basic_protocols.Protocol): def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None): self.socket = socket self.conn_id = conn_id self.group_id_callback = group_id_callback self.bandwidth_callback = bandwidth_callback # self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_GROUP_ID PULL_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.sub_protocol = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: self.sub_protocol = basic_protocols.ReceiveDataProtocol(self.socket) self.state = self.states.RECV_GROUP_ID # if self.state is self.states.RECV_GROUP_ID: if self.sub_protocol.run(block=block): group_id = int.from_bytes(self.sub_protocol.received_data, byteorder='big', signed=False) if group_id == 0: # a group of 0 means no group group_id = None # self.group_id_callback(self.conn_id, group_id) self.sub_protocol = basic_protocols.PullDataProtocolWithMetrics(self.socket) self.state = self.states.PULL_DATA # # if self.state is self.states.PULL_DATA: if self.sub_protocol.run(block=block): self.state = self.states.DONE if self.bandwidth_callback: self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate()) # return True # # return False # # if __name__ == '__main__': import sys logging.basicConfig(level=logging.DEBUG) # if sys.argv[1] == 'client': import os # endpoint = ('127.0.0.1', 4747) #endpoint = ('127.0.0.1', 8627) #proxy = ('127.0.0.1', 9003+int(sys.argv[3])-1) #proxy = ('127.0.0.1', 9003) proxy = None username = bytes([x for x in os.urandom(12) if x != 0]) #username = None data_MB = 500000 # if len(sys.argv) > 2: wait_until = int(sys.argv[2]) else: wait_until = None # client = ClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username, wait_until=wait_until) client.run() # elif sys.argv[1] == 'server': import multiprocessing import queue # endpoint = ('127.0.0.1', 4747) processes = [] processes_map = {} joinable_connections = multiprocessing.Queue() conn_counter = [0] group_queue = multiprocessing.Queue() bw_queue = multiprocessing.Queue() # def group_id_callback(conn_id, group_id): # put them in a queue to display later #logging.debug('For conn %d Received group id: %d', conn_id, group_id) group_queue.put({'conn_id':conn_id, 'group_id':group_id}) # def bw_callback(conn_id, data_size, transfer_rate): # put them in a queue to display later bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'transfer_rate':transfer_rate}) # def start_server_conn(socket, conn_id): server = ServerConnectionProtocol(socket, conn_id, group_id_callback=group_id_callback, bandwidth_callback=bw_callback) try: server.run() except KeyboardInterrupt: socket.close() finally: joinable_connections.put(conn_id) # # def accept_callback(socket): conn_id = conn_counter[0] conn_counter[0] += 1 #logging.debug('Adding connection %d', conn_id) p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id)) processes.append(p) processes_map[conn_id] = p p.start() # l = basic_protocols.ServerListener(endpoint, accept_callback) # try: while True: l.accept() try: while True: conn_id = joinable_connections.get(False) p = processes_map[conn_id] p.join() # except queue.Empty: pass # # except KeyboardInterrupt: print() # bw_values = {} group_values = {} # try: while True: bw_val = bw_queue.get(False) bw_values[bw_val['conn_id']] = bw_val # except queue.Empty: pass # try: while True: group_val = group_queue.get(False) group_values[group_val['conn_id']] = group_val # except queue.Empty: pass # group_set = set([x['group_id'] for x in group_values.values()]) for group in group_set: # doesn't handle group == None conns_in_group = [x[0] for x in group_values.items() if x[1]['group_id'] == group] in_group = [x for x in bw_values.values() if x['conn_id'] in conns_in_group] if len(in_group) > 0: avg_data_size = sum([x['data_size'] for x in in_group])/len(in_group) avg_transfer_rate = sum([x['transfer_rate'] for x in in_group])/len(in_group) # logging.info('Group size: %d', len(in_group)) logging.info('Avg Transferred (MB): %.4f', avg_data_size/(1024**2)) logging.info('Avg Transfer rate (MB/s): %.4f', avg_transfer_rate/(1024**2)) # # # for p in processes: p.join() # # #