import os
import shutil
import re
from prettytable import PrettyTable
import json
import argparse
import glob
from openpyxl import Workbook
import subprocess


VALID_DTYPES = {"int16", "uint16", "bf16", "int8", "uint8"}


def pick_threshold(
    thresholds_by_dtype: dict, dtype: str, qdq_mode: int | None, default_tol: int
) -> int:
    if qdq_mode is None:
        return default_tol
    table = thresholds_by_dtype.get(dtype)
    if not isinstance(table, dict):
        return default_tol
    return int(table.get(f"qdq_mode_{qdq_mode}", default_tol))


def load_qdq_mode_config(path: str) -> tuple[dict, dict]:
    """
    Returns (folder_map, thresholds_by_dtype)

    Explicit folders + thresholds (recommended):
    {
      "folders": {
        "CONV_1x64x64x320": { "dtype": "uint16", "qdq_mode": 2 },
        "MHA_1x8x64x64":    { "dtype": "uint8",  "qdq_mode": 3 }
      },
      "thresholds": {
        "uint16": { "qdq_mode_0": 1, "qdq_mode_1": 1, "qdq_mode_2": 256, "qdq_mode_3": 0 },
        "uint8":  { "qdq_mode_0": 1, "qdq_mode_1": 1, "qdq_mode_2": 1,   "qdq_mode_3": 0 }
      }
    }
    """
    folder_map, thresholds = {}, {}
    try:
        with open(path, "r") as f:
            cfg = json.load(f)
    except Exception as e:
        print(f"Warning: failed to load --qdq_mode JSON at {path}: {e}")
        return folder_map, thresholds

    # Schema A
    if isinstance(cfg, dict) and ("folders" in cfg or "thresholds" in cfg):
        folders = cfg.get("folders", {})
        if isinstance(folders, dict):
            for k, v in folders.items():
                if isinstance(v, dict) and "dtype" in v and "qdq_mode" in v:
                    folder_map[k] = {
                        "dtype": v["dtype"],
                        "qdq_mode": int(v["qdq_mode"]),
                    }
        thr = cfg.get("thresholds", {})
        if isinstance(thr, dict):
            for dt, tbl in thr.items():
                if isinstance(tbl, dict):
                    thresholds[dt] = dict(tbl)
        return folder_map, thresholds

    return folder_map, thresholds


def load_dtype_map(path: str) -> dict:
    """
    Supports either:
      { "FolderA": "uint16", "FolderB": "int16" }
    or:
      [ {"folder": "FolderA", "dtype": "uint16"}, ... ]
    """
    try:
        with open(path, "r") as f:
            data = json.load(f)
        if isinstance(data, dict):
            return data
        if isinstance(data, list):
            out = {}
            for item in data:
                if isinstance(item, dict) and "folder" in item and "dtype" in item:
                    out[item["folder"]] = item["dtype"]
            return out
    except Exception as e:
        print(f"Warning: failed to load dtype map from {path}: {e}")
    return {}


def resolve_dtype(
    folder_name: str,
    cli_dtype: str | None,
    dtype_map: dict,
    qdq_folder_dtype: str | None,
) -> str:
    # priority: --dtype > qdq per-folder dtype > dtype-map > get_dtype()
    if cli_dtype:
        return cli_dtype
    if qdq_folder_dtype:
        return qdq_folder_dtype
    if folder_name in dtype_map:
        return dtype_map[folder_name]
    return get_dtype(folder_name)


def build_test_command(
    exe: str,
    xclbin: str,
    perf_testing: bool,
    profile_perf: bool,
    rel_err_pc: bool,
    dtype: str,
    threshold: int | None,
    print_results: bool,
) -> list[str]:
    """
    xrt_flow.exe <xclbin> [num_threads] [num_runs] [out_compare_dtype] [Debug_flag] [error_threshold?|"mha"?]

    - Always pass: xclbin, num_threads=1, num_runs, dtype, Debug_flag=0
    - If rel_err_pc=True → append "mha" (and ignore threshold)
    - Else, if threshold is not None → append str(threshold)
    - Else → append nothing
    """
    iterations = "1000" if (perf_testing or profile_perf) else "1"
    print_results = "1" if print_results else "0"
    if rel_err_pc:
        dtype = "uint16"
    args = [exe, xclbin, "1", iterations, dtype, print_results]

    if threshold is not None:
        args.append(str(threshold))
    elif rel_err_pc:
        args.append("mha")

    return args


def get_error_tolerance(dir_name) -> int:
    # print(f"Check error tolerance {dir_name}")
    op_name = dir_name.split("_")[0]
    match op_name:
        case "MHA":
            return 1000
        case "Add":
            return 500
        case "LayerNormalization":
            return 256
        case "GroupNormalization":
            return 256
        case "Concat":
            return 500
        case _:
            return 20


def get_dtype(dir_name) -> str:
    op_name = dir_name.split("_")[0]
    match op_name:
        case "LayerNormalization":
            return "uint16"
        case "GroupNormalization":
            return "uint16"
        case "Concat":
            return "uint16"
        case _:
            return "int16"


def capture_value_after_keywords(file_path, keywords):
    values = {}
    # Step 1: Open the file and read its content
    with open(file_path, "r") as file:
        content = file.read()
    # Step 2: Define keywords or extract them based on frequency or other criteria
    # Step 3: Search for each keyword and capture the value right after it
    for keyword in keywords:
        if keyword == "TEST PASSED!":
            if keyword in content:
                values[keyword] = "Pass"
            else:
                values[keyword] = "Fail"
        else:
            pattern = re.compile(rf"\b{keyword}\b(\S+)")
            match = pattern.search(content)

            if match:
                # Capture the value right after the keyword
                values[keyword] = match.group(1)
            else:
                values[keyword] = None  # If the keyword is not found

    return values


def main(
    perf_testing: bool,
    profile_perf: bool,
    rel_err_pc: bool,
    cli_dtype: str | None,
    dtype_map_path: str | None,
    qdq_mode_path: str | None,
    print_results: bool,
) -> int:
    # Executable & assets
    executable_name = "xrt_flow_test_patch_datatype_debug.exe"
    xclbin_name = "out.xclbin"

    output_file = "output.xlsx"
    output_json_file = "output.json"
    keywords = [
        "Maximum Error = ",
        "Maximum Error Percentage = ",
        "L2 norm of Error = ",
        "L2 norm per element = ",
        "Root Mean square Error = ",
        "Root Mean Absolute Error = ",
        "Average Relative Error Percentage = ",
        "iterations = ",
        "TEST PASSED!",
    ]
    required_files = [
        "ifm.bin",
        "ofm.bin",
        "param.bin",
        "txn.bin",
        "wgt.bin",
        "ctrl.bin",
        "patch.json",
    ]

    # Paths & folders
    script_dir = os.path.dirname(os.path.abspath(__file__))
    os.chdir(script_dir)
    current_dir = os.getcwd()
    folder_names = [name for name in os.listdir(script_dir) if os.path.isdir(name)]
    skip_folders = ["cut_graphs", "DataGen"]
    folder_names = [name for name in folder_names if name not in skip_folders]

    # dtype map (optional)
    dtype_map = {}
    if dtype_map_path:
        try:
            dtype_map = load_dtype_map(dtype_map_path)
        except Exception as e:
            print(f"Warning: failed to load dtype-map '{dtype_map_path}': {e}")

    # qdq_mode config (optional)
    folder_qdq_map, thresholds_by_dtype = ({}, {})
    global_qdq_mode = None
    if qdq_mode_path:
        try:
            folder_qdq_map, thresholds_by_dtype = load_qdq_mode_config(qdq_mode_path)
        except Exception as e:
            print(f"Warning: failed to load --qdq_mode '{qdq_mode_path}': {e}")

    # pm_id list (optional)
    pm_id_map = {}
    pm_id_list_path = os.path.join(current_dir, "pm_id_list.json")
    # folder_names = [name for name in os.listdir(script_dir) if os.path.isdir(name)]
    if os.path.isfile(pm_id_list_path):
        # Load from JSON file
        with open(pm_id_list_path, "r") as f:
            for entry in json.load(f):
                pm_id_map[entry["folder"]] = entry["pm_id"]
    else:
        # Default mode: scan provided folders for pm_*.bin and txn_pm_*.bin
        found_any = False
        for folder in folder_names:
            pm_bin_pattern = os.path.join(script_dir, "pm_*.bin")
            txn_pm_bin_pattern = os.path.join(script_dir, "txn_pm_*.bin")

            pm_bin_files = glob.glob(pm_bin_pattern)
            txn_pm_bin_files = glob.glob(txn_pm_bin_pattern)

            if pm_bin_files and txn_pm_bin_files:
                pm_id_map[folder] = 0  # Assign default pm_id since ID is unknown
                found_any = True
        if not found_any:
            print(
                "Warning: No pm_id_list.json and no pm_*.bin files found in any folder. pm_id_map will be empty."
            )

    # Distribute executables and clear logs
    for folder_name in folder_names:
        folder_path = os.path.join(current_dir, folder_name)
        try:
            shutil.copy(executable_name, folder_path)
            shutil.copy(xclbin_name, folder_path)
            if profile_perf and os.path.exists(os.path.join(current_dir, "xrt.ini")):
                shutil.copy("xrt.ini", folder_path)
            log_path = os.path.join(folder_path, "log.txt")
            if os.path.exists(log_path):
                os.remove(log_path)
        except Exception as e:
            print(f"Warning: pre-copy failed for '{folder_name}': {e}")

    # Prepare table/results
    shapeTable = PrettyTable(
        [
            "Shape",
            "Maximum Error",
            "Maximum Error %",
            "L2 norm",
            "L2 norm p/element",
            "RMS error",
            "RMA error",
            "Average Relative Error %",
            "XRT time(us)",
            "Pass/Fail",
        ]
    )
    results_list = []

    # Iterate folders
    for folder_name in folder_names:
        folder_path = os.path.join(current_dir, folder_name)

        # Required inputs present?
        missing = [
            f
            for f in required_files
            if not os.path.exists(os.path.join(folder_path, f))
        ]
        if missing:
            print(f"Error: Missing in '{folder_name}': {', '.join(missing)}")
            return

        # ---------- Resolve dtype & threshold ----------
        # base fallback
        base_tol = get_error_tolerance(folder_name)

        # qdq overrides ONLY if --qdq_mode is given
        qdq_folder_dtype = None
        qdq_mode_val = None
        if qdq_mode_path:
            entry = folder_qdq_map.get(folder_name)
            if isinstance(entry, dict):
                qdq_folder_dtype = entry.get("dtype")
                qdq_mode_val = entry.get("qdq_mode")
            if qdq_mode_val is None:
                qdq_mode_val = global_qdq_mode  # optional global default

        # dtype priority: --dtype > qdq per-folder > dtype-map > get_dtype()
        dtype_used = resolve_dtype(folder_name, cli_dtype, dtype_map, qdq_folder_dtype)

        if dtype_used not in VALID_DTYPES:
            print(
                f"Warning: dtype '{dtype_used}' for '{folder_name}' not in {VALID_DTYPES}. Using get_dtype()."
            )
            dtype_used = get_dtype(folder_name)

        # threshold selection:
        if qdq_mode_path:
            tol_val = pick_threshold(
                thresholds_by_dtype, dtype_used, qdq_mode_val, base_tol
            )
        else:
            tol_val = base_tol

        # Build command (last arg optional)
        cmd_list = build_test_command(
            exe="xrt_flow_test_patch_datatype_debug",
            xclbin=xclbin_name,
            perf_testing=perf_testing,
            profile_perf=profile_perf,
            rel_err_pc=rel_err_pc,
            dtype=dtype_used,
            threshold=None if rel_err_pc else tol_val,
            print_results=print_results,
        )

        # Attach -id if present
        if folder_name in pm_id_map:
            pm_id = pm_id_map[folder_name]
            cmd_list += ["-id", str(pm_id_map[folder_name])]
            # Copy pm_id-based files
            pm_file = f"pm_{pm_id}.bin"
            txn_file = f"txn_pm_{pm_id}.bin"
            for f in [pm_file, txn_file]:
                src = os.path.join(script_dir, f)
                dst = os.path.join(folder_path, f)
                if os.path.exists(src):
                    shutil.copy(src, dst)
                    print(f"Copied {f} to {folder_name}")
                else:
                    print(f"Missing file: {f} — skipping")

        print(f"Running in {folder_name}: {' '.join(cmd_list)}")
        if profile_perf:
            os.environ["ENABLE_PROFILE"] = "1"

        # Run inside folder; append to log.txt
        os.chdir(folder_path)
        with open("log.txt", "a") as lf:
            lf.write("CMD: " + " ".join(cmd_list) + "\n")
            subprocess.run(cmd_list, stdout=lf, stderr=lf, check=False)
        os.chdir(current_dir)

        # Parse results
        file_path = os.path.join(folder_name, "log.txt")
        if not os.path.exists(file_path):
            print(f"log.txt missing in {folder_name}")
        captured = capture_value_after_keywords(file_path, keywords)
        if not captured:
            print(
                f"Warning: No captured values for folder '{folder_name}'. Skipping row."
            )
            continue

        shapeTable.add_row(
            [
                folder_name,
                captured[keywords[0]],
                captured[keywords[1]],
                captured[keywords[2]],
                captured[keywords[3]],
                captured[keywords[4]],
                captured[keywords[5]],
                captured[keywords[6]],
                captured[keywords[7]],
                captured[keywords[8]],
            ]
        )

        results_list.append(
            {
                "Shape": folder_name,
                "Maximum Error": captured[keywords[0]],
                "Maximum Error Percentage": captured[keywords[1]],
                "L2 norm": captured[keywords[2]],
                "L2 norm per element": captured[keywords[3]],
                "Root Mean squar": captured[keywords[4]],
                "iterations time(us)": captured[keywords[7]],
                "Root Mean Absolute error": captured[keywords[5]],
                "Average Relative Error Percentage ": captured[keywords[6]],
                "Pass or Fail": captured[keywords[8]],
                "dtype_used": dtype_used,
                "threshold_used": None if rel_err_pc else tol_val,
                "qdq_mode_used": qdq_mode_val if qdq_mode_path else None,
            }
        )

        # Optional profiling artifact move (unchanged)
        if profile_perf:
            original = os.path.join(folder_name, "record_timer_ts.json")
            if os.path.exists(original):
                new_name = f"record_timer_ts_{folder_name}.json"
                new_path = os.path.join(folder_name, new_name)
                try:
                    os.rename(original, new_path)
                    dest = os.path.abspath(os.path.join(folder_name, "..", new_name))
                    shutil.copy(new_path, dest)
                    print(f"Renamed and copied to: {dest}")
                except Exception as e:
                    print(f"Warning moving profile json in '{folder_name}': {e}")

    # Print table
    print(shapeTable)
    if not results_list:
        print(f"Error: Result list is empty")
        return 1
    print("Tasks completed successfully!")

    # Save Excel
    data_rows = shapeTable.get_string().split("\n")[2:-1]
    data_rows = [row.split("|")[1:-1] for row in data_rows]
    data_rows = [[item.strip() for item in row] for row in data_rows]

    wb = Workbook()
    ws = wb.active
    ws.title = "shape"
    ws.append(shapeTable.field_names)
    for row in data_rows:
        ws.append(row)
    wb.save(output_file)

    # Save JSON
    with open(output_json_file, "w") as jf:
        json.dump(results_list, jf, indent=4)

    print(f"Results saved to {output_file} and {output_json_file}.")
    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run xrt_flow_test_patch_datatype_debug over subfolders."
    )
    parser.add_argument(
        "--perf_testing", action="store_true", help="Use many iterations (1000)."
    )
    parser.add_argument(
        "--profile_perf", action="store_true", help="Enable timer/profiling mode."
    )
    parser.add_argument(
        "--rel_err_pc",
        action="store_true",
        help="Use 'mha' tag as last arg; do not pass threshold.",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default=None,
        help="Global dtype override (e.g., int16, uint16, fp16, uint8).",
    )
    parser.add_argument(
        "--dtype-map", type=str, default=None, help="JSON mapping of folder->dtype."
    )
    parser.add_argument(
        "--qdq_mode",
        type=str,
        default=None,
        help="JSON with per-folder dtype/qdq_mode and per-dtype threshold tables (see header comment).",
    )
    parser.add_argument(
        "--print", action="store_true", help="Enable printing of detailed results."
    )
    args = parser.parse_args()

    main(
        args.perf_testing,
        args.profile_perf,
        args.rel_err_pc,
        args.dtype,
        args.dtype_map,
        args.qdq_mode,
        args.print,
    )
