#!/usr/bin/env python3

# Usage: ./run_native_bench [-z] [niters]
# Use -z to benchmark the legacy zkp code; otherwise benchmark our sigma-rs code
# niters defaults to 10

import os
import re
import subprocess
import sys

progname = sys.argv.pop(0)
if len(sys.argv) > 0 and sys.argv[0] == "-z":
    sys.argv.pop(0)
    loxdir = "../application-lox-zkp/crates/lox-library"
    cargo_features = "--features=bridgeauth"
else:
    loxdir = "../application-lox/crates/lox-extensions"
    cargo_features = "--features=bridgeauth,test"

if len(sys.argv) > 0:
    niters = sys.argv[0]
else:
    niters = "10"

os.environ['LOX_BENCH_NITERS'] = niters
os.chdir(os.path.dirname(os.path.realpath(progname)))
os.chdir(loxdir)

proc = subprocess.Popen(["cargo", "test", "--release", "bench_",
    cargo_features, "--", "--nocapture", "--test-threads", "1"],
    stdout=subprocess.PIPE, text=True)
protocol = None
protocol_map = {
    'BLOCKAGE-MIGRATION': 'Blockage Migration',
    'CHECK-BLOCKAGE': 'Check Blockage',
    'ISSUE-INVITATION': 'Issue Invite',
    'LEVEL-UP-2: 44 days': 'Level Up',
    'TRUST-MIGRATION-0: 30 days': 'Trust Migration',
    'OPEN-INVITATION': 'Open Invitation',
    'REDEEM-INVITATION': 'Redeem Invite',
    'TRUST-PROMOTION-1: 30 days': 'Trust Promotion',
    'UPDATE-CRED': 'Update Credential',
    'UPDATE-INVITE': 'Update Invite',
}
req_size = {}
resp_size = {}
client_time = {}
server_time = {}
for line in proc.stdout:
    print(line,end='')
    if matched := re.match(r'----(.*)----', line):
        protocol = protocol_map[matched.group(1)]
        continue
    if matched := re.match(r'Request bytes range: \[\d+, (\d+)\]', line):
        req_size[protocol] = matched.group(1)
    elif matched := re.match(r'Response bytes range: \[\d+, (\d+)\]', line):
        resp_size[protocol] = matched.group(1)
    elif matched := re.match(r'Total client ms: ([\.\d]+) ± ([\.\d]+)', line):
        client_time[protocol] = (matched.group(1), matched.group(2))
    elif matched := re.match(r'Response ms: ([\.\d]+) ± ([\.\d]+)', line):
        server_time[protocol] = (matched.group(1), matched.group(2))

# The order in which we output the protocols
protocol_list = [
    'Open Invitation',
    'Trust Promotion',
    'Trust Migration',
    'Level Up',
    'Issue Invite',
    'Redeem Invite',
    'Check Blockage',
    'Blockage Migration',
    'Update Invite',
    'Update Credential',
]

print("\n=== Table 2 ===\n")
print("protocol,client native ms,server native ms")
for p in protocol_list:
    print(f"{p},{client_time[p][0]} ({client_time[p][1]}),{server_time[p][0]} ({server_time[p][1]})")

print("\n=== Table 5 ===\n")
print("protocol,request size,response size")
for p in protocol_list:
    print(f"{p},{req_size[p]},{resp_size[p]}")
print()
