'''
This module contains all type defintions required for subsystems to interact
with each other. These can be divided into several groups as follows.

AIE physical resources

High-level data transfer IR

Core instructions

Low-level buffer descriptor IR

Low-level layer control sequence
'''


import struct
from enum import Enum
from typing import List, Dict, Tuple, Optional, Union, Type, Any
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

from . import config


################################################################################
#
# User API for AIE Physical Resources
#
################################################################################


class DevGen(Enum):
    Aie2p = 1
    Aie4 = 2


class CascDir(Enum):
    Vertical = 1
    Horizontal = 2


class DmaDir(Enum):
    S2MM = 1
    MM2S = 2


class DmaChannel:
    __slots__ = ('dir', 'id')

    def __init__(self, dir: DmaDir, id: int):
        self.dir = dir
        self.id = id

        config.check_init()

    # Make DmaChannel hashable

    def __str__(self) -> str:
        return self.dir.name.lower() + '_' + str(self.id)

    def __eq__(self, other) -> bool:
        return (self.dir == other.dir and
                self.id == other.id)

    def __hash__(self) -> int:
        return hash(str(self))


class TileType(Enum):
    Core = 0
    Memtile = 1
    Shim = 2


class AieTile:
    __slots__ = ('type', 'col', 'row')

    def __init__(self, type: TileType, col: int, row: int = 0):
        self.type = type
        self.col = col
        self.row = row

        config.check_init()
        # Check that location is valid
        if self.type == TileType.Core:
            if not ((0 <= self.col < config.NUM_AIE_COLS) and (0 <= self.row < config.NUM_AIE_ROWS)):
                raise ValueError('Invalid core location!')
        elif self.type == TileType.Memtile:
            if not ((0 <= self.col < config.NUM_AIE_COLS) and (self.row == 0)):
                raise ValueError('Invalid memtile location!')
        else:
            if not ((0 <= self.col < config.NUM_AIE_COLS) and (self.row == 0)):
                raise ValueError('Invalid shim location!')

    # Make AieTile hashable

    def __str__(self) -> str:
        return self.type.name.lower() + '_' + str(self.col) + '_' + str(self.row)

    def __eq__(self, other) -> bool:
        return (self.type == other.type and
                self.col == other.col and
                self.row == other.row)

    def __hash__(self) -> int:
        return hash(str(self))


class AieDma:
    __slots__ = ('tile', 'channel')

    def __init__(self, tile: AieTile, channel: DmaChannel):
        self.tile = tile
        self.channel = channel

        config.check_init()
        max_core_dma_channel = (
            config.MAX_CORE_S2MM_DMA_CHANNEL if self.channel.dir == DmaDir.S2MM else
            config.MAX_CORE_MM2S_DMA_CHANNEL
        )
        max_memtile_dma_channel = (
            config.MAX_MEMTILE_S2MM_DMA_CHANNEL if self.channel.dir == DmaDir.S2MM else
            config.MAX_MEMTILE_MM2S_DMA_CHANNEL
        )
        max_shim_dma_channel = (
            config.MAX_SHIM_S2MM_DMA_CHANNEL if self.channel.dir == DmaDir.S2MM else
            config.MAX_SHIM_MM2S_DMA_CHANNEL
        )
        if self.tile.type == TileType.Core:
            if not (0 <= self.channel.id <= max_core_dma_channel):
                raise ValueError('Invalid core DMA channel!')
        elif self.tile.type == TileType.Memtile:
            if not (0 <= self.channel.id <= max_memtile_dma_channel):
                raise ValueError('Invalid memtile DMA channel!')
        else:
            if not (0 <= self.channel.id <= max_shim_dma_channel):
                raise ValueError('Invalid shim DMA channel!')

    # Make AieDma hashable

    def __str__(self) -> str:
        return str(self.tile) + '_' + str(self.channel)

    def __eq__(self, other) -> bool:
        return (self.tile == other.tile and
                self.channel == other.channel)

    def __hash__(self) -> int:
        return hash(str(self))


class DmaConnection:
    __slots__ = ('read_dma', 'write_dma')

    def __init__(self, read_dma: AieDma, write_dma: AieDma):
        self.read_dma = read_dma
        self.write_dma = write_dma

        config.check_init()
        valid_dir = (read_dma.channel.dir == DmaDir.MM2S and
                     write_dma.channel.dir == DmaDir.S2MM)
        if not valid_dir:
            raise ValueError('Invalid DMA direction in stream connection')
        valid_tiles = ((read_dma.tile.type == TileType.Shim and
                        write_dma.tile.type == TileType.Memtile) or
                       (read_dma.tile.type == TileType.Memtile and
                        write_dma.tile.type == TileType.Core) or
                       (read_dma.tile.type == TileType.Core and
                        write_dma.tile.type == TileType.Memtile) or
                       (read_dma.tile.type == TileType.Memtile and
                        write_dma.tile.type == TileType.Shim))
        if not valid_tiles:
            raise ValueError('Invalid tile locations for stream connection!')

    # Make DmaConnection hashable

    def __str__(self) -> str:
        return str(self.read_dma) + '_' + str(self.write_dma)

    def __eq__(self, other) -> bool:
        return (self.read_dma == other.read_dma and
                self.write_dma == other.write_dma)

    def __hash__(self) -> int:
        return hash(str(self))


class CoreConnection:
    __slots__ = ('src_core', 'dst_core')

    def __init__(self, src_core: AieTile, dst_core: AieTile):
        self.src_core = src_core
        self.dst_core = dst_core

        config.check_init()
        assert self.src_core.type == TileType.Core
        assert self.dst_core.type == TileType.Core
        if self.src_core == self.dst_core:
            raise RuntimeError(f"Invalid core connection from Src {src_core} to Dst {dst_core}!")

    # Make CoreConnection hashable

    def __str__(self) -> str:
        return str(self.src_core) + '_' + str(self.dst_core)

    def __eq__(self, other) -> bool:
        return ((self.src_core == other.src_core) and
                (self.dst_core == other.dst_core))

    def __hash__(self) -> int:
        return hash(str(self))


class OverlayShape:
    __slots__ = ('num_cols', 'num_rows', 'start_col', 'start_row')

    def __init__(
        self,
        num_cols: int,
        num_rows: int,
        start_col: int = 0,
        start_row: int = 0
    ):
        self.num_cols = num_cols
        self.num_rows = num_rows
        self.start_col = start_col
        self.start_row = start_row

        config.check_init()
        if not (0 <= self.start_col and self.start_col + self.num_cols <= config.NUM_AIE_COLS):
            raise ValueError('Invalid overlay column shape!')
        if not (0 <= self.start_row and self.start_row + self.num_rows <= config.NUM_AIE_ROWS):
            raise ValueError('Invalid overlay column shape!')

    # Make OverlayShape hashable

    def __str__(self) -> str:
        return f'({self.num_cols}, {self.num_rows}, {self.start_col}, {self.start_row})'

    def __eq__(self, other) -> bool:
        return (self.num_cols == other.num_cols and
                self.num_rows == other.num_rows and
                self.start_col == other.start_col and
                self.start_row == other.start_row)

    def __hash__(self) -> int:
        return hash(str(self))


def core_tile(col: int, row: int) -> AieTile:
    return AieTile(TileType.Core, col, row)


def memory_tile(col: int) -> AieTile:
    return AieTile(TileType.Memtile, col)


def shim_tile(col: int) -> AieTile:
    return AieTile(TileType.Shim, col)


def core_dma(col: int, row: int, dir: DmaDir, id: int) -> AieDma:
    return AieDma(AieTile(TileType.Core, col, row), DmaChannel(dir, id))


def memtile_dma(col: int, dir: DmaDir, id: int) -> AieDma:
    return AieDma(AieTile(TileType.Memtile, col), DmaChannel(dir, id))


def shim_dma(col: int, dir: DmaDir, id: int) -> AieDma:
    return AieDma(AieTile(TileType.Shim, col), DmaChannel(dir, id))


################################################################################
#
# User API for High-Level Data Transfer IR
#
################################################################################


class BackEnd(Enum):
    Adf = 0
    TxnHostPatch = 1
    CertAsm = 2


class SyncStrategy(Enum):
    Default = 0
    Parallel_1_to_N = 1
    Parallel_N_to_1 = 2
    Serial_M_to_N = 3
    Async = 4
    Remote_Barrier = 5

def check_access_pattern_design_rules(
    dma: AieDma,
    length: int,
    offset: int,
    step: List[int],
    wrap: List[int],
    padding: List[Tuple[int, int]],
    iter_step: Optional[int],
    iter_wrap: Optional[int],
    shim_buffer_index: Optional[int]
):
    if dma.tile.type not in (TileType.Memtile, TileType.Shim):
        raise ValueError('Invalid tile location for transfer parameter!')
    if offset < 0:
        raise ValueError('Invalid address offset!')

    if len(step) == 0:
        raise ValueError('Invalid number of step dimensions!')
    if len(step) not in (len(wrap), len(wrap) + 1):
        raise ValueError('Number of step and wrap dimensions are incompatible!')
    if (config.DEV_GEN == DevGen.Aie4) and (step[0] != 1):
        raise ValueError(f'Invalid dim zero step, must be 1!')

    if iter_wrap is not None and not (1 <= iter_wrap <= config.MAX_ITER_WRAP):
        raise ValueError('Invalid iteration wrap!')

    if dma.tile.type == TileType.Memtile:

        if not (0 <= length <= config.MAX_MEMTILE_BUFFER_LENGTH):
            raise ValueError('Invalid memtile transfer length!')
        if len(step) > config.MAX_MEMTILE_DIMS or len(wrap) > config.MAX_MEMTILE_DIMS - 1:
            raise ValueError('Invalid number of memtile transfer dimensions!')
        for s in step:
            if not (config.MIN_STEP_SIZE <= s <= config.MAX_MEMTILE_STEP):
                raise ValueError('Invalid memtile transfer step!')
        for w in wrap:
            if not (1 <= w <= config.MAX_MEMTILE_WRAP):
                raise ValueError('Invalid memtile transfer wrap!')
        if iter_step is not None and not (config.MIN_STEP_SIZE <= iter_step <= config.MAX_MEMTILE_STEP):
            raise ValueError('Invalid memtile transfer iteration step!')
        if len(padding) > config.MAX_MEMTILE_PAD_DIMS:
            raise ValueError('Invalid memtile transfer padding dimensions!')
        if len(padding) > 0 and dma.channel.dir != DmaDir.MM2S:
            raise ValueError('Invalid DMA direction for constant padding!')
        padding_invalid = ((len(padding) > 0 and not (0 <= padding[0][0] <= config.MAX_MEMTILE_D0_PAD)) or
                           (len(padding) > 0 and not (0 <= padding[0][1] <= config.MAX_MEMTILE_D0_PAD)) or
                           (len(padding) > 1 and not (0 <= padding[1][0] <= config.MAX_MEMTILE_D1_PAD)) or
                           (len(padding) > 1 and not (0 <= padding[1][1] <= config.MAX_MEMTILE_D1_PAD)) or
                           (len(padding) > 2 and not (0 <= padding[2][0] <= config.MAX_MEMTILE_D2_PAD)) or
                           (len(padding) > 2 and not (0 <= padding[2][1] <= config.MAX_MEMTILE_D2_PAD)) or
                           (len(padding) > config.MAX_MEMTILE_PAD_DIMS))
        for _, after in padding[len(wrap):]:
            if after != 0:
                padding_invalid = True
        if padding_invalid:
            raise ValueError('Invalid memtile constant padding!')
        if shim_buffer_index is not None:
            raise ValueError('Invalid shim buffer index for memtile!')

    else:

        if not (0 <= length <= config.MAX_SHIM_BUFFER_LENGTH):
            raise ValueError('Invalid shim transfer length!')
        if len(step) > config.MAX_SHIM_DIMS or len(wrap) > config.MAX_SHIM_DIMS - 1:
            raise ValueError('Invalid number of shim transfer dimensions!')
        for s in step:
            if not (config.MIN_STEP_SIZE <= s <= config.MAX_SHIM_STEP):
                raise ValueError('Invalid shim transfer step!')
        for w in wrap:
            if not (1 <= w <= config.MAX_SHIM_WRAP):
                raise ValueError('Invalid shim transfer wrap!')
        if iter_step is not None and not (config.MIN_STEP_SIZE <= iter_step <= config.MAX_SHIM_STEP):
            raise ValueError('Invalid shim transfer iteration step!')
        if len(padding) > 0:
            raise ValueError('Invalid shim transfer padding!')
        if (shim_buffer_index is not None and
            not (0 <= shim_buffer_index <= config.MAX_SHIM_ADDR)):
            raise ValueError('Invalid shim buffer index!')


class TransferParams:
    __slots__ = ('dma', '_length', '_offset', '_step', '_wrap', '_padding',
                 '_iter_step', '_iter_wrap',
                 'shim_buffer_index', 'name')

    def __init__(
        self,
        dma: AieDma,
        length: Union[List[int], int],
        offset: Union[List[int], int] = 0,
        step: Union[List[List[int]], List[int]] = [1],
        wrap: Union[List[List[int]], List[int]] = [],
        padding: Union[List[List[Tuple[int, int]]], List[Tuple[int, int]]] = [],
        iter_step: Union[List[Optional[int]], Optional[int]] = None,
        iter_wrap: Union[List[Optional[int]], Optional[int]] = None,
        shim_buffer_index: Optional[int] = None,
        name: Optional[str] = "No BD Name",
    ):
        self.dma = dma
        self._length = length
        self._offset = offset
        self._step = step
        self._wrap = wrap
        self._padding = padding
        self._iter_step = iter_step
        self._iter_wrap = iter_wrap
        self.shim_buffer_index = shim_buffer_index
        self.name = name

        config.check_init()
        for i in range(self._num_reconfig()):
            check_access_pattern_design_rules(
                self.dma,
                self.length_i(i),
                self.offset_i(i),
                self.step_i(i),
                self.wrap_i(i),
                self.padding_i(i),
                self.iter_step_i(i),
                self.iter_wrap_i(i),
                self.shim_buffer_index,
            )

    def _num_reconfig(self) -> int:
        return max(
            len(self._length) if isinstance(self._length, list) else 1,
            len(self._offset) if isinstance(self._offset, list) else 1,
            len(self._step) if any(isinstance(elem, list) for elem in self._step) else 1,
            len(self._wrap) if any(isinstance(elem, list) for elem in self._wrap) else 1,
            len(self._padding) if any(isinstance(elem, list) for elem in self._padding) else 1,
            len(self._iter_step) if isinstance(self._iter_step, list) else 1,
            len(self._iter_wrap) if isinstance(self._iter_wrap, list) else 1,
        )

    def length_i(self, iter: int) -> int:
        if isinstance(self._length, list):
            return self._length[iter]
        return self._length

    def offset_i(self, iter: int) -> int:
        if isinstance(self._offset, list):
            return self._offset[iter]
        return self._offset

    def step_i(self, iter: int) -> List[int]:
        if any(isinstance(elem, list) for elem in self._step):
            return self._step[iter]
        return self._step

    def wrap_i(self, iter: int) -> List[int]:
        if any(isinstance(elem, list) for elem in self._wrap):
            return self._wrap[iter]
        return self._wrap

    def padding_i(self, iter: int) -> List[Tuple[int, int]]:
        if any(isinstance(elem, list) for elem in self._padding):
            return self._padding[iter]
        return self._padding

    def iter_step_i(self, iter: int) -> Optional[int]:
        if isinstance(self._iter_step, list):
            return self._iter_step[iter]
        return self._iter_step

    def iter_wrap_i(self, iter: int) -> Optional[int]:
        if isinstance(self._iter_wrap, list):
            return self._iter_wrap[iter]
        return self._iter_wrap


class DataTransfer:
    __slots__ = ('repeat_counts', 'tile', 'buffer_addrs', 'buffer_size',
                 'write_params', 'read_params',
                 'sync_strategy', 'reuse_ratio', 'buffer_split')

    def __init__(
        self,
        repeat_counts: List[int],
        tile: AieTile,
        buffer_addrs: List[int],
        buffer_size: int,
        write_params: List[TransferParams],
        read_params: List[TransferParams],
        sync_strategy: SyncStrategy = SyncStrategy.Default,
        reuse_ratio: int = 1,
        buffer_split: int = 1
    ):
        self.repeat_counts = repeat_counts
        self.tile = tile
        self.buffer_addrs = buffer_addrs
        self.buffer_size = buffer_size
        self.write_params = write_params
        self.read_params = read_params
        self.sync_strategy = sync_strategy
        self.reuse_ratio = reuse_ratio
        self.buffer_split = buffer_split

        config.check_init()
        for count in self.repeat_counts:
            if count < 0:
                raise ValueError('Invalid core repeat count!')
        if self.tile.type not in (TileType.Memtile, TileType.Shim):
            raise ValueError('Invalid tile location for data transfer!')
        if self.tile.type == TileType.Memtile:
            for param in (self.write_params + self.read_params):
                max_neighbor_channel = (
                    config.MAX_MEMTILE_S2MM_NEIGHBOR_CHANNEL if param.dma.channel.dir == DmaDir.S2MM else
                    config.MAX_MEMTILE_MM2S_NEIGHBOR_CHANNEL
                )
                diff = abs(tile.col - param.dma.tile.col)
                is_invalid = ((diff > config.MAX_NEIGHBOR_ACCESS) or
                              ((diff > 0) and (param.dma.channel.id > max_neighbor_channel)))
                if is_invalid:
                    raise ValueError('Invalid neighbor access!')
        else:
            for param in (self.write_params + self.read_params):
                if self.tile != param.dma.tile:
                    raise ValueError('Invalid tile type for neighbor access!')
        if self.tile.type == TileType.Memtile:
            for addr in self.buffer_addrs:
                if addr % config.MEMTILE_ADDR_GRAN != 0:
                    raise ValueError(f'Buffer address should be {config.MEMTILE_ADDR_GRAN} bits aligned!')
                max_neighbor_col = min(config.MAX_NEIGHBOR_ACCESS, (config.NUM_AIE_COLS - 1) - self.tile.col)
                max_memtile_addr = ((max_neighbor_col + 1) * (config.MAX_MEMTILE_ADDR + 1)) - 1
                if (addr < 0) or addr + self.buffer_size - 1 > max_memtile_addr:
                    raise ValueError('Invalid memtile address!')
        else:
            for addr in self.buffer_addrs:
                if addr < 0 or addr > config.MAX_SHIM_ADDR:
                    raise ValueError('Invalid shim address!')
        for param in self.write_params:
            if param.dma.channel.dir != DmaDir.S2MM:
                raise ValueError('Invalid DMA direction for write transfer parameter!')
        for param in self.read_params:
            if param.dma.channel.dir != DmaDir.MM2S:
                raise ValueError('Invalid DMA direction for read transfer parameter!')
        if not (1 <= self.reuse_ratio <= config.MAX_LOCK_VALUE):
            raise ValueError('Invalid reuse ratio!')
        if self.reuse_ratio > 1 and len(self.buffer_addrs) > 1:
            raise ValueError('Invalid reuse ratio for multiple buffering scheme!')
        if self.buffer_split < 1:
            raise ValueError('Invalid buffer split!')
        for param in self.write_params + self.read_params:
            if param._num_reconfig() not in (1, len(self.repeat_counts)):
                raise ValueError('Invalid number of transfer param reconfigs!')
        for param in self.write_params + self.read_params:
            if param._num_reconfig() == 1:
                is_iteration_invalid = (
                    sum(self.repeat_counts) > config.MAX_ITER_WRAP and
                    param.iter_step_i(0) is not None and
                    param.iter_wrap_i(0) is None
                )
            else:
                is_iteration_invalid = any(
                    self.repeat_counts[i] > config.MAX_ITER_WRAP and
                    param.iter_step_i(i) is not None and
                    param.iter_wrap_i(i) is None
                    for i in range(param._num_reconfig())
                )
            if is_iteration_invalid:
                raise ValueError('Invalid repeat count used with iteration step!')


################################################################################
#
# User API for Core Instructions
#
################################################################################


def uint16_to_bytes(x: int) -> bytes:
    assert 0 <= x <= 2**16 - 1
    return x.to_bytes(length=2, byteorder='little', signed=False)


class CoreInstr(ABC):
    __slots__ = ()

    LOOP_OP                = 0
    BUFFER_ACQ_OP          = 1
    BUFFER_REL_OP          = 2
    BD_CONFIG_OP           = 3
    KERNEL_CALL_OP         = 4
    KERNEL_CALL_IN0_IN1_OP = 5

    @abstractmethod
    def to_bytes(
        self,
        kernel_ids: Dict[str, int],
        buffer_ids: Dict[DmaChannel, int]
    ) -> bytes:
        pass


class Loop(CoreInstr):
    __slots__ = ('num_iters', 'loop_body')

    def __init__(
        self,
        num_iters: int,
        loop_body: List[Type[CoreInstr]]
    ):
        self.num_iters = num_iters
        self.loop_body = loop_body

        config.check_init()
        if self.num_iters < 0:
            raise ValueError('Invalid number of loop iterations!')

    def to_bytes(
        self,
        kernel_ids: Dict[str, int],
        buffer_ids: Dict[DmaChannel, int]
    ) -> bytes:
        body = b''
        for instr in self.loop_body:
            body += instr.to_bytes(kernel_ids, buffer_ids)
        asm = (uint16_to_bytes(CoreInstr.LOOP_OP) +
               uint16_to_bytes(self.num_iters) +
               uint16_to_bytes(len(body)) +
               uint16_to_bytes(0) +
               body)
        return asm


class AcqBuffer(CoreInstr):
    __slots__ = ('dma_channel', 'disable')

    def __init__(
        self,
        dma_channel: DmaChannel,
        disable: bool = False,
    ):
        self.dma_channel = dma_channel
        self.disable = disable

        config.check_init()

    def to_bytes(
        self,
        kernel_ids: Dict[str, int],
        buffer_ids: Dict[DmaChannel, int]
    ) -> bytes:
        asm = (uint16_to_bytes(CoreInstr.BUFFER_ACQ_OP) +
               uint16_to_bytes(buffer_ids[self.dma_channel]) +
               uint16_to_bytes(0) +
               uint16_to_bytes(0))
        if self.disable:
            asm = b''
        return asm


class RelBuffer(CoreInstr):
    __slots__ = ('dma_channel', 'disable')

    def __init__(
        self,
        dma_channel: DmaChannel,
        disable: bool = False,
    ):
        self.dma_channel = dma_channel
        self.disable = disable

        config.check_init()

    def to_bytes(
        self,
        kernel_ids: Dict[str, int],
        buffer_ids: Dict[DmaChannel, int]
    ) -> bytes:
        asm = (uint16_to_bytes(CoreInstr.BUFFER_REL_OP) +
               uint16_to_bytes(buffer_ids[self.dma_channel]) +
               uint16_to_bytes(0) +
               uint16_to_bytes(0))
        if self.disable:
            asm = b''
        return asm


class ConfigBuffer(CoreInstr):
    __slots__ = ('dma_channel', 'ping_addr', 'pong_addr', 'buffer_size', 'offset', 'step', 'wrap', 'repeat_count')

    DISABLE_PONG_ADDR = 0xFFFF

    def __init__(
        self,
        dma_channel: DmaChannel,
        ping_addr: int,
        pong_addr: Optional[int],
        buffer_size: int,
        offset: int = 0,
        step: List[int] = [1],
        wrap: List[int] = []
    ):
        self.dma_channel = dma_channel
        self.ping_addr = ping_addr
        self.pong_addr = pong_addr
        self.buffer_size = buffer_size
        self.offset = offset
        self.step = step
        self.wrap = wrap
        self.repeat_count: Optional[int] = None

        config.check_init()
        if not (0 <= self.ping_addr and self.ping_addr + self.buffer_size <= config.MAX_CORE_ADDR):
            raise ValueError('Invalid ping address for core buffer config!')
        if self.ping_addr % config.MIN_CORE_BUFFER_ALIGNMENT != 0:
            raise ValueError('Invalid core ping address alignment!')

        if not ((self.pong_addr is None) or
                (0 <= self.pong_addr and self.pong_addr + self.buffer_size <= config.MAX_CORE_ADDR)):
            raise ValueError('Invalid pong address for core buffer config!')
        if self.pong_addr is not None and self.pong_addr % config.MIN_CORE_BUFFER_ALIGNMENT != 0:
            raise ValueError('Invalid core pong address alignment!')

        if not (0 <= self.buffer_size // 4 <= config.MAX_CORE_BUFFER_LENGTH):
            raise ValueError('Invalid core transfer length!')

        if len(self.step) > config.MAX_CORE_DIMS or len(self.wrap) > config.MAX_CORE_DIMS - 1:
            raise ValueError('Invalid number of core transfer dimensions!')
        if len(self.step) not in (len(self.wrap), len(self.wrap) + 1):
            raise ValueError('Number of step and wrap dimensions are incompatible!')
        for s in self.step:
            if not (config.MIN_STEP_SIZE <= s <= config.MAX_CORE_STEP):
                raise ValueError('Invalid core transfer step!')
        for w in self.wrap:
            if not (1 <= w <= config.MAX_CORE_WRAP):
                raise ValueError('Invalid core transfer wrap!')

    def to_bytes(
        self,
        kernel_ids: Dict[str, int],
        buffer_ids: Dict[DmaChannel, int]
    ) -> bytes:
        assert self.repeat_count is not None
        word_size = 4
        ping_encoding = self.ping_addr // word_size
        if self.pong_addr is None:
            pong_encoding = ConfigBuffer.DISABLE_PONG_ADDR
        else:
            pong_encoding = self.pong_addr // word_size
        length_encoding = self.buffer_size // word_size
        step = self.step + [1] * (config.MAX_CORE_DIMS - len(self.step))
        wrap = self.wrap + [0] * (config.MAX_CORE_DIMS - len(self.wrap))
        asm = (uint16_to_bytes(CoreInstr.BD_CONFIG_OP) +
               uint16_to_bytes(self.repeat_count) +
               uint16_to_bytes(buffer_ids[self.dma_channel]) +
               uint16_to_bytes(ping_encoding) +
               uint16_to_bytes(pong_encoding) +
               uint16_to_bytes(length_encoding) +
               uint16_to_bytes(self.offset) +
               uint16_to_bytes(step[0]) +
               uint16_to_bytes(step[1]) +
               uint16_to_bytes(step[2]) +
               uint16_to_bytes(wrap[0]) +
               uint16_to_bytes(wrap[1]))
        return asm


class CallKernel(CoreInstr):
    __slots__ = ('opcode', 'kernel_name', 'kernel_params')

    def __init__(
        self,
        kernel_name: str,
        kernel_params: bytes = b''
    ):
        param_align = 8
        self.opcode = CoreInstr.KERNEL_CALL_OP
        self.kernel_name = kernel_name
        self.kernel_params = kernel_params + (b'\x00' * (len(kernel_params) % param_align))

        config.check_init()

    def to_bytes(
        self,
        kernel_ids: Dict[str, int],
        buffer_ids: Dict[DmaChannel, int]
    ) -> bytes:
        asm = (uint16_to_bytes(self.opcode) +
               uint16_to_bytes(kernel_ids[self.kernel_name]) +
               uint16_to_bytes(len(self.kernel_params)) +
               uint16_to_bytes(0) +
               self.kernel_params)
        return asm


################################################################################
#
# Internal Low-Level Data Transfer IR
#
################################################################################


class Lock:
    __slots__ = ('aie_tile', 'id', 'init_value')

    def __init__(
        self,
        aie_tile: AieTile,
        id: int,
        init_value: int
    ):
        self.aie_tile = aie_tile
        self.id = id
        self.init_value = init_value

        # Check for invalid lock allocations

        assert 0 <= self.init_value <= config.MAX_LOCK_VALUE
        if self.aie_tile.type == TileType.Core:
            assert 0 <= self.id <= config.MAX_CORE_LOCK_ID
        elif self.aie_tile.type == TileType.Memtile:
            assert 0 <= self.id <= config.MAX_MEMTILE_LOCK_ID
        else:
            assert 0 <= self.id <= config.MAX_SHIM_LOCK_ID

    # __str__ to be used for name in generated code

    def __str__(self) -> str:
        return f'lock_{self.aie_tile}_id{self.id}'


class BufferDescriptor:
    __slots__ = ('aie_dma', 'id', 'buffer_addr',
                 '_offset', '_length', '_step', '_wrap', '_padding', '_iter_step', '_iter_wrap',
                 'lock_enable', 'lock_acq', 'lock_acq_value', 'lock_rel', 'lock_rel_value',
                 'use_next_bd', 'next_bd',
                 'packet_enable', 'packet_id', 'is_lock_bd', 'fold', 'name', 'barrier_id')

    def __init__(
        self,
        aie_dma: AieDma,
        id: int,
        buffer_addr: int = 0,
        offset: Union[List[int], int] = 0,
        length: Union[List[int], int] = 0,
        step: Union[List[List[int]], List[int]] = [1],
        wrap: Union[List[List[int]], List[int]] = [],
        padding: Union[List[List[Tuple[int, int]]], List[Tuple[int, int]]] = [],
        iter_step: Union[List[Optional[int]], Optional[int]] = None,
        iter_wrap: Union[List[Optional[int]], Optional[int]] = None,
        lock_enable: bool = False,
        lock_acq: Optional[Lock] = None,
        lock_acq_value: int = 0,
        lock_rel: Optional[Lock] = None,
        lock_rel_value: int = 0,
        use_next_bd: bool = False,
        next_bd = None,
        packet_enable: bool = False,
        packet_id: Optional[int] = None,
        is_lock_bd: Optional[bool] = False,
        fold: Optional[int] = 0,
        name: Optional[str] = "No BD Name",
        barrier_id: Optional[int] = None,
    ):
        self.aie_dma = aie_dma
        self.id = id
        self.buffer_addr = buffer_addr
        self._offset = offset
        self._length = length
        self._step = step
        self._wrap = wrap
        self._padding = padding
        self._iter_step = iter_step
        self._iter_wrap = iter_wrap
        self.lock_enable = lock_enable
        self.lock_acq = lock_acq
        self.lock_acq_value = lock_acq_value
        self.lock_rel = lock_rel
        self.lock_rel_value = lock_rel_value
        self.use_next_bd = use_next_bd
        self.next_bd = next_bd
        self.packet_enable = packet_enable
        self.packet_id = packet_id
        self.is_lock_bd = is_lock_bd
        self.fold = fold
        self.name = name
        self.barrier_id = barrier_id

        # Check for invalid BD allocations
        # TODO: Give a better error message for invalid lock values
        assert -config.MAX_LOCK_VALUE <= self.lock_acq_value <= config.MAX_LOCK_VALUE
        assert 0 <= self.lock_rel_value <= config.MAX_LOCK_VALUE

    def length_i(self, iter: int) -> int:
        if isinstance(self._length, list):
            return self._length[iter]
        return self._length

    def offset_i(self, iter: int) -> int:
        if isinstance(self._offset, list):
            return self._offset[iter]
        return self._offset

    def step_i(self, iter: int) -> List[int]:
        if any(isinstance(elem, list) for elem in self._step):
            return self._step[iter]
        return self._step

    def wrap_i(self, iter: int) -> List[int]:
        if any(isinstance(elem, list) for elem in self._wrap):
            return self._wrap[iter]
        return self._wrap

    def padding_i(self, iter: int) -> List[Tuple[int, int]]:
        if any(isinstance(elem, list) for elem in self._padding):
            return self._padding[iter]
        return self._padding

    def iter_step_i(self, iter: int) -> Optional[int]:
        if isinstance(self._iter_step, list):
            return self._iter_step[iter]
        return self._iter_step

    def iter_wrap_i(self, iter: int) -> Optional[int]:
        if isinstance(self._iter_wrap, list):
            return self._iter_wrap[iter]
        return self._iter_wrap


class BufferTask:
    __slots__ = ('buffer_descriptor', 'repeat_count')

    def __init__(
        self,
        buffer_descriptor: BufferDescriptor,
        repeat_count: int
    ):
        self.buffer_descriptor = buffer_descriptor
        self.repeat_count = repeat_count

        assert 1 <= repeat_count <= config.MAX_REPEAT_COUNT


class DataBuffer:
    __slots__ = ('buffer_descriptors', 'locks', 'buffer_tasks')

    def __init__(
        self,
        buffer_descriptors: List[BufferDescriptor],
        locks: List[Lock],
        buffer_tasks: List[List[BufferTask]]
    ):
        self.buffer_descriptors = buffer_descriptors
        self.locks = locks
        self.buffer_tasks = buffer_tasks


################################################################################
#
# Internal Low-Level Control Code Sequence
#
################################################################################


class BdConfig:
    __slots__ = ('bd', 'iter')

    def __init__(self, bd: BufferDescriptor, iter: int):
        self.bd = bd
        self.iter = iter

    def needs_reconfig(self, other) -> bool:
        # NOTE: For now, we assume reconfigurations won't
        # change the chaining and locking, only the access pattern.
        assert self.bd.aie_dma == other.bd.aie_dma
        assert self.bd.id == other.bd.id
        assert self.bd.buffer_addr == other.bd.buffer_addr
        assert self.bd.lock_enable == other.bd.lock_enable
        assert self.bd.lock_acq == other.bd.lock_acq
        assert self.bd.lock_acq_value == other.bd.lock_acq_value
        assert self.bd.lock_rel == other.bd.lock_rel
        assert self.bd.lock_rel_value == other.bd.lock_rel_value
        assert self.bd.use_next_bd == other.bd.use_next_bd
        assert (
            (self.bd.next_bd is None and other.bd.next_bd is None) or
            (self.bd.next_bd.id == other.bd.next_bd.id)
        )
        return (
            (self.length_i()    != other.length_i()) or
            (self.offset_i()    != other.offset_i()) or
            (self.step_i()      != other.step_i()) or
            (self.wrap_i()      != other.wrap_i()) or
            (self.padding_i()   != other.padding_i()) or
            (self.iter_step_i() != other.iter_step_i()) or
            (self.iter_wrap_i() != other.iter_wrap_i())
        )

    # __str__ to be used for BD name in generated code

    def __str__(self) -> str:
        config_str = f'_config{self.iter}' if self.iter > 0 else ''
        fold_str = f'_fold{self.bd.fold}' if self.bd.fold > 0 else ''
        return 'bd_' + str(self.bd.aie_dma.tile) + '_id' + str(self.bd.id) + config_str + fold_str

    def length_i(self) -> int:
        return self.bd.length_i(self.iter)

    def offset_i(self) -> int:
        return self.bd.offset_i(self.iter)

    def step_i(self) -> List[int]:
        return self.bd.step_i(self.iter)

    def wrap_i(self) -> List[int]:
        return self.bd.wrap_i(self.iter)

    def padding_i(self) -> List[Tuple[int, int]]:
        return self.bd.padding_i(self.iter)

    def iter_step_i(self) -> Optional[int]:
        return self.bd.iter_step_i(self.iter)

    def iter_wrap_i(self) -> Optional[int]:
        return self.bd.iter_wrap_i(self.iter)


class ControlOpVisitor:
    __slots__ = ()

    @abstractmethod
    def visit_config_buffer_descriptor(self, op):
        pass

    @abstractmethod
    def visit_patch_ddr_addr(self, op):
        pass

    @abstractmethod
    def visit_set_lock_value(self, op):
        pass

    @abstractmethod
    def visit_enqueue_task(self, op):
        pass

    @abstractmethod
    def visit_wait_dma_done(self, op):
        pass

    @abstractmethod
    def visit_aqcuire_lock(self, op):
        pass

    @abstractmethod
    def visit_remote_barrier(self, op):
        pass


class ControlOp(ABC):
    __slots__ = ()

    @abstractmethod
    def apply(self, visitor: Type[ControlOpVisitor]):
        pass


class ConfigBufferDescriptor(ControlOp):
    __slots__ = ('cfg')

    def __init__(self, cfg: BdConfig):
        self.cfg = cfg

    def apply(self, visitor: Type[ControlOpVisitor]) -> Any:
        return visitor.visit_config_buffer_descriptor(self)


class PatchDdrAddr(ControlOp):
    __slots__ = ('cfg')

    def __init__(self, cfg: BdConfig):
        self.cfg = cfg

    def apply(self, visitor: Type[ControlOpVisitor]) -> Any:
        return visitor.visit_patch_ddr_addr(self)


class SetLockValue(ControlOp):
    __slots__ = ('lock')

    def __init__(self, lock: Lock):
        self.lock = lock

    def apply(self, visitor: Type[ControlOpVisitor]) -> Any:
        return visitor.visit_set_lock_value(self)


class EnqueueTask(ControlOp):
    __slots__ = ('cfg_chain', 'repeat_count')

    def __init__(self, cfg_chain: List[BdConfig], repeat_count: int):
        self.cfg_chain = cfg_chain
        self.repeat_count = repeat_count

    def apply(self, visitor: Type[ControlOpVisitor]) -> Any:
        return visitor.visit_enqueue_task(self)


class WaitDmaDone(ControlOp):
    __slots__ = ('dma')

    def __init__(self, dma: AieDma):
        self.dma = dma

    def apply(self, visitor: Type[ControlOpVisitor]) -> Any:
        return visitor.visit_wait_dma_done(self)


class AcquireLock(ControlOp):
    __slots__ = ('lock')

    def __init__(self, lock: Lock):
        self.lock = lock

    def apply(self, visitor: Type[ControlOpVisitor]) -> Any:
        return visitor.visit_aqcuire_lock(self)


class RemoteBarrier(ControlOp):
    __slots__ = ('id')

    def __init__(self, id: int):
        self.id = id

    def apply(self, visitor: Type[ControlOpVisitor]) -> Any:
        return visitor.visit_remote_barrier(self)


class LayerControl:
    __slots__ = ('control_pkts', 'startup_control', 'dataflow_phases', 'final_barrier')

    def __init__(self):
        self.control_pkts: List[Type[ControlOp]] = []
        self.startup_control: List[Type[ControlOp]] = []
        self.dataflow_phases: List[List[Type[ControlOp]]] = []
        self.final_barrier: List[Type[ControlOp]] = []
        

@dataclass
class DmaPaddingMap:
    """
    DMA Padding configuration.

    pad_value_per_element:
        User-facing padding value per data element.
        - bits_per_element == 8:
            int in [0, 255]  (int8/uint8 pattern)
        - bits_per_element == 16:
            * int   -> 16-bit int pattern (int16/uint16)
            * float -> bfloat16 value (we encode as bf16)
        - bits_per_element == 32:
            * int   -> 32-bit int pattern
            * float -> float32 value

    bits_per_element:
        Width of each element in bits (8, 16, or 32).

    enable_dma_pad:
        Whether DMA constant padding is enabled.

    pad_value:
        Internal 32-bit padding word used by HW.
    """

    pad_value_per_element: Union[int, float] = 0
    bits_per_element: int = 8
    enable_dma_pad: bool = False

    pad_value: int = field(init=False, repr=True)

    def __post_init__(self):

        def make_const_pad_word(bits: int, pad: Union[int, float]) -> int:
            """
            Convert per-element padding into a 32-bit constant-padding word.
            - 8-bit  ints  -> byte replicated 4x
            - 16-bit ints  -> halfword replicated 2x
            - 16-bit float -> treated as bfloat16 (top 16 bits of float32)
            - 32-bit int   -> raw 32-bit pattern
            - 32-bit float -> float32 pattern
            """
            if bits == 8:
                if not isinstance(pad, int):
                    raise TypeError("8-bit pad value must be an int")
                if not (0 <= pad <= 0xFF):
                    raise ValueError("8-bit pad must be in [0, 255]")
                b = pad & 0xFF
                return (b << 24) | (b << 16) | (b << 8) | b

            elif bits == 16:
                if isinstance(pad, int):
                    # int16/uint16 path
                    if not (0 <= pad <= 0xFFFF):
                        raise ValueError("16-bit pad (int) must be in [0, 65535]")
                    h = pad & 0xFFFF
                elif isinstance(pad, float):
                    # treat float pad as bfloat16:
                    # take float32 bits, keep top 16 as bf16 payload.
                    f32_bits = int.from_bytes(struct.pack("<f", pad), "little")
                    h = (f32_bits >> 16) & 0xFFFF
                else:
                    raise TypeError("16-bit pad must be int or float (bfloat16)")
                return (h << 16) | h

            elif bits == 32:
                if isinstance(pad, float):
                    packed = struct.pack("<f", pad)  # float32
                    return int.from_bytes(packed, "little") & 0xFFFFFFFF
                elif isinstance(pad, int):
                    return pad & 0xFFFFFFFF
                else:
                    raise TypeError("32-bit pad must be int or float")

            else:
                raise ValueError("bits_per_element must be 8, 16, or 32")

        # Validate bits
        if self.bits_per_element not in (8, 16, 32):
            raise ValueError(
                f"bits_per_element must be 8, 16, or 32, got {self.bits_per_element}"
            )

        if not self.enable_dma_pad:
            # Padding disabled -> fixed zero pad value
            self.pad_value = 0
            return

        # Compute the HW-format padding
        self.pad_value = make_const_pad_word(
            self.bits_per_element,
            self.pad_value_per_element,
        )
