#!/usr/bin/env python3
"""
Submit and monitor pytest jobs to LSF with test selection and filtering capabilities.

This script allows you to submit pytest tests to LSF with the same selection
and filtering options that pytest provides, including:
- Markers (-m)
- Keywords (-k)
- Specific test files
- Specific test functions
- Target selection (--target)
"""

import os
import re
import subprocess
import shlex
from pathlib import Path
from typing import List, Optional
import typer
import pytest
from buildtest.common import BuildTarget
from graph.utilities import logger


# Repository paths
REPO_DIR = Path(os.environ.get("AIE4_ROOT_DIR")).absolute()
BUILDTEST_DIR = REPO_DIR / "buildtest"
WRAPPER_SCRIPT = BUILDTEST_DIR / "pytest_lsf_wrapper.sh"

# Create typer app
app = typer.Typer(
    help="Submit pytest jobs to LSF with test selection and filtering",
    add_completion=False,
    rich_markup_mode="rich"
)


def get_test_list(pytest_args: List[str]) -> List[str]:
    """
    Use pytest API to collect tests that match the given criteria.

    Args:
        pytest_args: List of pytest arguments (markers, keywords, files, etc.)

    Returns:
        List of test node IDs (e.g., "test_conv.py::test_conv[0-l2]")
    """
    # Use pytest's API to collect tests
    # The Session object collects all tests matching the criteria
    class TestCollector:
        """Collect pytest tests"""
        def __init__(self):
            self.collected = []

        def pytest_collection_finish(self, session):
            """Hook called after collection is finished"""
            self.collected = [item.nodeid for item in session.items]

    collector = TestCollector()

    # Run pytest in collection-only mode with our custom plugin
    pytest.main(
        ["--collect-only", "-q"] + pytest_args,
        plugins=[collector]
    )

    return collector.collected


def sanitize_job_name(test_id: str) -> str:
    """
    Convert a pytest test ID into a valid LSF job name.

    LSF job names have restrictions, so we need to sanitize:
    - Remove/replace special characters that LSF doesn't like
    - Keep it readable but safe

    Example:
        test_conv.py::test_a16w8[shape_9_[32, 32, 640]]
        -> test_conv.py_test_a16w8_shape_9__32_32_640

    Args:
        test_id: Pytest test node ID

    Returns:
        Sanitized job name safe for LSF
    """

    # Replace :: with / for path-like structure
    name = test_id.replace("::", "/")

    # Remove .py extension for cleaner names
    name = name.replace(".py", "")

    # Replace [ and ] with underscores
    name = name.replace("[", "_").replace("]", "")

    # Replace commas and spaces with underscores
    name = name.replace(", ", "_").replace(",", "_").replace(" ", "_")

    # Replace other problematic characters with underscores
    # Keep only alphanumeric, underscore, hyphen, slash, and dot
    name = re.sub(r'[^a-zA-Z0-9_/.-]', '_', name)

    # Collapse multiple underscores into one
    name = re.sub(r'_+', '_', name)

    # Remove leading/trailing underscores
    name = name.strip('_')

    return name


def submit_lsf_job(
    test_id: str,
    target: Optional[BuildTarget] = None,
    job_name: Optional[str] = None,
    queue: str = "normal",
    mem_limit: str = "32GB",
    output_dir: Optional[Path] = None,
    additional_bsub_args: Optional[List[str]] = None,
    dry_run: bool = True,
    resources: str = "select[osdistro=rhel && (osver=ws8)]",
    hwtest: bool = False,
    output_root: Optional[str] = None,
    extra_pytest_args: Optional[List[str]] = None,
    artifact_dir: Optional[Path] = None
) -> Optional[str]:
    """
    Submit a single test to LSF.

    Args:
        test_id: Pytest test node ID (e.g., "test_conv.py::test_conv[0-l2]")
        target: Build target (sim, dataflow, cert, cert_sim). Optional for non-buildtest tests.
        job_name: LSF job name (auto-generated if None)
        queue: LSF queue name
        mem_limit: Memory limit for the job
        output_dir: Directory for stdout/stderr files
        additional_bsub_args: Additional arguments to pass to bsub
        resources: Resource requirements for the job
        hwtest: Run the test on HW
        output_root: Output root directory for pytest --output-root (supports {{worker_id}} template)
        extra_pytest_args: Additional pytest arguments (e.g., ['--run-model-compilation'])
        artifact_dir: Directory for test artifacts (passed via AIE4_ARTIFACT_DIR env var)

    Returns:
        Job ID returned by bsub, or None if submission failed
    """
    # Generate job name from test ID if not provided
    if job_name is None:
        job_name = sanitize_job_name(test_id)

    # Set output directory
    if output_dir is None:
        output_dir = BUILDTEST_DIR / "lsf_logs"
    output_dir.mkdir(parents=True, exist_ok=True)

    stdout_file = output_dir / f"{job_name}.out"
    stderr_file = output_dir / f"{job_name}.err"
    stdout_file.parent.mkdir(parents=True, exist_ok=True)

    # I hope this doesn't cause problems -- Jack
    if os.path.exists(stdout_file):
        try:
            os.remove(stdout_file)
            print(f"File '{stdout_file}' successfully deleted.")
        except PermissionError:
            print(f"Error: Permission denied. Could not delete '{stdout_file}'.")

    # Build bsub command
    bsub_cmd = [
        "bsub",
        "-L", "/bin/bash",
        "-J", job_name,
        "-q", queue,
        "-M", mem_limit,
        "-o", str(stdout_file),
        "-e", str(stderr_file),
        "-R", resources
    ]

    # Add additional bsub arguments if provided
    if additional_bsub_args:
        bsub_cmd.extend(additional_bsub_args)

    # Build pytest command with proper quoting for test_id (which may contain [ ] etc)
    pytest_cmd_parts = ["pytest", test_id, "-v", "-n", "0", "-s"]

    # Add target if provided (buildtest tests need this)
    if target:
        pytest_cmd_parts.extend(["--target", target.value])

    if hwtest:
        pytest_cmd_parts.extend("--hwtest")

    # Add output_root if provided
    if output_root:
        pytest_cmd_parts.extend(["--output-root", output_root])

    # Add extra pytest args (e.g., --run-model-compilation)
    if extra_pytest_args:
        pytest_cmd_parts.extend(extra_pytest_args)

    # Add artifact_dir as pytest CLI argument
    if artifact_dir:
        pytest_cmd_parts.extend(["--artifact-dir", str(artifact_dir)])

    pytest_cmd = " ".join(shlex.quote(part) for part in pytest_cmd_parts)

    # Build the wrapper command
    wrapper_cmd = f"bash {WRAPPER_SCRIPT} {pytest_cmd}"

    # Complete command: bash wrapper.sh <pytest command>
    bsub_cmd.append(shlex.quote(wrapper_cmd))
    logger.info("Submitting job with command: %s", " ".join(bsub_cmd))

    if dry_run:
        return None

    # Submit job
    try:
        result = subprocess.run(
            bsub_cmd,
            capture_output=True,
            text=True,
            check=True
        )

        # Extract job ID from bsub output
        # Typical output: "Job <12345> is submitted to queue <medium>."
        job_id = "unknown"
        for line in result.stdout.splitlines():
            if "Job <" in line:
                job_id = line.split("<")[1].split(">")[0]
                break

        logger.info("Submitted: %s (Job ID: %s)", job_name, job_id)
        return job_id

    except subprocess.CalledProcessError as e:
        logger.error("Failed to submit %s: %s", job_name, e.stderr)
        return None


@app.command()
def submit(
    # Pytest selection arguments
    markers: Optional[str] = typer.Option(
        None, "-m", "--markers",
        help="Only run tests matching given marker expression (e.g., 'dma', 'pdi')"
    ),
    keywords: Optional[str] = typer.Option(
        None, "-k", "--keywords",
        help="Only run tests matching given keyword expression"
    ),
    testpaths: Optional[List[str]] = typer.Argument(
        None,
        help="Test files or test node IDs to run"
    ),
    # Target argument
    target: BuildTarget = typer.Option(
        BuildTarget.SIM,
        "--target",
        help="Build target for the operator",
        case_sensitive=False
    ),
    # HW test argument
    hwtest: bool = typer.Option(
        False,
        "--hwtest",
        help="Run hardware tests",
        case_sensitive=False
    ),
    # LSF arguments
    queue: str = typer.Option(
        "medium",
        "-q", "--queue",
        help="LSF queue name"
    ),
    mem: str = typer.Option(
        "16GB",
        "--mem",
        help="Memory limit per job"
    ),
    output_dir: Optional[Path] = typer.Option(
        None,
        "--output-dir",
        help="Directory for LSF stdout/stderr files (default: buildtest/lsf_logs)"
    ),
    job_prefix: Optional[str] = typer.Option(
        None,
        "--job-prefix",
        help="Prefix for LSF job names"
    ),
    # Additional options
    dry_run: bool = typer.Option(
        False,
        "--dry-run",
        help="Show what would be submitted without actually submitting"
    ),
    submit: bool = typer.Option(  # pylint: disable=redefined-outer-name
        False,
        "--submit",
        help="Submit the jobs"
    ),
    bsub_args: Optional[str] = typer.Option(
        None,
        "--bsub-args",
        help="Additional arguments to pass to bsub (quote the string)"
    ),
):
    """
    Submit pytest jobs to LSF with test selection and filtering.

    Examples:

        # Collect all DMA tests with sim target
        python pytest_lsf.py submit -m dma --target sim

        # Dry run (show what would be submitted)
        python pytest_lsf.py submit -m dma --target sim --dry-run

        # Submit all DMA tests with sim target
        python pytest_lsf.py submit -m dma --target sim --submit

        # Submit tests matching keyword
        python pytest_lsf.py submit -k "conv and not l3" --target sim --submit

        # Submit specific test
        python pytest_lsf.py submit test_binary.py::test_binary[add_8-0] --target sim --submit
    """
    # Build pytest arguments for test collection
    pytest_args = []

    if markers:
        pytest_args.extend(["-m", markers])

    if keywords:
        pytest_args.extend(["-k", keywords])

    if testpaths:
        pytest_args.extend(testpaths)

    # If no selection criteria provided, show help
    if not pytest_args:
        logger.error("Error: No test selection criteria provided.")
        logger.error("Use -m, -k, or specify test files/paths.")
        raise typer.Exit(code=1)

    # Get list of tests to submit
    logger.info("Collecting tests with: %s", " ".join(pytest_args))
    tests = get_test_list(pytest_args)

    if not tests:
        logger.error("No tests collected. Check your selection criteria.")
        raise typer.Exit(code=1)

    logger.info("Found %d test(s) to submit\n", len(tests))

    # Show collected tests if requested
    if not submit and not dry_run:
        logger.info("Tests that would be submitted:")
        for test in tests:
            logger.info("  - %s", test)
        raise typer.Exit(code=0)

    # Parse additional bsub args
    additional_bsub_args = bsub_args.split() if bsub_args else None

    # Submit jobs
    submitted_jobs = []
    failed_jobs = []

    for i, test_id in enumerate(tests, 1):
        # Generate job name
        base_name = sanitize_job_name(test_id)
        job_name = f"{job_prefix}_{base_name}" if job_prefix else base_name

        logger.info("\n\n%s", "-"*60)
        logger.info("[%d/%d] Submitting: %s", i, len(tests), test_id)

        job_id = submit_lsf_job(
            test_id=test_id,
            target=target,
            job_name=job_name,
            queue=queue,
            mem_limit=mem,
            output_dir=output_dir,
            additional_bsub_args=additional_bsub_args,
            dry_run=dry_run,
            hwtest=hwtest
        )

        if dry_run:
            continue

        if job_id:
            submitted_jobs.append((job_name, job_id))
        else:
            failed_jobs.append(test_id)

    # Summary
    logger.info("\n%s", "="*60)
    if dry_run:
        logger.info("DRY RUN: Would have submitted %d job(s)", len(tests))
    else:
        logger.info("Successfully submitted: %d job(s)", len(submitted_jobs))
        if failed_jobs:
            logger.warning("Failed to submit: %d job(s)", len(failed_jobs))
            for test in failed_jobs:
                logger.warning("  - %s", test)

        logger.info("\nLogs will be written to: %s", output_dir or BUILDTEST_DIR / "lsf_logs")
        logger.info("\nMonitor jobs with: bjobs")
        logger.info("Check job output: bpeek <job_id>")
    logger.info("="*60)


@app.command()
def status(
    job_prefix: Optional[str] = typer.Option(
        None,
        "--job-prefix",
        help="Filter jobs by prefix"
    ),
    show_all: bool = typer.Option(
        False,
        "--all",
        help="Show all jobs (including completed)"
    ),
    verbose: bool = typer.Option(
        False,
        "-v", "--verbose",
        help="Show detailed job information"
    ),
):
    """
    Check status of submitted LSF jobs.

    Examples:

        # Show status of all running jobs
        python pytest_lsf.py status

        # Show all jobs including completed
        python pytest_lsf.py status --all

        # Show jobs with specific prefix
        python pytest_lsf.py status --job-prefix test_conv

        # Show detailed information
        python pytest_lsf.py status -v
    """
    # Build bjobs command
    cmd = ["bjobs", "-w"]

    if show_all:
        cmd.append("-a")

    if verbose:
        cmd.append("-l")

    # Run bjobs
    try:
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            check=False
        )

        if result.returncode != 0 and "No unfinished job found" not in result.stderr:
            logger.error("Error running bjobs: %s", result.stderr)
            raise typer.Exit(code=1)

        # Filter by job prefix if provided
        if job_prefix:
            lines = result.stdout.splitlines()
            if lines:
                # Print header
                logger.info(lines[0])
                # Filter and print matching jobs
                count = 0
                for line in lines[1:]:
                    if job_prefix in line:
                        logger.info(line)
                        count += 1
                logger.info("\nFound %d job(s) matching prefix '%s'", count, job_prefix)
        else:
            # Print all output
            logger.info("\n%s", result.stdout)

        if "No unfinished job found" in result.stderr or "No job found" in result.stderr:
            logger.info("No jobs found")

    except FileNotFoundError as exc:
        logger.error("Error: 'bjobs' command not found. LSF may not be available.")
        raise typer.Exit(code=1) from exc


@app.command()
def logs(
    job_name: str = typer.Argument(
        ...,
        help="Job name (without .out/.err extension)"
    ),
    output_dir: Optional[Path] = typer.Option(
        None,
        "--output-dir",
        help="Directory containing LSF logs (default: buildtest/lsf_logs)"
    ),
    show_errors: bool = typer.Option(
        False,
        "-e", "--errors",
        help="Show stderr instead of stdout"
    ),
    follow: bool = typer.Option(
        False,
        "-f", "--follow",
        help="Follow the log file (like tail -f)"
    ),
):
    """
    Show logs for a submitted job.

    Examples:

        # Show stdout for a job
        python pytest_lsf.py logs test_binary_add_8_0

        # Show stderr
        python pytest_lsf.py logs test_binary_add_8_0 --errors

        # Follow the log file
        python pytest_lsf.py logs test_binary_add_8_0 --follow
    """
    if output_dir is None:
        output_dir = BUILDTEST_DIR / "lsf_logs"

    # Determine which log file to show
    ext = ".err" if show_errors else ".out"
    log_file = output_dir / f"{job_name}{ext}"

    if not log_file.exists():
        logger.error("Log file not found: %s", log_file)
        logger.info("Available log files:")
        for f in sorted(output_dir.glob("*.out")):
            logger.info("  - %s", f.stem)
        raise typer.Exit(code=1)

    # Show the log file
    if follow:
        cmd = ["tail", "-f", str(log_file)]
        subprocess.run(cmd, check=False)
    else:
        cmd = ["cat", str(log_file)]
        subprocess.run(cmd, check=False)


if __name__ == "__main__":
    app()
