'''
This module implements the behavioral simulator for dataflow deadlock detection.
External facing functions are documented below.

run_overlay_deadlock_check - runs a behavioral simulation of a dataflow.
This will raise a RuntimeError exception and dump the task queue state if a
deadlock is detected.
'''


from typing import List, Dict, Optional, Union, Type, Deque
from collections import OrderedDict, deque

from .types import (
    AieTile, TileType, DmaDir, DmaChannel, AieDma, OverlayShape,
    TransferParams, DataTransfer, DmaConnection,
    CoreInstr, Loop, AcqBuffer, RelBuffer, ConfigBuffer
)
from . import config


# Compute unique DMA channels for transfer params
def transfer_param_dmas(transfer_params: List[TransferParams]) -> List[AieDma]:
    return list(OrderedDict.fromkeys([param.dma for param in transfer_params]))


# Compute chains of transfer params that share the same DMA channel
def transfer_param_chains(transfer_params: List[TransferParams]) -> List[List[TransferParams]]:
    dmas = transfer_param_dmas(transfer_params)
    chains = [[param for param in transfer_params if param.dma == dma]
              for dma in dmas]
    return chains


class CoreTask:
    __slots__ = ('dma_channel', 'length')

    def __init__(
        self,
        dma_channel: DmaChannel,
        length: int
    ):
        assert length > 0
        self.dma_channel = dma_channel
        self.length = length

    # str conversion for debug logging

    def __str__(self) -> str:
        return f'(Channel: {self.dma_channel}, Length: {self.length})'


# This function converts the user specified core instructions to a
# list of core tasks to be used by the behavioral deadlock simulator.
# As part of conversion, we check that the user specified locking
# pattern is valid
def conv_to_core_tasks(core_instrs: List[Type[CoreInstr]]) -> List[CoreTask]:

    # Initialize all buffers as not acquired with undefined length
    buffer_lengths: Dict[DmaChannel, Optional[int]] = {}
    is_acquired: Dict[DmaChannel, bool] = {}
    for i in range(config.MAX_CORE_S2MM_DMA_CHANNEL + 1):
        buffer_lengths[DmaChannel(DmaDir.S2MM, i)] = None
        is_acquired[DmaChannel(DmaDir.S2MM, i)] = False
    for i in range(config.MAX_CORE_MM2S_DMA_CHANNEL + 1):
        buffer_lengths[DmaChannel(DmaDir.MM2S, i)] = None
        is_acquired[DmaChannel(DmaDir.MM2S, i)] = False

    def conv_to_core_tasks_rec(core_instrs: List[Type[CoreInstr]]) -> List[CoreTask]:

        tasks = []

        for instr in core_instrs:
            if isinstance(instr, Loop):

                for _ in range(instr.num_iters):
                    tasks += conv_to_core_tasks_rec(instr.loop_body)

            elif isinstance(instr, AcqBuffer):

                if buffer_lengths[instr.dma_channel] is None:
                    raise RuntimeError('Invalid attempt to acquire buffer with uninitialized config!')
                if is_acquired[instr.dma_channel]:
                    raise RuntimeError('Invalid attempt to acquire buffer multiple times!')

                is_acquired[instr.dma_channel] = True

                # Create a DMA write task when the core acquires an input buffer
                length = buffer_lengths[instr.dma_channel]
                if (instr.dma_channel.dir == DmaDir.S2MM) and (length > 0):
                    tasks.append(CoreTask(instr.dma_channel, length))

            elif isinstance(instr, RelBuffer):

                if buffer_lengths[instr.dma_channel] is None:
                    raise RuntimeError('Invalid attempt to release buffer with uninitialized config!')
                if not is_acquired[instr.dma_channel]:
                    raise RuntimeError('Invalid attempt to release buffer without acquire!')

                is_acquired[instr.dma_channel] = False

                # Create a DMA read task when the core releases an output buffer
                length = buffer_lengths[instr.dma_channel]
                if (instr.dma_channel.dir == DmaDir.MM2S) and (length > 0):
                    tasks.append(CoreTask(instr.dma_channel, length))

            elif isinstance(instr, ConfigBuffer):
                word_size = 4
                buffer_lengths[instr.dma_channel] = instr.buffer_size // word_size

        return tasks

    tasks = conv_to_core_tasks_rec(core_instrs)

    for i in range(config.MAX_CORE_S2MM_DMA_CHANNEL + 1):
        if is_acquired[DmaChannel(DmaDir.S2MM, i)]:
            raise RuntimeError(f'Invalid attempt to acquire buffer without release!')
    for i in range(config.MAX_CORE_MM2S_DMA_CHANNEL + 1):
        if is_acquired[DmaChannel(DmaDir.MM2S, i)]:
            raise RuntimeError(f'Invalid attempt to acquire buffer without release!')

    return tasks


# This class stores the simulation state for a single core
class CoreSim:
    __slots__ = ('aie_tile', 'core_tasks', 'task_index', 'pending_transfer_length')

    def __init__(
        self,
        aie_tile: AieTile,
        core_instrs: List[Type[CoreInstr]]
    ):
        self.aie_tile = aie_tile
        self.core_tasks = conv_to_core_tasks(core_instrs)
        self.task_index = 0
        self.pending_transfer_length = 0
        for task in self.core_tasks:
            self.pending_transfer_length += task.length

        assert aie_tile.type == TileType.Core

    def available_transfer_length(self, dma_channel: DmaChannel) -> int:
        length = 0
        task_available = ((self.task_index < len(self.core_tasks)) and
                            (dma_channel == self.core_tasks[self.task_index].dma_channel))
        if task_available:
            length = self.core_tasks[self.task_index].length
        return length

    def run_transfer(self, dma_channel: DmaChannel, length: int):
        assert self.task_index < len(self.core_tasks)
        assert dma_channel == self.core_tasks[self.task_index].dma_channel
        assert length <= self.core_tasks[self.task_index].length
        self.core_tasks[self.task_index].length -= length
        if self.core_tasks[self.task_index].length == 0:
            self.task_index += 1
        self.pending_transfer_length -= length


class DmaTask:
    __slots__ = ('repeat_count', 'transfer_length', 'curr_length', 'enable_locking', 'lock', 'acq_val', 'pending_tasks')

    def __init__(
        self,
        repeat_count: int,
        length: int
    ):
        assert repeat_count > 0
        assert length > 0
        self.repeat_count = repeat_count
        self.transfer_length = length
        self.curr_length = length
        self.enable_locking = False
        self.lock = 0
        self.acq_val = 0
        self.pending_tasks = []

    # str conversion for debug logging

    def __str__(self) -> str:
        return (f'(Repeat Count: {self.repeat_count}, ' +
                f'Transfer Length: {self.transfer_length}, ' +
                f'Pending Length: {self.curr_length}, ' +
                f'Enable Lock: {self.enable_locking}, ' +
                f'Lock: {self.lock}, ' +
                f'Acq Val: {self.acq_val})')


# This class stores the simulation state for a single DMA
class DmaSim:
    __slots__ = ('aie_tile', 'task_queues', 'pending_transfer_length')

    def __init__(
        self,
        aie_tile: AieTile
    ):
        self.aie_tile = aie_tile
        max_s2mm_dma_channel = (
            config.MAX_MEMTILE_S2MM_DMA_CHANNEL if self.aie_tile.type == TileType.Memtile else \
            config.MAX_SHIM_S2MM_DMA_CHANNEL
        )
        max_mm2s_dma_channel = (
            config.MAX_MEMTILE_MM2S_DMA_CHANNEL if self.aie_tile.type == TileType.Memtile else \
            config.MAX_SHIM_MM2S_DMA_CHANNEL
        )
        self.task_queues: Dict[DmaChannel, Deque[DmaTask]] = {}
        for i in range(max_s2mm_dma_channel + 1):
            self.task_queues[DmaChannel(DmaDir.S2MM, i)] = deque()
        for i in range(max_mm2s_dma_channel + 1):
            self.task_queues[DmaChannel(DmaDir.MM2S, i)] = deque()
        self.pending_transfer_length = 0

    def enqueue_task(self, channel: DmaChannel, task: DmaTask):
        self.task_queues[channel].append(task)
        self.pending_transfer_length += task.repeat_count * task.transfer_length

    def _channel_available(self, dma_channel: DmaChannel) -> bool:
        queue = self.task_queues[dma_channel]
        is_available = ((len(queue) > 0) and
                        ((not queue[0].enable_locking) or
                         (queue[0].lock >= queue[0].acq_val)))
        return is_available

    def available_transfer_length(self, dma_channel: DmaChannel) -> int:
        queue = self.task_queues[dma_channel]
        length = 0
        if self._channel_available(dma_channel):
            length = queue[0].curr_length
        return length

    def run_transfer(self, dma_channel: DmaChannel, length: int):
        queue = self.task_queues[dma_channel]
        assert self._channel_available(dma_channel)
        assert length <= queue[0].curr_length
        queue[0].curr_length -= length
        if queue[0].curr_length == 0:
            queue[0].repeat_count -= 1
            queue[0].curr_length = queue[0].transfer_length
            if queue[0].enable_locking:
                queue[0].lock = 0
                for task in queue[0].pending_tasks:
                    task.lock += 1
        if queue[0].repeat_count == 0:
            queue.popleft()
        self.pending_transfer_length -= length


class DmaChain:
    __slots__ = ('read_dma', 'write_dmas')

    def __init__(
        self,
        read_dma: AieDma,
        write_dmas: List[AieDma]
    ):
        self.read_dma = read_dma
        self.write_dmas = write_dmas

        assert self.read_dma.channel.dir == DmaDir.MM2S
        for dma in self.write_dmas:
            assert dma.channel.dir == DmaDir.S2MM

    def available_transfer_length(self, simulators: Dict[AieTile, Union[CoreSim, DmaSim]]) -> int:
        length = simulators[self.read_dma.tile].available_transfer_length(self.read_dma.channel)
        for dma in self.write_dmas:
            length = min(length, simulators[dma.tile].available_transfer_length(dma.channel))
        return length

    def run_transfer(self, simulators: Dict[AieTile, Union[CoreSim, DmaSim]], length: int):
        simulators[self.read_dma.tile].run_transfer(self.read_dma.channel, length)
        for dma in self.write_dmas:
            simulators[dma.tile].run_transfer(dma.channel, length)


def sim_enqueue_transfer(
    simulators: Dict[AieTile, Union[CoreSim, DmaSim]],
    iter: int,
    repeat_count: int,
    write_params: List[TransferParams],
    read_params: List[TransferParams],
    reuse_ratio: int
):
    assert repeat_count > 0
    assert len(read_params) + len(write_params) > 0
    write_param_chains = transfer_param_chains(write_params)
    read_param_chains  = transfer_param_chains(read_params)
    writers = [
        DmaTask(repeat_count, sum(param.length_i(iter) for param in chain))
        for chain in write_param_chains if any(param.length_i(iter) > 0 for param in chain)
    ]
    readers = [
        DmaTask(repeat_count, sum(param.length_i(iter) for param in chain) * reuse_ratio)
        for chain in read_param_chains if any(param.length_i(iter) > 0 for param in chain)
    ]
    if len(writers) > 0 and len(readers) > 0:
        for writer in writers:
            writer.enable_locking = True
            writer.lock = len(readers)
            writer.acq_val = len(readers)
            writer.pending_tasks = readers
        for reader in readers:
            reader.enable_locking = True
            reader.lock = 0
            reader.acq_val = len(writers)
            reader.pending_tasks = writers
    for writer, chain in zip(writers, write_param_chains):
        simulators[chain[0].dma.tile].enqueue_task(chain[0].dma.channel, writer)
    for reader, chain in zip(readers, read_param_chains):
        simulators[chain[0].dma.tile].enqueue_task(chain[0].dma.channel, reader)


def log_deadlock_report(
    shape: OverlayShape,
    simulators: Dict[AieTile, Union[CoreSim, DmaSim]],
    dma_chains: List[DmaChain]
) -> str:
    log = '''
Hardware Task Queue Report
--------------------------------
'''

    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        for row in range(shape.start_row, shape.start_row + shape.num_rows):
            tile = AieTile(TileType.Core, col, row)
            sim = simulators[tile]
            log += f'\n{tile}'
            queue = '\n    '.join([str(task) for task in sim.core_tasks[sim.task_index:]])
            if queue == '': queue = '<empty>'
            log += '\n    ' + queue
    log += '\n'

    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        tile = AieTile(TileType.Memtile, col, 0)
        sim = simulators[tile]
        for dir in (DmaDir.S2MM, DmaDir.MM2S):
            max_memtile_dma_channel = (
                config.MAX_MEMTILE_S2MM_DMA_CHANNEL if dir == DmaDir.S2MM else
                config.MAX_MEMTILE_MM2S_DMA_CHANNEL
            )
            for id in range(max_memtile_dma_channel + 1):
                channel = DmaChannel(dir, id)
                log += f'\n{tile} {channel}'
                queue = '\n    '.join([str(task) for task in sim.task_queues[channel]])
                if queue == '': queue = '<empty>'
                log += '\n    ' + queue
    log += '\n'

    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        tile = AieTile(TileType.Shim, col, 0)
        sim = simulators[tile]
        for dir in (DmaDir.S2MM, DmaDir.MM2S):
            max_shim_dma_channel = (
                config.MAX_SHIM_S2MM_DMA_CHANNEL if dir == DmaDir.S2MM else
                config.MAX_SHIM_MM2S_DMA_CHANNEL
            )
            for id in range(max_shim_dma_channel + 1):
                channel = DmaChannel(dir, id)
                log += f'\n{tile} {channel}'
                queue = '\n    '.join([str(task) for task in sim.task_queues[channel]])
                if queue == '': queue = '<empty>'
                log += '\n    ' + queue

    log += '''

Stalled Graph Connection Report
--------------------------------
'''
    for chain in dma_chains:
        is_stalled = simulators[chain.read_dma.tile].available_transfer_length(chain.read_dma.channel) > 0
        for dma in chain.write_dmas:
            if simulators[dma.tile].available_transfer_length(dma.channel) > 0:
                is_stalled = True
        if is_stalled:
            log += f'\n{chain.read_dma} -> '
            log += '('
            log += ', '.join([str(dma) for dma in chain.write_dmas])
            log += ')'

    return log


def tile_in_bounds(shape: OverlayShape, tile: AieTile) -> bool:
    return ((shape.start_col <= tile.col <= shape.start_col + shape.num_cols) and
            (((tile.type == TileType.Core) and
              (shape.start_row <= tile.row <= shape.start_row + shape.num_rows)) or
             ((tile.type in (TileType.Memtile, TileType.Shim)) and
              (tile.row == 0))))


def check_data_transfers_valid(
    shape: OverlayShape,
    memtile_transfers: List[DataTransfer],
    shim_transfers: List[DataTransfer]
):
    for transfer in memtile_transfers:
        if transfer.tile.type != TileType.Memtile:
            raise ValueError('Invalid tile type for memtile transfer!')
    for transfer in shim_transfers:
        if transfer.tile.type != TileType.Shim:
            raise ValueError('Invalid tile type for shim transfer!')

    data_transfers = memtile_transfers + shim_transfers
    if len(data_transfers) == 0:
        raise ValueError('Memtile or shim transfer list must be non-empty!')

    num_iters = len(data_transfers[0].repeat_counts)
    for transfer in data_transfers:
        if len(transfer.repeat_counts) != num_iters:
            raise ValueError('All transfers must have the same length for repeat count list!')
        if not tile_in_bounds(shape, transfer.tile):
            raise ValueError('Invalid tile location!')


def check_dma_connections_valid(
    shape: OverlayShape,
    dma_connections: List[DmaConnection]
):
    for cxn in dma_connections:
        if not (tile_in_bounds(shape, cxn.read_dma.tile) and tile_in_bounds(shape, cxn.write_dma.tile)):
            raise ValueError('Invalid tile location!')


def run_overlay_deadlock_check(
    shape: OverlayShape,
    core_instrs: Union[List[Type[CoreInstr]], Dict[AieTile, List[Type[CoreInstr]]]],
    memtile_transfers: List[DataTransfer],
    shim_transfers: List[DataTransfer],
    dma_connections: List[DmaConnection],
    param_channel_id: int = 0
):
    param_instrs = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, param_channel_id), 0x0, None, config.MAX_CORE_LAYER_PARAM_SIZE),
        AcqBuffer(DmaChannel(DmaDir.S2MM, param_channel_id)),
        RelBuffer(DmaChannel(DmaDir.S2MM, param_channel_id)),
    ]

    check_data_transfers_valid(shape, memtile_transfers, shim_transfers)
    check_dma_connections_valid(shape, dma_connections)

    simulators: Dict[AieTile, Union[CoreSim, DmaSim]] = {}
    if isinstance(core_instrs, dict):
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            for row in range(shape.start_row, shape.start_row + shape.num_rows):
                simulators[AieTile(TileType.Core, col, row)] = CoreSim(AieTile(TileType.Core, col, row),
                                                                       param_instrs + core_instrs[AieTile(TileType.Core, col, row)])
    else:
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            for row in range(shape.start_row, shape.start_row + shape.num_rows):
                simulators[AieTile(TileType.Core, col, row)] = CoreSim(AieTile(TileType.Core, col, row),
                                                                       param_instrs + core_instrs)
    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        simulators[AieTile(TileType.Memtile, col, 0)] = DmaSim(AieTile(TileType.Memtile, col, 0))
        simulators[AieTile(TileType.Shim, col, 0)] = DmaSim(AieTile(TileType.Shim, col, 0))

    read_dmas = list(OrderedDict.fromkeys([cxn.read_dma for cxn in dma_connections]))
    dma_chains = [
        DmaChain(read_dma, [cxn.write_dma
                            for cxn in dma_connections
                            if cxn.read_dma == read_dma])
        for read_dma in read_dmas
    ]

    data_transfers = memtile_transfers + shim_transfers
    num_iters = len(data_transfers[0].repeat_counts)

    for i in range(num_iters):

        for transfer in data_transfers:
            if transfer.repeat_counts[i] > 0:
                for param in (transfer.write_params + transfer.read_params):
                    if len(simulators[param.dma.tile].task_queues[param.dma.channel]) > 0:
                        log = log_deadlock_report(shape, simulators, dma_chains)
                        print(log)
                        raise RuntimeError('DMA deadlock detected!')

        for transfer in data_transfers:
            if transfer.repeat_counts[i] > 0:
                sim_enqueue_transfer(
                    simulators,
                    i,
                    transfer.repeat_counts[i],
                    transfer.write_params,
                    transfer.read_params,
                    transfer.reuse_ratio,
                )

        while True:
            is_forward_progress = False
            for chain in dma_chains:
                transfer_length = chain.available_transfer_length(simulators)
                if transfer_length > 0:
                    chain.run_transfer(simulators, transfer_length)
                    is_forward_progress = True
            if not is_forward_progress:
                break


    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        for row in range(shape.start_row, shape.start_row + shape.num_rows):
            if simulators[AieTile(TileType.Core, col, row)].pending_transfer_length != 0:
                log = log_deadlock_report(shape, simulators, dma_chains)
                print(log)
                raise RuntimeError('Core failed to complete!')

    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        if simulators[AieTile(TileType.Memtile, col, 0)].pending_transfer_length != 0:
            log = log_deadlock_report(shape, simulators, dma_chains)
            print(log)
            raise RuntimeError('Memtile failed to complete!')

    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        if simulators[AieTile(TileType.Shim, col, 0)].pending_transfer_length != 0:
            log = log_deadlock_report(shape, simulators, dma_chains)
            print(log)
            raise RuntimeError('Shim failed to complete!')
