123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- #!/usr/bin/python3
- #
- import basic_protocols
- import logging
- import enum
- import time
- import socket
- #
- class ClientConnectionProtocol(basic_protocols.Protocol):
- def __init__(self, endpoint, total_bytes, data_generator=None, proxy=None, username=None, wait_until=None):
- self.endpoint = endpoint
- self.data_generator = data_generator
- self.total_bytes = total_bytes
- self.proxy = proxy
- self.username = username
- self.wait_until = wait_until
- #
- self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN CONNECT_TO_PROXY SEND_GROUP_ID PUSH_DATA DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.socket = socket.socket()
- self.sub_protocol = None
- self.group_id = int(self.wait_until*1000) if self.wait_until is not None else 0
- # a group id of 0 means no group
- #
- if self.proxy is None:
- logging.debug('Socket %d connecting to endpoint %r...', self.socket.fileno(), self.endpoint)
- self.socket.connect(self.endpoint)
- else:
- logging.debug('Socket %d connecting to proxy %r...', self.socket.fileno(), self.proxy)
- self.socket.connect(self.proxy)
- #
- #
- def _run_iteration(self, block=True):
- if self.state is self.states.READY_TO_BEGIN:
- if self.proxy is None:
- group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
- self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
- self.state = self.states.SEND_GROUP_ID
- else:
- self.sub_protocol = basic_protocols.Socks4Protocol(self.socket, self.endpoint, username=self.username)
- self.state = self.states.CONNECT_TO_PROXY
- #
- #
- if self.state is self.states.CONNECT_TO_PROXY:
- if self.sub_protocol.run(block=block):
- group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False)
- self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes)
- self.state = self.states.SEND_GROUP_ID
- #logging.debug('Sent group ID.')
- #
- #
- if self.state is self.states.SEND_GROUP_ID:
- if block and self.wait_until is not None:
- time.sleep(self.wait_until-time.time())
- #
- if (self.wait_until is None or time.time() >= self.wait_until) and self.sub_protocol.run(block=block):
- self.sub_protocol = basic_protocols.PushDataProtocol(self.socket, self.total_bytes,
- data_generator=self.data_generator,
- send_max_bytes=1024*512)
- self.state = self.states.PUSH_DATA
- #
- #
- if self.state is self.states.PUSH_DATA:
- if self.sub_protocol.run(block=block):
- self.state = self.states.DONE
- return True
- #
- #
- return False
- #
- #
- class ServerConnectionProtocol(basic_protocols.Protocol):
- def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None):
- self.socket = socket
- self.conn_id = conn_id
- self.group_id_callback = group_id_callback
- self.bandwidth_callback = bandwidth_callback
- #
- self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_GROUP_ID PULL_DATA DONE')
- self.state = self.states.READY_TO_BEGIN
- #
- self.sub_protocol = None
- #
- def _run_iteration(self, block=True):
- if self.state is self.states.READY_TO_BEGIN:
- self.sub_protocol = basic_protocols.ReceiveDataProtocol(self.socket)
- self.state = self.states.RECV_GROUP_ID
- #
- if self.state is self.states.RECV_GROUP_ID:
- if self.sub_protocol.run(block=block):
- group_id = int.from_bytes(self.sub_protocol.received_data, byteorder='big', signed=False)
- if group_id == 0:
- # a group of 0 means no group
- group_id = None
- #
- self.group_id_callback(self.conn_id, group_id)
- self.sub_protocol = basic_protocols.PullDataProtocol(self.socket)
- self.state = self.states.PULL_DATA
- #
- #
- if self.state is self.states.PULL_DATA:
- if self.sub_protocol.run(block=block):
- self.state = self.states.DONE
- if self.bandwidth_callback:
- self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate())
- #
- return True
- #
- #
- return False
- #
- #
- if __name__ == '__main__':
- import sys
- logging.basicConfig(level=logging.DEBUG)
- #
- if sys.argv[1] == 'client':
- import os
- #
- endpoint = ('127.0.0.1', 4747)
- #endpoint = ('127.0.0.1', 8627)
- #proxy = ('127.0.0.1', 9003+int(sys.argv[3])-1)
- #proxy = ('127.0.0.1', 9003)
- proxy = None
- username = bytes([x for x in os.urandom(12) if x != 0])
- #username = None
- data_MB = 500000
- #
- if len(sys.argv) > 2:
- wait_until = int(sys.argv[2])
- else:
- wait_until = None
- #
- client = ClientConnectionProtocol(endpoint, data_MB*2**20, proxy=proxy, username=username, wait_until=wait_until)
- client.run()
- #
- elif sys.argv[1] == 'server':
- import multiprocessing
- import queue
- #
- endpoint = ('127.0.0.1', 4747)
- processes = []
- processes_map = {}
- joinable_connections = multiprocessing.Queue()
- conn_counter = [0]
- group_queue = multiprocessing.Queue()
- bw_queue = multiprocessing.Queue()
- #
- def group_id_callback(conn_id, group_id):
- # put them in a queue to display later
- #logging.debug('For conn %d Received group id: %d', conn_id, group_id)
- group_queue.put({'conn_id':conn_id, 'group_id':group_id})
- #
- def bw_callback(conn_id, data_size, transfer_rate):
- # put them in a queue to display later
- bw_queue.put({'conn_id':conn_id, 'data_size':data_size, 'transfer_rate':transfer_rate})
- #
- def start_server_conn(socket, conn_id):
- server = ServerConnectionProtocol(socket, conn_id, group_id_callback=group_id_callback, bandwidth_callback=bw_callback)
- try:
- server.run()
- except KeyboardInterrupt:
- socket.close()
- finally:
- joinable_connections.put(conn_id)
- #
- #
- def accept_callback(socket):
- conn_id = conn_counter[0]
- conn_counter[0] += 1
- #logging.debug('Adding connection %d', conn_id)
- p = multiprocessing.Process(target=start_server_conn, args=(socket, conn_id))
- processes.append(p)
- processes_map[conn_id] = p
- p.start()
- socket.close()
- # close this process' copy of the socket
- #
- l = basic_protocols.ServerListener(endpoint, accept_callback)
- #
- try:
- while True:
- l.accept()
- try:
- while True:
- conn_id = joinable_connections.get(False)
- p = processes_map[conn_id]
- p.join()
- #
- except queue.Empty:
- pass
- #
- #
- except KeyboardInterrupt:
- print()
- #
- bw_values = {}
- group_values = {}
- #
- try:
- while True:
- bw_val = bw_queue.get(False)
- bw_values[bw_val['conn_id']] = bw_val
- #
- except queue.Empty:
- pass
- #
- try:
- while True:
- group_val = group_queue.get(False)
- group_values[group_val['conn_id']] = group_val
- #
- except queue.Empty:
- pass
- #
- group_set = set([x['group_id'] for x in group_values.values()])
- for group in group_set:
- # doesn't handle group == None
- conns_in_group = [x[0] for x in group_values.items() if x[1]['group_id'] == group]
- in_group = [x for x in bw_values.values() if x['conn_id'] in conns_in_group]
- if len(in_group) > 0:
- avg_data_size = sum([x['data_size'] for x in in_group])/len(in_group)
- avg_transfer_rate = sum([x['transfer_rate'] for x in in_group])/len(in_group)
- #
- logging.info('Group size: %d', len(in_group))
- logging.info('Avg Transferred (MB): %.4f', avg_data_size/(1024**2))
- logging.info('Avg Transfer rate (MB/s): %.4f', avg_transfer_rate/(1024**2))
- #
- #
- #
- for p in processes:
- p.join()
- #
- #
- #
|