"""Dataflow for Resize (NNI)."""
from typing import Tuple, List, no_type_check

from utils.utils_common import (
    overlay_3x4_core_stack_addr,
    log,
)
from scheduler.common import (
    L3Alloc,
    L3Alloc_to_Shim,
    overlay_3x4_F_ids, overlay_3x4_A_ids,
    overlay_3x4_S_ids, overlay_3x4_dma_connections,
    prm_shim_memory,
    prm_shim_mm2s, prm_memtile_memory,
    prm_memtile_s2mm, prm_memtile_mm2s
)
from buildscripts.common import ScheduleInputs

from dmacompiler import (
    DevGen, set_dev_gen, config,
    OverlayShape,
    DataTransfer, SyncStrategy,
    AieTile, TileType, DmaChannel,
    DmaDir, memtile_dma, shim_dma,
    ConfigBuffer, AcqBuffer, RelBuffer,
    compute_buffer_size,
    generate_transfer_params,
    generate_shim_data_transfer,
    run_layer_compilation,
    pack_reconfig_transfers
)

from tiler.resize_tiler import ResizeMapping

set_dev_gen(DevGen.Aie4)


def Yi_slice(dims: ResizeMapping, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    """Function to do Yi_slice"""
    Yi_split = dims.Yos
    Yi_stride = dims.aie_cols * Yi_split
    Yi_start = (col * Yi_split) + (start_iter * Yi_stride)
    Yi_stop = Yi_start + dims.Yis if Yi_start <= dims.Yi else Yi_start
    Yi_size = max(0, min(Yi_stop, dims.Yi)) - max(0, min(Yi_start, dims.Yi))
    return (Yi_start, Yi_stop, Yi_stride, Yi_size)


def Yi_split_iters(dims: ResizeMapping, col: int) -> List[Tuple[int, int]]:
    """Function to do Yi_split_iters"""
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        Yi_start, Yi_stop, _, _ = Yi_slice(dims, col, start_iter + num_iters - 1)
        has_no_padding = not ((Yi_start < 0) or (Yi_stop > dims.Yi))
        return has_no_padding
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split


def Yo_slice(dims: ResizeMapping, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    """Function to do Yo_slice"""
    Yo_stride = dims.aie_cols * dims.Yos
    Yo_start = (col * dims.Yos) + (start_iter * Yo_stride)
    Yo_stop = min(Yo_start + dims.Yos, dims.Yo) if Yo_start <= dims.Yo else Yo_start
    Yo_size = Yo_stop - Yo_start
    return (Yo_start, Yo_stop, Yo_stride, Yo_size)


def ifm_shim_repeat_counts(y_loop: int) -> List[int]:
    """Function to do ifm_shim_repeat_counts"""
    repeat_counts = [0 for _ in range(y_loop)]
    repeat_counts[0] = 1
    return repeat_counts


def Yi_repeat_counts(dims: ResizeMapping, col: int) -> List[int]:
    """Function to do Yi_repeat_counts"""
    repeat_counts = [0 for _ in range(dims.Y_loop * dims.scale_X * dims.scale_Y)]
    _, num_iters = Yi_split_iters(dims, col)[0]
    for i in range(dims.Y_loop):
        repeat_counts[i * dims.scale_X * dims.scale_Y] = 1 if i < num_iters else 0
    return repeat_counts


def ifm_shim_memory(dims: ResizeMapping) -> str:
    """Function to do ifm_shim_memory"""
    return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'


def ifm_shim_mm2s(dims: ResizeMapping, col: int) -> List[str]:
    """Function to do ifm_shim_mm2s"""
    def fmt(start_iter: int, num_iters: int) -> str:
        Yi_start, Yi_stop, Yi_stride, _ = Yi_slice(dims, col, start_iter)
        return (
            f'Xi:0:{dims.Xi}:{dims.Xi} '
            f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
            f'Ci:0:{dims.Ci}:{dims.Ci} '
            f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Xi:0:{dims.Xi} Ci:0:{dims.Ci}'
        )
    fs = [fmt(s, n) for s, n in Yi_split_iters(dims, col)]
    return fs


def ifm_memtile_memory(dims: ResizeMapping, col: int) -> List[str]:
    """Function to do ifm_memtile_memory"""
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        if Yi_size <= 0:
            Yi_size = dims.Yis
        return f'Yi:{Yi_size} Xi:{dims.Xi} Ci:{dims.Ci}'
    fs = [
        fmt(s)
        for s, n in Yi_split_iters(dims, col)
        for _ in range(n * dims.scale_Y * dims.scale_X)
        ]
    return fs


def ifm_memtile_s2mm(dims: ResizeMapping, col: int) -> List[str]:
    """Function to do ifm_memtile_s2mm"""
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        return f'Yi:0:{Yi_size} Xi:0:{dims.Xi} Ci:0:{dims.Ci}'
    fs = [
        fmt(s)
        for s, n in Yi_split_iters(dims, col)
        for _ in range(n * dims.scale_Y * dims.scale_X)
        ]
    return fs


def ofm_shim_memory(dims: ResizeMapping) -> str:
    """Function to do ofm_shim_memory"""
    return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'


def ofm_shim_s2mm(dims: ResizeMapping, col: int) -> List[str]:
    """Function to do ofm_shim_s2mm"""
    def fmt(start_iter: int, num_iters: int, i: int, j: int) -> str:
        _ = num_iters
        Yo_start, _, _, _ = Yo_slice(dims, col, start_iter)
        Yo_start = Yo_start * dims.scale_Y

        Yo_start_final = min(dims.Yo, Yo_start + i * dims.Yis)
        Yo_stop_final = min(dims.Yo, Yo_start_final + dims.Yis)
        return (
            f'Xo:0:{dims.Xo}:{dims.Xo} '
            f'Co:0:{dims.Co}:{dims.Co} '
            f'Yo:{Yo_start_final}:{Yo_stop_final} Xo:{j % dims.scale_X}:{dims.Xo}:{dims.scale_X} Co:0:{dims.Co}'

        )
    fs = [fmt(iter, 1, i, j) for iter in range(dims.Y_loop) for i in range(dims.scale_Y) for j in range(dims.scale_X)]
    return fs


# NOTE: Mypy doesn't correctly infer types with the abstract base construct
# used by the core instruction list, so we disable type checking locally here
@no_type_check
def compile_dataflow(schedule_input: ScheduleInputs):
    '''Compile dataflow with given mapping information'''
    dims: ResizeMapping = schedule_input.mapping
    L3_alloc: L3Alloc | None = schedule_input.L3_alloc
    shim_alloc = L3Alloc_to_Shim(L3_alloc)
    log(f"Shim Allocator: {shim_alloc}")

    prm_memtile_addr = 0
    prm_memtile_size = dims.aie_rows * config.MAX_CORE_LAYER_PARAM_SIZE

    ifm_ping_addr = prm_memtile_addr + prm_memtile_size
    ifm_memtile_size = dims.Nis * dims.Yis * dims.Xis * dims.Cis * dims.ifm_bytes

    ofm_shim_size = dims.Yo * dims.Xo * dims.Co * dims.ifm_bytes

    core_instrs = [ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), 0, 0, 0),
                   AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                   RelBuffer(DmaChannel(DmaDir.S2MM, 0))]

    interpolation_repeat = dims.Y_loop * dims.scale_Y * dims.scale_X

    memtile_transfers = []
    prm_memtile_transfers = [
        DataTransfer(
            [1] + [0] * (interpolation_repeat - 1),
            AieTile(TileType.Memtile, col),
            [prm_memtile_addr], prm_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                prm_memtile_memory(),
                prm_memtile_s2mm())],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                prm_memtile_memory(),
                prm_memtile_mm2s(row)) for row in range(dims.aie_rows)],
            sync_strategy=SyncStrategy.Serial_M_to_N
            ) for col in range(dims.aie_cols)
        ]
    memtile_transfers += prm_memtile_transfers

    ifm_memtile_transfers = [
        DataTransfer(
            Yi_repeat_counts(dims, col),
            AieTile(TileType.Memtile, col),
            [ifm_ping_addr], ifm_memtile_size,
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                ifm_memtile_memory(dims, col),
                ifm_memtile_s2mm(dims, col),
                [1 for _ in range(interpolation_repeat)],
                dims.ifm_bits)],
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[0]),
                ifm_memtile_memory(dims, col),
                ifm_memtile_s2mm(dims, col),
                [1 for _ in range(interpolation_repeat)],
                dims.ifm_bits)],
            reuse_ratio=dims.scale_Y*dims.scale_X
            ) for col in range(dims.aie_cols)
        ]
    memtile_transfers += ifm_memtile_transfers

    shim_transfers = []
    prm_shim_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (interpolation_repeat - 1),
            shim_dma(col, DmaDir.MM2S, 0),
            shim_alloc.prm_xrt_idx,
            prm_shim_memory(),
            prm_shim_mm2s(col),
            buffer_offset=shim_alloc.prm_xrt_offset
            ) for col in range(dims.aie_cols)
        ]
    shim_transfers += prm_shim_transfers

    ifm_shim_transfers = [
        generate_shim_data_transfer(
            ifm_shim_repeat_counts(interpolation_repeat),
            shim_dma(col, DmaDir.MM2S, 0),
            shim_alloc.ifm_xrt_idx,
            ifm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
            buffer_offset=shim_alloc.ifm_xrt_offset,
        ) for col in range(dims.aie_cols) for _, fmt in enumerate(ifm_shim_mm2s(dims, col))
    ]
    shim_transfers += ifm_shim_transfers

    ofm_shim_transfers = [
        DataTransfer(
            [1] * interpolation_repeat,
            AieTile(TileType.Shim, col), [shim_alloc.ofm_xrt_idx], ofm_shim_size,
            [pack_reconfig_transfers(
                shim_dma(col, DmaDir.S2MM, 0),
                [ofm_shim_memory(dims)] * interpolation_repeat,
                ofm_shim_s2mm(dims, col),
                [1] * interpolation_repeat,
                dims.ofm_bits,
                buffer_offset=[shim_alloc.ofm_xrt_offset],
            )],
            [],
        ) for col in range(dims.aie_cols)
    ]
    shim_transfers += ofm_shim_transfers

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        schedule_input.kernel_names,
        schedule_input.kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_3x4_dma_connections(),
        back_end=schedule_input.backend,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        param_channel_id=0,
        layer_file=schedule_input.layer_file_name,
        dma_padding_map=schedule_input.dma_pad,
    )

    shim_prm_size = compute_buffer_size(prm_shim_memory())
    shim_wgt_size = 4

    shim_prm_offset_next_layer = shim_alloc.prm_xrt_offset + shim_prm_size
    shim_wgt_offset_next_layer = shim_alloc.wgt_xrt_offset + shim_wgt_size

    return shim_prm_offset_next_layer, shim_wgt_offset_next_layer
