#!/usr/bin/env python3
"""
Generate model compilation results report.

Combines result.json files from each model's artifact directory
and generates a markdown table summarizing compilation status.
"""
import argparse
import sys
from pathlib import Path

from dataclasses import fields

from tabulate import tabulate

from aie4_bench.compilation_result import CompilationResult, CompilationStatus


def find_model_dirs(artifact_dir: Path) -> list[Path]:
    """Find all model artifact directories.

    A model directory is identified by containing either:
    - model_cfg.yaml (test started and wrote config)
    - result.json (test wrote a result, even if it failed early)

    Args:
        artifact_dir: Path to the artifacts directory

    Returns:
        List of paths to model directories
    """
    # Find dirs with model_cfg.yaml OR result.json
    cfg_dirs = {cfg.parent for cfg in artifact_dir.rglob("model_cfg.yaml")}
    result_dirs = {r.parent for r in artifact_dir.rglob("result.json")}
    return list(cfg_dirs | result_dirs)


def combine_results(artifact_dir: Path) -> tuple[list[CompilationResult], str]:
    """Combine all result.json files from artifact directory.
       Combine all subgraph summary files from artifact directory.

    Args:
        artifact_dir: Path to the artifacts directory

    Returns:
        List of CompilationResult objects (includes placeholders for missing results)
        String which includes subgraph summary for all models
    """
    results = []
    subgraph_summaries: list[str] = []

    model_dirs = find_model_dirs(artifact_dir)

    for model_dir in model_dirs:
        json_path = model_dir / "result.json"
        summary_path = model_dir / "subgraph_summary.log"

        if summary_path.exists():
            try:
                subgraph_summaries.append(summary_path.read_text().rstrip())
            except Exception:  # pylint: disable=broad-except
                subgraph_summaries.append(
                    "Failed to read Subgraph Summary!\n"
                )

        if not json_path.exists():
            model_name = model_dir.name
            results.append(CompilationResult(
                model=model_name,
                status=CompilationStatus.FAILED,
                error=f"{model_name} - result.json missing (test crashed)"
            ))
            continue

        try:
            results.append(CompilationResult.from_json(json_path))
        except Exception as e:  # pylint: disable=broad-except
            print(f"Warning: Failed to read {json_path}: {e}", file=sys.stderr)
            results.append(CompilationResult(
                model=model_dir.name,
                status=CompilationStatus.UNKNOWN,
                error=f"Failed to parse result.json: {e}"
            ))

    return results, "\n\n".join(subgraph_summaries)


def print_results_table(results: list[CompilationResult]) -> None:
    """Print results in a formatted table.

    Args:
        results: List of CompilationResult objects
    """
    if not results:
        print("No model compilation results found.")
        return

    # Sort: FAILED first, then PASSED
    status_order = {CompilationStatus.FAILED: 0, CompilationStatus.UNKNOWN: 1, CompilationStatus.PASSED: 2}
    results.sort(key=lambda x: status_order.get(x.status, 99))

    # Truncate error messages for table display
    rows = []
    for r in results:
        error = r.error
        if error:
            # First line only, truncated
            error = error.split('\n')[0]
        rows.append([r.model, r.status.value, r.subgraphs, error])

    headers = [f.name.title() for f in fields(CompilationResult)]
    print(tabulate(rows, headers=headers, tablefmt="github"))

    # Summary
    total = len(results)
    passed = sum(1 for r in results if r.status == CompilationStatus.PASSED)
    failed = sum(1 for r in results if r.status == CompilationStatus.FAILED)
    print(f"\nSummary: {passed}/{total} models passed, {failed} failed")


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(description='Generate model compilation results report')
    parser.add_argument('--artifact-dir', required=True, help='Path to artifacts directory')
    args = parser.parse_args()

    artifact_dir = Path(args.artifact_dir)
    results, subgraph_summary = combine_results(artifact_dir)

    print("MODEL_COMPILATION_RESULTS_START")
    print("# Model Compilation Status")
    print_results_table(results)
    print("\n---\n")

    print("# Model Subgraph Summary")
    print(subgraph_summary)
    print("MODEL_COMPILATION_RESULTS_END")


if __name__ == '__main__':
    main()
