"""
Contains all common logic shared between dataflow code.
"""

from dataclasses import dataclass, field
from enum import Enum
from typing import List, Tuple, Optional, Union, Any

from utils.utils_common import (
    BaseDims,
    core_to_split,
)

from dmacompiler import (
    DevGen,
    set_dev_gen,
    config,
    DmaConnection,
    DmaDir,
    core_dma,
    memtile_dma,
    shim_dma,
    AieTile,
    TileType,
    CoreConnection
)

set_dev_gen(DevGen.Aie4)


class LinearOpType(Enum):
    """Define data types supported by linear ops"""

    conv_A8W8_noqdq = 1
    gemm_A16W8_qdq = 2
    gemm_A16W4_qdq = 3
    gemm_A8W8_qdq = 4
    conv_A16W8_qdq = 5
    conv_A16W4_qdq = 6
    conv_A8W8_qdq = 7
    gemm_A16A16_v1 = 8
    gemm_A16A16_v2 = 9
    dwc_A8W8_noqdq = 10
    dwc_A8W8_qdq = 11
    dwc_A16W8_qdq = 12

    def to_dtype(self) -> int:
        '''Map LinearOpType to integer dtype'''
        return self.value


def align_up(offset: int, alignment: Optional[int]) -> int:
    """Round offset up to next aligned address"""
    if alignment is None or offset % alignment == 0:
        return offset
    return ((offset // alignment) + 1) * alignment

#
# 3 Column x 4 Row Overlay Definition
#


def overlay_3x4_O_ids() -> list[int]:
    """Channel ID allocations for core to memtile s2mm out"""
    return [0, 1, 4, 5]


def overlay_3x4_F_ids() -> list[int]:
    """Channel ID allocations for shim to memtile s2mm fill"""
    return [2, 3, 6, 7]


def overlay_3x4_A_ids() -> list[int]:
    """Channel ID allocations for memtile mm2s to core unicast"""
    return [0, 2, 5, 7]


def overlay_3x4_B_ids(col: int) -> list[int]:
    """Channel ID allocations for memtile mm2s to core broadcast"""
    ids = []
    if col in (0, 2):
        ids = [1, 3]
    elif col == 1:
        ids = [6, 8]
    else:
        assert False
    return ids


def overlay_3x4_S_ids(col: int) -> list[int]:
    """Channel ID allocations for memtile mm2s to shim spill"""
    ids = []
    if col in (0, 2):
        ids = [4, 8]
    elif col == 1:
        ids = [4, 9]
    else:
        assert False
    return ids


def unicast_channels() -> list[tuple[int, int]]:
    """Generate unicast memtile mm2s channel allocations (row, id)"""
    return list(enumerate(overlay_3x4_A_ids()))


def broadcast_channels(col: int) -> list[tuple[int, int]]:
    """Geneate broadcast memtile mm2s channel allocations (row, id)"""
    return list(enumerate(overlay_3x4_B_ids(col)))


def overlay_3x4_param_channel_id() -> int:
    """Define overlay layer parameter channel id"""
    return 0


def overlay_3x4_dma_connections() -> list[DmaConnection]:
    """
    Generates dma connections for the 3 column overlay defined by arch team
    NOTE:   'overlay_3x4_dma_connections' how two temporary changes changes to cover-up
            the AIE-MAPPER error of only 2 Shim MM2S channels. Once the error is resolved
            the following changes will be made:
            1. num_fill = 4
            2. Uncomment Line -  'assert len(overlay_3x4_F_ids()) == num_fill'
    """
    aie_cols = 3
    aie_rows = 4
    num_fill = 4
    num_spill = 2
    assert len(overlay_3x4_O_ids()) == aie_rows
    assert len(overlay_3x4_F_ids()) == num_fill
    assert len(overlay_3x4_A_ids()) == aie_rows
    assert len(overlay_3x4_B_ids(0)) == (aie_rows // 2)
    assert len(overlay_3x4_S_ids(0)) == num_spill
    dma_connections = (
        [
            # shim to memtile
            DmaConnection(shim_dma(col, DmaDir.MM2S, i), memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[i]))
            for col in range(aie_cols)
            for i in range(num_fill)
        ]
        + [
            # memtile to core unicast
            DmaConnection(memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]), core_dma(col, row, DmaDir.S2MM, 0))
            for col in range(aie_cols)
            for row in range(aie_rows)
        ]
        + [
            # memtile to core broadcast
            DmaConnection(memtile_dma(col, DmaDir.MM2S, overlay_3x4_B_ids(col)[row % 2]), core_dma(col, row, DmaDir.S2MM, 1))
            for col in range(aie_cols)
            for row in range(aie_rows)
        ]
        + [
            # core to memtile
            DmaConnection(core_dma(col, row, DmaDir.MM2S, 0), memtile_dma(col, DmaDir.S2MM, overlay_3x4_O_ids()[row]))
            for col in range(aie_cols)
            for row in range(aie_rows)
        ]
        + [
            # memtile to shim
            DmaConnection(memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[i]), shim_dma(col, DmaDir.S2MM, i))
            for col in range(aie_cols)
            for i in range(num_spill)
        ]
    )
    return dma_connections


def overlay_3x4_col_core_stream_bdcast() -> List[CoreConnection]:
    """Core stream connections for broadcast along column"""
    core_connections = []
    aie_cols = 3
    aie_rows = 4
    core_connections = [
        CoreConnection(AieTile(TileType.Core, col, 0), AieTile(TileType.Core, col, row))
        for row in range(1, aie_rows)
        for col in range(aie_cols)
    ]

    return core_connections
#
# Layer Parameter Transfer
#


def prm_memtile_memory() -> str:
    """Layer parameters L2 data order"""
    return f"Row:{config.NUM_AIE_ROWS} " f"Byte:{config.MAX_CORE_LAYER_PARAM_SIZE}"


def prm_shim_memory() -> str:
    """Layer parameters DDR data order"""
    return f"Col:{config.NUM_AIE_COLS} Row:{config.NUM_AIE_ROWS} " f"Byte:{config.MAX_CORE_LAYER_PARAM_SIZE}"


def prm_memtile_mm2s(row: int) -> str:
    """Layer parameters memtile mm2s access pattern for core at row"""
    return f"Row:{row}:{row + 1} " f"Byte:0:{config.MAX_CORE_LAYER_PARAM_SIZE}"


def prm_memtile_s2mm() -> str:
    """Layer parameters memtile s2mm access pattern"""
    return f"Row:0:{config.NUM_AIE_ROWS} " f"Byte:0:{config.MAX_CORE_LAYER_PARAM_SIZE}"


def prm_shim_mm2s(col: int) -> str:
    """Layer parameters shim mm2s access pattern for col"""
    return f"Col:{col}:{col + 1} Row:0:{config.NUM_AIE_ROWS} " f"Byte:0:{config.MAX_CORE_LAYER_PARAM_SIZE}"


#
# Shim DDR Patch index
#


@dataclass
class ShimAlloc:
    """Mapping for shim_alloc idx"""

    def __init__(
        self,
        ifm_buffer_id: int,
        wgt_buffer_id: int,
        ofm_buffer_id: int,
        prm_buffer_id: int,
    ):
        self.ifm_buffer_id = ifm_buffer_id
        self.wgt_buffer_id = wgt_buffer_id
        self.ofm_buffer_id = ofm_buffer_id
        self.prm_buffer_id = prm_buffer_id


def shim_alloc() -> ShimAlloc:
    """Define shim_alloc idx"""
    return ShimAlloc(1, 2, 0, 3)


@dataclass
class ShimAllocator:
    """Mapping for shim idx and offset"""

    ifm_xrt_idx: Union[int, List[int]] = field(default_factory=lambda: shim_alloc().ifm_buffer_id)
    prm_xrt_idx: int = field(default_factory=lambda: shim_alloc().prm_buffer_id)
    wgt_xrt_idx: int = field(default_factory=lambda: shim_alloc().wgt_buffer_id)
    ofm_xrt_idx: int = field(default_factory=lambda: shim_alloc().ofm_buffer_id)

    ifm_xrt_offset: Union[int, List[int]] = 0
    prm_xrt_offset: int = 0
    wgt_xrt_offset: int = 0
    ofm_xrt_offset: int = 0


@dataclass
class TensorSlicer:
    """
    Separate class for handling all dimension slicing operations
    """

    def __init__(self, dims: BaseDims):
        """Initialize slicer with a BaseDims object"""
        self.dims = dims

    def Yi_slice(self, col: int, row: int, i: int) -> tuple[int, int, int]:
        """Slice for axis Yi at core (col, row) during iteration i of Y_loop"""
        _, Y_idx, _, _ = core_to_split(self.dims, col, row)
        in_bounds = ((Y_idx * self.dims.Yos) + (i * self.dims.Yos * self.dims.Y_split)) < self.dims.Yo
        Yi_stride = self.dims.Yos * self.dims.Sy * self.dims.Y_split
        Yi_start = (Y_idx * self.dims.Yos * self.dims.Sy) + (i * Yi_stride) - self.dims.Py if in_bounds else self.dims.Yi
        Yi_stop = Yi_start + self.dims.Yis if in_bounds else self.dims.Yi
        return Yi_start, Yi_stop, Yi_stride

    def Xi_slice(self, col: int, row: int, i: int) -> tuple[int, int, int]:
        """Slice for axis Xi at core (col, row) during iteration i of X_loop"""
        _, _, X_idx, _ = core_to_split(self.dims, col, row)
        in_bounds = ((X_idx * self.dims.Xos) + (i * self.dims.Xos * self.dims.X_split)) < self.dims.Xo
        Xi_stride = self.dims.Xos * self.dims.Sx * self.dims.X_split
        Xi_start = (X_idx * self.dims.Xos * self.dims.Sx) + (i * Xi_stride) - self.dims.Px if in_bounds else self.dims.Xi
        Xi_stop = Xi_start + self.dims.Xis if in_bounds else self.dims.Xi
        return Xi_start, Xi_stop, Xi_stride

    def Yo_slice(self, col: int, row: int, i: int) -> tuple[int, int, int]:
        """Slice for axis Yo at core (col, row) during iteration i of Y_loop"""
        _, Y_idx, _, _ = core_to_split(self.dims, col, row)
        Yo_stride = self.dims.Yos * self.dims.Y_split
        Yo_start = min((Y_idx * self.dims.Yos) + (i * Yo_stride), self.dims.Yo)
        Yo_stop = min(Yo_start + self.dims.Yos, self.dims.Yo)
        return Yo_start, Yo_stop, Yo_stride

    def Xo_slice(self, col: int, row: int, i: int) -> tuple[int, int, int]:
        """Slice for axis Xo at core (col, row) during iteration i of the X_loop"""
        _, _, X_idx, _ = core_to_split(self.dims, col, row)
        Xo_stride = self.dims.Xos * self.dims.X_split
        Xo_start = min((X_idx * self.dims.Xos) + (i * Xo_stride), self.dims.Xo)
        Xo_stop = min(Xo_start + self.dims.Xos, self.dims.Xo)
        return Xo_start, Xo_stop, Xo_stride

    def Co_slice(self, col: int, row: int, i: int) -> tuple[int, int, int]:
        """Slice for axis Co at core (col, row) during iteration i of the Co_loop"""
        _, _, _, Co_idx = core_to_split(self.dims, col, row)
        Co_stride = self.dims.Cos * self.dims.Co_split
        Co_start = min((Co_idx * self.dims.Cos) + (i * Co_stride), self.dims.Co)
        Co_stop = min(Co_start + self.dims.Cos, self.dims.Co)
        return Co_start, Co_stop, Co_stride

    def Yo_slice_iter(self, col: int, start_iter: int) -> Tuple[int, int, int, int]:
        """Slice for axis Yo at core col during iteration start_iter of the Y_loop"""
        Yo_stride = self.dims.Y_split * self.dims.Yos
        Yo_start = (col * self.dims.Yos) + (start_iter * Yo_stride)
        Yo_stop = min(Yo_start + self.dims.Yos, self.dims.Yo) if Yo_start <= self.dims.Yo else Yo_start
        Yo_size = Yo_stop - Yo_start
        return (Yo_start, Yo_stop, Yo_stride, Yo_size)

    def Xo_slice_iter(self, col: int, start_iter: int) -> Tuple[int, int, int, int]:
        """Slice for axis Xo at core col during iteration start_iter of the X_loop"""
        Xo_stride = self.dims.aie_cols * self.dims.Xos
        Xo_start = (col * self.dims.Xos) + (start_iter * Xo_stride)
        Xo_stop = min(Xo_start + self.dims.Xos, self.dims.Xo) if Xo_start <= self.dims.Xo else Xo_start
        Xo_size = Xo_stop - Xo_start
        return (Xo_start, Xo_stop, Xo_stride, Xo_size)

    def Co_slice_iter(self, col: int, start_iter: int) -> Tuple[int, int, int, int]:
        """Slice for axis Co at core col during iteration start_iter of the Co_loop"""
        Co_stride = self.dims.Cos * self.dims.Co_split  # = dims.Com
        Co_start = (col * self.dims.Cos) + (start_iter * Co_stride)
        Co_stop = min(Co_start + Co_stride, self.dims.Co) if Co_start <= self.dims.Co else Co_start
        Co_size = Co_stop - Co_start
        return (Co_start, Co_stop, Co_stride, Co_size)

    def Yo_split_iters(self, col: int) -> List[Tuple[int, int]]:
        """Split Y_loop iterations for core col into full slices"""

        def can_iterate(start_iter: int, num_iters: int) -> bool:
            _, _, _, Yo_size = self.Yo_slice_iter(col, start_iter + num_iters - 1)
            is_full_slice = Yo_size == self.dims.Yos
            return is_full_slice

        split = []
        curr_iters = 0
        while curr_iters < self.dims.Y_loop:
            start_iter = curr_iters
            num_iters = 1
            if can_iterate(start_iter, num_iters):
                while can_iterate(start_iter, num_iters + 1):
                    num_iters += 1
            split.append((start_iter, num_iters))
            curr_iters += num_iters
        return split

    def Co_split_iters(self) -> List[int]:
        """
        Split Co_loop iterations for core col into full slices
        Returns [full_repeats, depad_phase] where:
        - full_repeats: number of full iterations where all Co_split cores are active
        - depad_phase: 0 or 1 indicating if there's a partial iteration with some cores inactive
        """

        # Calculate the effective Co per split
        Co_per_split = self.dims.Cos * self.dims.Co_split

        # Calculate how many full iterations we can do
        full_repeats = self.dims.Co // Co_per_split
        if full_repeats == 0:
            full_repeats = 1

        # Check if there's a remainder that requires depadding
        remainder = self.dims.Co % Co_per_split
        depad_phase = 1 if (remainder > 0 and self.dims.Co_loop > 1) else 0

        return [full_repeats, depad_phase]

    def Ci_split_iters(self) -> List[int]:
        '''
        Split Ci_loop iters into full slices and Ci pad phases
        '''
        full_repeats = self.dims.Ci // self.dims.Cis
        remainder = self.dims.Ci % self.dims.Cis
        pad_phase = 1 if remainder > 0 else 0

        return [full_repeats, pad_phase]


def unpack_pair(v: Union[List[int], Tuple[int, int], List[List[int]]], default: Tuple[int, int]) -> Tuple[int, int]:
    """Accept [id, off], (id, off), or [[id, off], ...] (take first)"""
    if v is None:
        return default
    if isinstance(v, (list, tuple)) and v and isinstance(v[0], (list, tuple)):
        v = v[0]
    if isinstance(v, (list, tuple)) and len(v) >= 2:
        return (int(v[0]), int(v[1]))
    return default


@dataclass(frozen=True)
class L3Alloc:
    """L3 fused tensor locations decided by graph-level analysis."""

    ifm: Union[List[int], Tuple[int, int], List[List[int]]] = field(default_factory=lambda: [1, 0])
    ofm: Union[List[int], Tuple[int, int], List[List[int]]] = field(default_factory=lambda: [0, 0])
    wgt: Union[List[int], Tuple[int, int], List[List[int]]] = field(default_factory=lambda: [2, 0])
    prm: Union[List[int], Tuple[int, int], List[List[int]]] = field(default_factory=lambda: [3, 0])

    def to_shim(self) -> ShimAllocator:
        """Convert L3Alloc to ShimAllocator.
        Supports:
          - [idx, offset] or (idx, offset)
          - [idx, offset, size] (size ignored)
          - [[idx0, off0], [idx1, off1, size1], ...]
        Returns either scalar or list form for IFM.
        """

        def extract_ifm_indices_and_offsets(
            val: Any, default: Tuple[int, int]
        ) -> Tuple[Union[int, List[int]], Union[int, List[int]]]:
            """Handle scalar or nested IFM structure with optional size field."""
            # Nested list of pairs/triples
            if isinstance(val, list) and val and all(isinstance(x, (list, tuple)) for x in val):
                idxs, offs = [], []
                for x in val:
                    if len(x) >= 2 and isinstance(x[0], int) and isinstance(x[1], int):
                        idxs.append(x[0])
                        offs.append(x[1])
                if idxs:
                    return idxs, offs
            # Fallback: single pair or triple
            return unpack_pair(val, default)

        ifm_i, ifm_o = extract_ifm_indices_and_offsets(self.ifm, (1, 0))
        prm_i, prm_o = unpack_pair(self.prm, (3, 0))
        wgt_i, wgt_o = unpack_pair(self.wgt, (2, 0))
        ofm_i, ofm_o = unpack_pair(self.ofm, (0, 0))

        return ShimAllocator(
            ifm_xrt_idx=ifm_i,
            ifm_xrt_offset=ifm_o,
            prm_xrt_idx=prm_i,
            prm_xrt_offset=prm_o,
            wgt_xrt_idx=wgt_i,
            wgt_xrt_offset=wgt_o,
            ofm_xrt_idx=ofm_i,
            ofm_xrt_offset=ofm_o,
        )


def L3Alloc_to_Shim(l3: Optional[L3Alloc]) -> ShimAllocator:
    """Function to Convert L3Alloc to ShimAllocator"""
    return (l3 or L3Alloc()).to_shim()
