"""
PDI variant selection utilities.

This module selects the most suitable PDI variant for a model/subgraph based on
its required kernel set.

Given:
  - required combined_kernel_names
  - required combined_kernel_includes

It finds a pdi_variants entry that is a superset of the requirements
(can contain extra kernels/includes, but never fewer).

If no suitable variant exists, the module terminates compilation with a hard
failure to prevent generating an invalid or incomplete build artifact.
"""
from __future__ import annotations

import os
import sys
from typing import Any, Dict, List, Mapping, Iterable, Set, Tuple
from utils.unique_pdi_variants import pdi_variants as PDI_VARIANTS


def find_suitable_pdi_variant(
    pdi_variants: Dict[str, Dict[str, Any]],
    required_kernel_names: Dict[str, Any],
    required_kernel_includes: List[str],
) -> str:
    """
    Find a pdi_combination_{idx} in pdi_variants that is a SUPERSET of the required
    kernels/includes (can have more, but not less).

    Matching rules:
      - Kernel names: candidate must contain ALL keys in required_kernel_names.
        (Kernel IDs/values are ignored as you said they won't change.)
      - Includes: candidate must contain ALL strings in required_kernel_includes.

    Returns:
      The key of the matching variant, e.g. "pdi_combination_3".

    Raises:
      RuntimeError if no suitable variant exists.
    """
    def _as_set_includes(includes: Iterable[str]) -> Set[str]:
        if includes is None:
            return set()
        if not isinstance(includes, (list, tuple, set, frozenset)):
            raise TypeError("combined_kernel_includes must be a list/tuple/set of strings")
        return {str(x) for x in includes}

    def _as_set_kernels(names: Mapping[str, Any]) -> Set[str]:
        if names is None:
            return set()
        if not isinstance(names, dict):
            raise TypeError("combined_kernel_names must be a dict")
        return {str(k) for k in names.keys()}

    req_kernels = _as_set_kernels(required_kernel_names)
    req_includes = _as_set_includes(required_kernel_includes)

    best_key: str | None = None
    best_score: Tuple[int, int] | None = None  # prefer minimal extra stuff

    for variant_key, variant in pdi_variants.items():
        cand_names = variant.get("combined_kernel_names", {})
        cand_includes = variant.get("combined_kernel_includes", [])

        cand_kernels = _as_set_kernels(cand_names)
        cand_incs = _as_set_includes(cand_includes)

        if not req_kernels.issubset(cand_kernels):
            continue
        if not req_includes.issubset(cand_incs):
            continue

        # Choose the "tightest" superset: fewest extra kernels, then fewest extra includes
        extra_kernels = len(cand_kernels - req_kernels)
        extra_includes = len(cand_incs - req_includes)
        score = (extra_kernels, extra_includes)

        if best_score is None or score < best_score:
            best_score = score
            best_key = variant_key

    if best_key is None:
        msg = (
            "Model compilation failed: no suitable PDI found. "
            f"Required kernels={sorted(req_kernels)}; "
            f"required includes={sorted(req_includes)}"
        )
        print(msg, file=sys.stderr, flush=True)
        # hard exit
        os._exit(1)

    return best_key


if __name__ == "__main__":
    required = {
      "combined_kernel_names": {'run_quant': 11, 'run_bdcastadd_16': 17, 'run_layernorm_fp16x16': 14, 'run_gemm_int16x8': 4,
                                'run_dequant': 10, 'run_gemm_int16x16_transpose': 16, 'run_softmax_fp16x16': 6, 'run_lut_fp16x16': 20, 'run_l2norm_fp16x16': 7},
      "combined_kernel_includes": ['super.hh', 'q/q.hpp', 'q/q_wrapper.cc', 'broadcast/run_bdcastadd_wrapper.cc', 'q/q_impl.hpp',
                                   'dq/dq_impl.hpp', 'layer_norm_fp16x16/layer_norm_fp16x16_wrapper.cc', 'gemm_qdq_int16x8/gemm_int16x8_wrapper.cc',
                                   'dq/dq.hpp', 'dq/dq_wrapper.cc', 'gemm_qdq_int16x16_transpose/gemm_int16x16_transpose_wrapper.cc',
                                   'softmax_fp16x16/softmax_fp16x16_wrapper.cc', 'linear_approx_bf16/linear_approx_bf16_wrapper.cc',
                                   'l2norm_fp16x16/l2norm_fp16x16_wrapper.cc']
    }

    picked = find_suitable_pdi_variant(
        PDI_VARIANTS,
        required_kernel_names=required["combined_kernel_names"],
        required_kernel_includes=required["combined_kernel_includes"],
    )
    print("PDI Variant Picked:", picked)
