# fmt: on
from dataclasses import dataclass
import re
from typing import Any, ClassVar, Dict, Optional


@dataclass
class KernelMetadataLoader:
    _kernel_dict: ClassVar[Dict[str, Any]] = {}

    @classmethod
    def load_dict(cls, kernel_dict: Dict[str, Any]):
        cls._kernel_dict.update(kernel_dict)

    def get_kernel_dict(self) -> Optional[Dict[str, Any]]:
        return self._kernel_dict

    def get_kernel_for_op(self, op_type: str) -> Optional[dict]:
        """
        Retrieves the kernel from the kernel dictionary using the given op_type.
        """
        return self._kernel_dict.get(op_type)

    def get_native_dtype(
        self, op_type: str, default: Any = None
    ) -> Optional[list[str]]:
        """
        Get the native dtype for a given op_type from the kernel metadata dictionary.
        If the op_type is not found, it will try to find a partial match.
        If the op_type is not found and no partial match is found, it will return the default value.
        """

        if default is None:
            default = ["Not Found"]

        # FIXME: remove this when we have information about the native dtype for belows ops in the Dataflow kernel file.
        if any(
            op_type.startswith(base)
            for base in ["Reshape", "Flatten", "Gather", "GatherElements", "Split", "Squeeze", "Unsqueeze", "ReduceSum"]
        ):
            return ["any"]

        if "Concat" in op_type:
            op_type = re.sub(r"^(Concat)(\d+)(_?)", r"\1\3", op_type)

        op_kernel = self._kernel_dict.get(op_type)
        if op_kernel is None:
            for key in self._kernel_dict:
                if op_type in key and isinstance(self._kernel_dict[key], dict):
                    op_kernel = self._kernel_dict[key]
                    break

        if isinstance(op_kernel, dict):
            return op_kernel.get("native_dtype", default)
        return default

    def get_qdq_selector(self, op_type: str) -> Optional[dict]:
        """
        Retrieves the 'qdq_selector' dictionary for a given op type from the kernel dictionary.
        """
        if not (kernel := self._kernel_dict.get(op_type)):
            return None
        kernel_param = kernel.get("kernel_param", {})
        return kernel_param.get("qdq_selector", {})

    def get_hardware_datatype(self, dtype: str) -> str:
        """
        Return hardware data type (downcast) for given dtype.
        If dtype not present in dtype_downcast_map, return unchanged.
        Example:
          get_hardware_datatype("int64") -> "int16"
          get_hardware_datatype("float32") -> "bfloat16"
          get_hardware_datatype("uint16") -> "uint16"
        """

        mapping: Dict[str, str] = self._kernel_dict.get("dtype_downcast_map", {})
        return mapping.get(dtype, dtype)
