import os
import sys

from kernel.softmax_fp16x16.softmax_params import softmax_layer_params, SoftmaxDims

from dmacompiler import (
    DevGen,
    OverlayShape,
    DataTransfer,
    TransferParams,
    BackEnd,
    DmaChannel,
    DmaDir,
    AieTile,
    TileType,
    memtile_dma,
    shim_dma,
    ConfigBuffer,
    AcqBuffer,
    RelBuffer,
    CallKernel,
    run_layer_compilation,
    set_dev_gen,
)

from kerneltest.overlay_1x1 import overlay_stack_addr, aie4_overlay_dma_connections

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, "..", "..", "dmacompiler"))


set_dev_gen(DevGen.Aie4)


def Memtile(col: int):
    return AieTile(TileType.Memtile, col, 0)


def Shimtile(col: int):
    return AieTile(TileType.Shim, col, 0)


def compile_uniop_1x1_dataflow(
    Bsubv_: int, Msubv_: int, Nsubv_: int, backend=BackEnd.Adf
):
    kernel_names = ["run_softmax_fp16x16"]
    kernel_includes = ["super.hh", "softmax_fp16x16/softmax_fp16x16_wrapper.cc"]

    AieRows = 1
    AieCols = 1
    QdqNodes = 2
    QdqPrm = 3
    QdqPrmBytes = 4

    Msubv = Bsubv_ * Msubv_
    Nsubv = Nsubv_
    IfmBytes = 2

    LayerPrmSize = 1024
    CoreInputSize = Msubv * Nsubv * IfmBytes
    CoreMaskSize = CoreInputSize // 16
    CoreOutputSize = CoreInputSize
    # CoreQdqPrmSize = QdqNodes * QdqPrm * QdqPrmBytes
    # CoreQdqBuffersize = 1024

    CoreQdqPrmSize =  256 # ((QdqNodes * QdqPrm + 1) * QdqPrmBytes)  # + 1 : for nlf_enable flag
    CoreQdqBuffersize = 1024 - CoreQdqPrmSize
    CoreQdqTotalSize = CoreQdqPrmSize + CoreQdqBuffersize

    MemtileInputSize = CoreInputSize
    MemtileMaskSize = CoreMaskSize
    MemtileOutputSize = CoreOutputSize

    ShimInputSize = MemtileInputSize
    ShimMaskSize = MemtileMaskSize
    ShimOutputSize = MemtileOutputSize

    CoreInputPingAddr = 0
    CoreMaskPingAddr = CoreInputPingAddr + CoreInputSize
    CoreOutputPingAddr = CoreMaskPingAddr + CoreMaskSize
    CoreQdqParamAddr = CoreMaskPingAddr + CoreMaskSize
    CoreQdqBufferAddr = CoreQdqParamAddr + CoreQdqPrmSize

    print("Last L1 Addr :", CoreQdqBufferAddr + CoreQdqBuffersize)

    MemtilePrmPingAddr = 0
    MemtileInPingAddr = MemtilePrmPingAddr + LayerPrmSize * (AieRows)
    MemtileMaskPingAddr = MemtileInPingAddr + CoreInputSize * (AieRows)
    MemtileOutPingAddr = MemtileMaskPingAddr + CoreMaskSize * (AieRows)

    core_instrs = [
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutputPingAddr, None, CoreOutputSize),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreInputPingAddr, None, CoreInputSize),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreMaskPingAddr, None, CoreMaskSize),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
        CallKernel(
            "run_softmax_fp16x16",
            softmax_layer_params(
                CoreInputPingAddr,
                CoreMaskPingAddr,
                CoreOutputPingAddr,
                CoreQdqParamAddr,
                CoreQdqParamAddr+384,
                CoreQdqBufferAddr,
                Nsubv_,
                Msubv_, Nsubv_,
                Msubv_ * Nsubv_ // 32,
                Msubv_ * Nsubv_ // 8,
                SoftmaxDims(Bsubv_, Msubv_, Nsubv_),
            ),
        ),
        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
    ]

    mem_in = [
        DataTransfer(
            [1],
            Memtile(col),
            [MemtileInPingAddr],
            MemtileInputSize,
            [TransferParams(memtile_dma(col, DmaDir.S2MM, 0), MemtileInputSize // 4)],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, 0), MemtileInputSize // 4)],
        )
        for col in range(AieCols)
    ]

    mem_mask = [
        DataTransfer(
            [1],
            Memtile(col),
            [MemtileMaskPingAddr],
            MemtileMaskSize,
            [TransferParams(memtile_dma(col, DmaDir.S2MM, 0), MemtileMaskSize // 4)],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, 0), MemtileMaskSize // 4)],
        )
        for col in range(AieCols)
    ]

    mem_out = [
        DataTransfer(
            [1],
            Memtile(col),
            [MemtileOutPingAddr],
            MemtileOutputSize,
            [TransferParams(memtile_dma(col, DmaDir.S2MM, 2), MemtileOutputSize // 4)],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, 5), MemtileOutputSize // 4)],
        )
        for col in range(AieCols)
    ]

    mem_layer_param = [
        DataTransfer(
            [1],
            Memtile(col),
            [MemtilePrmPingAddr],
            LayerPrmSize,
            [TransferParams(memtile_dma(col, DmaDir.S2MM, 0), LayerPrmSize // 4)],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, 0), LayerPrmSize // 4)],
        )
        for col in range(AieCols)
    ]

    # ------------------------------------------------------------------------------

    shim_in = [
        DataTransfer(
            [1],
            Shimtile(col),
            [1],
            ShimInputSize,
            [],
            [TransferParams(shim_dma(col, DmaDir.MM2S, 0), ShimInputSize // 4)],
        )
        for col in range(AieCols)
    ]

    shim_mask = [
        DataTransfer(
            [1],
            Shimtile(col),
            [1],
            ShimMaskSize,
            [],
            [
                TransferParams(
                    shim_dma(col, DmaDir.MM2S, 0),
                    ShimMaskSize // 4,
                    offset=ShimInputSize // 4,
                )
            ],
        )
        for col in range(AieCols)
    ]

    shim_out = [
        DataTransfer(
            [1],
            Shimtile(col),
            [0],
            ShimOutputSize,
            [TransferParams(shim_dma(col, DmaDir.S2MM, 0), ShimOutputSize // 4)],
            [],
        )
        for col in range(AieCols)
    ]

    shim_layer_param = [
        DataTransfer(
            [1],
            Shimtile(col),
            [3],
            LayerPrmSize,
            [],
            [
                TransferParams(
                    shim_dma(col, DmaDir.MM2S, 0),
                    LayerPrmSize // 4,
                    offset=((col * LayerPrmSize // 4)),
                )
            ],
        )
        for col in range(AieCols)
    ]

    memtile_transfers = mem_layer_param + mem_in + mem_mask + mem_out
    shim_transfers = shim_layer_param + shim_in + shim_mask + shim_out

    run_layer_compilation(
        OverlayShape(AieCols, AieRows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        aie4_overlay_dma_connections(AieCols, AieRows),
        back_end=backend,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
    )


if __name__ == "__main__":

    backend = back_end = BackEnd.Adf
    compile_uniop_1x1_dataflow(backend)
