#!/usr/bin/python3 # import basic_protocols import logging import enum import time import socket # class ClientProtocol(basic_protocols.Protocol): def __init__(self, socket, total_bytes, wait_until=None, send_buffer_len=None, use_acceleration=None): self.socket = socket self.total_bytes = total_bytes self.wait_until = wait_until self.send_buffer_len = send_buffer_len self.use_acceleration = use_acceleration # self.states = enum.Enum('CLIENT_CONN_STATES', 'READY_TO_BEGIN SEND_GROUP_ID PUSH_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.sub_protocol = None self.group_id = int(self.wait_until*1000) if self.wait_until is not None else 0 # a group id of 0 means no group # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: group_id_bytes = self.group_id.to_bytes(8, byteorder='big', signed=False) self.sub_protocol = basic_protocols.SendDataProtocol(self.socket, group_id_bytes) self.state = self.states.SEND_GROUP_ID # if self.state is self.states.SEND_GROUP_ID: if block and self.wait_until is not None: time.sleep(self.wait_until-time.time()) # if (self.wait_until is None or time.time() >= self.wait_until) and self.sub_protocol.run(block=block): self.sub_protocol = basic_protocols.PushDataProtocol(self.socket, self.total_bytes, send_buffer_len=self.send_buffer_len, use_acceleration=self.use_acceleration) self.state = self.states.PUSH_DATA # # if self.state is self.states.PUSH_DATA: if self.sub_protocol.run(block=block): self.state = self.states.DONE return True # # return False # # class ServerProtocol(basic_protocols.Protocol): def __init__(self, socket, conn_id, group_id_callback=None, bandwidth_callback=None, use_acceleration=None): self.socket = socket self.conn_id = conn_id self.group_id_callback = group_id_callback self.bandwidth_callback = bandwidth_callback self.use_acceleration = use_acceleration # self.states = enum.Enum('SERVER_CONN_STATES', 'READY_TO_BEGIN RECV_GROUP_ID PULL_DATA DONE') self.state = self.states.READY_TO_BEGIN # self.sub_protocol = None # def _run_iteration(self, block=True): if self.state is self.states.READY_TO_BEGIN: self.sub_protocol = basic_protocols.ReceiveDataProtocol(self.socket) self.state = self.states.RECV_GROUP_ID # if self.state is self.states.RECV_GROUP_ID: if self.sub_protocol.run(block=block): group_id = int.from_bytes(self.sub_protocol.received_data, byteorder='big', signed=False) if group_id == 0: # a group of 0 means no group group_id = None # self.group_id_callback(self.conn_id, group_id) self.sub_protocol = basic_protocols.PullDataProtocol(self.socket, use_acceleration=self.use_acceleration) self.state = self.states.PULL_DATA # # if self.state is self.states.PULL_DATA: if self.sub_protocol.run(block=block): self.state = self.states.DONE if self.bandwidth_callback: self.bandwidth_callback(self.conn_id, self.sub_protocol.data_size, self.sub_protocol.calc_transfer_rate()) # return True # # return False # #