import argparse
import os
import json
from typing import Dict
import subprocess

from dmacompiler import BackEnd
from dataflow.mha.mha_build import mha_mini_build_qdq
from OGOAT.src.Scheduling_Engine.schedules.BufferAllocatorResult import BufferAllocations
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

test_dir = os.environ.get("TEST_DIR")

def ceildiv(x: int, d: int) -> int:
    return -(x // -d)

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d

def generate_standalone_buffer_allocation(tiler_output: Dict) -> BufferAllocations:
    pass

def run_dataflow(H, Sin_q, Sin_kv, Sin_dh, M_subv, cols, attn_mask_exist, bias , mini_3p0, back_end, disable_fast_pm):

    kernel_includes = [
        'super.hh',
	    'mha_qdq/wrapper_mini_mha.cc',
        'norm/softmax.cc'
    ]
    kernel_names = {}
    kernel_list = [
	    'run_gemm_qdq_mini',
        'run_softmax_qdq',
	    'run_bcast_add_mini',
	    'run_mini_mha_preprocess',
        'run_presoftmax_dequant'
    ]
    for s in kernel_list:
        try:
            kidx = kernel_func_list.index(s)
            kernel_names[s] = kidx
        except ValueError:
            print(f"Error: '{s}' not found in the kernel func list!")
    curr_dir = os.path.dirname(__file__)

    Sin_q_mod = 64 if(Sin_q == 1) else Sin_q
    model_data_folder =f"{test_dir}/model_data_{Sin_q_mod}x{Sin_dh}x{Sin_kv}_{Sin_q_mod}x{Sin_kv}x{Sin_dh}/"
    if bias:
        model_data_folder = f"{test_dir}/model_data_{Sin_q}x{Sin_dh}x{Sin_kv}_H_{H}/"

    mha_mini_build_qdq(
        H=H,
        Sin_q=Sin_q,
        Sin_kv=Sin_kv,
        Sin_dh=Sin_dh,
        Sq_subv=M_subv,
        aie_overlay_cols=cols,
        back_end=back_end,
        standalone_hw_test=True,
        kernel_names=kernel_names,
        kernel_includes=kernel_includes,
        front_end_only=False,
        disable_fast_pm=disable_fast_pm,
        model_data=model_data_folder,
        out_folder=curr_dir,
        enable_attn_mask=attn_mask_exist,
        K_is_tranposed_on_DDR=(Sin_q==1),
        B=bias,
        mini_3p0=mini_3p0
    )

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test script for the mha mini operator")


    parser.add_argument(
        "--Sin_q", required=True, type=int, help="Query size"
    )
    parser.add_argument(
        "--Sin_kv", required=True, type=int, help="Key/Value size"
    )
    parser.add_argument(
        "--Sin_dh", required=True, type=int, help="head size"
    )
    parser.add_argument(
        "--H", required=True, type=int, help="number of heads"
    )
    parser.add_argument(
        "--M_subv", required=True, type=int, help="number of subvolume for the QKV matrix"
    )
    parser.add_argument(
        "--cols", required=True, type=int, help="Overlay selection"
    )
    parser.add_argument(
        "--mask",
        help="Profile auto scheduler",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--bias",
        type=int,
        default=None,
        help="Enable bias in the MHA mini operator (0 for no bias, 1 for bias enabled)",
    )
    parser.add_argument(
        "--mini_3p0",
        help="Profile auto scheduler",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--mode",
        required=True,
        choices=["hw", "sim"],
        help="Run the test in hardware or simulation mode",
    )


    parser.add_argument('--disable_fast_pm', action="store_true", default=False, help="Disable fast pm mode")

    args = parser.parse_args()

    back_end = BackEnd.TxnHostPatch if args.mode == "hw" else BackEnd.Adf

    script_dir = os.path.dirname(os.path.abspath(__file__))  # Get the script's directory
    makefile_path = os.path.join(script_dir, "Makefile")  # Default Makefile in script dir

    # Run 'make clean' with a specified Makefile
    subprocess.run(["make", "-f", makefile_path, "clean"], check=True)

    run_dataflow(args.H, args.Sin_q, args.Sin_kv, args.Sin_dh, args.M_subv, args.cols, args.mask, args.bias, args.mini_3p0, back_end, disable_fast_pm=args.disable_fast_pm)

