#!/usr/bin/env python3
"""
Model E2E Test Runner

SSH into a Windows machine and execute model E2E tests (PSD, PSU, etc).
Parses report.csv and generates markdown output for GitHub issues.

Workflow (from C:\\Users\\Administrator\\Desktop\\shajaisw):
    0. Sync PowerShell scripts from repo to remote via SFTP
    1. powershell .\\sync_models_repo.ps1    - Sync models from gitenterprise
    2. powershell .\\sync_aie4_repo.ps1      - Sync aie4_models at specific commit/PR
    3. powershell .\\compile_run_qhw4_debug.ps1 $modelName  - Compile and run
    4. powershell .\\run_compiled_qhw4_debug.ps1 $modelName - Rerun pre-compiled (optional)

Prerequisites:
    - DataGen directory with Consts and Activations\\ort subdirectories
    - Bootstrap Windows machine first using bootstrap_windows.sh

Supported models: psd1, psd2, psd3, psd4, psh, psi, psu0, psu1, vit-base

Usage:
    # Basic usage (syncs to origin/main)
    python e2e_test.py --model psd1

    # Test specific commit
    python e2e_test.py --model psd1 --commit abc123

    # Test a PR (fetches refs/pull/123/head)
    python e2e_test.py --model psd1 --pr 123

    # Skip git sync (use current state on Windows - for dev iteration)
    python e2e_test.py --model psd1 --skip-sync

    # Skip compilation (run pre-compiled only)
    python e2e_test.py --model psd1 --skip-compile

    # Skip execution (compile only - for split compile/run workflow)
    python e2e_test.py --model psd1 --skip-execute

Developer Workflow:
    1. Fast iteration with --skip-sync:
       - SSH to Windows and checkout your branch manually
       - Run: python e2e_test.py --model psd1 --skip-sync
       - Edit code on Windows, re-run with --skip-sync

    2. PR validation with --pr:
       - Create PR on gitenterprise
       - Run: python e2e_test.py --model psd1 --pr <pr_number>
"""

import os
import sys
import argparse
import csv
import time
from io import StringIO
from typing import Optional

import paramiko
from tabulate import tabulate


# Supported models
SUPPORTED_MODELS = ["psd1", "psd2", "psd3", "psd4", "psh", "psi", "psu0", "psu1", "vit-base"]

# Base working directory on Windows
BASE_WORK_DIR = r"C:\Users\Administrator\Desktop\shajaisw"


def create_ssh_client(hostname: str, port: int, username: str, password: str) -> paramiko.SSHClient:
    """Create and return an SSH client connected to the remote host."""
    client = paramiko.SSHClient()
    client.load_system_host_keys()
    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    client.connect(hostname, port=port, username=username, password=password)
    return client


def run_powershell_command(client: paramiko.SSHClient, command: str, work_dir: str) -> tuple[int, str, str]:
    """
    Run a PowerShell command on the remote Windows machine.

    Args:
        client: SSH client connected to Windows machine
        command: PowerShell command to execute
        work_dir: Working directory for the command

    Returns:
        Tuple of (exit_code, stdout, stderr)
    """
    # Wrap command to run in specified directory with PowerShell
    full_command = f'powershell -Command "Set-Location \'{work_dir}\'; {command}"'

    print("\n" + "=" * 40)
    print(f"Running: {command}")
    print(f"  Working directory: {work_dir}")

    _stdin, stdout, stderr = client.exec_command(full_command)
    exit_code = stdout.channel.recv_exit_status()

    stdout_text = stdout.read().decode('utf-8', errors='replace')
    stderr_text = stderr.read().decode('utf-8', errors='replace')

    if stdout_text.strip():
        print(f"stdout:\n{stdout_text}")
    if stderr_text.strip():
        print(f"stderr:\n{stderr_text}")
    print(f"exit_code: {exit_code}")
    print("=" * 40)

    return exit_code, stdout_text, stderr_text


def read_remote_file(client: paramiko.SSHClient, remote_path: str) -> Optional[str]:
    """Read a file from the remote machine."""
    try:
        sftp = client.open_sftp()
        with sftp.file(remote_path, 'r') as f:
            content = f.read().decode('utf-8', errors='replace')
        sftp.close()
        return content
    except Exception as e:  # pylint: disable=broad-except
        print(f"Failed to read {remote_path}: {e}")
        return None


def sync_scripts_to_remote(client: paramiko.SSHClient, remote_dir: str) -> bool:
    """
    Sync PowerShell scripts from local repo to remote Windows machine.

    Args:
        client: SSH client connected to Windows machine
        remote_dir: Remote directory to sync scripts to

    Returns:
        True if successful, False otherwise
    """
    # Get the directory where this script lives
    local_scripts_dir = os.path.join(os.path.dirname(__file__), "scripts")

    if not os.path.isdir(local_scripts_dir):
        print(f"Warning: Scripts directory not found: {local_scripts_dir}")
        return True  # Not a fatal error, scripts may already be on remote

    scripts = [f for f in os.listdir(local_scripts_dir) if f.endswith('.ps1')]
    if not scripts:
        print("No .ps1 scripts found to sync")
        return True

    print(f"\n{'=' * 40}")
    print(f"Syncing {len(scripts)} scripts to {remote_dir}")

    try:
        sftp = client.open_sftp()

        for script in scripts:
            local_path = os.path.join(local_scripts_dir, script)
            remote_path = f"{remote_dir}\\{script}"
            # SFTP uses forward slashes even on Windows
            remote_path_sftp = remote_path.replace("\\", "/")

            print(f"  {script} -> {remote_path}")
            sftp.put(local_path, remote_path_sftp)

        sftp.close()
        print("Scripts synced successfully")
        print("=" * 40)
        return True

    except Exception as e:  # pylint: disable=broad-except
        print(f"Failed to sync scripts: {e}")
        print("=" * 40)
        return False


# PCI device instance ID for NPU
DEVICE_INSTANCE_ID = r"PCI\VEN_1022&DEV_17F1&SUBSYS_17F11022&REV_10\4&212933EF&0&0142"


def pnp_toggle(client: paramiko.SSHClient, action: str, instance_id: str = DEVICE_INSTANCE_ID) -> bool:
    """
    Toggle PCI device state using pnputil.exe over SSH.

    Args:
        client: SSH client connected to Windows machine
        action: 'disable' or 'enable'
        instance_id: PCI device instance ID

    Returns:
        True if successful, False otherwise
    """
    if action not in ("disable", "enable"):
        raise ValueError("action must be 'disable' or 'enable'")

    pnp_action = "/disable-device" if action == "disable" else "/enable-device"
    ps_cmd = f"pnputil {pnp_action} '{instance_id}'"

    print(f"[PNP] {action.upper()} device: {instance_id[:50]}...")
    exit_code, _, _ = run_powershell_command(client, ps_cmd, "C:\\")

    return exit_code == 0


def parse_report_csv(csv_content: str, max_l2_norm_threshold: float = 2.0) -> tuple[list[dict], bool]:
    """
    Parse report.csv content and return structured data.

    CSV format: Op Name, max_diff, L2_norm, L2_norm per element, Error Count, Note
    Success criteria: max_diff <= threshold and at least one row exists

    Args:
        csv_content: Raw CSV content
        max_l2_norm_threshold: Maximum acceptable L2_norm per element value (default: 2.0)

    Returns:
        Tuple of (list of row dicts, all_passed bool)
    """
    reader = csv.DictReader(StringIO(csv_content))
    rows = list(reader)

    # No data = failure
    if not rows:
        return rows, False

    all_passed = True
    for row in rows:
        max_diff = row.get('L2_norm per element', '0')
        try:
            if float(max_diff) > max_l2_norm_threshold:
                all_passed = False
        except ValueError:
            pass

    return rows, all_passed


def generate_markdown_table(rows: list[dict]) -> str:
    """Generate a markdown table from parsed CSV rows."""
    if not rows:
        return "*No data in report*"

    headers = list(rows[0].keys())
    table_data = [[row.get(h, '') for h in headers] for row in rows]
    return tabulate(table_data, headers=headers, tablefmt="github")


def npu_reset(client: paramiko.SSHClient):
    """Reset NPU device by disabling and enabling it via PnP cmdlets."""
    print("[PNP] disabling device...")
    pnp_toggle(client, "disable")
    time.sleep(1)
    print("[PNP] enabling device...")
    pnp_toggle(client, "enable")
    time.sleep(2)


def kill_processes_using_directory(client: paramiko.SSHClient, _work_dir: str):
    """Kill any processes that have file handles open in the work directory."""
    # Kill common culprits: python, waic_runner processes
    ps_cmd = """
Get-Process -Name python*,waic* -ErrorAction SilentlyContinue | Stop-Process -Force -ErrorAction SilentlyContinue
Start-Sleep -Seconds 1
"""
    run_powershell_command(client, ps_cmd, "C:\\")
    time.sleep(1)


def run_step(client: paramiko.SSHClient, command: str, work_dir: str, model_dir: str, report: StringIO) -> bool:
    """
    Run a step command and handle errors.

    Args:
        client: SSH client connected to Windows machine
        command: PowerShell command to execute
        work_dir: Working directory for the command
        model_dir: Directory where report.csv is generated
        report: StringIO to write results

    Returns:
        True if step succeeded, False otherwise
    """
    success = True

    # Kill processes using model directory and reset NPU device before tests
    kill_processes_using_directory(client, model_dir)
    npu_reset(client)

    # Compile and run
    exit_code, stdout, stderr = run_powershell_command(
        client, command, work_dir
    )
    if exit_code != 0:
        report.write(f"**{command} failed** (exit code: {exit_code})\n```\n{stderr or stdout}\n```\n")
        success = False

    # Parse report.csv from model directory
    report_path_win = f"{model_dir}\\report.csv"
    csv_content = read_remote_file(client, report_path_win)

    if csv_content:
        rows, all_passed = parse_report_csv(csv_content)
        success = success and all_passed
        report.write(generate_markdown_table(rows))
    else:
        report.write("report.csv not found\n")
        success = False

    return success


def run_e2e_test(
    host: str,
    port: int,
    username: str,
    password: str,
    model: str,
    skip_compile: bool = False,
    skip_execute: bool = False,
    commit: Optional[str] = None,
    skip_sync: bool = False,
    pr_number: Optional[int] = None
) -> tuple[bool, str]:
    """
    Run the model E2E test suite on a remote Windows machine.

    Args:
        host: Windows machine IP address
        port: SSH port
        username: SSH username
        password: SSH password
        model: Model name (e.g., psd1, psd2)
        skip_compile: If True, only run pre-compiled version
        skip_execute: If True, only compile without running
        commit: Git commit SHA to checkout (default: origin/main)
        skip_sync: If True, skip all git sync operations (use current state)
        pr_number: PR number to checkout (fetches refs/pull/{pr}/head)

    Returns:
        Tuple of (success bool, markdown report string)
    """
    report = StringIO()
    success = True
    model_dir = f"{BASE_WORK_DIR}\\{model}"

    try:
        print(f"Connecting to {host}:{port} as {username}...")
        client = create_ssh_client(host, port, username, password)
        print("Connected successfully")

        # Sync PowerShell scripts from repo to remote
        if not sync_scripts_to_remote(client, BASE_WORK_DIR):
            report.write("**Failed to sync scripts to remote**\n")
            client.close()
            return False, report.getvalue()

        if skip_sync:
            print("\n[SKIP] Skipping git sync (--skip-sync specified)")
            print("       Using current state on Windows machine")
        else:
            # Sync models repository
            exit_code, stdout, stderr = run_powershell_command(
                client, ".\\sync_models_repo.ps1", BASE_WORK_DIR
            )
            if exit_code != 0:
                report.write(f"**sync_models_repo.ps1 failed** (exit code: {exit_code})\n```\n{stderr or stdout}\n```\n")
                client.close()
                return False, report.getvalue()

            # Sync aie4_models repository
            # Priority: --pr > --commit > default (origin/main)
            if pr_number:
                # Fetch PR ref and checkout
                sync_cmd = f".\\sync_aie4_repo.ps1 -Commit 'pull/{pr_number}/head'"
                print(f"\n[PR] Fetching PR #{pr_number} (refs/pull/{pr_number}/head)")
            elif commit:
                sync_cmd = f".\\sync_aie4_repo.ps1 -Commit '{commit}'"
            else:
                sync_cmd = ".\\sync_aie4_repo.ps1"

            exit_code, stdout, stderr = run_powershell_command(
                client, sync_cmd, BASE_WORK_DIR
            )
            if exit_code != 0:
                report.write(f"**sync_aie4_repo.ps1 failed** (exit code: {exit_code})\n```\n{stderr or stdout}\n```\n")
                client.close()
                return False, report.getvalue()

        # Compile and run tests
        if skip_compile:
            # Only run pre-compiled version
            step_success = run_step(client, f".\\run_compiled_qhw4_debug.ps1 {model}", BASE_WORK_DIR, model_dir, report)
        elif skip_execute:
            # Only compile, don't run
            report.write("**Compile only mode (--skip-execute)**\n")
            exit_code, stdout, stderr = run_powershell_command(
                client, f".\\compile_run_qhw4_debug.ps1 {model} -SkipExecute", BASE_WORK_DIR
            )
            if exit_code != 0:
                report.write(f"**compile failed** (exit code: {exit_code})\n```\n{stderr or stdout}\n```\n")
                step_success = False
            else:
                report.write(f"Compiled artifacts in: `{model_dir}\\waic_work_compile`\n")
                step_success = True
        else:
            # Try compile first, fallback to pre-compiled if it fails
            step_success = run_step(client, f".\\compile_run_qhw4_debug.ps1 {model}", BASE_WORK_DIR, model_dir, report)
            if not step_success:
                report.write("\n**Compile failed, trying pre-compiled version...**\n\n")
                run_cmd = f".\\run_compiled_qhw4_debug.ps1 {model}"
                step_success = run_step(client, run_cmd, BASE_WORK_DIR, model_dir, report)

        success = success and step_success

        client.close()

    except Exception as e:  # pylint: disable=broad-except
        report.write(f"**Error:** {str(e)}\n")
        success = False

    return success, report.getvalue()


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(description="Run model E2E tests on remote Windows machine")
    parser.add_argument(
        "--host",
        default="10.228.203.217",
        help="Windows machine IP address"
    )
    parser.add_argument(
        "--port",
        type=int,
        default=22,
        help="SSH port (default: 22)"
    )
    parser.add_argument(
        "--username",
        default="Administrator",
        help="SSH username (default: Administrator)"
    )
    parser.add_argument(
        "--password",
        default=os.environ.get("PSD_WINDOWS_PASSWORD", "amdlabp@ssw0rd"),
        help="SSH password (or set PSD_WINDOWS_PASSWORD env var)"
    )
    parser.add_argument(
        "--model",
        default="psd1",
        choices=SUPPORTED_MODELS,
        help=f"Model to test (default: psd1, choices: {', '.join(SUPPORTED_MODELS)})"
    )
    parser.add_argument(
        "--skip-compile",
        action="store_true",
        help="Skip compilation and only run pre-compiled version"
    )
    parser.add_argument(
        "--skip-execute",
        action="store_true",
        help="Skip execution (compile only - for split compile/run workflow)"
    )
    parser.add_argument(
        "--commit",
        default=None,
        help="Git commit SHA to checkout aie4_models (default: origin/main)"
    )
    parser.add_argument(
        "--pr",
        type=int,
        default=None,
        metavar="NUMBER",
        help="PR number to test (fetches refs/pull/<NUMBER>/head from origin)"
    )
    parser.add_argument(
        "--skip-sync",
        action="store_true",
        help="Skip git sync operations (use current state on Windows for fast dev iteration)"
    )
    parser.add_argument(
        "--output",
        help="Output markdown file path (default: {model}_e2e_report.md)"
    )

    args = parser.parse_args()

    # Validate mutually exclusive options
    if args.skip_sync and (args.commit or args.pr):
        print("Error: --skip-sync cannot be used with --commit or --pr")
        sys.exit(1)
    if args.commit and args.pr:
        print("Error: --commit and --pr are mutually exclusive")
        sys.exit(1)
    if args.skip_compile and args.skip_execute:
        print("Error: --skip-compile and --skip-execute are mutually exclusive")
        sys.exit(1)

    if not args.password:
        print("Error: Password required. Set --password or PSD_WINDOWS_PASSWORD env var")
        sys.exit(1)

    output_file = args.output or f"{args.model}_e2e_report.md"

    success, report = run_e2e_test(
        host=args.host,
        port=args.port,
        username=args.username,
        password=args.password,
        model=args.model,
        skip_compile=args.skip_compile,
        skip_execute=args.skip_execute,
        commit=args.commit,
        skip_sync=args.skip_sync,
        pr_number=args.pr
    )

    # Write report
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write(report)
    print(f"\nReport written to: {output_file}")

    # Print report to stdout as well
    print("\n" + "=" * 60)
    print(report)
    print("=" * 60)

    sys.exit(0 if success else 1)


if __name__ == "__main__":
    main()
