'''
This module implements the mapping layer that lowers physical
BD allocations to a control code sequence.
External facing functions are documented below.

generate_layer_control - lowers the physical BD allocations to a list
of layer control opcodes. The opcodes are split into different lists
based on their meaning.
    - control_pkts contains operations to be assembled as control packet data
    - startup_control contains the operations to send the control packets
    - dataflow_phases contains the operations to run each dataflow subphase
    - final_barrier contains the operations to wait for the end of a layer
'''


from typing import List, Dict
from collections import OrderedDict, defaultdict

from .types import (
    DevGen,
    OverlayShape, DataBuffer, DmaConnection,
    AieTile, TileType, DmaDir, Lock, shim_dma,
    BufferDescriptor, BufferTask,
    ConfigBufferDescriptor, PatchDdrAddr, SetLockValue,
    EnqueueTask, WaitDmaDone, AcquireLock,
    BdConfig, LayerControl, RemoteBarrier,
)
from . import config


class BdAlloc:
    '''Represent a BD ID allocation assigned to a particular DMA channel'''
    __slots__ = ('dma', 'id')

    def __init__(self, bd: BufferDescriptor):
        self.dma = bd.aie_dma
        self.id = bd.id

    # Make BdAlloc hashable

    def __str__(self) -> str:
        return f'{self.dma}_{self.id}'

    def __eq__(self, other) -> bool:
        return (self.dma == other.dma) and (self.id == other.id)

    def __hash__(self) -> int:
        return hash(str(self))


def filter_tasks(data_buffers: List[DataBuffer], iter: int, dir: DmaDir, enable_task_queue_optimization: bool = False) -> List[BufferTask]:
    """
        Filters tasks based on DMA direction and optionally sorts them by fold
        while preserving order for equal fold values.
    """
    tasks =  [
        task
        for buffer in data_buffers
        for task in buffer.buffer_tasks[iter]
        if task.buffer_descriptor.aie_dma.channel.dir == dir
    ]
    
    if enable_task_queue_optimization:
        tasks = sorted(tasks, key=lambda task: task.buffer_descriptor.fold)

    return tasks


def compute_first_wait_iter(data_buffers: List[DataBuffer]) -> int:
    '''
    Find the phase where the first mask poll occurs

    This is used to determine which control can be offloaded
    to control packets. Other subphase control is done with the
    control processor.
    '''
    num_iters = len(data_buffers[0].buffer_tasks)
    for buffer in data_buffers:
        assert len(buffer.buffer_tasks) == num_iters
    prev_running_dmas = {}
    for iter in range(num_iters):
        write_tasks = filter_tasks(data_buffers, iter, DmaDir.S2MM)
        read_tasks = filter_tasks(data_buffers, iter, DmaDir.MM2S)
        for tasks in (write_tasks, read_tasks):
            for task in tasks:
                if task.buffer_descriptor.aie_dma in prev_running_dmas:
                    return iter
        for tasks in (write_tasks, read_tasks):
            for task in tasks:
                prev_running_dmas[task.buffer_descriptor.aie_dma] = None
    return num_iters


def bd_config_chain(
    bd: BufferDescriptor,
    curr_bd_configs: Dict[BdAlloc, BdConfig]
) -> List[BdConfig]:
    '''Extract the BD configuration list from chaining information'''
    cfg_chain = []
    head = bd
    while head is not None:
        cfg_chain.append(curr_bd_configs[BdAlloc(head)])
        head = head.next_bd
    return cfg_chain


def get_uc_idx(col: int):
    """LayerControl is a per column when IS_MULTI_UC"""
    if config.IS_MULTI_UC:
        return col
    return 0   


def generate_layer_control(
    shape: OverlayShape,
    data_buffers: List[DataBuffer],
    dma_connections: List[DmaConnection],
    enable_task_queue_optimization: bool,
) -> list[LayerControl]:
    '''
    Create the control sequence for a given layer.

    This function has four main parts
        1. Create control packet operations for initial BD/lock configuration and task enqueue
        2. Create startup control operations to send the control packets
        3. Create control operations for each phase of the dataflow
        4. Create the final barrier to synchronize on the end of a layer
    '''

    if config.DEV_GEN == DevGen.Aie2p:
        ctrl_channel_id = config.SHIM_CTRL_MM2S_CHANNEL_ID
        ctrl_bd_id = config.SHIM_CTRL_PKT_BD_ID
    elif config.DEV_GEN == DevGen.Aie4:
        # TODO: Replace dummy values with dedicated control channels
        ctrl_channel_id = 0
        ctrl_bd_id = -1
    else:
        assert False

    first_wait_iter = compute_first_wait_iter(data_buffers)

    first_bd_cfg_iter: Dict[BdAlloc, int] = {
        BdAlloc(bd): 0
        for buffer in data_buffers
        for bd in buffer.buffer_descriptors
    }
    for buffer in data_buffers:
        for iter in range(first_wait_iter):
            for task in buffer.buffer_tasks[iter]:
                head = task.buffer_descriptor
                while head is not None:
                    first_bd_cfg_iter[BdAlloc(head)] = iter
                    head = head.next_bd

    if config.IS_MULTI_UC:
        layer = [LayerControl() for _ in range(config.NUM_UC_USED)]
    else:
        layer = [LayerControl()]

    curr_bd_configs: Dict[BdAlloc, BdConfig] = {}

    #
    # Control Packet Creation
    #

    num_bds: Dict[AieTile, int] = {}
    num_locks: Dict[AieTile, int] = {}
    num_tasks: Dict[AieTile, int] = {}
    has_ctrl_config: Dict[AieTile, bool] = {}
    for type in (TileType.Memtile, TileType.Shim):
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            tile = AieTile(type, col)
            sorted_bds = sorted([bd
                                 for buffer in data_buffers
                                 for bd in buffer.buffer_descriptors
                                 if bd.aie_dma.tile == tile],
                                key=(lambda bd: bd.id))
            sorted_locks = sorted([lock
                                   for buffer in data_buffers
                                   for lock in buffer.locks
                                   if lock.aie_tile == tile],
                                  key=(lambda lock: lock.id))
            num_bds[tile] = len(sorted_bds)
            num_locks[tile] = len(sorted_locks)
            has_ctrl_config[tile] = (
                (len(sorted_bds) > 0) or (len(sorted_locks) > 0)
            )
            for bd in sorted_bds:
                # TODO: This would break for a cases where the first BD enqueue
                # has a different configuration than iteration 0. This code should
                # explicitly compute the first enqueue iteration.
                init_cfg = BdConfig(bd, first_bd_cfg_iter[BdAlloc(bd)])
                layer[get_uc_idx(col)].control_pkts.append(ConfigBufferDescriptor(init_cfg))
                curr_bd_configs[BdAlloc(bd)] = init_cfg
            for lock in sorted_locks:
                layer[get_uc_idx(col)].control_pkts.append(SetLockValue(lock))
            for iter in range(first_wait_iter):
                write_tasks = filter_tasks(data_buffers, iter, DmaDir.S2MM, enable_task_queue_optimization)
                read_tasks = filter_tasks(data_buffers, iter, DmaDir.MM2S, enable_task_queue_optimization)
                tile_tasks = [task
                              for tasks in (write_tasks, read_tasks)
                              for task in tasks
                              if task.buffer_descriptor.aie_dma.tile == tile]
                if iter == 0:
                    num_tasks[tile] = len(tile_tasks)
                else:
                    num_tasks[tile] += len(tile_tasks)
                for task in tile_tasks:
                    layer[get_uc_idx(col)].control_pkts.append(
                        EnqueueTask(bd_config_chain(task.buffer_descriptor, curr_bd_configs),
                                    task.repeat_count))
            # NOTE: This completes the ctrl handshake by setting
            # the lock to +0 when the ctrl packet transfer is complete.
            if has_ctrl_config[tile]:
                layer[get_uc_idx(col)].control_pkts.append(SetLockValue(Lock(AieTile(type, col), 0, +0)))

    #
    # Startup Control Creation
    #

    # NOTE: This initiates the ctrl handshake by setting the
    # locks to +1 at the beginning of the transaction.
    # If this tile has no ctrl packet configuration, then we just
    # initialize the lock to +0, which is required for the "NOP" lock
    # in parallel synchronization methods.
    for type in (TileType.Memtile, TileType.Shim):
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            if has_ctrl_config[AieTile(type, col)]:
                layer[get_uc_idx(col)].startup_control.append(SetLockValue(Lock(AieTile(type, col), 0, +1)))
    # NOTE: Here we assume the memory order is as follows.
    # Memtile Ctrl Pkts Col 0
    # Memtile Ctrl Pkts Col 1
    # ...
    # Shim Ctrl Pkts Col 0
    # Shim Ctrl Pkts Col 1
    # ...
    # The length of each transfer is determined by the
    # number of BD and lock configs for each tile.
    ctrl_offset = 0
    for type in (TileType.Memtile, TileType.Shim):
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            tile = AieTile(type, col)
            if has_ctrl_config[tile]:
                ctrl_length = (
                    (num_bds[tile] * config.BD_CONFIG_CTRL_PKT_WORDS) +
                    ((num_locks[tile] + 1) * config.LOCK_CONFIG_CTRL_PKT_WORDS) +
                    (num_tasks[tile] * config.TASK_ENQUEUE_CTRL_PKT_WORDS)
                )
                ctrl_bd = BufferDescriptor(
                    shim_dma(col, DmaDir.MM2S, ctrl_channel_id),
                    ctrl_bd_id,
                    buffer_addr=config.SHIM_CTRL_BUFFER_IDX,
                    offset=ctrl_offset,
                    length=ctrl_length,
                    name="Control Handshake",
                )
                ctrl_cfg = BdConfig(ctrl_bd, 0)
                layer[get_uc_idx(col)].startup_control.append(ConfigBufferDescriptor(ctrl_cfg))
                curr_bd_configs[BdAlloc(ctrl_bd)] = ctrl_cfg
                layer[get_uc_idx(col)].startup_control.append(PatchDdrAddr(ctrl_cfg))
                layer[get_uc_idx(col)].startup_control.append(EnqueueTask(bd_config_chain(ctrl_bd, curr_bd_configs), 1))
                ctrl_offset += ctrl_length
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            # NOTE: We cannot poll the shim control BD for completion,
            # since the control packets are enqueuing work. The handshake
            # lock acquire will indicate completion of control packets.
            if (type == TileType.Memtile) and has_ctrl_config[AieTile(type, col)]:
                layer[get_uc_idx(col)].startup_control.append(WaitDmaDone(shim_dma(col, DmaDir.MM2S, ctrl_channel_id)))
    # NOTE: This waits for the handshake to complete, since the final
    # control packet will set the lock to +0.
    for type in (TileType.Memtile, TileType.Shim):
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            if has_ctrl_config[AieTile(type, col)]:
                layer[get_uc_idx(col)].startup_control.append(AcquireLock(Lock(AieTile(type, col), 0, +0)))

    #
    # Dataflow Phase Creation
    #

    num_iters = len(data_buffers[0].buffer_tasks)
    for buffer in data_buffers:
        assert len(buffer.buffer_tasks) == num_iters

    prev_running_dmas = OrderedDict()
    for iter in range(first_wait_iter):
        write_tasks = filter_tasks(data_buffers, iter, DmaDir.S2MM)
        read_tasks = filter_tasks(data_buffers, iter, DmaDir.MM2S)
        for tasks in (write_tasks, read_tasks):
            for task in tasks:
                prev_running_dmas[task.buffer_descriptor.aie_dma] = None

    start_iter = first_wait_iter
    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        layer[get_uc_idx(col)].dataflow_phases = [[] for _ in range(num_iters)]
    for iter in range(start_iter, num_iters):
        iters_with_remote_barrier = defaultdict(set)
        curr_running_dmas = OrderedDict()
        curr_running_bd_ids = OrderedDict()
        write_tasks = filter_tasks(data_buffers, iter, DmaDir.S2MM, enable_task_queue_optimization)
        read_tasks = filter_tasks(data_buffers, iter, DmaDir.MM2S, enable_task_queue_optimization)
        for tasks in (write_tasks, read_tasks):
            for task in tasks:
                do_wait = (
                    (task.buffer_descriptor.aie_dma in prev_running_dmas) and
                    (task.buffer_descriptor.aie_dma not in curr_running_dmas)
                )
                col = task.buffer_descriptor.aie_dma.tile.col
                if do_wait:
                    layer[get_uc_idx(col)].dataflow_phases[iter].append(WaitDmaDone(task.buffer_descriptor.aie_dma))

                # NOTE: Here it is safe to reconfigure the BD, since any
                # outstanding tasks from previous iterations that reference
                # this BD are complete. We only reconfigure if the BD changes
                # from its current configuration. We also track the current
                # running BD allocations as an internal safety check.
                head = task.buffer_descriptor
                while head is not None:
                    bd_alloc = BdAlloc(head)
                    curr_cfg = curr_bd_configs[bd_alloc]
                    next_cfg = BdConfig(head, iter)
                    if curr_cfg.needs_reconfig(next_cfg):
                        assert bd_alloc not in curr_running_bd_ids
                        assert iter >= 1
                        layer[get_uc_idx(col)].dataflow_phases[iter].append(ConfigBufferDescriptor(next_cfg))
                        curr_bd_configs[bd_alloc] = next_cfg
                        do_patch = (
                            (head.aie_dma.tile.type == TileType.Shim)
                        )
                        if do_patch:
                            layer[get_uc_idx(col)].dataflow_phases[iter].append(PatchDdrAddr(next_cfg))
                    head = head.next_bd

                layer[get_uc_idx(col)].dataflow_phases[iter].append(
                    EnqueueTask(bd_config_chain(task.buffer_descriptor, curr_bd_configs),
                                task.repeat_count))

                if isinstance(task.buffer_descriptor.barrier_id, int):
                    iters_with_remote_barrier[iter].add(task.buffer_descriptor.barrier_id)

                curr_running_dmas[task.buffer_descriptor.aie_dma] = None
                head = task.buffer_descriptor
                while head is not None:
                    curr_running_bd_ids[BdAlloc(head)] = None
                    head = head.next_bd
        prev_running_dmas.update(curr_running_dmas)

        for iter_idx in iters_with_remote_barrier:
            for barrier_id in iters_with_remote_barrier[iter_idx]:
                for col_idx in range(shape.start_col, shape.start_col + shape.num_cols):
                    layer[get_uc_idx(col_idx)].dataflow_phases[iter_idx].append(RemoteBarrier(barrier_id))


    #
    # Final Barrier Creation
    #

    shim_out_dmas = [dma for cxn in dma_connections for dma in (cxn.read_dma, cxn.write_dma)
                     if ((cxn.read_dma.tile.type == TileType.Memtile) and
                         (cxn.write_dma.tile.type == TileType.Shim) and
                         (cxn.write_dma in prev_running_dmas))]
    memtile_out_dmas = [cxn.write_dma for cxn in dma_connections
                        if ((cxn.read_dma.tile.type == TileType.Core) and
                            (cxn.write_dma.tile.type == TileType.Memtile) and
                            (cxn.write_dma in prev_running_dmas))]
    if len(shim_out_dmas) > 0:
        final_wait_dmas = shim_out_dmas
    elif len(memtile_out_dmas) > 0:
        final_wait_dmas = memtile_out_dmas
    else:
        final_wait_dmas = [dma for dma in prev_running_dmas]
    for dma in final_wait_dmas:
        col = dma.tile.col
        layer[get_uc_idx(col)].final_barrier.append(WaitDmaDone(dma))

    # Remote Barrier on Final_Barrier (For L2 and L3 schedules)
    for col_idx in range(shape.start_col, shape.start_col + shape.num_cols):
        layer[get_uc_idx(col_idx)].final_barrier.append(RemoteBarrier(0))

    return layer
