"""Compile GAP L3 dataflow for AIE4."""

from dmacompiler import (
    DevGen,
    OverlayShape,
    DataTransfer,
    SyncStrategy,
    generate_transfer_params,
    DmaChannel,
    DmaDir,
    AieTile,
    TileType,
    memtile_dma,
    shim_dma,
    Loop,
    ConfigBuffer,
    AcqBuffer,
    RelBuffer,
    CallKernel,
    run_layer_compilation,
    set_dev_gen,
)

from utils.utils_common import (
    overlay_3x4_core_stack_addr,
    iceil,
)
from scheduler.common import (
    overlay_3x4_dma_connections,
    prm_memtile_memory,
    prm_shim_memory,
    prm_memtile_s2mm,
    prm_memtile_mm2s,
    prm_shim_mm2s,
    overlay_3x4_A_ids,
    overlay_3x4_F_ids,
    overlay_3x4_O_ids,
    overlay_3x4_S_ids,
)
from scheduler.gap.gap_comman import GAPDims, gen_aie4_gap_params

set_dev_gen(DevGen.Aie4)


def gap_memtile_ifm_memory(dims: GAPDims) -> str:
    """Generate memory layout string for memtile IFM memory."""
    return f"Ci:{dims.aie_rows * dims.Cis} Yi:{dims.Yi}  Ci:{dims.Cis} Xi:{dims.Xi} Ci:{dims.C_g}"


def gap_memtile_ifm_s2mm(dims: GAPDims, col: int) -> str:
    """Generate S2MM data layout string for memtile IFM."""
    if col < dims.aie_cols_used:
        return (
            f"Ci:0:{dims.aie_rows * dims.Cis}:{dims.Cis} "
            f"Yi:0:{dims.Yi} Xi:0:{dims.Xi} Ci:0:{dims.Cis}"
        )
    return ""


def gap_memtile_ifm_mm2s(dims: GAPDims, row: int) -> str:
    """Generate MM2S data layout string for memtile IFM."""
    if dims.Pad:
        return f"Yi:0:{dims.Yi} Ci:0:{dims.Pad_Cis}:{dims.C_g} Xi:0:{dims.Xi} Ci:0:{dims.C_g}"
    return (
        f"Ci:{row * dims.Cis}:{(row + 1) * dims.Cis}:{dims.Cis} Yi:0:{dims.Yi}:{dims.Yis} "
        f"Yi:0:{dims.Yis} Ci:0:{dims.Cis}:{dims.C_g} Xi:0:{dims.Xi} Ci:0:{dims.C_g}"
    )


def gap_memtile_ofm_memory(dims: GAPDims) -> str:
    """Generate memory layout string for memtile OFM memory."""
    return f"Yo:{dims.Yo} Co:{dims.aie_rows * dims.Cos} Xo:{dims.Xo} Co:{dims.Cos}"


def gap_memtile_ofm_s2mm(dims: GAPDims, row: int) -> str:
    """Generate S2MM data layout string for memtile OFM."""
    return (
        f"Yo:0:{dims.Yo} Co:{row * dims.Cos}:{(row + 1) * dims.Cos}:{dims.Cos} "
        f"Xo:0:{dims.Xo} Co:0:{dims.Cos}"
    )


def gap_memtile_ofm_mm2s(dims: GAPDims, col: int) -> str:
    """Generate MM2S data layout string for memtile OFM."""
    if col < dims.aie_cols_used:
        return (
            f"Yo:0:{dims.Yo} Co:0:{dims.aie_rows * dims.Cos}:{dims.Cos} "
            f"Xo:0:{dims.Xo} Co:0:{dims.Cos}"
        )
    return ""


def gap_shimtile_ifm_memory(dims: GAPDims) -> str:
    """Generate memory layout string for shimtile IFM memory."""
    return f"Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}"


def gap_shimtile_ifm_mm2s(dims: GAPDims, col: int) -> str:
    """Generate MM2S data layout string for shimtile IFM."""
    if col < dims.aie_cols_used:
        ci_chunk = dims.Ci // dims.aie_cols_used
        return (
            f"Ci:{col * ci_chunk}:{(col + 1) * ci_chunk}:{dims.Cis} Yi:0:{dims.Yi} "
            f"Xi:0:{dims.Xi} Ci:0:{dims.Cis}"
        )
    return ""


def gap_shimtile_ofm_memory(dims: GAPDims) -> str:
    """Generate memory layout string for shimtile OFM memory."""
    return f"Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}"


def gap_shimtile_ofm_s2mm(dims: GAPDims, col: int) -> str:
    """Generate S2MM data layout string for shimtile OFM."""
    if col < dims.aie_cols_used:
        co_chunk = dims.Co // dims.aie_cols_used
        return (
            f"Co:{col * co_chunk}:{(col + 1) * co_chunk}:{dims.Cos} Yo:0:{dims.Yo} "
            f"Xo:0:{dims.Xo} Co:0:{dims.Cos}"
        )
    return ""


def compile_gap_dataflow_l3(dims: GAPDims) -> None:
    """
    Compile GAP dataflow for AIE4.

    Args:
        dims: GAP dimensions and parameters
    """
    core_param_s2mm_channel = 0
    core_ifm_s2mm_channel = 0
    core_ofm_mm2s_channel = 0

    shim_param_mm2s_channel = 0
    shim_ifm_mm2s_channel = 1
    shim_ofm_s2mm_channel = 0

    prm_buffer_id = 3
    ifm_buffer_id = 1
    ofm_buffer_id = 0

    CoreParamSize = dims.prm_size
    CoreIfmSize = dims.Yis * dims.Xis * dims.Cis * dims.act_bits // dims.bits_per_byte
    CoreOfmSize = dims.Yos * dims.Xos * dims.Cos * dims.out_bits // dims.bits_per_byte
    # CoreTdmSize = CoreOfmSize

    MemtileParamSize = CoreParamSize * dims.aie_rows
    MemtileIfmSize = CoreIfmSize * dims.aie_rows
    MemtileOfmSize = CoreOfmSize * dims.aie_rows

    ShimParamSize = MemtileParamSize
    ShimIfmSize = MemtileIfmSize
    ShimOfmSize = MemtileOfmSize

    MemtileParamAddr = 0
    MemtileIfmPingAddr = MemtileParamAddr + MemtileParamSize
    MemtileOfmPingAddr = MemtileIfmPingAddr + MemtileIfmSize

    if dims.Pad:
        CoreIfmSize = dims.Yis * dims.Xis * dims.Pad_Cis * dims.act_bits // dims.bits_per_byte

    CoreIfmPingAddr = 0
    CoreOfmPingAddr = iceil(CoreIfmPingAddr + CoreIfmSize, dims.align_l1)
    CoreTdmPingAddr = iceil(CoreOfmPingAddr + CoreOfmSize, dims.align_l1)

    def get_core_instrs(dims: GAPDims, _core_col_id: int, _core_row_id: int) -> list:
        if dims.Y == 1:
            return [
                ConfigBuffer(
                    DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel),
                    CoreOfmPingAddr,
                    None,
                    CoreOfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
                ConfigBuffer(
                    DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel),
                    CoreIfmPingAddr,
                    None,
                    CoreIfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                CallKernel(
                    dims.kernel_names[0],
                    gen_aie4_gap_params(dims, 0, 0, CoreTdmPingAddr),
                ),
                RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                RelBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
            ]

        if dims.Y > 1:
            return [
                ConfigBuffer(
                    DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel),
                    CoreOfmPingAddr,
                    None,
                    CoreOfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
                ConfigBuffer(
                    DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel),
                    CoreIfmPingAddr,
                    None,
                    CoreIfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                CallKernel(
                    dims.kernel_names[0],
                    gen_aie4_gap_params(dims, 1, 0, CoreTdmPingAddr),
                ),
                RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                Loop(
                    dims.Y - 2,
                    [
                        ConfigBuffer(
                            DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel),
                            CoreIfmPingAddr,
                            None,
                            CoreIfmSize,
                        ),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                        CallKernel(
                            dims.kernel_names[0],
                            gen_aie4_gap_params(dims, 0, 1, CoreTdmPingAddr),
                        ),
                        RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                    ],
                ),
                ConfigBuffer(
                    DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel),
                    CoreIfmPingAddr,
                    None,
                    CoreIfmSize,
                ),
                AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                CallKernel(
                    dims.kernel_names[0],
                    gen_aie4_gap_params(dims, 1, 1, CoreTdmPingAddr),
                ),
                RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                RelBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
            ]
        return None

    instr_dict = {}
    for col in range(dims.aie_cols):
        for row in range(dims.aie_rows):
            instr_dict[AieTile(TileType.Core, col, row)] = get_core_instrs(dims, col, row)

    memtile_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col),
            [MemtileParamAddr],
            MemtileParamSize,
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                    prm_memtile_memory(),
                    prm_memtile_s2mm(),
                    dims.prm_bits,
                )
            ],
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                    prm_memtile_memory(),
                    prm_memtile_mm2s(row),
                    dims.prm_bits,
                )
                for row in range(dims.aie_rows)
            ],
        )
        for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col),
            [MemtileIfmPingAddr],
            MemtileIfmSize,
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]),
                    gap_memtile_ifm_memory(dims),
                    gap_memtile_ifm_s2mm(dims, col),
                    dims.act_bits,
                )
            ],
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                    gap_memtile_ifm_memory(dims),
                    gap_memtile_ifm_mm2s(dims, row),
                    dims.act_bits,
                    enable_padding=True,
                )
                for row in range(dims.aie_rows)
            ],
            sync_strategy=SyncStrategy.Parallel_1_to_N,
        )
        for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col),
            [MemtileOfmPingAddr],
            MemtileOfmSize,
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_O_ids()[row]),
                    gap_memtile_ofm_memory(dims),
                    gap_memtile_ofm_s2mm(dims, row),
                    dims.out_bits,
                )
                for row in range(dims.aie_rows)
            ],
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[0]),
                    gap_memtile_ofm_memory(dims),
                    gap_memtile_ofm_mm2s(dims, col),
                    dims.out_bits,
                )
            ],
        )
        for col in range(dims.aie_cols)
    ]

    shim_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col),
            [prm_buffer_id],
            ShimParamSize,
            [],
            [
                generate_transfer_params(
                    shim_dma(col, DmaDir.MM2S, shim_param_mm2s_channel),
                    prm_shim_memory(),
                    prm_shim_mm2s(col),
                    dims.prm_bits,
                )
            ],
        )
        for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col),
            [ifm_buffer_id],
            ShimIfmSize,
            [],
            [
                generate_transfer_params(
                    shim_dma(col, DmaDir.MM2S, shim_ifm_mm2s_channel),
                    gap_shimtile_ifm_memory(dims),
                    gap_shimtile_ifm_mm2s(dims, col),
                    dims.act_bits,
                )
            ],
        )
        for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col),
            [ofm_buffer_id],
            ShimOfmSize,
            [
                generate_transfer_params(
                    shim_dma(col, DmaDir.S2MM, shim_ofm_s2mm_channel),
                    gap_shimtile_ofm_memory(dims),
                    gap_shimtile_ofm_s2mm(dims, col),
                    dims.out_bits,
                )
            ],
            [],
        )
        for col in range(dims.aie_cols)
    ]

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        dims.kernel_names,
        dims.kernel_includes,
        instr_dict,
        memtile_transfers,
        shim_transfers,
        overlay_3x4_dma_connections(),
        back_end=dims.back_end,
        param_channel_id=core_param_s2mm_channel,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        layer_file=dims.layer_file_name,
    )
