import os
import shutil
import subprocess
import argparse

from HW_requirements.test_script import HW_test
CURRDIR = os.path.dirname(os.path.abspath(__file__)) #dataflow/mha
ROOTDIR = os.path.abspath(os.path.join(CURRDIR, '..','..'))  #WAIC
FLOW = os.path.abspath(os.path.join(CURRDIR, '..'))  #dataflow
XCLBIN_SCRIPT = os.path.join(FLOW, 'xclbin', 'xclbin_build.py')
output_root = os.path.join(ROOTDIR, 'WAIC_Outputs_mha')
HW_req = os.path.join(ROOTDIR, 'HW_requirements')
mha_script_mini = os.path.join(FLOW, 'mha', 'test_mha_mini.py')
mha_script_2p1 = os.path.join(FLOW, 'mha', 'test_mha_2p1_test.py')
base_folder = os.path.join(ROOTDIR, 'WAIC_Outputs_mha')
WAIC_out = base_folder  # default fallback


def create_folder_name(command):
    parts = command.split()
    script_name = os.path.basename(parts[1])  # get just the filename
   
    if script_name == "test_mha_mini.py":
        if len(parts) >= 14:
            folder_name = os.path.join(base_folder, 
                f"mha_mini_{parts[3]}_{parts[5]}_{parts[7]}_{parts[11]}_{parts[13]}"
            )
        else:
            folder_name = base_folder

    elif script_name == "test_mha_2p1_test.py":
        if len(parts) >= 14:
            folder_name = os.path.join(base_folder, 
                f"mha_2p1_{parts[3]}_{parts[5]}_{parts[7]}_{parts[11]}_{parts[13]}"
            )
        else:
            folder_name = base_folder
    else:
        folder_name = base_folder

    if "--mask" in parts:
        folder_name = folder_name + "_m"

    return folder_name

def run_xclbin(output_dir,disable_fast_pm=False):
    command = ([
        'python', XCLBIN_SCRIPT,
        '-o', '8x4',
        '-k', 'run_gemm_qdq_mini,run_softmax_qdq,run_bcast_add_mini,run_mini_mha_preprocess,run_presoftmax_dequant',
        '-i', 'super.hh,mha_qdq/wrapper_mini_mha.cc,norm/softmax.cc',
        '-d', output_dir
    ] + (["--disable_fast_pm"] if disable_fast_pm else [])
    )
 
    # Run the command
    result = subprocess.run(command, text=True)
    # Print the output
    print("Output:", result.stdout)
    print("Error:", result.stderr)

parser = argparse.ArgumentParser(description="Standalone MHA script")
parser.add_argument("-HW_IP", "--HW_IP", help="Set HW IP address", default="10.228.45.202")
parser.add_argument("-clean", "--delete_dir", help="delete output directory if it already exists",action="store_true")
parser.add_argument("-mha_type",choices=["mha_mini", "mha_2p1"],help="Specify MHA type to run (default: run all types)")
parser.add_argument('--disable_fast_pm', action="store_true", default=False, help="Disable fast pm mode")

args = parser.parse_args()

disable_fast_pm = args.disable_fast_pm
if args.delete_dir:

    if os.path.exists(output_root):
        print(f"Cleaning up existing output directory: {output_root}")
        shutil.rmtree(output_root)
    else:
        print(f"No existing output directory to clean at: {output_root}")


# List of files to copy
files_to_copy = [
    'ifm.bin', 'ofm.bin', 'ctrl.bin', 'txn.bin', 'param.bin', 'wgt.bin', 'patch.json', 'ctrl_meta.json'
]

# Remove the files if they exist
for file in files_to_copy:
    file_path = os.path.join(CURRDIR, file)
    if os.path.exists(file):
        os.remove(file)
        print(f"Removed {file}")
    else:
        print(f"{file} does not exist")

print("File removal process completed.")

if args.mha_type is None or args.mha_type == "mha_mini":

    commands = [

        f"python3 {mha_script_mini} --Sin_q 64 --Sin_kv 64 --Sin_dh 64 --H 1 --cols 8 --M_subv 16 --mode hw",
        f"python3 {mha_script_mini} --Sin_q 64 --Sin_kv 77 --Sin_dh 64 --H 8 --cols 8 --M_subv 16 --mode hw",
        f"python3 {mha_script_mini} --Sin_q 256 --Sin_kv 77 --Sin_dh 64 --H 4 --cols 8 --M_subv 16 --mode hw",
        f"python3 {mha_script_mini} --Sin_q 1024 --Sin_kv 77 --Sin_dh 64 --H 5 --cols 8 --M_subv 16 --mode hw",
        f"python3 {mha_script_mini} --Sin_q 4096 --Sin_kv 77 --Sin_dh 64 --H 2 --cols 8 --M_subv 16 --mode hw",
        f"python3 {mha_script_mini} --Sin_q 64 --Sin_kv 64 --Sin_dh 64 --H 6 --cols 8 --M_subv 16 --mode hw --mask",
        f"python3 {mha_script_mini} --Sin_q 1 --Sin_kv 64 --Sin_dh 64 --H 6 --cols 8 --M_subv 16 --mode hw --mask",
        f"python3 {mha_script_mini} --Sin_q 131  --Sin_kv 77 --Sin_dh 64 --H 1 --cols 8 --M_subv 16 --mode  hw --bias 1",
        f"python3 {mha_script_mini} --Sin_q 151 --Sin_kv 151 --Sin_dh 128 --H 6 --cols 8 --M_subv 16 --mode hw --mask --mini_3p0 0",
        #f"python3 {mha_script_mini} --Sin_q 77  --Sin_kv 77 --Sin_dh 64 --H 80 --cols 8 --M_subv 16 --mode  hw --bias 10",
        
        
    ]

    #python3 test_mha_mini.py --Sin_q 64 --Sin_kv 64 --Sin_dh 64 --H 6 --cols 8 --M_subv 16 --mode sim --mask
    # Run each command and copy files to the corresponding folder
    for command in commands:
        # Check if the script file exists before running the command
        script_file = command.split()[1]
        if not os.path.exists(script_file):
            print(f"Script file {script_file} not found. Skipping command: {command}")
            continue
        if disable_fast_pm:
            command += " --disable_fast_pm"
        # Run the command
        subprocess.run(f"{command}", shell=True)
        
        # Create the folder name based on the command line arguments
        folder_name = create_folder_name(command)
        
        # Create the folder if it doesn't exist
        if not os.path.exists(folder_name):
            os.makedirs(folder_name)
        
        # Copy the files to the folder
        for file in files_to_copy:
            if os.path.exists(file):
                shutil.copy(file, folder_name)

    print("All tests are done and files are copied to respective folders.")


if args.mha_type is None or args.mha_type == "mha_2p1":
    # Remove the files if they exist
    for file in files_to_copy:
        file_path = os.path.join(CURRDIR, file)
        if os.path.exists(file):
            os.remove(file)
            print(f"Removed {file}")
        else:
            print(f"{file} does not exist")

    print("File removal process completed.")

    commands = [ 
        f"python3 {mha_script_2p1} --Sin_q 151 --Sin_kv 151 --Sin_dh 128 --H 6 --cols 8 --M_subv 16 --mode hw",
        f"python3 {mha_script_2p1} --Sin_q 77 --Sin_kv 256 --Sin_dh 64 --H 1 --cols 8 --M_subv 16 --mode hw",
        f"python3 {mha_script_2p1} --Sin_q 256 --Sin_kv 256 --Sin_dh 64 --H 2 --cols 8 --M_subv 16 --mode hw",
        f"python3 {mha_script_2p1} --Sin_q 1024 --Sin_kv 1024 --Sin_dh 64 --H 2 --cols 8 --M_subv 16 --mode hw",
        f"python3 {mha_script_2p1} --Sin_q 4096 --Sin_kv 4096 --Sin_dh 64 --H 2 --cols 8 --M_subv 16 --mode hw",
    ]

    # Run each command and copy files to the corresponding folder
    for command in commands:
        # Check if the script file exists before running the command
        script_file = command.split()[1]
        if disable_fast_pm:
            command += " --disable_fast_pm"
        if not os.path.exists(script_file):
            print(f"Script file {script_file} not found. Skipping command: {command}")
            continue

        # Run the command
        subprocess.run(f"{command}", shell=True)
        
        # Create the folder name based on the command line arguments
        folder_name = create_folder_name(command)
        
        # Create the folder if it doesn't exist
        if not os.path.exists(folder_name):
            os.makedirs(folder_name)
        
        # Copy the files to the folder
        for file in files_to_copy:
            if os.path.exists(file):
                shutil.move(file, folder_name)

    print("All tests are done and files are copied to respective folders.")

print("Start building xclbin")
run_xclbin (WAIC_out, disable_fast_pm)

#connect to HW for test
print("Start HW testing")

if args.HW_IP != "None":
    HW_test(WAIC_out, HW_req, xclbin=0, xclbin_path= WAIC_out, overlay="8x4", use_bsub=False, host=args.HW_IP, perf_testing=False, golden_io=None, rename=False, profile_perf=False, rel_err_pc=True, disable_fast_pm=disable_fast_pm)
else:
    print("Running with manual copy and separate hw run")
