"""
Load kernel metadata as class object which can be
use by tiler/buffer allocator
"""

from __future__ import annotations
from enum import Enum
from dataclasses import dataclass
import os
import ast
from graphlib import TopologicalSorter, CycleError
from typing import Any, Callable, List, Dict, Type, Union, TYPE_CHECKING, Optional, Set
from abc import ABC, abstractmethod
from functools import lru_cache
from operator import attrgetter
from pydantic import BaseModel, Field, ConfigDict
import json5
from cachetools import LRUCache, cachedmethod
import tiler.kernel_metadata.build_kernel_metadata_vars as kmv  # kernel metadata vars
from tiler.base_tiler import npeval, _eval_syms


if TYPE_CHECKING:
    # Type checkers see the full recursive definition
    MetadataValueType = str | List[str] | List[int] | Dict[str, "MetadataValueType"]
else:
    # beartype doesn't support recursive types: https://github.com/beartype/beartype/issues/364
    MetadataValueType = Union[str, List[str], List[int], dict]

EvalExprReturnType = Union[bool, int, float]  # expected value of kernel expr eval


class OpType(str, Enum):
    """Kernel operator types"""
    BIASED_CONV_INT8X8 = "biased_conv_int8x8"
    ACTIVATED_MMULT_QDQ_INT16X8 = "activated_mmult_qdq_int16x8"
    ACTIVATED_MMULT_QDQ_INT16X16 = "activated_mmult_qdq_int16x16"
    ACTIVATED_CONV_QDQ_INT16X8 = "activated_conv_qdq_int16x8"
    # BroadcastShape has the op_name which we use to determine the loader
    BDCAST_ADD2D_UINT8 = "Add_qdq_BroadCast_uint8xuint8xuint8"
    BDCAST_ADD2D_INT8 = "Add_qdq_BroadCast_int8xint8xint8"
    BDCAST_ADD2D_UINT16 = "Add_qdq_BroadCast_uint16xuint16xuint16"
    BDCAST_ADD2D_INT16 = "Add_qdq_BroadCast_int16xint16xint16"
    BDCAST_MUL2D_UINT16 = "Mul_qdq_BroadCast_uint16xuint16xuint16"
    BDCAST_MUL2D_INT16 = "Mul_qdq_BroadCast_int16xint16xint16"
    BDCAST_MUL2D_UINT8 = "Mul_qdq_BroadCast_uint8xuint8xuint8"
    BDCAST_MUL2D_INT8 = "Mul_qdq_BroadCast_int8xint8xint8"
    BDCAST_SUB2D_UINT8 = "Sub_qdq_BroadCast_uint8xuint8xuint8"
    BDCAST_SUB2D_INT8 = "Sub_qdq_BroadCast_int8xint8xint8"
    BDCAST_SUB2D_UINT16 = "Sub_qdq_BroadCast_uint16xuint16xuint16"
    BDCAST_SUB2D_INT16 = "Sub_qdq_BroadCast_int16xint16xint16"
    BDCAST_DIV2D_UINT16 = "Div_qdq_BroadCast_uint16xuint16xuint16"
    BDCAST_DIV2D_INT16 = "Div_qdq_BroadCast_int16xint16xint16"
    BDCAST_DIV2D_UINT8 = "Div_qdq_BroadCast_uint8xuint8xuint8"
    BDCAST_DIV2D_INT8 = "Div_qdq_BroadCast_int8xint8xint8"
    EWISE_ADD2D_UINT8 = "Add_qdq_EleWise_uint8xuint8xuint8"
    EWISE_ADD2D_INT8 = "Add_qdq_EleWise_int8xint8xint8"
    EWISE_ADD2D_UINT16 = "Add_qdq_EleWise_uint16xuint16xuint16"
    EWISE_ADD2D_INT16 = "Add_qdq_EleWise_int16xint16xint16"
    EWISE_MUL2D_UINT16 = "Mul_qdq_EleWise_uint16xuint16xuint16"
    EWISE_MUL2D_INT16 = "Mul_qdq_EleWise_int16xint16xint16"
    EWISE_MUL2D_UINT8 = "Mul_qdq_EleWise_uint8xuint8xuint8"
    EWISE_MUL2D_INT8 = "Mul_qdq_EleWise_int8xint8xint8"
    EWISE_SUB2D_UINT8 = "Sub_qdq_EleWise_uint8xuint8xuint8"
    EWISE_SUB2D_INT8 = "Sub_qdq_EleWise_int8xint8xint8"
    EWISE_SUB2D_UINT16 = "Sub_qdq_EleWise_uint16xuint16xuint16"
    EWISE_SUB2D_INT16 = "Sub_qdq_EleWise_int16xint16xint16"
    EWISE_DIV2D_UINT8 = "Div_qdq_EleWise_uint8xuint8xuint8"
    EWISE_DIV2D_INT8 = "Div_qdq_EleWise_int8xint8xint8"
    EWISE_DIV2D_UINT16 = "Div_qdq_EleWise_uint16xuint16xuint16"
    EWISE_DIV2D_INT16 = "Div_qdq_EleWise_int16xint16xint16"


@dataclass(frozen=True)
class OpMetadataConfig:
    """Configuration for kernel operator metadata"""

    json_file: str
    param_class: Type[kmv.KernelBaseParams]


# Structured mapping of operator types to their metadata configuration
op_to_metadata_mapping: Dict[OpType, OpMetadataConfig] = {
    OpType.ACTIVATED_MMULT_QDQ_INT16X8: OpMetadataConfig(
        json_file="activated_mmult_qdq_int16x8.json",
        param_class=kmv.mmult_qdq_int16x8_params,
    ),
    OpType.BIASED_CONV_INT8X8: OpMetadataConfig(
        json_file="biased_conv_int8x8.json", param_class=kmv.conv_int8x8_params
    ),
    OpType.ACTIVATED_MMULT_QDQ_INT16X16: OpMetadataConfig(
        json_file="activated_mmult_qdq_int16x16.json",
        param_class=kmv.mmult_qdq_int16x8_params,
    ),
    OpType.ACTIVATED_CONV_QDQ_INT16X8: OpMetadataConfig(
        json_file="activated_conv_qdq_int16x8.json",
        param_class=kmv.conv_qdq_int16x8_params,
    ),
}

add2d_bf16x16_ops = [
    OpType.BDCAST_ADD2D_UINT8,
    OpType.BDCAST_ADD2D_INT8,
    OpType.BDCAST_ADD2D_UINT16,
    OpType.BDCAST_ADD2D_INT16,
    OpType.BDCAST_SUB2D_UINT8,
    OpType.BDCAST_SUB2D_INT8,
    OpType.BDCAST_SUB2D_UINT16,
    OpType.BDCAST_SUB2D_INT16,
    OpType.EWISE_ADD2D_UINT8,
    OpType.EWISE_ADD2D_INT8,
    OpType.EWISE_ADD2D_UINT16,
    OpType.EWISE_ADD2D_INT16,
    OpType.EWISE_SUB2D_UINT8,
    OpType.EWISE_SUB2D_INT8,
    OpType.EWISE_SUB2D_UINT16,
    OpType.EWISE_SUB2D_INT16,
]
mul2d_bf16x16_ops = [
    OpType.BDCAST_MUL2D_UINT16,
    OpType.BDCAST_MUL2D_INT16,
    OpType.BDCAST_MUL2D_UINT8,
    OpType.BDCAST_MUL2D_INT8,
    OpType.EWISE_MUL2D_UINT16,
    OpType.EWISE_MUL2D_INT16,
    OpType.EWISE_MUL2D_UINT8,
    OpType.EWISE_MUL2D_INT8,
]

div2d_bf16x16_ops = [
    OpType.BDCAST_DIV2D_UINT16,
    OpType.BDCAST_DIV2D_INT16,
    OpType.BDCAST_DIV2D_UINT8,
    OpType.BDCAST_DIV2D_INT8,
    OpType.EWISE_DIV2D_UINT16,
    OpType.EWISE_DIV2D_INT16,
    OpType.EWISE_DIV2D_UINT8,
    OpType.EWISE_DIV2D_INT8,
]

for optype in add2d_bf16x16_ops:
    op_to_metadata_mapping[optype] = OpMetadataConfig(
        json_file="add2d_bf16x16.json", param_class=kmv.Broadcast2DParams
    )

for optype in mul2d_bf16x16_ops:
    op_to_metadata_mapping[optype] = OpMetadataConfig(
        json_file="mul_bf16x16.json", param_class=kmv.Broadcast2DParams
    )

for optype in div2d_bf16x16_ops:
    op_to_metadata_mapping[optype] = OpMetadataConfig(
        json_file="div_bf16x16.json", param_class=kmv.Broadcast2DParams
    )


# Module-level cached function for loading JSON metadata
@lru_cache(maxsize=128)
def _load_json_metadata(metadata_path: str) -> Dict[str, Any]:
    """
    Load and cache JSON metadata from file.

    Thread-safe and prevents duplicate loads when called concurrently.
    Each unique metadata file is loaded only once across all threads.

    Args:
        metadata_path: Absolute path to JSON metadata file

    Returns:
        Parsed JSON data as dictionary

    Example:
        >>> import os
        >>> path = os.path.expandvars("$AIE4_ROOT_DIR/tiler/kernel_metadata/activated_mmult_qdq_int16x8.json")
        >>> _load_json_metadata.cache_clear()
        >>> data1 = _load_json_metadata(path)
        >>> data2 = _load_json_metadata(path)
        >>> data1 == data2  # Same content
        True
        >>> _load_json_metadata.cache_info().currsize  # Cached after first load
        1
    """
    with open(metadata_path, encoding="utf-8") as f:
        return json5.load(f)


class BaseOpParserCfg(BaseModel):
    """Config class used in BaseOpParser.init()"""

    # NOTE - Need to use "Any" because of json dict having unclear valuetype
    json_data: Dict[str, Any]  # store metadata per field
    op_type: str  # op_name


class BaseOpParser(ABC):
    """Abstract base class for kernel metadata parser"""

    def __init__(self, cfg: BaseOpParserCfg):
        """cfg: BaseOpParserCfg instance with json_data and op_type"""
        self.json_data = cfg.json_data
        self.op_type = cfg.op_type

    @abstractmethod
    def get_granularity(self) -> Dict[str, MetadataValueType]:
        """Abstract function implemented by OpParser to load granularity"""

    @abstractmethod
    def get_minimums(self) -> Dict[str, MetadataValueType]:
        """Abstract function implemented by OpParser to load expressions"""

    @abstractmethod
    def get_bank_placements(self) -> Dict[str, MetadataValueType]:
        """Abstract function implemented by OpParser to load bank placement"""

    def __get_field__(self, field: str) -> Dict[str, Any] | MetadataValueType:
        """Generic getter for nested fields using dot notation."""
        parts = field.split(".")
        value = self.json_data
        for part in parts:
            if isinstance(value, dict) and part in value:
                value = value[part]
            else:
                raise ValueError(f"{field} field is missing.")
        return value

    def get_name(self) -> str:
        """get function to return metadata field name"""
        return self.__get_field__("name")

    def get_templates(self) -> Dict[str, MetadataValueType]:
        """get function to return metadata field- name"""
        return self.__get_field__("templates")

    def get_parameters(self) -> Dict[str, MetadataValueType]:
        """get function to return metadata field- parameters"""
        return self.__get_field__("parameters")

    def get_variables(self) -> Dict[str, MetadataValueType]:
        """get function to return metadata field- variables"""
        return self.__get_field__("variables")

    def get_requirements(self) -> Dict[str, MetadataValueType]:
        """get function to return metadata field- requirements"""
        return self.__get_field__("requirements")

    def get_kernel_interface(self) -> Dict[str, MetadataValueType]:
        """get function to return metadata field- kernel interface"""
        return self.__get_field__("kernel_interface")

    def get_performance(self) -> Dict[str, MetadataValueType]:
        """get function to return metadata field- performance"""
        return self.__get_field__("performance")

    def get_kernel_parameter_setup(self) -> Dict[str, MetadataValueType]:
        """get function to return metadata field- kernel_parameter_setup"""
        return self.__get_field__("kernel_parameter_setup")


# Default OpParser impl
class OpParser(BaseOpParser):
    """Default kernel metadata parser class"""

    def get_granularity(self) -> MetadataValueType:
        """OpParser implmementaion to get granularity"""
        return self.__get_field__("requirements.granularity")

    def get_minimums(self) -> MetadataValueType:
        """OpParser implmementaion to get minimums"""
        return self.__get_field__("requirements.minimums")

    def get_bank_placements(self) -> MetadataValueType:
        """OpParser implmementaion to get bank placement"""
        return self.__get_field__("requirements.bank_placements")


class KernelMetadataLoaderCfg(BaseModel):
    """Cfg class for kernel metadata loader"""

    model_config = ConfigDict(arbitrary_types_allowed=True)
    op_type_to_json: Dict[OpType, OpMetadataConfig] = (
        op_to_metadata_mapping  # map to op_name to metadata
    )
    parsers: Dict[OpType, Type[BaseOpParser]] = Field(
        default_factory=dict
    )  # local cache to store metadata
    parser_cls: Type[BaseOpParser] = OpParser  # parsing method
    param_cls: Dict[OpType, Type[kmv.KernelBaseParams]] = Field(
        default_factory=dict
    )  # local cache to store metadata param class


class MetadataReturnType(BaseModel):
    """Return type of KernelMetadataLoader.get_metadata()"""

    model_config = ConfigDict(arbitrary_types_allowed=True)
    parser: BaseOpParser
    param_class: Type[kmv.KernelBaseParams]


class KernelMetadataLoader:
    """
    Kernel metadata loader used by other module to load and eval kernel requirements etc.

    Thread-safe and process-safe:
    - JSON file loading is cached via @synchronized_cache to prevent duplicate I/O
    - Multiple threads/processes loading the same metadata will share cached data
    - No thundering herd problem when loading metadata concurrently
    """

    def __init__(self, cfg: KernelMetadataLoaderCfg):
        """
        cfg: KernelMetadataLoaderCfg instance with op_type_to_json, parser_class and parsers
        """
        self.op_type_to_json: Dict[OpType, OpMetadataConfig] = cfg.op_type_to_json
        self.parsers: Dict[OpType, Type[BaseOpParser]] = cfg.parsers
        self.param_cls: Dict[OpType, Type[kmv.KernelBaseParams]] = cfg.param_cls
        self.parser_cls: Type[BaseOpParser] = cfg.parser_cls

    def __get_parser(self, op_type: OpType) -> BaseOpParser:
        """private function to get op metadata parser"""
        if op_type not in self.parsers:
            raise ValueError(f"Parser for '{op_type}' not loaded. Call load() first.")
        return self.parsers.get(op_type)

    def __get_metadata_param_class(self, op_type: OpType) -> Type[kmv.KernelBaseParams]:
        """private function to get op metadata params"""
        if op_type not in self.param_cls:
            raise ValueError(
                f"metadata_param_class for '{op_type}' not loaded. Call load() first."
            )
        return self.param_cls.get(op_type, None)

    def load(self, op_type: OpType):
        """
        Load and store metadata and params class name per op in parsers/params_cls.

        Uses cached JSON loading to prevent duplicate file I/O when called
        from multiple threads or processes with the same op_type.
        """
        if op_type not in self.op_type_to_json:
            raise ValueError(f"No JSON file mapped for op_type: {op_type}")
        metadata_file = self.op_type_to_json[op_type].json_file
        self.param_cls[op_type] = self.op_type_to_json[op_type].param_class
        metadata_path = os.path.expandvars(
            f"$AIE4_ROOT_DIR/tiler/kernel_metadata/{metadata_file}"
        )

        # Use cached JSON loading - prevents duplicate file reads
        data = _load_json_metadata(metadata_path)

        parser = self.parser_cls(BaseOpParserCfg(json_data=data, op_type=op_type))
        self.parsers[op_type] = parser

    def get_metadata(self, op_type: OpType) -> MetadataReturnType:
        """return metadata and param class name per OP"""
        if op_type not in self.parsers:
            raise ValueError(f"Parser for '{op_type}' not loaded. Call load() first.")
        return MetadataReturnType(
            parser=self.__get_parser(op_type),
            param_class=self.__get_metadata_param_class(op_type),
        )

    def __infer_parser(
        self, metadata_params: kmv.KernelBaseParams, op_type: Optional[OpType] = None
    ) -> Optional[BaseOpParser]:
        """
        Infer parser from op_type or metadata_params class.

        Args:
            metadata_params: Kernel parameters to infer from if op_type not provided
            op_type: Optional operation type (if provided, used directly)

        Returns:
            Parser instance, or None if no matching parser found
        """
        if op_type is not None:
            return self.__get_parser(op_type)

        # Infer op_type from the params class
        param_class = type(metadata_params)
        for loaded_op_type in self.parsers.keys():
            if self.__get_metadata_param_class(loaded_op_type) == param_class:
                return self.__get_parser(loaded_op_type)

        return None

    def build_eval_context(
        self,
        metadata_params: kmv.KernelBaseParams,
        op_type: Optional[OpType] = None,
        extra_params: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """
        Build a complete evaluation context for expressions.

        This combines:
        1. All fields from metadata_params (via eval_dict())
        2. JSON variables from metadata
        3. kernel_parameter_setup variables from metadata
        4. Any additional extra_params

        Args:
            metadata_params: Kernel parameters (subvolume, time_split, etc.)
            op_type: Optional operation type (inferred from params class if not provided)
            extra_params: Additional parameters to add to context

        Returns:
            Dictionary with all variables needed for expression evaluation
        """
        # Start with parameters from metadata_params
        eval_context = metadata_params.eval_dict()

        # Add extra parameters if provided
        if extra_params:
            eval_context.update(extra_params)

        # Infer parser from op_type or params class
        parser = self.__infer_parser(metadata_params, op_type)

        # If we found a parser, combine and evaluate all variables together
        if parser is not None:
            # Collect all variables from both sections
            all_variables = {}

            # Get variables section
            variables = parser.get_variables()
            if variables:
                all_variables.update(variables)

            # Get and parse kernel_parameter_setup section
            kernel_params = parser.get_kernel_parameter_setup()
            if kernel_params and isinstance(kernel_params, dict):
                param_variables = self.__parse_kernel_params(kernel_params)
                if param_variables:
                    all_variables.update(param_variables)

            # Evaluate all variables together in one topological sort
            if all_variables:
                eval_context = self.__eval_variables_in_order(
                    all_variables, eval_context
                )

        return eval_context

    def eval_expr(
        self,
        expr: str,
        metadata_params: kmv.KernelBaseParams,
        op_type: Optional[OpType] = None,
    ) -> EvalExprReturnType:
        """
        Evaluate an expression using metadata_params and JSON variables.

        Args:
            expr: The expression to evaluate
            metadata_params: Kernel parameters (provides base context)
            op_type: Optional operation type (inferred from params class if not provided)

        Returns:
            Evaluated result (bool or int)
        """
        try:
            # Build full evaluation context including JSON variables
            var_dict = self.build_eval_context(metadata_params, op_type)
            val = npeval(expr, var_dict)  # pylint: disable=eval-used
            if not isinstance(val, EvalExprReturnType):
                raise ValueError(
                    f"Expression '{expr}' did not evaluate to a boolean or integer."
                )
            return val
        except Exception as e:
            print(
                f"Error: {e} for expression '{expr}' with params {metadata_params.eval_dict()}."
            )
            raise e

    def eval_performance(
        self,
        metadata_params: kmv.KernelBaseParams,
        op_type: OpType,
        extra_params: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Union[int, float, bool]]:
        """
        Evaluate all performance expressions from kernel metadata.

        This method:
        1. Loads the variables section and evaluates them in dependency order
        2. Loads the performance section expressions
        3. Evaluates each performance metric using parameters + variables context
        4. Returns a dictionary of metric_name -> computed_value

        Args:
            metadata_params: Kernel parameters (subvolume, time_split, etc.)
            op_type: Operation type to get metadata for
            extra_params: Additional parameters not in metadata_params (e.g., H_outer, do_bias)

        Returns:
            Dictionary mapping performance metric names to their computed values

        Example:
            >>> from tiler.kernel_metadata import build_kernel_metadata_vars as kmv
            >>> from tiler.load_kernel_metadata import KernelMetadataLoader, KernelMetadataLoaderCfg, OpType
            >>> loader = KernelMetadataLoader(KernelMetadataLoaderCfg())
            >>> loader.load(OpType.BIASED_CONV_INT8X8)
            >>> params = kmv.conv_int8x8_params(
            ...     subvolume=kmv.ConvSubVolumeParams(H=1, W=64, Ci=64, Co=64, Kh=3, Kw=3, Sh=1, Sw=1),
            ...     time_split=kmv.ConvTimeSplitParams(H=1, W=64, Ci=64, Co=64, Kh=3),
            ...     hardened_loop=1,
            ...     has_relu6=1,
            ...     has_lrelu=1,
            ...     has_bias=1,
            ...     H_outer=0,
            ...     do_bias=1,
            ...     activation='ReLU',
            ...     lrelu_alpha=0
            ... )
            >>> perf = loader.eval_performance(params, OpType.BIASED_CONV_INT8X8, {})
            >>> isinstance(perf, dict)
            True
        """
        # Build full evaluation context including JSON variables and extra params
        eval_context = self.build_eval_context(metadata_params, op_type, extra_params)

        # Get parser to access performance section
        parser = self.__get_parser(op_type)

        # Get performance section
        performance = parser.get_performance()
        if not performance or not isinstance(performance, dict):
            return {}

        # Evaluate each performance expression
        perf_results = {}
        for metric_name, expr in performance.items():
            if isinstance(expr, str):
                try:
                    value = npeval(expr, eval_context)
                    # Store result and add to context for dependent expressions
                    perf_results[metric_name] = value
                    eval_context[metric_name] = value
                except Exception as e:
                    raise ValueError(
                        f"Failed to evaluate performance metric '{metric_name}': {expr}"
                    ) from e
            else:
                # Non-string values (constants)
                perf_results[metric_name] = expr
                eval_context[metric_name] = expr

        return perf_results

    def __extract_variable_dependencies(
        self, expr: str, available_vars: Set[str]
    ) -> Set[str]:
        """
        Extract variable dependencies from an expression using AST parsing.

        Args:
            expr: Expression string to analyze
            available_vars: Set of variable names that could be dependencies

        Returns:
            Set of variable names that this expression depends on
        """
        try:
            # Strip whitespace to avoid SyntaxError from leading/trailing spaces
            tree = ast.parse(expr.strip(), mode="eval")
            dependencies = set()

            for node in ast.walk(tree):
                if isinstance(node, ast.Name):
                    # Only track dependencies on variables we're evaluating
                    # Skip built-in functions and symbols from _eval_syms
                    if node.id in available_vars and node.id not in _eval_syms:
                        dependencies.add(node.id)

            return dependencies
        except Exception:  # pylint: disable=W0718
            return set()  # Exceptions mean we can't extract dependencies

    def __eval_variables_in_order(
        self, variables: Dict[str, Union[str, int, bool]], initial_context: dict
    ) -> dict:
        """
        Evaluate variables in dependency order using topological sort.

        Variables can reference each other, so we use AST parsing to extract
        dependencies and topological sort to determine the correct evaluation order.

        Args:
            variables: Dictionary of variable_name -> expression
            initial_context: Initial evaluation context (from parameters)

        Returns:
            Updated context dictionary with all variables evaluated
        """
        context = initial_context.copy()

        # Build dependency graph for ALL string expressions
        # Don't add constants to context yet - wait until evaluation
        available_vars = set(variables.keys()) | set(context.keys())
        dependency_graph = {}

        for var_name, expr in variables.items():
            if isinstance(expr, str):
                deps = self.__extract_variable_dependencies(expr, available_vars)
                # Only filter out dependencies from initial_context (parameters),
                # not from variables we're about to evaluate
                # Also remove self-references to avoid circular dependencies
                deps_on_variables = deps - set(initial_context.keys()) - {var_name}
                dependency_graph[var_name] = deps_on_variables
            elif isinstance(expr, (int, bool, float)):
                # Constants have no dependencies
                dependency_graph[var_name] = set()

        # Topologically sort and evaluate
        try:
            sorter = TopologicalSorter(dependency_graph)
            for var_name in sorter.static_order():
                if var_name in variables:
                    expr = variables[var_name]
                    if isinstance(expr, str):
                        try:
                            value = npeval(expr, context)
                            context[var_name] = value
                        except Exception as e:
                            raise ValueError(
                                f"Cannot evaluate variable '{var_name}': {expr}"
                            ) from e
                    elif isinstance(expr, (int, bool, float)):
                        context[var_name] = expr
        except CycleError as e:
            raise ValueError(
                f"Circular dependency detected in variable: {e}, vars: {list(variables.keys())}"
            ) from e
        except ValueError as e:
            raise ValueError(
                f"Cannot evaluate variable: {e}, vars: {list(variables.keys())}"
            ) from e

        return context

    def __parse_kernel_params(
        self, kernel_params: Dict[str, Union[str, dict]]
    ) -> Dict[str, str]:
        """
        Parse kernel_parameter_setup section and extract variable expressions.

        The kernel_parameter_setup has entries like:
            "datatype varname": "expression"
        or for conditional types:
            "datatype varname": {"key": "...", "values": {...}}

        This method extracts the variable name and expression, stripping the type prefix.

        C-specific constructs filtering:
        Expressions containing sizeof() are automatically filtered out because:
        1. sizeof() is a C operator not available in Python
        2. dtype is a C struct, not a Python object with a sizeof method
        3. These expressions are meant purely for C code generation
        4. Example from activated_mmult_qdq_int16x16.json:
           "uint16_t step_Xi": "64 * M * sizeof(dtype.I0)"
           This calculates byte strides for C code but cannot be evaluated in Python.

        Args:
            kernel_params: Dictionary from kernel_parameter_setup section

        Returns:
            Dictionary of varname -> expression (type prefix removed, C-only expressions skipped)
        """
        variables = {}

        for type_and_name, expr in kernel_params.items():
            # Skip non-string expressions (conditionals are handled elsewhere)
            if not isinstance(expr, str):
                continue

            # Skip C-specific expressions that can't be evaluated in Python
            if "sizeof(" in expr:
                raise ValueError(
                    f"C-style expression {expr} containing sizeof(..) is not supported"
                )

            # Parse "datatype varname" or "datatype varname:bitfield"
            # Split on last space to handle types like "dims_2d_param"
            parts = type_and_name.rsplit(" ", 1)
            if len(parts) == 2:
                var_name = parts[1].split(":")[0]  # Remove bitfield notation if present
                variables[var_name] = expr
            else:
                raise ValueError(
                    f"Unexpected kernel_param key format, expected 'type var_name', found {type_and_name}"
                )

        return variables

    @staticmethod
    def build(op: OpType):
        """Static helper function to load metadata and return loader instance"""
        loader = KernelMetadataLoader(KernelMetadataLoaderCfg())
        loader.load(op)
        return loader

    @staticmethod
    def build_with_defaultcfg():
        """Static helper function to load metadata and return loader instance"""
        loader = KernelMetadataLoader(KernelMetadataLoaderCfg())
        return loader


class KernelMetadataForOp:
    """Helper class to load kernel metadata for a specific op type"""

    def __init__(self, op_type: OpType):
        self.op_type = op_type
        self.loader = KernelMetadataLoader.build(op_type)
        # Cache for validate() method
        self._validate_cache = LRUCache(maxsize=256)
        self._perf_cache = LRUCache(maxsize=256)

    @lru_cache(maxsize=128)
    def get_metadata(self) -> MetadataReturnType:
        """return metadata and param class name"""
        return self.loader.get_metadata(self.op_type)

    @lru_cache(maxsize=128)
    def get_parser(self) -> BaseOpParser:
        """return metadata parser instance"""
        return self.get_metadata().parser

    def __check_condition__(
        self,
        metadata_params: kmv.KernelBaseParams,
        getter: Callable[[BaseOpParser], BaseOpParser],
    ) -> None:
        """check if condition meets kernel requirements"""
        parser = self.get_parser()
        req_granularity = getter(parser)

        # Special handling for performance dict - ensure all expressions are evaluatable
        if isinstance(req_granularity, dict):
            # Use eval_performance which handles dependency order
            try:
                self.loader.eval_performance(metadata_params, self.op_type)
            except Exception as e:
                raise ValueError(
                    f"Performance expressions are not evaluatable: {e}"
                ) from e
            return

        req_granularity = (
            req_granularity if isinstance(req_granularity, list) else [req_granularity]
        )
        for req in req_granularity:
            expr_result = self.loader.eval_expr(req, metadata_params)
            if isinstance(expr_result, bool) and not expr_result:
                raise ValueError(
                    f"Granularity expression '{req}' not satisfied for {metadata_params.eval_dict()}."
                )

    @cachedmethod(
        cache=attrgetter("_validate_cache"),
        key=lambda self, params: hash((self.op_type, params.model_dump_json())),
    )
    def validate(self, metadata_params: kmv.KernelBaseParams) -> None:
        """
        Validate incoming metadata params against kernel requirements.

        Results are cached using hash of JSON serialization for fast lookups.
        Pydantic's model_dump_json() is deterministic and ~2x faster than sorting.
        """
        self.__check_condition__(
            metadata_params, lambda parser: parser.get_granularity()
        )
        self.__check_condition__(metadata_params, lambda parser: parser.get_minimums())
        self.__check_condition__(
            metadata_params, lambda parser: parser.get_performance()
        )

    @cachedmethod(
        cache=attrgetter("_perf_cache"),
        key=lambda self, params: hash((self.op_type, params.model_dump_json())),
    )
    def eval_performance(
        self, metadata_params: kmv.KernelBaseParams
    ) -> dict[str, int | float | bool]:
        """Get performance table"""
        return self.loader.eval_performance(metadata_params, self.op_type)
