"""
Module to capture kernel params used
while evaluating kernel metadata expr.
"""

from enum import Enum
from typing import Literal
from pydantic import BaseModel, Field, ConfigDict, computed_field


class ForbidExtraBaseModel(BaseModel):
    """
    Base model with extra check to ensure no additional variable are being passed
    """

    model_config = ConfigDict(extra="forbid")


class SyncType(str, Enum):
    """Sync types"""

    ASYNC = "async"
    ASYNC_REUSE = "async_reuse"
    STREAM = "stream"


class SyncTypeParams(ForbidExtraBaseModel):
    """Sync type parameters for I/O interfaces"""

    I0: SyncType = Field(..., description="Input 0 sync type")
    I1: SyncType = Field(..., description="Input 1 sync type")
    I2: SyncType = Field(..., description="Input 2 sync type")
    O0: SyncType = Field(..., description="Output 0 sync type")


class SubVolumeParams(ForbidExtraBaseModel):
    """Base subvolume params used by each OP"""

    H: int = Field(..., description="Height")
    W: int = Field(..., description="Width")
    Co: int = Field(..., description="ChannelOutput")
    Ci: int = Field(..., description="ChannelInput")


class TimeSplitParams(ForbidExtraBaseModel):
    """Base time split params used by each OP"""

    H: int = Field(..., description="Height")
    W: int = Field(..., description="Width")
    Co: int = Field(..., description="ChannelOutput")
    Ci: int = Field(..., description="ChannelInput")

    @computed_field
    @property
    def M(self) -> int:
        """Matrix multiply M dimension (corresponds to H)"""
        return self.H

    @computed_field
    @property
    def K(self) -> int:
        """Matrix multiply K dimension (corresponds to Ci)"""
        return self.Ci

    @computed_field
    @property
    def N(self) -> int:
        """Matrix multiply N dimension (corresponds to Co)"""
        return self.Co


class SubVolume2DParams(ForbidExtraBaseModel):
    """Subvolume parameters for 2D kernels, used for Broadcast Op"""
    R: int = Field(..., description="Rows")
    C: int = Field(..., description="Columns")


class ConvSubVolumeParams(SubVolumeParams):
    """Additional subvolume params fields used by Conv OP"""

    Kh: int = Field(..., description="KernelHeight")
    Kw: int = Field(..., description="KernelWidth")
    Sh: Literal[1, 2] = Field(..., description="StrideHeight")
    Sw: Literal[1, 2] = Field(..., description="StrideWidth")


class ConvTimeSplitParams(TimeSplitParams):
    """Additional time split params fields used by Conv OP"""

    Kh: int = Field(..., description="KernelHeight")


class KernelBaseParams(ForbidExtraBaseModel):
    """
    Used this base class to build variable dict needed to
    evaluate kernel metadata expression

    Note: This class provides the minimal context needed for JSON variable evaluation.
    Derived values (e.g., outer_time_iters, inner_loop) should be defined in the
    JSON "variables" section to avoid duplication and stay synchronized with the
    kernel metadata.
    """

    def eval_dict(self):
        """
        Eval_dict used by metadata loader to prepare dict for expr evaluation
        Automatically includes computed fields.
        """
        d = dict(self)
        # Automatically add all computed fields
        for field_name in self.model_computed_fields.keys():
            d[field_name] = getattr(self, field_name)
        return d


# NOTE: Ideally, var_dict could be built directly from the metadata fields
# (parameters and variables). However, since the metadata contains many fields
# that are currently unused, defining and maintaining a dedicated class for
# kernel parameters provides stricter validation and helps avoid issues with
# unused or extra variables.


class mmult_qdq_int16x8_params(KernelBaseParams):
    """
    Used this class to build variable dict for
    activated_mmult_qdq_int16x8 kernel metadata
    expression evaluations
    """

    subvolume: SubVolumeParams
    time_split: TimeSplitParams
    # Template parameters needed for performance evaluation
    has_actv_sum: Literal[0, 1] = Field(..., description="Activation sum flag")
    vector_coeff: Literal[0, 1, 2] = Field(..., description="Vector coefficient mode")


class mmult_qdq_int16x16_params(KernelBaseParams):
    """
    Used this class to build variable dict for
    activated_mmult_qdq_int16x16 kernel metadata
    expression evaluations
    """

    subvolume: SubVolumeParams
    time_split: TimeSplitParams
    # Override template parameters for int16x16
    # fmt: off
    has_actv_sum: Literal[0, 1, 2] = Field(..., description="Activation sum flag")
    has_vector_coeffs: Literal[0, 1, 2] = Field(..., description="Has vector coefficients")
    # Add vector_coeffs as a parameter (not template)
    vector_coeffs: Literal[0, 1, 2] = Field(..., description="Vector coefficients mode")
    # fmt: on


class conv_int8x8_params(KernelBaseParams):
    """
    Used this class to build variable dict for
    biased_conv_int8x8 kernel metadata expression
    evaluations
    """

    subvolume: ConvSubVolumeParams
    time_split: ConvTimeSplitParams
    hardened_loop: Literal[-1, 1, 2, 3]
    # Template parameters needed for performance evaluation
    has_relu6: Literal[0, 1] = Field(..., description="Has ReLU6 activation")
    has_lrelu: Literal[0, 1] = Field(...,
                                     description="Has Leaky ReLU activation")
    has_bias: Literal[0, 1] = Field(..., description="Has bias")
    # Extra runtime parameters with defaults
    H_outer: Literal[0, 1] = Field(..., description="Outer H loop flag")
    do_bias: Literal[0, 1] = Field(..., description="Bias flag")
    activation: Literal["ReLU", "ReLU6", "linear", "LReLU"] = Field(
        ..., description="Activation type"
    )
    lrelu_alpha: int = Field(..., description="Leaky ReLU alpha")


class conv_qdq_int16x8_params(KernelBaseParams):  # pylint: disable=too-many-public-methods
    """
    Used this class to build variable dict for
    activated_conv_qdq_int16x8 kernel metadata
    expression evaluations.
    """

    subvolume: ConvSubVolumeParams
    time_split: ConvTimeSplitParams
    hardened_loop: Literal[-1] = Field(..., description="Hardened loop flag")
    # Template parameters for conv_qdq (override conv_int8x8 template params)
    has_actv_sum: Literal[0, 1] = Field(..., description="Activation sum flag")
    vector_coeff: Literal[0, 1,
                          2] = Field(..., description="Vector coefficient mode")
    # Extra runtime parameters
    H_outer: Literal[0, 1] = Field(..., description="Outer H loop flag")
    sync_type: SyncTypeParams = Field(
        default_factory=SyncTypeParams, description="Sync type configuration"
    )


class BroadcastQuantizationShifts(KernelBaseParams):
    """Quantization shifts used for Xint8 kernel"""
    I0: int = Field(..., description="Shift for input 0")
    I1: int = Field(..., description="Shift for input 1")
    O0: int = Field(..., description="Shift for output 0")


class Broadcast2DParams(KernelBaseParams):
    """Parameters for 2D broadcast kernels"""
    subvolume: SubVolume2DParams
    subvolume_in1: SubVolume2DParams
    C: int = Field(..., description="OFM Channels")
    R: int = Field(..., description="OFM Rows")
    gran: int = 64
    has_scalar_broadcast: bool = Field(..., description="Has Scalar Broadcast")
    loop_range: int = 8
    use_mmac: bool = False
    max_kernel_cfg_lr: int = 8
    quantization_shifts: BroadcastQuantizationShifts = Field(
        default_factory=BroadcastQuantizationShifts, description="Quantization shifts"
    )
