#!/usr/bin/env python3
"""
Group operator directories by DTYPE_OFM and run hardware validation per group.

This script:
1. Scans operator directories for conv_cfg.json or gemm_cfg.json
2. Extracts DTYPE_OFM from each config
3. Groups operators by dtype
4. Creates dtype-specific directories
5. Runs hardware validation for each dtype group
"""

import os
import sys
import json
import shutil
import argparse
from pathlib import Path
from collections import defaultdict
from tabulate import tabulate
from buildtest.common import run_hw_validation


def find_config_file(op_path: str) -> str | None:
    """
    Find conv_cfg.json or gemm_cfg.json in operator directory.

    Args:
        op_path: Path to operator directory

    Returns:
        Path to config file if found, None otherwise
    """
    for cfg_name in ['conv_cfg.json', 'gemm_cfg.json', 'broadcast_cfg.json', 'uniop_cfg.json']:
        cfg_path = os.path.join(op_path, cfg_name)
        if os.path.exists(cfg_path):
            return cfg_path
    return None


def normalize_dtype(dtype: str) -> str:
    """
    Normalize dtype by converting unsigned types to signed equivalents.

    Args:
        dtype: Original dtype (e.g., 'uint8', 'int16')

    Returns:
        Normalized dtype (unsigned -> signed mapping)
    """
    if dtype == 'uint8':
        return 'int8'
    if dtype == 'float32':
        return 'fp32'
    return dtype


def extract_dtype_from_config(cfg_path: str) -> str | None:
    """
    Extract DTYPE_OFM from config JSON file.

    Args:
        cfg_path: Path to config file

    Returns:
        DTYPE_OFM value if found, None otherwise
    """
    try:
        with open(cfg_path, 'r', encoding='utf-8') as f:
            cfg = json.load(f)
            return cfg.get('DTYPE_OFM', cfg.get('DTYPE_OUT'))
    except Exception as e:  # pylint: disable=broad-except
        print(f"Warning: Failed to read {cfg_path}: {e}", file=sys.stderr)
        return None


def group_operators_by_dtype(output_dir: str) -> tuple[dict[str, list[str]], list[str]]:
    """
    Scan output directory and group operators by DTYPE_OFM.

    Args:
        output_dir: Root output directory containing operator folders

    Returns:
        Tuple of (dtype_groups dict, unassigned_dirs list)
    """
    dtype_groups = defaultdict(list)
    unassigned_dirs = []

    for op_dir in os.listdir(output_dir):
        op_path = os.path.join(output_dir, op_dir)
        if not os.path.isdir(op_path):
            continue

        cfg_file = find_config_file(op_path)
        if cfg_file:
            dtype_act = extract_dtype_from_config(cfg_file)
            if dtype_act:
                # Normalize dtype: uint* -> int*
                normalized_dtype = normalize_dtype(dtype_act)
                dtype_groups[normalized_dtype].append(op_dir)
            else:
                unassigned_dirs.append(op_dir)
        else:
            unassigned_dirs.append(op_dir)

    return dict(dtype_groups), unassigned_dirs


def move_operators_to_dtype_dir(output_dir: str, dtype: str, op_dirs: list[str]) -> str:
    """
    Create dtype-specific directory and move operators into it.

    Args:
        output_dir: Root output directory
        dtype: Data type name
        op_dirs: List of operator directory names to move

    Returns:
        Path to dtype-specific directory
    """
    dtype_dir = os.path.join(output_dir, f'dtype_{dtype}')
    os.makedirs(dtype_dir, exist_ok=True)

    for op_dir in op_dirs:
        src = os.path.join(output_dir, op_dir)
        dst = os.path.join(dtype_dir, op_dir)
        try:
            shutil.move(src, dst)
        except Exception as e:  # pylint: disable=broad-except
            print(f"Warning: Failed to move {src} to {dst}: {e}", file=sys.stderr)

    return dtype_dir


def combine_json_results(dtype_dirs: list[str]) -> list[dict]:
    """
    Combine all output_*.json files from dtype directories.

    Args:
        dtype_dirs: List of dtype directory paths

    Returns:
        Combined list of test result dictionaries
    """
    combined_results = []

    for dtype_dir_path in dtype_dirs:
        # Use glob to find all output_*.json files recursively
        for json_path in Path(dtype_dir_path).glob('output_*.json'):
            try:
                with open(json_path, 'r', encoding='utf-8') as f:
                    dtype_results = json.load(f)
                    # Merge results - JSON is a list of test result dicts
                    if isinstance(dtype_results, list):
                        combined_results.extend(dtype_results)
            except Exception as e:  # pylint: disable=broad-except
                print(f"Warning: Failed to read {json_path}: {e}", file=sys.stderr)

    return combined_results


def print_results_table(results: list[dict]) -> None:
    """
    Print test results in a formatted table using tabulate.

    Args:
        results: List of test result dictionaries
    """
    # Filter for failed tests only
    failed_results = [r for r in results if r.get('Pass or Fail', '').lower() != 'pass']
    if failed_results:
        print(tabulate(failed_results, headers='keys', tablefmt='github'))
    print(f"Summary: {len(results)} tests processed, {len(failed_results)} failed")


def validate_dtype_groups(dtype_groups: dict[str, list[str]], hw_host: str, output_dir: str) -> list[str]:
    """
    Run hardware validation for each dtype group.

    Args:
        dtype_groups: Dictionary mapping dtype to list of operator directories
        hw_host: Hardware host IP address
        output_dir: Root output directory

    Returns:
        List of dtype directory paths
    """
    dtype_dirs = []
    for dtype, op_dirs in dtype_groups.items():
        dtype_dir = move_operators_to_dtype_dir(output_dir, dtype, op_dirs)
        dtype_dirs.append(dtype_dir)
        print(f"\nRunning hardware validation for dtype={dtype} ({len(op_dirs)} operators)")
        run_hw_validation(out_dir=dtype_dir, dtype=dtype, host=hw_host)
    return dtype_dirs


def main():
    """Main entry point for grouping and validating by dtype."""
    parser = argparse.ArgumentParser(
        description='Group operators by DTYPE_OFM and run hardware validation'
    )
    parser.add_argument(
        '--output-dir',
        required=True,
        help='Root output directory containing operator folders'
    )
    parser.add_argument(
        '--hw-host',
        required=True,
        help='Hardware host IP address'
    )

    args = parser.parse_args()

    # Group operators by DTYPE_OFM
    dtype_groups, unassigned_dirs = group_operators_by_dtype(args.output_dir)

    if unassigned_dirs:
        print(f"Found {len(unassigned_dirs)} operator(s) without DTYPE_OFM:", file=sys.stderr)
        for op_dir in unassigned_dirs:
            print(f"  - {op_dir}", file=sys.stderr)
        if not all(dir.startswith('op_pdi') for dir in unassigned_dirs):
            print("Please ensure all operators have valid conv_cfg.json or gemm_cfg.json with DTYPE_OFM.", file=sys.stderr)
            sys.exit(1)

    if not dtype_groups:
        print("Warning: No operators found with DTYPE_OFM", file=sys.stderr)
        sys.exit(0)

    # Log expected DI_PASS count for CI validation
    expected_count = len(dtype_groups)
    print(f"EXPECTED_DI_PASS_COUNT={expected_count}")

    # Run hardware validation for each dtype group
    dtype_dirs = validate_dtype_groups(dtype_groups, args.hw_host, args.output_dir)

    # Combine and display all results
    print("COMBINED_HARDWARE_VALIDATION_RESULTS_START")
    combined_results = combine_json_results(dtype_dirs)
    print_results_table(combined_results)
    print("COMBINED_HARDWARE_VALIDATION_RESULTS_END")


if __name__ == '__main__':
    main()
