"""AIE4 hardware test script for running tests on remote Windows machines."""

import os
import platform
import shutil
import subprocess
import sys
import tarfile
import tempfile
import fnmatch
from datetime import datetime
from pathlib import Path
import json

import paramiko
from dotenv import load_dotenv
from scp import SCPClient
from utils.build_utils import get_ml_timeline_log_level


def _is_windows_local():
    # This script itself is running on a Windows machine (the DUT)
    return platform.system().lower().startswith("win")


load_dotenv("/actions-runners/.env")

if platform.system().lower().startswith("win"):
    import importlib

    def ensure_pkg(package):
        """Ensure a Python package is installed, install if missing."""
        try:
            importlib.import_module(package)
        except ImportError:
            print(f"[SETUP] Installing missing package: {package}")
            subprocess.run(
                [sys.executable, "-m", "pip", "install", package, "--quiet"],
                check=False,
            )

    for pkg in ["paramiko", "openpyxl", "prettytable", "pandas"]:
        ensure_pkg(pkg)

VALID_DTYPES = {"int16", "uint16", "bf16", "int8", "uint8", "fp32", "int4"}


def copy_file_with_smbclient(src, dst, host, username, password):
    """Copy a file to remote Windows machine using smbclient."""
    try:
        if not os.path.exists(src):
            print(f"Source path '{src}' does not exist.")
            return
        dst_path = os.path.join(dst, os.path.basename(src))
        smbclient_command = (
            f"smbclient //{host}/Users -U {username}%{password} "
            f'-c \'put "{src}" "{dst_path}"; exit\''
        )
        subprocess.run(smbclient_command, shell=True, check=True)
        print(f"Copied '{src}' to '{dst}'")
    except Exception as e:  # pylint: disable=W0718
        print(f"An error occurred: {e}")


def create_dir_with_smbclient(dst, host, username, password):
    """Create a directory on remote Windows machine using smbclient."""
    try:
        smbclient_command = (
            f"smbclient //{host}/Users -U {username}%{password} "
            f"-c 'mkdir \"{dst}\"; exit'"
        )
        subprocess.run(smbclient_command, shell=True, check=True)
        print(f"Created directory '{dst}'")
    except Exception as e:  # pylint: disable=W0718
        print(f"An error occurred: {e}")


def create_ssh_client(hostname, port, username, password):
    """Create and return an SSH client connected to the remote host."""
    client = paramiko.SSHClient()
    client.load_system_host_keys()
    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    client.connect(hostname, port=port, username=username, password=password)
    return client


def copy_file_from_remote(hostname, port, username, password, remote_path, local_path):
    """Copy a file from remote Windows machine to local using SCP."""
    try:
        ssh_client = create_ssh_client(hostname, port, username, password)
        with SCPClient(ssh_client.get_transport()) as scp:
            scp.get(remote_path, local_path)
        print(f"Copied '{remote_path}' to '{local_path}'")
    except Exception as e:  # pylint: disable=W0718
        print(f"An error occurred: {e}")


def run_remote_script(
    host,
    port,
    username,
    password,
    remote_script_path,
    perf_testing_flag,
    dtype_flag,
    mode_flag,
    debug_flag=False,
):
    """Execute dolphin_test_aie4.py on remote Windows machine via SSH."""
    client = paramiko.SSHClient()
    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    try:
        client.connect(hostname=host, port=port, username=username, password=password)
        command = f"python {remote_script_path}"
        if perf_testing_flag:
            command += " --perf_testing"
        if dtype_flag:
            command += f" --dtype {dtype_flag}"
        if mode_flag == "op":
            command += " --op"
        elif mode_flag == "sg":
            command += " --sg"
        if debug_flag:
            command += " --print"
        _stdin, stdout, stderr = client.exec_command(command)
        stdout.channel.recv_exit_status()
        print("Output:", stdout.read().decode())
        print("Errors:", stderr.read().decode())
    finally:
        client.close()


def run_remote_script_golden_io(
    host, port, username, password, remote_script_path, subfolder, mode="copy"
):
    """Execute copy_golden_io.py on remote Windows machine via SSH."""
    client = paramiko.SSHClient()
    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    try:
        client.connect(hostname=host, port=port, username=username, password=password)
        command = f"python {remote_script_path} {subfolder} {mode}"
        _stdin, stdout, stderr = client.exec_command(command)
        stdout.channel.recv_exit_status()
        print("Output:", stdout.read().decode())
        print("Errors:", stderr.read().decode())
    finally:
        client.close()


def generate_output_path(name, output_dir, timestamp):
    """Generate output file path with timestamp."""
    return f"{output_dir}/{name}_{timestamp}.json"


def copy_and_run_script(src_dir, script_name, des_dir):
    """Copy a Python script to destination directory and run it."""
    src_script_path = os.path.join(src_dir, script_name)
    shutil.copy(src_script_path, des_dir)
    des_script_path = os.path.join(des_dir, script_name)
    subprocess.run(["python", des_script_path], cwd=des_dir, check=True)


def create_tarball(base_dir, file_list, tar_name):
    """
    Create tar.gz archive(s) from file list, splitting if size exceeds 1GB.

    Args:
        base_dir: root directory for relative path calculation
        file_list: list of [folder_path, relative_file_path] pairs
        tar_name: base name for tar.gz archive

    Returns:
        list of created tar.gz file paths
    """
    max_size_bytes = 1024 * 1024 * 1024  # 1GB per tar file
    archive_num = 0
    current_tar = None
    current_size = 0

    total_size = 0
    for f0, f1 in file_list:
        fp = os.path.join(f0, f1)
        if os.path.isfile(fp):
            total_size += os.path.getsize(fp)

    print(f"Total size: {total_size // (1024 * 1024)} MB")

    completed_size = 0
    completion_percentage = 0.0
    tar_paths = []

    for f0, rel_path in file_list:
        full_path = os.path.join(f0, rel_path)

        if not os.path.isfile(full_path):
            print(f"Skipping missing file: {full_path}")
            continue

        item_size = os.path.getsize(full_path)

        if current_tar is None or (current_size + item_size) > max_size_bytes:
            if current_tar:
                current_tar.close()

            archive_name = os.path.join(
                tempfile.gettempdir(), f"{tar_name}_{archive_num}.tar.gz"
            )
            current_tar = tarfile.open(archive_name, "w:gz")  # pylint: disable=R1732
            tar_paths.append(archive_name)
            archive_num += 1
            current_size = 0

        try:
            candidate_arc = os.path.relpath(full_path, base_dir)
            if candidate_arc.startswith(".."):
                arcname = os.path.basename(full_path)
            else:
                arcname = candidate_arc
        except ValueError:
            arcname = os.path.basename(full_path)

        tarinfo = current_tar.gettarinfo(full_path, arcname=arcname)
        completed_size += tarinfo.size
        current_size += item_size

        with open(full_path, "rb") as fobj:
            current_tar.addfile(tarinfo, fobj)

        progress = completed_size / total_size if total_size else 1
        if progress > completion_percentage:
            completion_percentage += 0.1
            print(f"Progress: {progress:.0%}")

    if current_tar:
        current_tar.close()

    return tar_paths


def copy_and_untar_allfolders(
    base_dir,
    remote_base,
    host,
    username,
    password,
    port,
    file_list,
    basic_file_list,
    timestamp,
    mode_flag="op",
    filter_patterns=None,
):
    """
    Copy and extract hardware test files to remote Windows machine.

    Args:
        base_dir: local directory containing op_* folders
        mode_flag: "op" for hw_package subdir, "sg" for flat structure
        filter_patterns: list of wildcards like ["*MatMul*", "*Conv*"]
    """
    try:
        if not os.path.exists(base_dir):
            print(f"Source directory '{base_dir}' does not exist.")
            return

        if filter_patterns is None:
            filter_patterns = []
        elif isinstance(filter_patterns, str):
            filter_patterns = [filter_patterns]

        contents = os.listdir(base_dir)
        folders = [f for f in contents if os.path.isdir(os.path.join(base_dir, f))]

        if filter_patterns:
            folders = [
                f
                for f in folders
                if any(fnmatch.fnmatch(f, pat) for pat in filter_patterns)
            ]
            if not folders:
                print(f"[FILTER] No matching folders for patterns: {filter_patterns}")
                return

        print(f"[INFO] Folders selected: {folders}")

        matching_files = list(basic_file_list)

        for subfolder in folders:
            subfolder_path = os.path.join(base_dir, subfolder)

            for fname in file_list:
                if mode_flag == "op":
                    rel_path = os.path.join("hw_package", fname)
                else:
                    rel_path = fname

                candidate = os.path.join(subfolder_path, rel_path)

                if os.path.isfile(candidate):
                    matching_files.append([subfolder_path, rel_path])

        print(f"Creating tarball: {timestamp}.tar.gz ...")
        tar_paths = create_tarball(base_dir, matching_files, timestamp)

        for tar_file in tar_paths:
            print(f"Tarball created: {tar_file}")
            print(f"Uploading {tar_file} → {remote_base}/{timestamp} ...")
            copy_file_with_smbclient(
                tar_file, os.path.join(remote_base, timestamp), host, username, password
            )
            print(f"Uploaded {tar_file}")

            ssh_client = create_ssh_client(host, port, username, password)
            remote_ssh_dir = f"/Users/{remote_base}/{timestamp}".replace(
                "\\", "/"
            ).replace("//", "/")
            untar_cmd = (
                f'cd "{remote_ssh_dir}" && tar -xzf {os.path.basename(tar_file)}'
            )

            _stdin, stdout, stderr = ssh_client.exec_command(untar_cmd)
            exit_code = stdout.channel.recv_exit_status()

            if exit_code == 0:
                print(f"Untarred successfully in remote: {remote_ssh_dir}")
            else:
                print("Untar failed:")
                print(stderr.read().decode())

            ssh_client.close()
    except Exception as e:  # pylint: disable=W0718
        print(f"An error occurred: {e}")


def ensure_dir_local(path):
    """Ensure directory exists locally."""
    os.makedirs(path, exist_ok=True)


def copy_selected_tree_local(dst_root, file_list_pairs, keep_rel=True):
    """Copy selected files from file_list_pairs to destination root."""
    ensure_dir_local(dst_root)
    for base, rel in file_list_pairs:
        src_path = os.path.join(base if base else "", rel)
        if not os.path.isfile(src_path):
            continue
        if keep_rel and base:
            rel_subdir = os.path.basename(os.path.dirname(src_path))
            dst_dir = os.path.join(dst_root, rel_subdir)
        else:
            dst_dir = dst_root
        os.makedirs(dst_dir, exist_ok=True)
        shutil.copy2(src_path, os.path.join(dst_dir, os.path.basename(src_path)))


def copy_output_subtrees_local(base_dir, dst_root, subfile_names, mode_flag="op"):
    """
    Copy hardware binaries from each subfolder to destination root.

    Args:
        mode_flag: "sg" for flat structure, "op" for hw_package subdir
    """
    ensure_dir_local(dst_root)
    for entry in os.scandir(base_dir):
        if not entry.is_dir():
            continue
        op_root = entry.path
        if mode_flag == "op":
            src_sub = os.path.join(op_root, "hw_package")
            if not os.path.isdir(src_sub):
                continue
        else:
            src_sub = op_root
        present = [f for f in subfile_names if os.path.isfile(os.path.join(src_sub, f))]
        if not present:
            continue
        dst_sub = os.path.join(dst_root, os.path.basename(op_root))
        os.makedirs(dst_sub, exist_ok=True)
        for f in present:
            shutil.copy2(os.path.join(src_sub, f), os.path.join(dst_sub, f))


def run_local_script(
    py_path, perf_testing_flag, dtype_flag, mode_flag, debug_flag=False
):
    """Run dolphin_test_aie4.py locally on Windows DUT."""
    cmd = ["python", py_path]
    if perf_testing_flag:
        cmd.append("--perf_testing")
    if dtype_flag:
        cmd += ["--dtype", dtype_flag]
    if mode_flag == "op":
        cmd.append("--op")
    elif mode_flag == "sg":
        cmd.append("--sg")
    if debug_flag:
        cmd.append("--print")
    subprocess.run(cmd, check=True)


def fetch_file_local(src_path, local_dest_path):
    """Copy a file locally with directory creation."""
    os.makedirs(os.path.dirname(local_dest_path), exist_ok=True)
    shutil.copy2(src_path, local_dest_path)


def get_test_files_for_golden_io(golden_io):
    """Determine which test files to copy based on golden_io mode."""
    files = ["control.elf", "param.bin", "wgt.bin"]
    if golden_io is None:
        files = ["control.elf", "ifm.bin", "ofm.bin", "param.bin", "wgt.bin"]

    elif "update" in golden_io:
        files = ["control.elf", "ifm.bin", "ofm.bin", "param.bin", "wgt.bin"]

    if get_ml_timeline_log_level() > 0:
        files += ["ml_timeline_metadata.json", "xrt.ini", "aie_trace_config.json"]

    return files


def get_golden_io_config(golden_io):
    """Parse golden_io parameter and return (mode, subfolders)."""
    if golden_io is None:
        return None, []

    if "update" in golden_io:
        subfolders = [s for s in golden_io if s != "update"]
        if not subfolders:
            print("[ERROR] No subfolder specified for update mode.")
            sys.exit(1)
        return "update", subfolders

    return "copy", golden_io


def run_local_windows_test(
    output,
    dest_root,
    basic_file_list,
    golden_io,
    mode,
    perf_testing,
    dtype_for_test,
    timestamp,
    debug=False,
):
    """Execute hardware test on local Windows machine."""
    os.makedirs(dest_root, exist_ok=True)

    # Copy basic files (test scripts)
    copy_selected_tree_local(
        dst_root=dest_root,
        file_list_pairs=basic_file_list,
        keep_rel=False,
    )

    # Copy test files and handle golden IO if needed
    test_files = get_test_files_for_golden_io(golden_io)
    copy_output_subtrees_local(
        base_dir=output,
        dst_root=dest_root,
        subfile_names=test_files,
        mode_flag=mode,
    )

    # Run golden IO script if needed
    golden_mode, golden_subfolders = get_golden_io_config(golden_io)
    if golden_mode:
        local_copy_golden = os.path.join(dest_root, "copy_golden_io.py")
        for subfolder in golden_subfolders:
            # NOTE: keeping this call as-is (no debug) so behavior doesn't change
            run_local_script(
                f"{local_copy_golden} {subfolder} {golden_mode}", False, False, mode
            )

    # Run the main test script
    local_test_py = os.path.join(dest_root, "dolphin_test_aie4.py")
    print("executing dolphin_test_aie4.py locally")
    run_local_script(
        local_test_py, perf_testing, dtype_for_test, mode, debug_flag=debug
    )

    # Fetch results
    local_path = generate_output_path("output", output, timestamp)
    fetch_file_local(os.path.join(dest_root, "output.json"), local_path)
    fetch_file_local(
        os.path.join(dest_root, "output.xlsx"),
        local_path.replace(".json", ".xlsx"),
    )

    local_path_summary_src = os.path.join(dest_root, "output_summary.json")
    if os.path.isfile(local_path_summary_src):
        local_path_summary = generate_output_path("output_summary", output, timestamp)
        fetch_file_local(local_path_summary_src, local_path_summary)

    # Cleanup
    shutil.rmtree(dest_root, ignore_errors=True)
    print(f"Directory {dest_root} removed successfully.")


def run_remote_windows_test(
    output,
    destination_path,
    host,
    username,
    password,
    port,
    basic_file_list,
    golden_io,
    mode,
    patterns,
    perf_testing,
    dtype_for_test,
    timestamp,
    debug=False,  # <--- NEW
):
    """Execute hardware test on remote Windows machine."""
    create_dir_with_smbclient(
        f"{destination_path}/{timestamp}", host, username, password
    )

    # Determine which files to copy and golden IO configuration
    test_files = get_test_files_for_golden_io(golden_io)
    golden_mode, golden_subfolders = get_golden_io_config(golden_io)

    # Copy and untar files to remote
    copy_and_untar_allfolders(
        output,
        destination_path,
        host,
        username,
        password,
        port,
        test_files,
        basic_file_list,
        timestamp,
        mode_flag=mode,
        filter_patterns=patterns if golden_io is None else None,
    )

    # Run golden IO script if needed
    if golden_mode:
        remote_script_path = (
            f"C:/Users/Administrator/Desktop/WAIC_test/{timestamp}/copy_golden_io.py"
        )
        for subfolder in golden_subfolders:
            run_remote_script_golden_io(
                host,
                port,
                username,
                password,
                remote_script_path,
                subfolder,
                golden_mode,
            )
        return  # Golden IO mode doesn't run main test

    # Run main test script
    remote_script_path = (
        f"C:/Users/Administrator/Desktop/WAIC_test/{timestamp}/dolphin_test_aie4.py"
    )
    print("executing dolphin_test_aie4.py on remote machine")
    run_remote_script(
        host,
        port,
        username,
        password,
        remote_script_path,
        perf_testing,
        dtype_for_test,
        mode,
        debug_flag=debug,
    )

    # Fetch results
    remote_path = f"C:/Users/Administrator/Desktop/WAIC_test/{timestamp}/output.json"
    remote_path_2 = f"C:/Users/Administrator/Desktop/WAIC_test/{timestamp}/output.xlsx"
    remote_path_3 = (
        f"C:/Users/Administrator/Desktop/WAIC_test/{timestamp}/out_summary.json"
    )

    local_path = generate_output_path("output", output, timestamp)
    print(f"dolphin_test_aie4.py completed, copying output.json to {local_path}")
    copy_file_from_remote(host, port, username, password, remote_path, local_path)
    copy_file_from_remote(
        host,
        port,
        username,
        password,
        remote_path_2,
        local_path.replace(".json", ".xlsx"),
    )

    local_path_summary = generate_output_path("out_summary", output, timestamp)
    copy_file_from_remote(
        host,
        port,
        username,
        password,
        remote_path_3,
        local_path_summary,
    )

    # --------  copy AIE_HW_Debug.txt from each subfolder --------
    # Batch copy all log files in a single SSH session
    try:
        ssh_client = create_ssh_client(host, port, username, password)
        with SCPClient(ssh_client.get_transport()) as scp:
            for entry in os.scandir(output):
                if not entry.is_dir():
                    continue
                subfolder_name = entry.name
                remote_log = (
                    f"C:/Users/Administrator/Desktop/WAIC_test/{timestamp}/"
                    f"{subfolder_name}/AIE_HW_Debug.txt"
                )
                local_log_dir = os.path.join(output, subfolder_name)
                os.makedirs(local_log_dir, exist_ok=True)
                local_log_path = os.path.join(local_log_dir, "AIE_HW_Debug.txt")
                try:
                    scp.get(remote_log, local_log_path)
                    print(f"Copied '{remote_log}' to '{local_log_path}'")
                except Exception as e:  # pylint: disable=W0718
                    print(f"[WARN] Could not copy log for {subfolder_name}: {e}")
        ssh_client.close()
    except Exception as e:  # pylint: disable=W0718
        print(f"[WARN] Could not establish SSH session for log collection: {e}")


def check_test_results(output, timestamp):
    """Check if all tests passed and print result."""
    local_path = Path(generate_output_path("output", output, timestamp))
    passed = False

    try:
        if local_path.exists():
            with open(local_path, encoding="utf-8") as f:
                data = json.load(f)
            if data and all(item.get("Pass or Fail") == "Pass" for item in data):
                passed = True
    except Exception:  # pylint: disable=W0718
        pass

    print("DI_PASS" if passed else "DI_FAIL")


def remove_remote_directory(hostname, port, username, password, directory):
    """
    Run the remote hardware test script on a Windows DUT.

    Args:
        host (str): The hostname or IP address of the remote machine.
        port (int): SSH port to connect to.

    Returns:
        None
    """
    print("running...")
    try:
        ssh_client = create_ssh_client(hostname, port, username, password)
        command = f'rd /s /q "{directory}"'
        stdout, _, stderr = ssh_client.exec_command(command)
        print(stdout.read().decode())
        print(stderr.read().decode())
        ssh_client.close()
        print(f"Directory {directory} removed successfully.")
    except Exception as e:  # pylint: disable=W0718
        print(
            f"An error occurred while trying to remove the directory {directory} on {hostname}: {e}"
        )


def normalize_filter_patterns(filter_patterns):
    """
    Normalize filter_patterns into a list of wildcard patterns.

    Accepts:
      - None
      - string: "MatMul", "*MatMul*", "MatMul,Conv"
      - list/tuple: ["MatMul", "Conv"], ["*MatMul*", "*Conv*"]
    Returns:
      list[str] of patterns with wildcards added if missing.
    """
    if not filter_patterns:
        return []

    # If it's already list/tuple, use as-is
    if isinstance(filter_patterns, (list, tuple)):
        raw_list = list(filter_patterns)
    else:
        # Anything else → convert to string and split by comma
        raw_list = str(filter_patterns).split(",")

    patterns: list[str] = []
    for p in raw_list:
        p = str(p).strip()
        if not p:
            continue
        # If user forgot wildcard, wrap with *...*
        if not any(ch in p for ch in "*?"):
            p = f"*{p}*"
        patterns.append(p)

    return patterns


def HW_test(
    output: str,
    HW_req: str,
    host: str,
    perf_testing: bool = False,
    golden_io: list[str] | None = None,
    dtype: str = "int16",
    mode: str = "op",  # "op" (default) or "sg"
    filter_patterns: str | None = None,
    debug: bool = False,
):
    """
    Run test on hardware.
    output -- path to output directory (will be normalized)

    mode:
      - "op"  / "operator"  -> operator mode (Output/op_x/hw_package/...)
      - "sg"  / "subgraph"  -> subgraph mode (current behavior: Output/sg_x/...)
    """

    # ---------- normalize paths ----------
    output = os.path.abspath(output)
    HW_req = os.path.abspath(HW_req)

    # ---------- normalize & validate dtype ----------
    if dtype is None:
        dtype_for_test = None
    else:
        if not isinstance(dtype, str):
            raise TypeError(f"dtype must be a string or None, got {type(dtype)}")
        if dtype not in VALID_DTYPES:
            raise ValueError(
                f"Invalid dtype '{dtype}'. Must be one of: {sorted(VALID_DTYPES)}"
            )
        dtype_for_test = dtype

    # ---------- normalize & validate mode ----------
    mode = (mode or "op").lower()
    if mode not in {"op", "operator", "sg", "subgraph"}:
        raise ValueError("mode must be one of: 'op', 'operator', 'sg', 'subgraph'")
    if mode == "operator":
        mode = "op"
    if mode == "subgraph":
        mode = "sg"

    # ---------- normalize filter patterns ----------
    patterns = normalize_filter_patterns(filter_patterns)
    print(f"[FILTER] Active filter patterns: {patterns if patterns else 'None'}")

    # ---------------- Main body ----------------
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    source_path = f"{HW_req}"
    destination_path = "Administrator/Desktop/WAIC_test"
    basic_file_list = [
        ["", os.path.join(HW_req, f)]
        for f in ["dolphin_test_aie4.py", "copy_golden_io.py", "ml_timeline_parser.py"]
    ]
    print(f"HW IP: {host}")
    port = 22
    username = "Administrator"
    password = os.getenv("HW_PWD", "amdlabp@ssw0rd")
    LOCAL_WINDOWS = _is_windows_local()
    dest_root = os.path.join(
        "C:/Users/Administrator/Desktop/WAIC_test", timestamp
    ).replace("\\", "/")

    # Prepare instance_ids_list.json for subgraph mode
    if mode == "sg":
        copy_and_run_script(source_path, "collect_instance_ids.py", output)
        kernel_list_path = os.path.join(output, "instance_ids_list.json")
        if os.path.isfile(kernel_list_path):
            basic_file_list.append(["", kernel_list_path])

    # Run tests on local or remote Windows machine
    # Run tests on local or remote Windows machine
    if LOCAL_WINDOWS:
        run_local_windows_test(
            output,
            dest_root,
            basic_file_list,
            golden_io,
            mode,
            perf_testing,
            dtype_for_test,
            timestamp,
            debug=debug,
        )
    else:
        try:
            run_remote_windows_test(
                output,
                destination_path,
                host,
                username,
                password,
                port,
                basic_file_list,
                golden_io,
                mode,
                patterns,
                perf_testing,
                dtype_for_test,
                timestamp,
                debug=debug,
            )
        finally:
            check_test_results(output, timestamp)

    remote_path_remove = f"C:/Users/Administrator/Desktop/WAIC_test/{timestamp}"
    remove_remote_directory(host, port, username, password, remote_path_remove)
