"""Run XRT host tests on AIE4 hardware using dolphin test framework."""
import os
import re
import json
import argparse
import subprocess
from subprocess import TimeoutExpired
import time
from pathlib import Path
from typing import Tuple

from prettytable import PrettyTable
from openpyxl import Workbook
from ml_timeline_parser import generate_ml_timeline_csv, get_uc_cycles_from_csv

VALID_DTYPES = {"int16", "uint16", "bf16", "int8", "uint8", "fp32", "int4"}

DEVICE_INSTANCE_ID = r"PCI\VEN_1022&DEV_17F1&SUBSYS_17F11022&REV_10\4&212933EF&0&0142"

# Retry configuration
DEFAULT_TIMEOUT_SECONDS = 10
DEFAULT_MAX_RETRIES = 3


# ML Timeline
def has_ml_timeline_jsons(bins_path: str | Path) -> bool:
    """
    Returns True if both ml_timeline_metadata.json and record_timer_ts.json
    exist inside bins_path, else False.
    """
    bins_path = Path(bins_path)
    return (bins_path / "ml_timeline_metadata.json").exists() and \
           (bins_path / "record_timer_ts.json").exists()


def generate_ml_timeline_profile(bins_path: str | Path) -> Tuple[int, int, int]:
    """
    If ml_timeline_metadata.json and record_timer_ts.json exist in bins_path,
    generate ml_timeline_profile.csv in the same directory.

    On any error, print the error and continue (no exception raised).
    """
    bins_path = Path(bins_path)

    metadata_path = bins_path / "ml_timeline_metadata.json"
    record_timer_path = bins_path / "record_timer_ts.json"
    outdir = bins_path

    if not metadata_path.exists():
        return None, None, None
    if not record_timer_path.exists():
        return None, None, None

    try:
        csv_path = generate_ml_timeline_csv(
            metadata_path=metadata_path,
            record_timer_path=record_timer_path,
            outdir=outdir,
        )
        uc0, uc2, uc4 = get_uc_cycles_from_csv(csv_path)
        print(f"[INFO] ML timeline profile CSV generated: {csv_path}")
        return uc0, uc2, uc4
    except Exception as e:  # pylint: disable=W0718
        print(f"[WARN] Failed to generate ML timeline profile for {bins_path}: {e}")
        return None, None, None

# ---------------------- Helpers ----------------------


def capture_value_after_keywords(file_path, keywords):
    """Capture values after specified keywords in a file."""
    values = {}
    with open(file_path, "r", encoding="utf-8", errors="ignore") as file:
        content = file.read()
    for keyword in keywords:
        if keyword == "TEST PASSED!":
            values[keyword] = "Pass" if keyword in content else "Fail"
        else:
            pattern = re.compile(rf"\b{re.escape(keyword)}(\S+)")
            match = pattern.search(content)
            values[keyword] = match.group(1) if match else None
    return values


def _ensure_list(x):
    if x is None:
        return []
    if isinstance(x, list):
        return x
    return [x]


def load_instance_kernel_map(instance_json_path: Path) -> dict[str, str]:
    """
    Reads instance_ids_list.json (as produced earlier) and returns
    { folder_name: kernel_name } mapping.
    - Accepts "instance_id": "foo" or "instance_id": ["foo", ...]
    - Stores the first non-empty id string; if none present, maps to ""
    """
    mapping = {}
    if not instance_json_path.is_file():
        print(f"Warning: instance file not found: {instance_json_path}")
        return mapping
    try:
        with open(instance_json_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        if isinstance(data, list):
            for item in data:
                folder = item.get("folder")
                ids = item.get("instance_id") or item.get("ids") or []
                ids = _ensure_list(ids)
                kernel = ""
                for k in ids:
                    if isinstance(k, str) and k.strip():
                        kernel = k.strip()
                        break
                if isinstance(folder, str) and folder:
                    mapping[folder] = kernel
    except Exception as e:  # pylint: disable=W0718
        print(f"Warning: failed to read {instance_json_path}: {e}")
    return mapping


def build_host_cmd(
    xrt_flow_exe: str,
    elf_file: Path,
    kernel_name: str,
    path_bin: Path,
    perf_testing: bool,
    out_compare_dtype: str,
    debug_flag: bool,
) -> list[str]:
    """Build command list for xrt_flow.exe execution."""
    n_thread = "1"
    n_iter = "1000" if perf_testing else "1"
    dtype = out_compare_dtype if out_compare_dtype in VALID_DTYPES else "int16"
    debug = "1" if debug_flag else "0"
    return [
        xrt_flow_exe,
        str(elf_file),
        kernel_name,
        str(path_bin),
        n_thread,
        n_iter,
        dtype,
        debug,
    ]


def pnp_toggle(instance_id, action):
    """
    Uses PowerShell PnP cmdlets:
      action = 'disable' or 'enable'
    """
    if action not in ("disable", "enable"):
        raise ValueError("action must be 'disable' or 'enable'")

    ps_action = "Disable-PnpDevice" if action == "disable" else "Enable-PnpDevice"

    cmd = [
        "powershell",
        "-Command",
        f"{ps_action} -InstanceId \"{instance_id}\" -Confirm:$false"
    ]

    print(f"[PNP] Running: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True, check=True)

    print("Output:", result.stdout)
    print("Errors:", result.stderr)
    return result.returncode == 0

# ---------------------- Main ----------------------


def main():
    """Main function to run XRT tests across all subfolders."""
    parser = argparse.ArgumentParser(
        description="Run xrt_flow.exe over all subfolders with required HW bins."
    )
    parser.add_argument(
        "--perf_testing", action="store_true",
        help="Use many iterations (1000) instead of 1."
    )
    parser.add_argument(
        "--dtype", type=str, default="int16",
        help=f"Output compare dtype (default: int16). One of {sorted(VALID_DTYPES)}."
    )
    parser.add_argument(
        "--print", action="store_true",
        help="Enable verbose debug flag to xrt_flow.exe (Debug_flag=1)."
    )
    parser.add_argument(
        "--xrt-flow-exe", type=str, default=r"C:\Users\Administrator\Downloads\NPU-Drivers\xrt\xrt\xrt_flow.exe",
        help="Name or path to xrt_flow executable (must be on PATH or provide full path)."
    )
    parser.add_argument(
        "--instance-json", type=str, default="instance_ids_list.json",
        help="Path to instance_ids_list.json to fetch kernel names."
    )
    parser.add_argument(
        "--timeout", type=int, default=DEFAULT_TIMEOUT_SECONDS,
        help=f"Timeout in seconds per test attempt (default: {DEFAULT_TIMEOUT_SECONDS})."
    )
    parser.add_argument(
        "--max-retries", type=int, default=DEFAULT_MAX_RETRIES,
        help=f"Max retry attempts on timeout (default: {DEFAULT_MAX_RETRIES})."
    )

    # ---------- Mode selection (op vs sg) ----------
    mode_group = parser.add_mutually_exclusive_group()
    mode_group.add_argument(
        "--op", "-operator", dest="mode", action="store_const", const="op",
        help="Operator mode (default)."
    )
    mode_group.add_argument(
        "--sg", "-subgraph", dest="mode", action="store_const", const="sg",
        help="Subgraph mode. Expects bins directly in each subfolder."
    )
    parser.set_defaults(mode="op")

    args = parser.parse_args()

    # Normalize mode
    mode = args.mode.lower()
    if mode not in {"op", "sg"}:
        print(f"Invalid mode '{mode}', falling back to 'op'.")
        mode = "op"

    script_dir = Path(__file__).resolve().parent
    os.chdir(script_dir)
    current_dir = Path.cwd()

    # Load folder -> kernel name mapping
    if mode == "sg":
        kernel_map = load_instance_kernel_map(Path(args.instance_json))
    else:
        kernel_map = {}

    # Required files per subfolder (or per hw_package)
    required_files = [
        "control.elf",
        "ifm.bin",
        "ofm.bin",
        "param.bin",
        "wgt.bin",
    ]

    # Gather subfolders (each is an op or subgraph folder)
    folder_names = [name for name in os.listdir(current_dir) if Path(name).is_dir()]

    # Table & outputs
    shapeTable = PrettyTable(
        [
            "Shape",
            "Maximum Error",
            "Maximum Error %",
            "L2 norm",
            "L2 norm p/element",
            "RMS error",
            "RMA error",
            "Average Relative Error %",
            "XRT time(us)",
            "E2E Cycle uC_0",
            "E2E Cycle uC_2",
            "E2E Cycle uC_4",
            "Pass/Fail",
        ]
    )
    results_list = []
    output_xlsx = "output.xlsx"
    output_json = "output.json"

    # Keywords to parse from log.txt
    keywords = [
        "Maximum Error = ",
        "Maximum Error Percentage = ",
        "L2 norm of Error = ",
        "L2 norm per element = ",
        "Root Mean square Error = ",
        "Root Mean Absolute Error = ",
        "Average Relative Error Percentage = ",
        "iterations = ",
        "TEST PASSED!",
    ]

    # --- Always reset the NPU PCI device before XRT setup ---
    print(f"[CONFIG] timeout={args.timeout}s, max_retries={args.max_retries}")
    print("[PNP] Pre-test reset: disabling device...")
    pnp_toggle(DEVICE_INSTANCE_ID, "disable")
    time.sleep(1)

    print("[PNP] Pre-test reset: enabling device...")
    pnp_toggle(DEVICE_INSTANCE_ID, "enable")
    time.sleep(2)

    HUNG_TESTS = {}
    for folder_name in folder_names:
        folder_path = current_dir / folder_name

        # Decide where the HW bins actually live
        if mode == "op":
            # Operator mode: use <op>/hw_package as the work dir
            bins_path = folder_path / "hw_package"
        else:
            # Subgraph mode: bins directly under the folder
            bins_path = folder_path

        if not bins_path.is_dir():
            print(f"Skipping '{folder_name}': bins folder not found at {bins_path}")
            continue

        # Required inputs present?
        missing = []
        for req in required_files:
            rp = bins_path / req
            if not rp.exists():
                missing.append(req)
        if missing:
            print(f"Error: Missing in '{folder_name}' (mode={mode}): {', '.join(missing)}")
            continue

        # Resolve kernel name (still keyed by top-level folder name)
        kernel_name = kernel_map.get(folder_name, "")
        if mode == "sg":
            kernel_name = kernel_map.get(folder_name, "")
            if not kernel_name:
                print(
                    f"Warning: No kernel name found for '{folder_name}' in "
                    f"{args.instance_json}. Using folder name as kernel."
                )
                kernel_name = folder_name  # fallback if needed
        else:
            # op mode: no instance_ids_list.json, always use fixed kernel name
            kernel_name = "aie4_models"

        # Resolve control.elf and path_bin (inside bins_path)
        elf_file = bins_path / "control.elf"
        if not elf_file.is_file():
            print(f"Error: control.elf not found in '{bins_path}'. Skipping.")
            continue
        path_bin = bins_path

        # Build host command
        dtype = args.dtype if (args.dtype in VALID_DTYPES) else "int16"
        cmd_list = build_host_cmd(
            xrt_flow_exe=args.xrt_flow_exe,
            elf_file=elf_file,
            kernel_name=kernel_name,
            path_bin=path_bin,
            perf_testing=args.perf_testing,
            out_compare_dtype=dtype,
            debug_flag=args.__dict__.get("print", False),
        )

        # Execute with retry logic on timeout
        print(f"[RUN] {folder_name} (mode={mode}): {' '.join(map(str, cmd_list))}")
        log_path = folder_path / "AIE_HW_Debug.txt"
        test_passed = False
        for attempt in range(1, args.max_retries + 1):
            try:
                with open(log_path, "a", encoding="utf-8") as lf:
                    if attempt > 1:
                        lf.write(f"\n--- RETRY ATTEMPT {attempt} ---\n")
                    lf.write("CMD: " + " ".join(map(str, cmd_list)) + "\n")
                    subprocess.run(
                        cmd_list,
                        cwd=str(bins_path),
                        stdout=lf,
                        stderr=lf,
                        check=False,
                        timeout=args.timeout,
                    )
                test_passed = True
                break
            except TimeoutExpired:
                print(f"[TIMEOUT] {folder_name}: attempt {attempt}/{args.max_retries} exceeded {args.timeout}s")
                # Reset device after timeout
                pnp_toggle(DEVICE_INSTANCE_ID, "disable")
                time.sleep(1)
                pnp_toggle(DEVICE_INSTANCE_ID, "enable")
                time.sleep(2)
                if attempt < args.max_retries:
                    print(f"[RETRY] {folder_name}: retrying...")

        if not test_passed:
            print(f"[FAILED] {folder_name}: all {args.max_retries} attempts timed out")
            # Record timeout in table and results
            shapeTable.add_row(
                [
                    folder_name,
                    "-",        # Maximum Error
                    "-",        # Maximum Error %
                    "-",        # L2 norm
                    "-",        # L2 norm p/element
                    "-",        # RMS error
                    "-",        # RMA error
                    "-",        # Avg Rel Error %
                    "-",        # XRT time(us)
                    "-",        # E2E Cycle uC_0
                    "-",        # E2E Cycle uC_2
                    "-",        # E2E Cycle uC_4
                    "TIMEOUT",  # Pass/Fail
                ]
            )

            results_list.append(
                {
                    "Shape": folder_name,
                    "Maximum Error": None,
                    "Maximum Error Percentage": None,
                    "L2 norm": None,
                    "L2 norm per element": None,
                    "Root Mean square": None,
                    "Root Mean Absolute error": None,
                    "Average Relative Error Percentage": None,
                    "iterations time(us)": None,
                    "E2E Cycle uC_0": None,
                    "E2E Cycle uC_2": None,
                    "E2E Cycle uC_4": None,
                    "Pass or Fail": "TIMEOUT",
                    "dtype_used": dtype,
                    "iterations_used": 1000 if args.perf_testing else 1,
                    "debug_flag": 1 if args.__dict__.get("print", False) else 0,
                    "kernel_name": kernel_name,
                    "elf": str(elf_file),
                    "mode": mode,
                    "path_bin": str(path_bin),
                }
            )

            HUNG_TESTS[folder_name] = {
                "reason": "timeout",
                "dtype_used": dtype,
                "iterations_used": 1000 if args.perf_testing else 1,
                "kernel_name": kernel_name,
                "elf": str(elf_file),
                "mode": mode,
                "path_bin": str(path_bin),
            }
            continue

        # Generate ML Timline CSV
        uc0_e2e = uc2_e2e = uc4_e2e = None
        if has_ml_timeline_jsons(bins_path):
            try:
                uc0_e2e, uc2_e2e, uc4_e2e = generate_ml_timeline_profile(bins_path)
            except Exception as e:  # pylint: disable=W0718
                print(f"[WARN] ML timeline profiling failed for '{folder_name}': {e}")

        # Parse results
        if not log_path.exists():
            print(f"Warning: log.txt missing in {folder_name}")
            continue

        captured = capture_value_after_keywords(str(log_path), keywords)
        if not captured:
            print(f"Warning: no parsed values for '{folder_name}'.")
            continue

        # Append table row
        shapeTable.add_row(
            [
                folder_name,
                captured[keywords[0]],
                captured[keywords[1]],
                captured[keywords[2]],
                captured[keywords[3]],
                captured[keywords[4]],
                captured[keywords[5]],
                captured[keywords[6]],
                captured[keywords[7]],
                uc0_e2e,
                uc2_e2e,
                uc4_e2e,
                captured[keywords[8]],
            ]
        )

        # Save JSON result entry
        results_list.append(
            {
                "Shape": folder_name,
                "Maximum Error": captured[keywords[0]],
                "Maximum Error Percentage": captured[keywords[1]],
                "L2 norm": captured[keywords[2]],
                "L2 norm per element": captured[keywords[3]],
                "Root Mean square": captured[keywords[4]],
                "Root Mean Absolute error": captured[keywords[5]],
                "Average Relative Error Percentage": captured[keywords[6]],
                "iterations time(us)": captured[keywords[7]],
                "E2E Cycle uC_0": uc0_e2e,
                "E2E Cycle uC_2": uc2_e2e,
                "E2E Cycle uC_4": uc4_e2e,
                "Pass or Fail": captured[keywords[8]],
                "dtype_used": dtype,
                "iterations_used": 1000 if args.perf_testing else 1,
                "debug_flag": 1 if args.__dict__.get("print", False) else 0,
                "kernel_name": kernel_name,
                "elf": str(elf_file),
                "mode": mode,
                "path_bin": str(path_bin),
            }
        )

    # Print table
    print(shapeTable)
    if not results_list:
        print("Error: Result list is empty")
        return 1
    print("Tasks completed successfully!")

    # ---- Build out_summary.json from the printed table/results_list ----
    # Status codes:
    # 1 = Pass
    # 2 = Fail (metrics present)
    # 3 = Fail and all other metric values are None
    summary = {}

    metric_keys = [
        "Maximum Error",
        "Maximum Error Percentage",
        "L2 norm",
        "L2 norm per element",
        "Root Mean square",
        "Root Mean Absolute error",
        "Average Relative Error Percentage",
        "iterations time(us)",
    ]

    for row in results_list:
        shape = row["Shape"]
        passfail = (row["Pass or Fail"] or "").strip()
        metrics = [row.get(k) for k in metric_keys]

        if passfail == "Pass":
            code = 1
        elif passfail == "TIMEOUT":
            code = 4  # special status for hung tests
        else:
            all_none = all(
                (m is None) or (isinstance(m, str) and m.strip() == "")
                for m in metrics
            )
            code = 3 if all_none else 2

        summary[shape] = code

    with open("out_summary.json", "w", encoding="utf-8") as sf:
        json.dump(summary, sf, indent=2, ensure_ascii=False)

    print("Summary saved to out_summary.json")
    # Save Excel
    data_rows = shapeTable.get_string().split("\n")[2:-1]
    data_rows = [row.split("|")[1:-1] for row in data_rows]
    data_rows = [[item.strip() for item in row] for row in data_rows]

    wb = Workbook()
    ws = wb.active
    ws.title = "shape"
    ws.append(shapeTable.field_names)
    for row in data_rows:
        ws.append(row)
    wb.save(output_xlsx)

    # Save JSON
    with open(output_json, "w", encoding="utf-8") as jf:
        json.dump(results_list, jf, indent=4, ensure_ascii=False)

    print(f"Results saved to {output_xlsx} and {output_json}.")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
