'''
# Common definitions for convolutional config builders
'''
from dataclasses import dataclass
from typing import Optional
from enum import IntEnum
from pydantic import computed_field

from dmacompiler import (
    set_dev_gen,
    DevGen,
)

from utils.utils_common import (
    BaseMappingWithL1,
    iceil, ceildiv,
    BaseDims,
    BaseShape,
    BaseMappingWithL1AndL2,
)

from scheduler.common import (
    LinearOpType,
)

set_dev_gen(DevGen.Aie4)


def conv_output(i: int, kernel: int, stride: int, pad: int) -> int:
    '''Derive output dimension from conv parameters'''
    o = ((i + (2 * pad) - kernel) // stride) + 1
    return o


def conv_input(o: int, kernel: int, stride: int) -> int:
    '''Derive input dimenson from conv parameters'''
    i = ((o - 1) * stride) + kernel
    return i


def align_Xis(
    Xos: int,
    Cin: int,
    Kx: int,
    Sx: int,
) -> int:
    '''
    This function is used to align the Xis dimension in case of small Cin
    Where the Kx is folded with Cin on the weights
    '''
    return iceil((Xos - 1) * Sx + Kx, max(1, 64 // Cin))


def is_split_valid(
    ofm: tuple[int, int, int],
    ofm_subv: tuple[int, int, int],
    spatial_split: tuple[int, int, int, int],
    enable_over_compute: bool,
) -> bool:
    '''Check that a spatial split meets the divisibility requirements'''
    Yo, Xo, Co = ofm
    Yos, Xos, Cos = ofm_subv
    N_split, Y_split, X_split, Co_split = spatial_split
    N_valid = N_split == 1  # Batch size is typically 1 for conv operations
    Y_loop = ceildiv(Yo, (Yos * Y_split))
    X_loop = ceildiv(Xo, (Xos * X_split))
    Co_loop = ceildiv(Co, (Cos * Co_split))
    Y_valid = ((Yo % (Yos * Y_split)) == 0) or (Y_loop == 1)
    X_valid = ((Xo % (Xos * X_split)) == 0) or (X_loop == 1)
    Co_valid = ((Co % (Cos * Co_split)) == 0) or (Co_loop == 1)
    # NOTE: This loop count requirement ensures that the output tensor
    # traversal is 5-dimensional, so it fits within a single BD. Future
    # development could optionally use reconfiguration to support
    # the 6th dimension. This implementation favors lower control overhead.
    loops_valid = (Y_loop == 1) or (X_loop == 1)
    is_valid = Y_valid and X_valid and Co_valid and loops_valid and N_valid if not enable_over_compute else True
    return is_valid


class ConvShape(BaseShape):
    """Define the shape of a conv operation"""
    kernel: tuple[int, int]
    stride: tuple[int, int]
    padding: tuple[int, int]
    vector_coeff: int
    linear_op_type: LinearOpType
    enable_over_compute: int
    ifm_bits: int
    wgt_bits: int
    ofm_bits: int
    bias_bits: int
    sign_A: int
    sign_W: int
    sign_O: int
    group: int
    transpose_wgts: int = 1
    Ci_orig: Optional[int] = None
    Co_orig: Optional[int] = None


class ConvMapping(BaseMappingWithL1AndL2):
    """Mapping for GEMM operations with auxiliary buffers"""

    @classmethod
    def from_base_mapping(
        cls,
        base_mapping: BaseMappingWithL1,
        ifm_L2_strategy: str,
        wgt_L2_strategy: str,
        ofm_L2_strategy: str,
    ) -> 'ConvMapping':
        """Create ConvMapping from BaseMappingWithL1 by adding L2 strategies"""
        return cls(
            ofm_pad=base_mapping.ofm_pad,
            ifm_pad=base_mapping.ifm_pad,
            ofm_subv=base_mapping.ofm_subv,
            ifm_subv=base_mapping.ifm_subv,
            spatial_split=base_mapping.spatial_split,
            iters=base_mapping.iters,
            kernel_gran=base_mapping.kernel_gran,
            ifm_bits=base_mapping.ifm_bits,
            wgt_bits=base_mapping.wgt_bits,
            ofm_bits=base_mapping.ofm_bits,
            bias_bits=base_mapping.bias_bits,
            ifm_L2_strategy=ifm_L2_strategy,
            wgt_L2_strategy=wgt_L2_strategy,
            ofm_L2_strategy=ofm_L2_strategy,
            l1_alloc=base_mapping.l1_alloc,
        )

    @computed_field  # type: ignore[misc]
    @property
    def tdm_L1_size(self) -> Optional[int]:
        """Get the TDM L1 buffer size"""
        if 'tdm' in self.l1_alloc:
            return self.l1_alloc['tdm'].size
        return None

    @computed_field  # type: ignore[misc]
    @property
    def wght_transpose_sb_L1_size(self) -> Optional[int]:
        """Get the Weights Transpose Scratch Buffer L1 buffer size"""
        if 'wght_t_sb' in self.l1_alloc:
            return self.l1_alloc['wght_t_sb'].size
        return None

    @computed_field  # type: ignore[misc]
    @property
    def vec_L1_size(self) -> Optional[int]:
        """Get the Vector L1 buffer size"""
        if 'vec' in self.l1_alloc:
            return self.l1_alloc['vec'].size
        return None

    @computed_field  # type: ignore[misc]
    @property
    def qdq_L1_size(self) -> Optional[int]:
        """Get the QDQ L1 buffer size"""
        if 'qdq' in self.l1_alloc:
            return self.l1_alloc['qdq'].size
        return None

    @computed_field  # type: ignore[misc]
    @property
    def qdq_L1_ping_addr(self) -> Optional[int]:
        """Get the QDQ L1 ping buffer address"""
        if 'qdq' in self.l1_alloc:
            return self.l1_alloc['qdq'].ping_addr
        return None

    @computed_field  # type: ignore[misc]
    @property
    def tdm_L1_ping_addr(self) -> Optional[int]:
        """Get the TDM L1 ping buffer address"""
        if 'tdm' in self.l1_alloc:
            return self.l1_alloc['tdm'].ping_addr
        return None

    @computed_field  # type: ignore[misc]
    @property
    def wght_transpose_sb_L1_ping_addr(self) -> Optional[int]:
        """Get the Weights Transpose Scratch Buffer L1 ping buffer address"""
        if 'wght_t_sb' in self.l1_alloc:
            return self.l1_alloc['wght_t_sb'].ping_addr
        return None

    @computed_field  # type: ignore[misc]
    @property
    def vec_L1_ping_addr(self) -> Optional[int]:
        """Get the Vector L1 ping buffer address"""
        if 'vec' in self.l1_alloc:
            return self.l1_alloc['vec'].ping_addr
        return None


@dataclass(init=False, slots=True)
class ConvDims(BaseDims):
    '''Unpack shorthand accessors for conv dimensions'''
    Ci_gran: int
    Co_gran: int
    Co_gran_wgt: int
    ifm_bits: int
    ofm_bits: int
    wgt_bits: int
    bias_bits: int
    wgt_L1_size: int
    vector_coeff: int
    enable_over_compute: bool
    tdm_L1_size: int
    vec_L1_size: int
    qdq_param_size: int
    ifm_to_xrt_idx: dict[str, int]
    sign_A: int
    sign_W: int
    sign_O: int
    transpose_wgts: int
    Ci_orig: int
    Co_orig: int
    is_split: int
    wght_transpose_sb_L1_size: int

    def __init__(self, shape: ConvShape, mapping: ConvMapping, enable_over_compute: bool = False):
        # Unpack shape / mapping
        Yi, Xi, Ci = shape.ifm
        Yo, Xo, Co = shape.ofm
        Yis, Xis, Cis = mapping.ifm_subv
        Yos, Xos, Cos = mapping.ofm_subv
        Ky, Kx = shape.kernel
        Sy, Sx = shape.stride
        Py, Px = shape.padding
        N_split, Y_split, X_split, Co_split = mapping.spatial_split

        # Loop counts
        Ci_loop = ceildiv(Ci, Cis)
        Y_loop = ceildiv(Yo, (Yos * Y_split))
        X_loop = ceildiv(Xo, (Xos * X_split))
        Co_loop = ceildiv(Co, (Cos * Co_split))

        # Initialize parent class
        BaseDims.__init__(
            self,
            N=1,
            Yi=Yi, Xi=Xi, Ci=Ci,
            Yo=Yo, Xo=Xo, Co=Co,
            Yis=Yis, Xis=Xis, Cis=Cis,
            Yos=Yos, Xos=Xos, Cos=Cos,
            Ky=Ky, Kx=Kx, Sy=Sy, Sx=Sx, Py=Py, Px=Px,
            N_split=N_split, Y_split=Y_split, X_split=X_split, Co_split=Co_split,
            N_loop=1, Y_loop=Y_loop, X_loop=X_loop, Co_loop=Co_loop, Ci_loop=Ci_loop,
            aie_cols=3, aie_rows=4,
        )

        self.Ci_gran, self.Co_gran = mapping.kernel_gran
        self.Co_gran_wgt = 32  # NOTE Co_gran out=64 wgt=32
        self.wgt_L1_size = mapping.wgt_L1_size
        self.ifm_bits = mapping.ifm_bits
        self.ofm_bits = mapping.ofm_bits
        self.wgt_bits = mapping.wgt_bits
        self.bias_bits = mapping.bias_bits
        self.vector_coeff = shape.vector_coeff
        self.enable_over_compute = enable_over_compute
        self.qdq_param_size = mapping.qdq_L1_size
        self.tdm_L1_size = mapping.tdm_L1_size
        self.wght_transpose_sb_L1_size = mapping.wght_transpose_sb_L1_size
        self.vec_L1_size = mapping.vec_L1_size
        self.transpose_wgts = shape.transpose_wgts
        self.Ci_orig = shape.Ci_orig
        self.Co_orig = shape.Co_orig
        self.is_split = 0
        self.ifm_to_xrt_idx = {"ifm0": 0, "ifm1": 1, "ifm2": 2}
        self.sign_A = shape.sign_A
        self.sign_W = shape.sign_W
        self.sign_O = shape.sign_O
        # assert is_split_valid(shape.ofm, mapping.ofm_subv, mapping.spatial_split, enable_over_compute)
        # if not enable_over_compute:
        #     assert self.Yo == conv_output(self.Yi, self.Ky, self.Sy, self.Py)
        #     assert self.Xo == conv_output(self.Xi, self.Kx, self.Sx, self.Px)

    def __str__(self) -> str:
        return f"{str(BaseDims)}, ConvDims(Ci_gran={self.Ci_gran}, Co_gran={self.Co_gran}, " \
               f"Co_gran_wgt={self.Co_gran_wgt}, " \
               f"Ci_loop={self.Ci_loop}, Y_loop={self.Y_loop}, X_loop={self.X_loop}, " \
               f"Co_loop={self.Co_loop}, ifm_bits={self.ifm_bits}, " \
               f"ofm_bits={self.ofm_bits}, wgt_bits={self.wgt_bits}, " \
               f"bias_bits={self.bias_bits}, wgt_L1_size={self.wgt_bits})"


@dataclass(frozen=True)
class ActMode(IntEnum):
    '''Activation function modes supported by conv kernel'''
    AC_SRS = 0
    AC_RELU = 1
    AC_RELU6 = 2
    AC_LRELU = 3
    AC_HSWISH = 4
