"""
Map nodenames from unique-nodes JSON to alloc JSON keys.

Modes:
1) Default (no -m/--assoc-map):
   - If -k PREFIX is given: print comma-separated alloc keys for nodenames[0]
     where the unique-node key starts with PREFIX.
   - Else: print "<prefix>: k1,k2,..." for all prefixes discovered from unique-node keys.

2) Association Map (-m/--assoc-map):
   - Ignore -k. For each unique-node block, find alloc key for each entry in its
     'nodenames' list (including nodenames[0]) and print:
       "<alloc_key_for_nodenames0>: k0,k1,k2,..."
   - Also returned as a dict when used as a library call.

This module can be used as a CLI or imported and called via `process_unique_alloc(...)`.
"""

import argparse
import json
import sys
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Mapping, Optional
from utils.utils_common import log


def load_json(path: str) -> Any:
    """Load JSON from file."""
    try:
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"ERROR: File not found: {path}", file=sys.stderr)
        sys.exit(2)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON in {path}: {e}", file=sys.stderr)
        sys.exit(2)


def build_name_to_alloc_key(alloc: Mapping[str, Any]) -> Dict[str, str]:
    """
    Build and return a map from alloc 'name' -> alloc key (string).
    Only takes entries where 'name' is a string; first occurrence wins.
    """
    out: Dict[str, str] = {}
    for akey, ablock in alloc.items():
        if isinstance(ablock, dict):
            nm = ablock.get("name")
            if isinstance(nm, str) and nm not in out:
                out[nm] = str(akey)
    return out


def _dedup_preserve(xs: Iterable[str]) -> List[str]:
    """Return xs with duplicates removed while preserving order."""
    seen: set[str] = set()
    out: List[str] = []
    for x in xs:
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out


def process_unique_alloc(
    unique_path: str,
    alloc_path: str,
    startswith: Optional[str],
    association_map: bool,
) -> Dict[str, List[str]]:
    """
    Core entry point. Load JSONs and perform mapping logic.

    Args:
        unique_path: Path to unique-nodes JSON.
        alloc_path: Path to alloc JSON.
        startswith: Optional prefix to filter unique-node keys (ignored if association_map=True).
        association_map: If True, build {alloc_key(nodenames[0]): [alloc_keys_of_all_nodenames_including_self]}.

    Returns:
        In association mode: dict as specified above.
        In non-association mode:
            - If startswith provided: {"__flat__": [k1, k2, ...]} where list is the flat, ordered keys.
            - Else: {prefix: [k1, k2, ...], ...} grouped by operator prefix.
        Missing names are skipped with warnings printed to stderr.
    """
    unique_nodes = load_json(unique_path)
    alloc = load_json(alloc_path)
    name_to_key = build_name_to_alloc_key(alloc)

    def get_nodenames(block: Any) -> List[str]:
        v = block.get("nodenames") if isinstance(block, dict) else None
        return [x for x in v if isinstance(x, str)] if isinstance(v, list) else []

    if association_map:
        # Build per-unique-block association including self (nodenames[0]).
        assoc: Dict[str, List[str]] = {}
        for ukey, block in unique_nodes.items():
            if not isinstance(ukey, str) or not isinstance(block, dict):
                continue
            nodenames = get_nodenames(block)
            if not nodenames:
                log(f"WARNING: Block '{ukey}' missing 'nodenames'", file=sys.stderr)
                continue

            # Map all nodenames (including first) -> alloc keys
            mapped = []
            for nm in nodenames:
                k = name_to_key.get(nm)
                if k is None:
                    log(f"WARNING: No alloc entry with name='{nm}' (block '{ukey}')", file=sys.stderr)
                    continue

                alloc_block = alloc.get(k, {})
                if not alloc_block.get("is_compilable", False):
                    log(
                        f"WARNING: alloc key '{k}' for nodename '{nm}' is not compilable; skipping",
                        file=sys.stderr
                    )
                    continue

                mapped.append(k)

            if not mapped:
                continue

            primary = mapped[0]  # unique key = alloc key for nodenames[0]
            assoc[primary] = _dedup_preserve(mapped)  # include primary; order-preserving
        return assoc

    # Non-association modes
    if startswith:
        flat: List[str] = []
        for ukey, block in unique_nodes.items():
            if not isinstance(ukey, str) or not isinstance(block, dict):
                continue
            if not ukey.startswith(startswith):
                continue
            nodenames = get_nodenames(block)
            if not nodenames:
                log(f"WARNING: Block '{ukey}' missing 'nodenames'", file=sys.stderr)
                continue
            nm0 = nodenames[0]
            k = name_to_key.get(nm0)
            if k is None:
                log(f"WARNING: No alloc entry with name='{nm0}' (from '{ukey}')", file=sys.stderr)
            else:
                flat.append(k)
        return {"__flat__": flat}

    # Group by operator prefix from unique-node key (prefix = split('_', 1)[0])
    grouped: Dict[str, List[str]] = defaultdict(list)
    for ukey, block in unique_nodes.items():
        if not isinstance(ukey, str) or not isinstance(block, dict) or "_" not in ukey:
            continue
        prefix = ukey.split("_", 1)[0]
        nodenames = get_nodenames(block)
        if not nodenames:
            log(f"WARNING: Block '{ukey}' missing 'nodenames'", file=sys.stderr)
            continue
        nm0 = nodenames[0]
        k = name_to_key.get(nm0)
        if k is None:
            log(f"WARNING: No alloc entry with name='{nm0}' (from '{ukey}')", file=sys.stderr)
            continue
        grouped[prefix].append(k)
    return dict(grouped)


def main(argv: Optional[List[str]] = None) -> None:
    """Parse CLI args and print results."""
    p = argparse.ArgumentParser(
        description="Map unique-nodes nodenames to alloc keys; optional association map."
    )
    p.add_argument("-u", "--unique", required=True, help="Path to unique-nodes JSON file")
    p.add_argument("-a", "--alloc", required=True, help="Path to alloc JSON file")
    p.add_argument("-k", "--starts-with", help="Filter unique-node keys by this prefix (ignored with -m)")
    p.add_argument(
        "-m", "--assoc-map", action="store_true",
        help="Build association: <alloc_key(nodenames[0])>: all_alloc_keys_for_that_block (includes self)"
    )
    args = p.parse_args(argv)

    result = process_unique_alloc(
        unique_path=args.unique,
        alloc_path=args.alloc,
        startswith=args.starts_with,
        association_map=bool(args.assoc_map),
    )

    # Print according to mode
    if args.assoc_map:
        for primary, keys in result.items():
            print(f"{primary}: {','.join(keys)}")
    else:
        if "__flat__" in result:
            print(",".join(result["__flat__"]))
        else:
            for prefix, keys in result.items():
                if keys:
                    print(f"{prefix}: {','.join(keys)}")


if __name__ == "__main__":
    main()
