from dmacompiler import BackEnd
from dataflow.mha.mha_build import mha_3p0_build_qdq

import argparse
import subprocess
import os


def test_psr_mha_3p0(test_id: int, mode: str):
    kernel_names = [
        "run_act_K_preprocess",
        "run_act_V_preprocess",
        "run_qkt_gemm_qdq",
        "run_sfmx_i16_to_i16",
        "run_smxv_gemm_qdq",
    ]
    kernel_includes = ["super.hh", "mha_qdq/wrapper_mha_3p0_4x4_i16i16.cc"]

    if mode == "hw":
        back_end = BackEnd.TxnHostPatch
    else:
        back_end = BackEnd.Adf

    model_data_256 = "model_data_256x64x256_256x256x64/"
    model_data_1024 = "model_data_1024x64x1024_1024x1024x64/"

    byte_string = subprocess.check_output(["git", "rev-parse", "HEAD"])
    cur_version = byte_string.decode("utf-8")[0:7]

    if test_id == 1:
        model_data_name = model_data_256
        tiler_output_name = "qkt_sm_smxv_1x256x64x256x64_8x4overlay_uint16.json"
    elif test_id == 2:
        model_data_name = model_data_1024
        tiler_output_name = "qkt_sm_smxv_1x1024x64x1024x64_8x4overlay_uint16.json"
    elif test_id == 3:
        model_data_name = model_data_256
        tiler_output_name = "qkt_sm_smxv_1x256x64x256x64_4x4overlay_uint16.json"
    elif test_id == 4:
        model_data_name = model_data_1024
        tiler_output_name = "qkt_sm_smxv_1x1024x64x1024x64_4x4overlay_uint16.json"
    else:
        assert False, f"Invalid test id: {test_id}"

    op_release_path = (
        None
        if (back_end == BackEnd.Adf)
        else "./psr_mha_3p0_" + str(test_id) + "_" + cur_version + "/"
    )

    curr_dir = os.path.dirname(__file__)
    tiler_output_path = os.path.join(curr_dir, tiler_output_name)
    model_data_path = os.path.join(curr_dir, model_data_name)

    mha_3p0_build_qdq(
        tiler_output_path,
        back_end,
        kernel_names,
        kernel_includes,
        model_data_path,
        op_release_path,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test script for the mha 3p0 operator")

    parser.add_argument(
        "--id", required=True, type=int, help="Id of the test that we want to run"
    )
    parser.add_argument(
        "--mode",
        required=True,
        choices=["hw", "sim"],
        help="Run the test in hardware or simulation mode",
    )

    args = parser.parse_args()
    test_psr_mha_3p0(args.id, args.mode)
