"""
Contains all common logic / data structure shared between dataflow and tiling stage
"""

import os
from dataclasses import dataclass
from datetime import datetime
from typing import List, Tuple, Optional, Union, NamedTuple
from abc import ABC
from pydantic import BaseModel, ConfigDict, computed_field, model_validator
from dmacompiler import AieTile

#
# Arithmetic Helpers
#


def ceildiv(x: int, d: int) -> int:
    """Integer ceiling division of x by d"""
    return (x + (d - 1)) // d


def floordiv(x: int, d: int) -> int:
    """Integer floor division of x by d"""
    return x // d


def iceil(x: int, d: int) -> int:
    """Integer ceiling function"""
    return ceildiv(x, d) * d


def ifloor(x: int, d: int) -> int:
    """Integer floor to nearest multiple of d"""
    return (x // d) * d


#
# BaseShape
#

Shape3 = NamedTuple("Shape3", [("Y", int), ("X", int), ("C", int)])
Shape4 = NamedTuple("Shape4", [("N", int), ("Y", int), ("X", int), ("C", int)])
Shape = Union[Shape3, Shape4]
Shapes = Union[Shape, List[Shape]]


class BaseShape(BaseModel, ABC):
    """Define the shape of any operation"""

    model_config = ConfigDict(frozen=True)
    ifm: Shapes
    ofm: Shapes


class BaseMapping(BaseModel, ABC):
    """Define subvolume size, spatial split, and L1 buffer placement"""

    model_config = ConfigDict(frozen=True)
    ofm_pad: tuple[int, int, int]
    ifm_pad: tuple[int, int, int] | list[tuple[int, int, int]]
    ifm_subv: tuple[int, int, int] | list[tuple[int, int, int]]
    ofm_subv: tuple[int, int, int]
    spatial_split: tuple[int, int, int, int]
    iters: tuple[int, int, int, int]
    kernel_gran: tuple[int, ...]
    ifm_bits: int
    wgt_bits: int
    ofm_bits: int
    bias_bits: int


def normalize_shape_to_4d(shape: Tuple[int, ...]) -> Tuple[int, int, int, int]:
    """Normalize tensor shape to 4D NHWC format by prepending 1s.

    Args:
        shape: Input tensor shape (1D, 2D, 3D, 4D, or 5D with batch=1)

    Returns:
        4D shape with 1s prepended as needed

    Raises:
        ValueError: If shape has >4 dimensions (or 5D with outermost != 1)

    Examples:
        >>> normalize_shape_to_4d((64,))
        (1, 1, 1, 64)
        >>> normalize_shape_to_4d((224, 224, 3))
        (1, 224, 224, 3)
        >>> normalize_shape_to_4d((4, 224, 224, 3))
        (4, 224, 224, 3)
        >>> normalize_shape_to_4d((1, 4, 224, 224, 3))
        (4, 224, 224, 3)
    """
    # Handle 5D tensors with outermost dimension of 1
    if len(shape) == 5:
        if shape[0] == 1:  # discard outermost dim if 1
            shape = shape[1:]
        else:
            raise ValueError(
                f"Unsupported 5D tensor with outermost dimension {shape[0]} != 1"
            )

    if len(shape) > 4:
        raise ValueError(f"Cannot normalize shape with >4 dimensions: {shape}")

    # Prepend 1s to make it 4D: (C,) -> (1,1,1,C), (H,W,C) -> (1,H,W,C), etc.
    return (1,) * (4 - len(shape)) + shape


def is_prime(n: int) -> bool:
    """Optimized prime checking function"""
    if n < 2:
        return False
    if n in (2, 3):
        return True
    if n % 2 == 0 or n % 3 == 0:
        return False
    # Check for divisors of the form 6k ± 1
    i = 5
    while i * i <= n:
        if n % i == 0 or n % (i + 2) == 0:
            return False
        i += 6
    return True


#
# Split modes and helper function
#
# def get_nyxc(fm: list[Shape] | Shape) -> tuple[list[int] | int, ...]:
#     """EConverts"""
#     if isinstance(fm, Shape3):
#         return (1, fm.Y, fm.X, fm.C)  # add N=1 if not present
#     if isinstance(fm, Shape4):
#         return (fm.N, fm.Y, fm.X, fm.C)
#     if isinstance(fm, list) and isinstance(fm[0], Shape3):
#         return ([1] * len(fm), *[[fm_i[d] for fm_i in fm] for d in range(3)])
#     if isinstance(fm, list) and isinstance(fm[0], Shape4):
#         return tuple([fm_i[d] for fm_i in fm] for d in range(4))
#     raise ValueError(f"Invalid fm value: {fm}")


def get_nyxc_from_iter(ifm: list[tuple] | tuple) -> tuple[list[int] | int, ...]:
    """Return ifm as (N, Yi, Xi, Ci). If N is not present, add N=1. If ifm is a list of tuples, return as a tuple of lists."""
    if isinstance(ifm, tuple) and len(ifm) == 3:
        return (1, *ifm)  # add N=1 if not present
    if isinstance(ifm, tuple) and len(ifm) == 4:
        return ifm
    if isinstance(ifm, list) and isinstance(ifm[0], tuple):
        # check all ifms have same dimensionality
        assert len(set(len(ifm_subv) for ifm_subv in ifm)) == 1
        ndims = len(ifm[0])
        if ndims == 3:
            return (
                [1] * len(ifm),
                *[[ifm_i[d] for ifm_i in ifm] for d in range(ndims)],
            )
        if ndims == 4:
            return tuple([ifm_i[d] for ifm_i in ifm] for d in range(ndims))
    raise ValueError(f"Invalid ifm value: {ifm}")


@dataclass(slots=True)
class BaseDims(ABC):
    """Base class for dimension handling with common attributes"""

    N: int = 1
    Yi: Union[int, list[int]] = 0
    Xi: Union[int, list[int]] = 0
    Ci: Union[int, list[int]] = 0
    Yo: int = 0
    Xo: int = 0
    Co: int = 0
    Nis: Union[int, list[int]] = 0
    Yis: Union[int, list[int]] = 0
    Xis: Union[int, list[int]] = 0
    Cis: Union[int, list[int]] = 0
    Nos: Union[int, list[int]] = 0
    Yos: int = 0
    Xos: int = 0
    Cos: int = 0
    Ky: int = 0
    Kx: int = 0
    Sy: int = 1
    Sx: int = 1
    Py: int = 0
    Px: int = 0
    N_split: int = 1
    Y_split: int = 1
    X_split: int = 1
    Co_split: int = 1
    N_loop: int = 1
    Y_loop: int = 1
    X_loop: int = 1
    Co_loop: int = 1
    Ci_loop: int = 1
    aie_cols: int = 3
    aie_rows: int = 4
    param_subv_size: int = 1024  # Default size for layer parameters

    @property
    def input_shape(
        self,
    ) -> tuple[int, int, int, int] | list[tuple[int, int, int, int]]:
        """Return input tensor shape (N, Yi, Xi, Ci)"""
        return (self.N, self.Yi, self.Xi, self.Ci)

    @property
    def output_shape(self) -> tuple[int, int, int, int]:
        """Return output tensor shape (N, Yo, Xo, Co)"""
        return (self.N, self.Yo, self.Xo, self.Co)

    @property
    def spatial_split(self) -> tuple[int, int, int, int]:
        """Return spatial split configuration (Y_split, X_split, Co_split)"""
        return (self.N_split, self.Y_split, self.X_split, self.Co_split)

    @property
    def aie_array_shape(self) -> tuple[int, int]:
        """Return AIE array dimensions (cols, rows)"""
        return (self.aie_cols, self.aie_rows)

    def total_aie_cores(self) -> int:
        """Calculate total number of AIE cores"""
        return self.aie_cols * self.aie_rows

    @classmethod
    def from_shape_and_mapping(
        cls,
        shape: BaseShape,
        mapping: BaseMapping,
        aie_cols=3,
        aie_rows=4,
    ) -> "BaseDims":
        """Initialize dims from shape and mapping objects"""
        N, Yi, Xi, Ci = get_nyxc_from_iter(shape.ifm)
        Nis, Yis, Xis, Cis = get_nyxc_from_iter(mapping.ifm_subv)
        _, Yo, Xo, Co = get_nyxc_from_iter(shape.ofm)
        _, Yos, Xos, Cos = get_nyxc_from_iter(mapping.ofm_subv)
        N_split, Y_split, X_split, Co_split = mapping.spatial_split
        Ci_loop, Y_loop, X_loop, Co_loop = mapping.iters
        return cls(
            N=N,
            Yi=Yi,
            Xi=Xi,
            Ci=Ci,
            Yo=Yo,
            Xo=Xo,
            Co=Co,
            Nis=Nis,
            Yis=Yis,
            Xis=Xis,
            Cis=Cis,
            Yos=Yos,
            Xos=Xos,
            Cos=Cos,
            N_split=N_split,
            Y_split=Y_split,
            X_split=X_split,
            Co_split=Co_split,
            N_loop=1,
            Y_loop=Y_loop,
            X_loop=X_loop,
            Co_loop=Co_loop,
            Ci_loop=Ci_loop,
            aie_cols=aie_cols,
            aie_rows=aie_rows,
        )


#
# Split modes and helper function
#

DMA_ONLY_SPATIAL_SPLIT_MODES = [
    # (N_split=1, Y_split, X_split, Co_split)
    (1, 3, 1, 1),
]

SPATIAL_SPLIT_MODES = [
    # (N_split=1, Y_split, X_split, Co_split)
    (1, 12, 1, 1),
    (1, 6, 2, 1),
    (1, 4, 3, 1),
    (1, 3, 4, 1),
    (1, 2, 6, 1),
    (1, 1, 12, 1),
    (1, 6, 1, 2),
    (1, 3, 2, 2),
    (1, 2, 3, 2),
    (1, 1, 6, 2),
    (1, 4, 1, 3),
    (1, 2, 2, 3),
    (1, 1, 4, 3),
    (1, 3, 1, 4),
    (1, 1, 3, 4),
    (1, 2, 1, 6),
    (1, 1, 2, 6),
    (1, 1, 1, 12),
]


def split_to_mode(dims: BaseDims) -> int:
    """
    Map spatial split to channel allocation mode (0 if IFM unicast)
    Works with ConvDims (inherits from BaseDims)
    conv_dims = ConvDims(...)
    mode = split_to_mode(conv_dims)

    Also works with BaseDims directly
    base_dims = BaseDims(Y_split=3, X_split=1, Co_split=4)
    mode = split_to_mode(base_dims)
    """
    mode_lookup = {
        # IFM unicast / WGT broadcast
        (1, 12, 1, 1): 0,
        (1, 6, 2, 1): 0,
        (1, 4, 3, 1): 0,
        (1, 3, 4, 1): 0,
        (1, 2, 6, 1): 0,
        (1, 1, 12, 1): 0,
        (1, 6, 1, 2): 0,
        (1, 3, 2, 2): 0,
        (1, 2, 3, 2): 0,
        (1, 1, 6, 2): 0,
        (1, 4, 1, 3): 0,
        (1, 2, 2, 3): 0,
        (1, 1, 4, 3): 0,
        # IFM broadcast / WGT unicast
        (1, 3, 1, 4): 1,
        (1, 1, 3, 4): 1,
        (1, 2, 1, 6): 1,
        (1, 1, 2, 6): 1,
        (1, 1, 1, 12): 1,
    }
    assert sorted(list(mode_lookup.keys())) == sorted(SPATIAL_SPLIT_MODES)
    mode = mode_lookup[(dims.N_split, dims.Y_split,
                        dims.X_split, dims.Co_split)]
    return mode


def core_to_split(dims: BaseDims, col: int, row: int) -> tuple[int, int, int, int]:
    """
    Map core (col, row) to logical image split (Y_idx, X_idx, Co_idx)
    Works with ConvDims
    conv_dims = ConvDims(shape, mapping)
    n_idx, y_idx, x_idx, co_idx = core_to_split(conv_dims, col, row)

    Also works with BaseDims directly
    base_dims = BaseDims(Y_split=2, X_split=3, Co_split=1, aie_cols=3, aie_rows=4)
    y_idx, x_idx, co_idx = core_to_split(base_dims, col, row)
    """
    return _core_to_split(
        dims.aie_rows, dims.Y_split, dims.X_split, dims.Co_split, col, row
    )


def _core_to_split(
    n_rows, Y_split, X_split, Co_split, col: int, row: int
) -> tuple[int, int, int, int]:
    def coreid(c: int, r: int) -> int:
        """Flatten core to 1d index"""
        return (c * n_rows) + r

    # Key format is (N_split, Y_split, X_split, Co_split)
    # Val is a lambda mapping physical core to image block position
    mode_lookup = {
        #
        # IFM unicast / WGT broadcast
        #
        # Co_split = 1
        (1, 12, 1, 1): (lambda id: (0, id, 0, 0)),
        (1, 6, 2, 1): (lambda id: (0, id // 2, id % 2, 0)),
        (1, 4, 3, 1): (lambda id: (0, id % 4, id // 4, 0)),
        (1, 3, 4, 1): (lambda id: (0, id // 4, id % 4, 0)),
        (1, 2, 6, 1): (lambda id: (0, (id // 2) % 2, (id % 2) + 2 * (id // 4), 0)),
        (1, 1, 12, 1): (lambda id: (0, 0, id, 0)),
        # Co_split = 2
        (1, 6, 1, 2): (lambda id: (0, id // 2, 0, id % 2)),
        (1, 3, 2, 2): (lambda id: (0, (id // 2) // 2, (id // 2) % 2, id % 2)),
        (1, 2, 3, 2): (lambda id: (0, (id // 2) % 2, id // 4, id % 2)),
        (1, 1, 6, 2): (lambda id: (0, 0, id // 2, id % 2)),
        # Co_split = 3
        (1, 4, 1, 3): (lambda id: (0, id % 4, 0, id // 4)),
        (1, 2, 2, 3): (lambda id: (0, (id % 4) // 2, (id % 4) % 2, id // 4)),
        (1, 1, 4, 3): (lambda id: (0, 0, id % 4, id // 4)),
        #
        # IFM broadcast / WGT unicast
        #
        # Co_split = 4
        (1, 3, 1, 4): (lambda id: (0, id // 4, 0, id % 4)),
        (1, 1, 3, 4): (lambda id: (0, 0, id // 4, id % 4)),
        # Co_split = 6
        (1, 2, 1, 6): (lambda id: (0, id % 2, 0, id // 2)),
        (1, 1, 2, 6): (lambda id: (0, 0, id % 2, id // 2)),
        # Co_split = 12
        (1, 1, 1, 12): (lambda id: (0, 0, 0, id)),
    }
    assert sorted(list(mode_lookup.keys())) == sorted(SPATIAL_SPLIT_MODES)
    (
        N_idx,
        Y_idx,
        X_idx,
        Co_idx,
    ) = mode_lookup[(1, Y_split, X_split, Co_split)](coreid(col, row))
    return N_idx, Y_idx, X_idx, Co_idx


class L1BufferAlloc(BaseModel):
    """L1 buffer allocation details"""

    model_config = ConfigDict(frozen=True)

    size: int
    ping_addr: int
    pong_addr: Optional[int] = None

    @property
    def is_double_buffered(self) -> bool:
        """Check if buffer is double buffered"""
        return self.pong_addr is not None


@dataclass(frozen=True)
class L2Alloc:
    """L2 fused tensor locations decided by graph-level analysis."""

    # Each tuple = (AieTile (Memtile), local_offset)
    ifm_L2_loc: Union[List[Tuple[AieTile, int]], Tuple[AieTile, int]]
    ofm_L2_loc: Tuple[AieTile, int]
    wgt_l2_loc: List[List[Union[AieTile, int]]]
    prm_l2_loc: List[List[Union[AieTile, int]]]
    enable_ifm_fill: bool
    enable_ofm_spill: bool
    enable_L2_fusion: bool = False


def L1Alloc_from_dict(l1_alloc: dict[str, tuple]) -> dict[str, L1BufferAlloc]:
    """Convert legacy l1_alloc dict with tuples to L1BufferAlloc objects"""
    # If l1_alloc contains tuples, convert to L1BufferAlloc objects
    new_l1_alloc = {}
    for key, value in l1_alloc.items():
        if isinstance(value, tuple):
            size, ping_addr, pong_addr = value
            new_l1_alloc[key] = L1BufferAlloc(
                size=size, ping_addr=ping_addr, pong_addr=pong_addr
            )
        else:
            new_l1_alloc[key] = value
    return new_l1_alloc


class BaseMappingWithL1(BaseMapping):
    """Mapping with L1 buffer allocation details"""

    model_config = ConfigDict(frozen=True)

    l1_alloc: dict[str, L1BufferAlloc]

    @model_validator(mode="before")
    @classmethod
    def build_from_legacy(cls, data: dict) -> dict:
        """Support legacy tuple-based l1_alloc format"""
        if "l1_alloc" in data:
            l1_alloc = data["l1_alloc"]

            # If l1_alloc contains tuples, convert to L1BufferAlloc objects
            if isinstance(l1_alloc, dict):
                data["l1_alloc"] = L1Alloc_from_dict(l1_alloc)

        return data

    # Computed properties for convenient access
    @computed_field  # type: ignore[misc]
    @property
    def ifm_L1_size(self) -> int:
        """Return IFM L1 buffer size"""
        return self.l1_alloc["ifm"].size

    @computed_field  # type: ignore[misc]
    @property
    def ifm_L1_ping_addr(self) -> int:
        """Return IFM L1 ping buffer address"""
        return self.l1_alloc["ifm"].ping_addr

    @computed_field  # type: ignore[misc]
    @property
    def ifm_L1_pong_addr(self) -> Optional[int]:
        """Return IFM L1 pong buffer address"""
        return self.l1_alloc["ifm"].pong_addr

    @computed_field  # type: ignore[misc]
    @property
    def wgt_L1_size(self) -> int:
        """Return WGT L1 buffer size"""
        return self.l1_alloc["wgt"].size

    @computed_field  # type: ignore[misc]
    @property
    def wgt_L1_ping_addr(self) -> int:
        """Return WGT L1 ping buffer address"""
        return self.l1_alloc["wgt"].ping_addr

    @computed_field  # type: ignore[misc]
    @property
    def wgt_L1_pong_addr(self) -> Optional[int]:
        """Return WGT L1 pong buffer address"""
        return self.l1_alloc["wgt"].pong_addr

    @computed_field  # type: ignore[misc]
    @property
    def ofm_L1_size(self) -> int:
        """Return OFM L1 buffer size"""
        return self.l1_alloc["ofm"].size

    @computed_field  # type: ignore[misc]
    @property
    def ofm_L1_ping_addr(self) -> int:
        """Return OFM L1 ping buffer address"""
        return self.l1_alloc["ofm"].ping_addr

    @computed_field  # type: ignore[misc]
    @property
    def ofm_L1_pong_addr(self) -> Optional[int]:
        """Return OFM L1 pong buffer address"""
        return self.l1_alloc["ofm"].pong_addr


class BaseMappingWithL1AndL2(BaseMappingWithL1):
    """Mapping with L1 and L2 strategies"""

    ifm_L2_strategy: str
    wgt_L2_strategy: str
    ofm_L2_strategy: str


#
# Constant Addr/ Size for Core
#


def overlay_3x4_core_stack_addr() -> int:
    """Define overlay core stack address"""
    return 125824 - 2 * 1024


def overlay_3x4_core_stack_size() -> int:
    """Define overlay core stack size"""
    return 2 * 1024


def overlay_3x4_core_heap_size() -> int:
    """Define overlay core heap size"""
    return 2 * 1024 + 2 * 1024


#
# logging APIs
#


def is_log_enabled():
    """Is Log Enabled?"""
    return int(os.getenv("LOG_ENABLED", "false").lower() in ("1", "true", "yes"))


def log(*args, **kwargs):
    """
    Prints a message only if LOG_ENABLED environment variable is set to True.
    Works just like the built-in print function.
    Can be turbed ON/OFF by either `export LOG_ENABLED=true` using bash
    or `python build_aie4.py ............... --enable_log True/False
    """
    if is_log_enabled():
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"[{timestamp}] [LOG] ", *args, **kwargs)


@dataclass
class ReadBins:
    """Flags for chained IFM and WGT reads (1 = enabled, 0 = disabled)"""

    read_ifm: int = 0
    read_wgt: int = 0


def is_subgraph_debug() -> bool:
    """Is Subgraph Debug Enabled?"""
    return os.environ.get("AIE4_SUBGRAPH_DEBUG", "0") == "1"


def is_save_subgraph_json() -> bool:
    """Is Subgraph Save JSON Enabled?"""
    is_save_json = (int(os.getenv("SAVE_SUBGRAPH_JSON", "false").lower() in ("1", "true", "yes"))
                    or is_subgraph_debug())
    return is_save_json
