#!/usr/bin/env python3

# Portions cribbed from https://www.w3reference.com/blog/python3-http-server-post-example/

from http.server import SimpleHTTPRequestHandler, HTTPServer
import json
import subprocess
import sys
import threading
import time

num_uploads = 0
niters = 3
httpd = None
lox_distributor = None
log_file = open("uploaded_log", "w")

def shutdown_servers():
    if lox_distributor is not None:
        lox_distributor.terminate()
    time.sleep(2)
    httpd.shutdown()

def restart_lox_distributor():
    global lox_distributor
    if lox_distributor is not None:
        lox_distributor.terminate()
    subprocess.run(["/bin/rm", "-rf", "lox_db"], cwd="../lox-distributor")
    lox_distributor = subprocess.Popen(["/usr/bin/cargo", "run", "--release", "--features",
                                        "test-branch"], cwd="../lox-distributor")
    time.sleep(2)


class PostHandler(SimpleHTTPRequestHandler):
    def do_POST(self):
        global num_uploads

        path = self.path
        if path != "/log":
            self.send_response(404)
            self.send_header('Content-Type', 'text/plain; charset=utf-8')
            self.end_headers()
            self.wfile.write("Bad URL for POST\n".encode('utf-8'))
            return

        content_length = int(self.headers.get('Content-Length', 0))
        
        post_body_bytes = self.rfile.read(content_length)
        post_body_json_str = post_body_bytes.decode('utf-8')
        post_body_str = json.loads(post_body_json_str)

        log_file.write(post_body_str)
        
        restart_lox_distributor()

        num_uploads += 1

        if num_uploads < niters:
            response = "0"
        else:
            response = "1"
        
        self.send_response(200)
        self.send_header('Content-Type', 'text/plain; charset=utf-8')
        self.end_headers()
        self.wfile.write(response.encode('utf-8'))

        if num_uploads == niters:
            # We're done
            threading.Thread(target=shutdown_servers).start()
 
if __name__ == '__main__':
    if len(sys.argv) > 1:
        niters = int(sys.argv[1])
    restart_lox_distributor()
    server_address = ('', 8000)
    httpd = HTTPServer(server_address, PostHandler)
    httpd.serve_forever()
