"""Compile GAP L2 dataflow for AIE4."""

from dmacompiler import (
    DevGen,
    OverlayShape,
    DataTransfer,
    SyncStrategy,
    generate_transfer_params,
    DmaChannel,
    DmaDir,
    AieTile,
    TileType,
    memtile_dma,
    shim_dma,
    Loop,
    ConfigBuffer,
    AcqBuffer,
    RelBuffer,
    CallKernel,
    run_layer_compilation,
    set_dev_gen,
)
from utils.utils_common import (
    overlay_3x4_core_stack_addr,
    iceil, log,
)
from scheduler.common import (
    overlay_3x4_dma_connections,
    prm_memtile_memory,
    prm_shim_memory,
    prm_memtile_s2mm,
    prm_memtile_mm2s,
    prm_shim_mm2s,
    overlay_3x4_A_ids,
    overlay_3x4_F_ids,
    overlay_3x4_O_ids,
    overlay_3x4_S_ids,
    L3Alloc, L3Alloc_to_Shim,
)
from scheduler.gap.gap_comman import GAPDims, gen_aie4_gap_params
from buildscripts.common import ScheduleInputs
set_dev_gen(DevGen.Aie4)


def gap_memtile_ifm_memory(dims: GAPDims) -> str:
    """Generate memory layout string for memtile IFM memory."""
    return f"Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}"


def gap_memtile_ifm_s2mm(dims: GAPDims, _col: int) -> str:
    """Generate S2MM data layout string for memtile IFM."""
    return f"Yi:0:{dims.Yi} Xi:0:{dims.Xi} Ci:0:{dims.Ci}"


def gap_memtile_ifm_mm2s(dims: GAPDims, row: int) -> str:
    """Generate MM2S data layout string for memtile IFM."""
    return (
        f"Yi:0:{dims.Yi} "
        f"Ci:{dims.Cis*row}:{dims.Cis*(row+1)}:{dims.C_g} "
        f"Xi:0:{dims.Xi} Ci:0:{dims.C_g}"
    )


def gap_memtile_ofm_memory(dims: GAPDims) -> str:
    """Generate memory layout string for memtile OFM memory."""
    return f"Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}"


def gap_memtile_ofm_s2mm(dims: GAPDims, row: int) -> str:
    """Generate S2MM data layout string for memtile OFM."""
    return (
        f"Yo:0:{dims.Yo} "
        f"Co:{dims.Cos*row}:{dims.Cos*(row+1)}:{dims.C_g} "
        f"Xo:0:{dims.Xo} Co:0:{dims.C_g}"
    )


def gap_memtile_ofm_mm2s(dims: GAPDims, _col: int) -> str:
    """Generate MM2S data layout string for memtile OFM."""
    return f"Yo:0:{dims.Yo} Xo:0:{dims.Xo} Co:0:{dims.Co}"


def gap_shimtile_ifm_memory(dims: GAPDims) -> str:
    """Generate memory layout string for shimtile IFM memory."""
    return f"Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}"


def gap_shimtile_ifm_mm2s(dims: GAPDims, _col: int) -> str:
    """Generate MM2S data layout string for shimtile IFM."""
    return f"Yi:0:{dims.Yi} Xi:0:{dims.Xi} Ci:0:{dims.Ci}"


def gap_shimtile_ofm_memory(dims: GAPDims) -> str:
    """Generate memory layout string for shimtile OFM memory."""
    return f"Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}"


def gap_shimtile_ofm_s2mm(dims: GAPDims, _col: int) -> str:
    """Generate S2MM data layout string for shimtile OFM."""
    return f"Yo:0:{dims.Yo} Xo:0:{dims.Xo} Co:0:{dims.Co}"


def compile_gap_dataflow_l2(schedule_input: ScheduleInputs):
    """
    Compile GAP dataflow for AIE4.

    Args:
        dims: GAP dimensions and parameters
    """
    dims: GAPDims = schedule_input.shape
    L3_alloc: L3Alloc | None = schedule_input.L3_alloc
    shim_alloc = L3Alloc_to_Shim(L3_alloc)
    log(f"Shim Allocator: {shim_alloc}")
    core_param_s2mm_channel = 0
    core_ifm_s2mm_channel = 0
    core_ofm_mm2s_channel = 0

    shim_param_mm2s_channel = 0
    shim_ifm_mm2s_channel = 1
    shim_ofm_s2mm_channel = 0

    prm_buffer_id = 3
    ifm_buffer_id = 1
    ofm_buffer_id = 0

    CoreParamSize = dims.prm_size
    CoreIfmSize = dims.Yis * dims.Xis * dims.Cis * dims.act_bits // dims.bits_per_byte
    CoreOfmSize = dims.Yos * dims.Xos * dims.Cos * dims.out_bits // dims.bits_per_byte
    # CoreTdmSize = CoreOfmSize

    MemtileParamSize = CoreParamSize * dims.aie_rows
    MemtileIfmSize = CoreIfmSize * dims.aie_rows
    MemtileOfmSize = CoreOfmSize * dims.aie_rows

    ShimParamSize = MemtileParamSize
    ShimIfmSize = MemtileIfmSize
    ShimOfmSize = MemtileOfmSize

    MemtileParamAddr = 0
    MemtileIfmPingAddr = MemtileParamAddr + MemtileParamSize
    MemtileOfmPingAddr = MemtileIfmPingAddr + MemtileIfmSize

    if dims.Pad:
        CoreIfmSize = (
            dims.Yis * dims.Xis * dims.Pad_Cis * dims.act_bits // dims.bits_per_byte
        )

    CoreIfmPingAddr = 0
    CoreOfmPingAddr = iceil(CoreIfmPingAddr + CoreIfmSize, dims.align_l1)
    CoreTdmPingAddr = iceil(CoreOfmPingAddr + CoreOfmSize, dims.align_l1)

    fusion_param = dims.fusion_param

    ifm_L2_tile, MemtileIfmPingAddr = fusion_param.ifm_L2_loc
    ifm_L2_tile = ifm_L2_tile.col
    log(f"MemtileIfmPingAddr: {MemtileIfmPingAddr}")
    log(f"ifm_L2_tile: {ifm_L2_tile}")
    prm_L2_alloc_tiles = [entry[0] for entry in fusion_param.prm_l2_loc]  # Use the keys as the tile index
    prm_L2_addrs = [entry[1] for entry in fusion_param.prm_l2_loc]
    log(f"prm_L2_alloc_tiles: {prm_L2_alloc_tiles}")
    log(f"prm_L2_addrs: {prm_L2_addrs}")

    ofm_L2_tile, MemtileOfmPingAddr = fusion_param.ofm_L2_loc
    ofm_L2_tile = ofm_L2_tile.col
    log(f"MemtileOfmPingAddr: {MemtileOfmPingAddr}")
    log(f"ofm_L2_tile: {ofm_L2_tile}")

    def _as_int(v):
        if isinstance(v, (list, tuple)):
            return _as_int(v[0]) if v else 0
        return int(v)

    MemtileParamAddr = _as_int(MemtileParamAddr)
    MemtileParamSize = _as_int(MemtileParamSize)

    def get_core_instrs(
        dims: GAPDims, _core_col_id: int, _core_row_id: int
    ) -> list | None:
        if dims.Y == 1:
            return [
                ConfigBuffer(
                    DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel),
                    CoreOfmPingAddr,
                    None,
                    CoreOfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
                ConfigBuffer(
                    DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel),
                    CoreIfmPingAddr,
                    None,
                    CoreIfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                CallKernel(
                    "run_globalavgpool_int8x8",
                    gen_aie4_gap_params(dims, 0, 0, CoreTdmPingAddr),
                ),
                RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                RelBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
            ]

        if dims.Y > 1:
            return [
                ConfigBuffer(
                    DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel),
                    CoreOfmPingAddr,
                    None,
                    CoreOfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
                ConfigBuffer(
                    DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel),
                    CoreIfmPingAddr,
                    None,
                    CoreIfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                CallKernel(
                    "run_globalavgpool_int8x8",
                    gen_aie4_gap_params(dims, 1, 0, CoreTdmPingAddr),
                ),
                RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                Loop(
                    dims.Y - 2,
                    [
                        ConfigBuffer(
                            DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel),
                            CoreIfmPingAddr,
                            None,
                            CoreIfmSize,
                        ),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                        CallKernel(
                            "run_globalavgpool_int8x8",
                            gen_aie4_gap_params(dims, 0, 1, CoreTdmPingAddr),
                        ),
                        RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                    ],
                ),
                ConfigBuffer(
                    DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel),
                    CoreIfmPingAddr,
                    None,
                    CoreIfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                CallKernel(
                    "run_globalavgpool_int8x8",
                    gen_aie4_gap_params(dims, 1, 1, CoreTdmPingAddr),
                ),
                RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                RelBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
            ]
        return None

    instr_dict: dict[AieTile, list] = {}
    for col in range(dims.aie_cols):
        for row in range(dims.aie_rows):
            key = AieTile(TileType.Core, col, row)
            instr_dict[key] = (
                get_core_instrs(dims, col, row) if col == ifm_L2_tile else []
            )

    memtile_transfers = [
        DataTransfer(
            [1],
            prm_L2_alloc_tiles[col],
            [prm_L2_addrs[col]],
            MemtileParamSize,
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                    prm_memtile_memory(),
                    prm_memtile_s2mm(),
                    dims.prm_bits,
                )
            ],
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                    prm_memtile_memory(),
                    prm_memtile_mm2s(row),
                    dims.prm_bits,
                )
                for row in range(dims.aie_rows)
            ],
        )
        for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col),
            [MemtileIfmPingAddr],
            MemtileIfmSize,
            (
                [
                    generate_transfer_params(
                        memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]),
                        gap_memtile_ifm_memory(dims),
                        gap_memtile_ifm_s2mm(dims, col),
                        dims.act_bits,
                    )
                ]
                if fusion_param.enable_ifm_fill
                else []
            ),
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                    gap_memtile_ifm_memory(dims),
                    gap_memtile_ifm_mm2s(dims, row),
                    dims.act_bits,
                    enable_padding=True,
                )
                for row in range(dims.aie_rows)
            ],
            sync_strategy=(
                SyncStrategy.Parallel_1_to_N
                if fusion_param.enable_ifm_fill
                else SyncStrategy.Default
            ),
        )
        for col in [ifm_L2_tile]
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col),
            [MemtileOfmPingAddr],
            MemtileOfmSize,
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_O_ids()[row]),
                    gap_memtile_ofm_memory(dims),
                    gap_memtile_ofm_s2mm(dims, row),
                    dims.out_bits,
                )
                for row in range(dims.aie_rows)
            ],
            (
                [
                    generate_transfer_params(
                        memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[0]),
                        gap_memtile_ofm_memory(dims),
                        gap_memtile_ofm_mm2s(dims, col),
                        dims.out_bits,
                    )
                ]
                if fusion_param.enable_ofm_spill
                else []
            ),
            sync_strategy=(
                SyncStrategy.Parallel_N_to_1
                if fusion_param.enable_ofm_spill
                else SyncStrategy.Default
            ),
        )
        for col in [ofm_L2_tile]
    ]

    shim_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col),
            [prm_buffer_id],
            ShimParamSize,
            [],
            [
                generate_transfer_params(
                    shim_dma(col, DmaDir.MM2S, shim_param_mm2s_channel),
                    prm_shim_memory(),
                    prm_shim_mm2s(col),
                    dims.prm_bits,
                    buffer_offset=shim_alloc.prm_xrt_offset,
                )
            ],
        )
        for col in range(dims.aie_cols)
    ] + (
        [
            DataTransfer(
                [1],
                AieTile(TileType.Shim, col),
                [ifm_buffer_id],
                ShimIfmSize,
                [],
                [
                    generate_transfer_params(
                        shim_dma(col, DmaDir.MM2S, shim_ifm_mm2s_channel),
                        gap_shimtile_ifm_memory(dims),
                        gap_shimtile_ifm_mm2s(dims, col),
                        dims.act_bits,
                    )
                ],
            )
            for col in [ifm_L2_tile]
        ]
        if fusion_param.enable_ifm_fill
        else []
    ) + (
        [
            DataTransfer(
                [1],
                AieTile(TileType.Shim, col),
                [ofm_buffer_id],
                ShimOfmSize,
                [
                    generate_transfer_params(
                        shim_dma(col, DmaDir.S2MM, shim_ofm_s2mm_channel),
                        gap_shimtile_ofm_memory(dims),
                        gap_shimtile_ofm_s2mm(dims, col),
                        dims.out_bits,
                    )
                ],
                [],
            )
            for col in [ifm_L2_tile]
        ]
        if fusion_param.enable_ofm_spill
        else []
    )

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        schedule_input.kernel_names,
        schedule_input.kernel_includes,
        instr_dict,
        memtile_transfers,
        shim_transfers,
        overlay_3x4_dma_connections(),
        back_end=schedule_input.backend,
        param_channel_id=core_param_s2mm_channel,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        layer_file=schedule_input.layer_file_name,
        dma_padding_map=schedule_input.dma_pad,
    )

    shim_prm_offset_next_layer = shim_alloc.prm_xrt_offset + (ShimParamSize*dims.aie_cols)
    shim_wgt_offset_next_layer = shim_alloc.wgt_xrt_offset + 64

    return shim_prm_offset_next_layer, shim_wgt_offset_next_layer
