import argparse
import os
import json
from typing import Dict
import subprocess

from dmacompiler import BackEnd
from dataflow.mha.mha_build import mha_2p1_build_qdq
from dataflow.mha.overlay_qkt_sm_8x4 import Mha2p1Parameters, generate_standalone_buffer_allocation


def run_dataflow(tiler_output_file: str, model_data_folder: str, back_end):
    kernel_names = ["run_qtxk", "run_sfmx"]
    kernel_includes = ["super.hh", "mha_qdq/wrapper_qkt_sm_4x4_i16i16.cc"]

    curr_dir = os.path.dirname(__file__)
    tiler_output_path = os.path.join(curr_dir, tiler_output_file)
    model_data_dir_path = os.path.join(curr_dir, model_data_folder)

    with open(tiler_output_path, "r") as fd:
        tiler_output = json.load(fd)

    parameters = Mha2p1Parameters.compute_internal_parameters(tiler_output)
    buffer_alloc = generate_standalone_buffer_allocation(tiler_output)

    byte_string = subprocess.check_output(["git", "rev-parse", "HEAD"])
    cur_version = byte_string.decode("utf-8")[0:7]

    op_release_path = (
        None
        if (back_end == BackEnd.Adf)
        else "./psr_mha_2p1_" + cur_version + "/"
    )

    mha_2p1_build_qdq(
        parameters,
        buffer_alloc,
        back_end,
        kernel_names,
        kernel_includes,
        model_data=model_data_dir_path,
        out_folder=op_release_path
    )


def test_mha_2p1_1024x64x1024(back_end):
    tiler_output = "qkt_sm_1024x64x1024_uint16.json"
    model_data_1024 = "model_data_1024x64x1024_1024x1024x64/"

    run_dataflow(tiler_output, model_data_1024, back_end)


def test_mha_2p1_4096x64x4096(back_end):
    tiler_output = "qkt_sm_4096x64x4096_uint16.json"
    model_data_4096 = "model_data_4096x64x4096_4096x4096x64/"

    run_dataflow(tiler_output, model_data_4096, back_end)


def test_mha_2p1_256x64x256(back_end):
    tiler_output = "qkt_sm_256x64x256_uint16.json"
    model_data_256 = "model_data_256x64x256_256x256x64/"

    run_dataflow(tiler_output, model_data_256, back_end)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test script for the mha 2p0 operator")

    parser.add_argument(
        "--id", required=True, type=int, help="Id of the test that we want to run"
    )
    parser.add_argument(
        "--mode",
        required=True,
        choices=["hw", "sim"],
        help="Run the test in hardware or simulation mode",
    )

    args = parser.parse_args()

    back_end = BackEnd.TxnHostPatch if args.mode == "hw" else BackEnd.Adf
    if args.id == 1:
        test_mha_2p1_256x64x256(back_end)
    if args.id == 2:
        test_mha_2p1_1024x64x1024(back_end)
    if args.id == 3:
        test_mha_2p1_4096x64x4096(back_end)
