"""This files includes the common functions that will be used across operators"""

from __future__ import annotations
from typing import Dict, Any, Union, TypeVar, Sequence, List, Optional, Tuple, Type
import json
from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace, field
from utils.utils_common import L2Alloc, ReadBins
from scheduler.common import L3Alloc
from dmacompiler import BackEnd, memory_tile, DmaPaddingMap


KERNELS: List[str] = [
    "run_conv_noqdq_a8w8",
    "run_maxpool_int8x8",
    "run_matadd_int8",
    "run_globalavgpool_int8x8",
    "run_gemm_int16x8",
    "run_gemm_int16x4",
    "run_softmax_fp16x16",
    "run_l2norm_fp16x16",
    "run_silu",
    "run_copy_fp16x16",
    "run_dequant",
    "run_quant",
    "run_matadd_16",
    "run_mul_16",
    "run_layernorm_fp16x16",
    "run_conv_qdq_a16w8",
    "run_gemm_int16x16_transpose",
    "run_bdcastadd_16",
    "run_bdcastmul_16",
    "run_bdcastadd_8",
    "run_lut_fp16x16",
    "run_dwc_qdq_a16w8",
    "run_group_norm_qdq",
    "run_bdcastdiv_16",
]


def get_kernel_id(kernel_name: str) -> int:
    """
    Returns the index (ID) of a kernel name from the KERNELS list.
    """
    try:
        return KERNELS.index(kernel_name)
    except ValueError as exc:
        raise ValueError(
            f"Unknown kernel name: '{kernel_name}'. "
            "Register kernel name in `KERNELS` list in buildscripts/common.py "
            f"Valid names are: {', '.join(KERNELS)}"
        ) from exc


class BaseKernelSelector:
    """
    Base class for operator-specific kernel selection.

    Subclasses should implement `select` and decide which metadata key
    to use for a given call. The return value must be the *field name*,
    e.g. 'kernel_names', 'kernel_names_dwc', 'kernel_includes_1', etc.
    """

    def __call__(
        self,
        field_name: str,
        attrs: Dict[str, Any],
        operator: str,
        metadata: Dict[str, Any],
    ) -> str:
        """
        Adapts the instance to the callable interface expected by
        OperatorsRegistry._select_kernel_field.

        Most subclasses will ignore some of these arguments and just use
        what they need (often `attrs` and `field`).
        """
        return self.select(field_name=field_name, attrs=attrs, operator=operator, metadata=metadata)

    def select(
        self,
        field_name: str,
        attrs: Dict[str, Any],
        operator: str,
        metadata: Dict[str, Any],
    ) -> str:
        """
        Decide which field name to use from `metadata`.

        Must be overridden in subclasses.
        """
        raise NotImplementedError("Subclasses must implement `select`.")


class OperatorsRegistry:
    """Class for storing the file name and path for operators"""

    _operators: Dict[str, Any] = {}
    _groups: Dict[str, Dict[str, Any]] = {}

    # Required fields for every operator (flat or nested)
    REQUIRED_FIELDS = {
        "testbench",
        "dataflow_script",
        "build_script",
        "kernel_names",
        "kernel_includes",
    }

    @classmethod
    def add_operator(
        cls, name: Union[str, Sequence[str]], metadata: Dict[str, Any], group_key: Optional[str] = None,
    ) -> None:
        """Adds an operator (one or many names) while enforcing required fields."""
        # Normalize names to a list
        names = [name] if isinstance(name, str) else list(name)
        if not names:
            raise ValueError("At least one operator name must be provided.")
        if not all(isinstance(n, str) and n for n in names):
            raise TypeError("All operator names must be non-empty strings.")

        # Validate once using the first name (for clearer error context)
        cls._validate_operator_fields(names[0], metadata)

        # Register the same metadata under each name
        for n in names:
            if n in cls._operators:
                raise ValueError(f"Operator '{n}' is already registered.")
            cls._operators[n] = metadata  # same dict reference is fine

        # If a group_key is provided, register the group
        if group_key is not None:
            if not isinstance(group_key, str) or not group_key:
                raise TypeError("group_key must be a non-empty string.")
            if group_key in cls._groups:
                raise ValueError(f"Group key '{group_key}' is already registered.")
            cls._groups[group_key] = {
                "names": list(names),
                "metadata": metadata,
            }

    @classmethod
    def _validate_operator_fields(cls, name: str, metadata: Dict[str, Any]) -> None:
        """Validates required fields for an operator."""
        missing_fields = cls.REQUIRED_FIELDS - metadata.keys()
        extra_fields = metadata.keys() - cls.REQUIRED_FIELDS

        if missing_fields:
            raise ValueError(f"Missing required fields for '{name}': {missing_fields}")

        # Allow extra fields only if they are:
        #   - kernel_names_* / kernel_includes_*
        #   - or 'kernel_selector' (callable rule for kernel selection)
        allowed_prefixes = ("kernel_names_", "kernel_includes_")
        allowed_exact = {"kernel_selector"}

        disallowed = {
            f
            for f in extra_fields
            if not (f in allowed_exact or f.startswith(allowed_prefixes))
        }
        if disallowed:
            raise ValueError(
                f"Invalid fields in '{name}': {disallowed}. "
                f"Expected base fields: {cls.REQUIRED_FIELDS} plus optional "
                f"'kernel_names_*' / 'kernel_includes_*' variants and "
                f"'kernel_selector'."
            )

        # Validate data types for required fields
        if not isinstance(metadata["testbench"], list) or not all(
            isinstance(f, str) for f in metadata["testbench"]
        ):
            raise ValueError(f"'testbench' in '{name}' must be a list of strings.")
        if not isinstance(metadata["dataflow_script"], str):
            raise ValueError(f"'dataflow_script' in '{name}' must be a string.")
        if not isinstance(metadata["build_script"], str):
            raise ValueError(f"'build_script' in '{name}' must be a string.")
        # Validate ALL kernel_names* fields
        for key, value in metadata.items():
            if key == "kernel_names" or key.startswith("kernel_names_"):
                if isinstance(value, list):
                    if not all(isinstance(k, str) for k in value):
                        raise ValueError(
                            f"'{key}' in '{name}' must be a list of strings."
                        )
                elif isinstance(value, dict):
                    if not all(isinstance(k, str) for k in value.keys()) or not all(
                        isinstance(v, int) for v in value.values()
                    ):
                        raise ValueError(
                            f"'{key}' in '{name}' must be a dict with string keys "
                            f"and integer values."
                        )
                else:
                    raise ValueError(
                        f"'{key}' in '{name}' must be a list of strings or a dict "
                        f"with string keys and integer values."
                    )

        # Validate ALL kernel_includes* fields
        for key, value in metadata.items():
            if key == "kernel_includes" or key.startswith("kernel_includes_"):
                if not isinstance(value, list) or not all(
                    isinstance(k, str) for k in value
                ):
                    raise ValueError(
                        f"'{key}' in '{name}' must be a list of strings."
                    )

        # Validate optional kernel_selector field (if present)
        if "kernel_selector" in metadata:
            selector = metadata["kernel_selector"]
            if not isinstance(selector, BaseKernelSelector):
                raise ValueError(
                    f"'kernel_selector' in '{name}' must be an instance of "
                    f"BaseKernelSelector (or a subclass). Got {type(selector).__name__!r}."
                )

    @classmethod
    def _select_kernel_field(cls, name: str, field_name: str, attrs: Dict[str, Any]) -> Any:
        """
        Internal helper to select the appropriate kernel field for an operator.

        Args:
            name:  operator name (e.g., 'Conv_qdq_int16xint8xint16')
            field: base field name ('kernel_names' or 'kernel_includes')
            attrs: JSON attributes dict (first value returned by
                   parse_json_to_dict_with_op)

        Behavior:
            1. Fetch operator metadata from the registry.
            2. If a 'kernel_selector' (BaseKernelSelector) is present in the
               metadata, call it to obtain a field name
               (e.g., 'kernel_names_dwc', 'kernel_includes_1').
               - The selector is called as:
                     selector(field=field, attrs=attrs,
                              operator=name, metadata=op_cfg)
               - It must return a string key. If that key is missing in
                 the metadata, we fall back to the base field.
            3. If there is no selector, simply return the base field.
        """
        op_cfg = cls.get_operator(name)
        if not op_cfg:
            raise KeyError(f"Operator '{name}' is not registered.")

        if field_name not in op_cfg:
            raise KeyError(
                f"Missing required field '{field_name}' in metadata for operator '{name}'."
            )

        selector = op_cfg.get("kernel_selector")
        if selector is not None:
            # At this point _validate_operator_fields has already enforced
            # that selector is a BaseKernelSelector instance, so it's callable.
            selected_key = selector(
                field_name=field_name,
                attrs=attrs,
                operator=name,
                metadata=op_cfg,
            )
            if not isinstance(selected_key, str):
                raise TypeError(
                    f"'kernel_selector' for '{name}' must return a field name string, "
                    f"got {type(selected_key).__name__!r} instead."
                )
            if selected_key in op_cfg:
                return op_cfg[selected_key]
            # fall back if the suggested field is missing
            return op_cfg[field_name]

        # Default: base field
        return op_cfg[field_name]

    @classmethod
    def get_kernel_names(cls, name: str, attrs: Dict[str, Any]) -> Any:
        """
        Return the appropriate 'kernel_names' value for an operator, taking
        into account JSON attributes and any registered kernel selection rule.

        If the operator has a 'kernel_selector' (BaseKernelSelector), that
        selector decides which metadata key to use (e.g. 'kernel_names',
        'kernel_names_dwc', 'kernel_names_1', etc.). Otherwise the base
        'kernel_names' field is returned.
        """
        return cls._select_kernel_field(name, "kernel_names", attrs)

    @classmethod
    def get_kernel_includes(cls, name: str, attrs: Dict[str, Any]) -> Any:
        """
        Return the appropriate 'kernel_includes' value for an operator, taking
        into account JSON attributes and any registered kernel selection rule.

        If the operator has a 'kernel_selector' (BaseKernelSelector), that
        selector decides which metadata key to use (e.g. 'kernel_includes',
        'kernel_includes_dwc', 'kernel_includes_1', etc.). Otherwise the base
        'kernel_includes' field is returned.
        """
        return cls._select_kernel_field(name, "kernel_includes", attrs)

    @classmethod
    def get_operators(cls) -> Dict[str, Any]:
        """Returns the dictionary of all operators."""
        return cls._operators

    @classmethod
    def get_operator(cls, name: str) -> Dict[str, Any]:
        """Returns a specific operator's metadata."""
        return cls._operators.get(name, {})

    @classmethod
    def get_operators_by_build_script(cls, build_script: str) -> List[str]:
        """
        Return a list of operator names linked to the given build_script.

        If multiple operator names share one metadata dict (multi-name registration),
        only return the first registered name for that metadata entry.
        """
        seen_metadata = set()
        result = []

        for name, meta in cls._operators.items():
            if meta.get("build_script") == build_script:
                meta_id = id(meta)
                # ensure we only return the first registered name for each metadata group
                if meta_id not in seen_metadata:
                    seen_metadata.add(meta_id)
                    result.append(name)

        return result

    @classmethod
    def get_group(cls, key: str) -> Dict[str, Any]:
        """
        Return a dict for the given group key containing:
          - 'operator_names': list of operator names
          - all metadata fields (testbench, build_script, etc.)
        Returns {} if the key is not found.
        """
        group = cls._groups.get(key)
        if not group:
            return {}
        # Merge operator names with metadata into a single dict
        result = {"operator_names": list(group["names"])}
        result.update(group["metadata"])
        return result

    @classmethod
    def get_group_operator_names(cls, key: str) -> List[str]:
        """Return all operator names registered under a given group key."""
        group = cls._groups.get(key)
        return list(group["names"]) if group else []

    @classmethod
    def get_group_field(cls, key: str, field_name: str) -> Any:
        """
        Convenience accessor:
        - field_name == 'operator_names' -> returns list of names
        - any other field_name -> returns metadata[field_name], or None if missing
        """
        if field_name == "operator_names":
            return cls.get_group_operator_names(key)
        group = cls._groups.get(key)
        if not group:
            return None
        return group["metadata"].get(field_name)


class OpRegistryGroupKey(Enum):
    """
    Enum representing the valid group keys registered in OperatorsRegistry.

    These keys correspond to operator groups such as Conv_A16,
    MatMul_A16W4, and Broadcast_A16, and are used to query or organize
    operator sets throughout the build pipeline.
    """
    CONV_A16 = "conv_a16"
    CONV_DWC_A16 = "conv_dwc_a16"
    MATMUL_A16W4 = "matmul_a16w4"
    MATMUL_A16W8 = "matmul_a16w8"
    #  can only register one key per operator group,
    # need multiple add_operator calls for bdcast
    BDCAST_ADD_A16 = "bdcast_add_a16"
    BDCAST_ADD_A8 = "bdcast_add_a8"
    BDCAST_MUL_A16 = "bdcast_mul_a16"
    BDCAST_MUL_A8 = "bdcast_mul_a8"
    BDCAST_SUB_A16 = "bdcast_sub_a16"
    BDCAST_SUB_A8 = "bdcast_sub_a8"
    BDCAST_DIV_A16 = "bdcast_div_a16"
    BDCAST_DIV_A8 = "bdcast_div_a8"

    ACTXACT_A16 = "actxact_a16"
    LP_NORM = "lp_norm"
    GP_NORM = "gp_norm"
    MUL = "ele_wise_mul"
    Q = "quant"
    DQ = "dequant"


def save_cfg_json(cfg: dict, path: str) -> None:
    """Helper to write dict to JSON file."""
    with open(path, "w", encoding="utf-8") as f:
        json.dump(cfg, f, indent=2, sort_keys=True)


def normalize_shape(shape):
    """
    Normalize input/output tensor shapes to a standard 4D format (N, Y, X, C).
    """
    if shape is None:
        raise ValueError("Input/Output shape not found in JSON")
    if isinstance(shape, int):
        shape = [shape]
    else:
        shape = list(shape)

    if len(shape) == 5:
        if shape[0] == 1:  # discard outermost dim if 1
            shape = shape[1:]
        else:
            raise ValueError(
                f"Unsupported 5D tensor with outermost dimension {shape[0]} != 1"
            )

    if len(shape) > 4:
        raise ValueError(
            f"Unsupported shape {shape}: only up to 4D (or 5D with batch=1) supported"
        )

    # pad with 1s on the left to make it 4D
    shape = [1] * (4 - len(shape)) + shape
    return tuple(shape)


@dataclass(frozen=True)
class BaseOp:
    """
    BaseOp class holds shapes, allocs, dataflow choice,
    I/O dtypes/signs, and read-bin policy.
    """

    # Shape
    Ni: int | list[int]
    Yi: int | list[int]
    Xi: int | list[int]
    Ci: int | list[int]
    No: int
    Yo: int
    Xo: int
    Co: int

    # Allocations
    L2: L2Alloc
    L3: L3Alloc

    # 0 -> L2 dataflow, 1 -> L3 dataflow
    dataflow_type: int

    # Chained DI
    read_bins: ReadBins

    # DType/Sign metadata
    sign_A: int
    sign_W: int
    sign_O: int
    dtype_A: int
    dtype_W: int
    dtype_O: int

    # Runner Specific Config
    wgt_fmt: WGTFormatting

    # DMA Padding Value
    pad_value: int | float
    is_dma_pad: bool


@dataclass
class ScheduleInputs:
    """Standard Dataclass input for compile_dataflow"""

    shape: Any
    mapping: Any
    dataflow_type: int
    L2_alloc: Optional[Any]
    L3_alloc: Optional[Dict[str, Any]]

    # Scheduling metadata (optional; run_op will finalize if None)
    dma_pad: DmaPaddingMap = field(default_factory=DmaPaddingMap)
    backend: Optional[BackEnd] = BackEnd.Adf
    kernel_names: Optional[Dict[str, int]] = None
    kernel_includes: Optional[List[str]] = None
    layer_file_name: Optional[str] = "dma.hpp"


P = TypeVar("P")  # per-op dataclass type


class OpBuild(ABC):
    """Generic orchestrator for op build flow."""

    @staticmethod
    def _get_tile_offset(dct: dict) -> tuple[memory_tile, int]:
        """Coerce {'<tile>': <offset>} into (tile, offset) as ints."""
        ((tile, offset),) = dct.items()
        return memory_tile(int(tile)), int(offset)

    @staticmethod
    def _get_wgt_addr(wgt_addr) -> list[list[int]]:
        """Coerce [[tile, [ping, pong]], ...] → [[memory_tile(tile), ping, pong], ...]."""
        return [[memory_tile(int(t)), int(p), int(q)] for t, (p, q) in wgt_addr]

    @staticmethod
    def _get_prm_addr(prm_addr) -> list[list[int]]:
        """Coerce [[tile, off], ...] → [[memory_tile(tile), off], ...]."""
        return [[memory_tile(int(t)), int(off)] for t, off in prm_addr]

    @staticmethod
    def _get_dma_pad(node_dict: dict, is_model_data: bool) -> tuple[int, bool]:
        """
        Safely extract DMA padding value with correct handling for:
            - None
            - "None"/"none"/"NONE"
            - ints or strings
            - lists and non-lists
        """
        attribute = (node_dict or {}).get("attributes") or {}
        raw = attribute.get("const_padding_value")

        # Case A: missing or actual None
        if raw is None:
            return 0, False

        # Make raw a list so indexing is consistent
        if isinstance(raw, (int, str, type(None))):
            pad_value = [raw]
        else:
            pad_value = raw

        val = pad_value[0]

        # Case B: actual Python None
        if val is None:
            return 0, False

        # Case C: string "None" (any case)
        if isinstance(val, str) and val.lower() == "none":
            return 0, False

        # Case D: proper integer value (model data side)
        if is_model_data:
            try:
                return int(val), True
            except Exception:   # pylint: disable=broad-exception-caught
                return 0, False

        # Case E: metadata side -> padding not needed
        return 0, True

    @abstractmethod
    def default_kernel_names(self) -> Dict[str, int]:
        """default_kernel_names for the operator"""
        return {}

    @abstractmethod
    def default_kernel_includes(self) -> List[str]:
        """default_kernel_includes for the operator"""
        return []

    def default_layer_file_name(self, backend: BackEnd) -> str:
        """default_layer_file_name"""
        return "aie4_dma.cpp" if backend == BackEnd.CertAsm else "dma.hpp"

    @abstractmethod
    def op_type(self) -> Type[P]:
        """Return the BaseOp dataclass for operator"""
        raise NotImplementedError

    @abstractmethod
    def _parse_from_dict(
        self,
        data: dict,
        shim_prm_offset: int,
        shim_wgt_offset: int,
        read_bins: ReadBins,
        read_model_data: bool,
        model_data_path: str,
    ) -> P:
        """Convert input JSON dict into op-specific dataclass"""
        raise NotImplementedError

    @abstractmethod
    def shape(self, op_class: P) -> Any:
        """Return the op's shape object"""
        raise NotImplementedError

    @abstractmethod
    def tiler(self, dims_shape: Any, op_class: P) -> ScheduleInputs:
        """Return the op's tiling and allocation configuration"""
        raise NotImplementedError

    @abstractmethod
    def L2_schedule(self, schedule_input: ScheduleInputs) -> Tuple[int, int]:
        """Return the compile_dataflow for L2_schedule"""
        raise NotImplementedError

    @abstractmethod
    def L3_schedule(self, schedule_input: ScheduleInputs) -> Tuple[int, int]:
        """Return the compile_dataflow for L3_schedule"""
        raise NotImplementedError

    @abstractmethod
    def preproc(self, schedule_input: ScheduleInputs, op_class: P) -> None:
        """Return the function to save preproc directives as JSON"""
        raise NotImplementedError

    def _finalize_schedule_inputs(
        self,
        schedule_input: ScheduleInputs,
        backend: BackEnd,
        kernel_names: Dict[str, int],
        kernel_includes: List[str],
        layer_file_name: str,
    ) -> ScheduleInputs:
        """Only fill missing pieces; preserve anything the op already set"""
        return replace(
            schedule_input,
            backend=backend,
            kernel_names=kernel_names,
            kernel_includes=kernel_includes,
            layer_file_name=layer_file_name,
        )

    def run_op(
        self,
        data: Dict,
        backend: BackEnd = BackEnd.Adf,
        kernel_names: Optional[Dict[str, int]] = None,
        kernel_includes: Optional[List[str]] = None,
        shim_prm_offset: int = 0,
        shim_wgt_offset: int = 0,
        layer_file_name: Optional[str] = None,
        read_bins: ReadBins = ReadBins(0, 0),
        read_model_data: bool = False,
        model_data_path: str = "",
    ) -> Tuple[int, int]:
        """Unified run_op"""
        # Defaults
        if kernel_names is None:
            kernel_names = self.default_kernel_names()
        if kernel_includes is None:
            kernel_includes = self.default_kernel_includes()
        if layer_file_name is None:
            layer_file_name = self.default_layer_file_name(backend)

        # Parse
        if isinstance(data, dict):
            op_class: P = self._parse_from_dict(
                data,
                shim_prm_offset,
                shim_wgt_offset,
                read_bins,
                read_model_data,
                model_data_path,
            )
        elif isinstance(data, self.op_type()):
            op_class = data
        else:
            raise TypeError("Expected type for Shape is Dict")

        # Shape + tiler
        dims_shape = self.shape(op_class)
        schedule_input = self.tiler(dims_shape, op_class)

        # Finalize scheduling metadata into ScheduleInputs
        schedule_input = self._finalize_schedule_inputs(
            schedule_input, backend, kernel_names, kernel_includes, layer_file_name
        )

        # Dispatch by dataflow
        if schedule_input.dataflow_type == 0:
            next_prm, next_wgt = self.L2_schedule(schedule_input)
        elif schedule_input.dataflow_type == 1:
            next_prm, next_wgt = self.L3_schedule(schedule_input)
        else:
            raise AssertionError(
                f"Unsupported dataflow type: {schedule_input.dataflow_type}"
            )

        # Preproc
        self.preproc(schedule_input, op_class)
        return next_prm, next_wgt


@dataclass
class WGTFormatting:
    """Define fields required for WGT Formatting"""

    node_name: str
    model_data_path: str
    read_model_data: bool
    dtype_act: str = "Missing_Value"
    dtype_wgt: str = "Missing_Value"
    dtype_bias: str = "Missing_Value"
    dtype_ofm: str = "Missing_Value"
    dtype_gamma: str = "Missing_Value"
    dtype_beta: str = "Missing_Value"


def dtype_info(dtype: str) -> tuple[int, int]:
    """Return (num_bits, sign) for a given dtype string."""
    dtype = dtype.lower()

    if dtype == "bool":
        return (8, 0)

    bits = "".join(ch for ch in dtype if ch.isdigit())
    bits = int(bits) if bits else 32

    if dtype.startswith("uint"):
        sign = 0
    elif dtype.startswith(("int", "float", "bfloat", "complex")):
        sign = 1
    else:
        sign = 0

    return (bits, sign)


# NOTE: For AIE4 all conv / GEMM kernels with QDQ only support float data type coefficients
DTYPE_C0: int = 32


def bytes_to_bits(num_bytes: int) -> int:
    """
    Convert bytes -> bits.
    1 byte = 8 bits.
    """
    if num_bytes < 0:
        raise ValueError("Number of bytes cannot be negative")
    return num_bytes * 8


def bits_to_bytes(num_bits: int) -> int:
    """
    Convert bits -> bytes.
    Rounds up (ceiling) because partial bytes must be counted.
    Example: 9 bits -> 2 bytes.
    """
    if num_bits < 0:
        raise ValueError("Number of bits cannot be negative")
    return (num_bits + 7) // 8   # ceiling division
