"""L3 Dataflow for Uniop"""
import sys
import os
from typing import Tuple

from dmacompiler import (
    set_dev_gen,
    DevGen,
    DataTransfer,
    AieTile,
    TileType,
    run_layer_compilation,
    shim_dma,
    memtile_dma,
    DmaDir,
    BackEnd,
    OverlayShape,
    CallKernel,
    Loop,
    SyncStrategy,
    compute_buffer_size,
    generate_transfer_params,
    generate_shim_data_transfer,
    ConfigBuffer,
    DmaChannel,
    AcqBuffer,
    RelBuffer,
    TransferParams
)
from utils.utils_common import (
    overlay_3x4_core_stack_addr,
    iceil,
    log,
)
from scheduler.common import (
    overlay_3x4_dma_connections,
    overlay_3x4_col_core_stream_bdcast,
    prm_memtile_memory,
    prm_shim_memory,
    prm_memtile_mm2s,
    prm_memtile_s2mm,
    prm_shim_mm2s,
    overlay_3x4_F_ids,
    overlay_3x4_A_ids,
    L3Alloc_to_Shim
)
from scheduler.uniop.uniop_util import UniOpTensor
from scheduler.uniop.uniop_common import UnaryShape, UnaryMapping

from kernel.l2norm_fp16x16.l2norm_params import l2norm_layer_params, L2normDims
from kernel.softmax_fp16x16.softmax_params import softmax_layer_params, SoftmaxDims, copy_layer_params
from kernel.groupnorm.gpn_params import gpn_layer_params
# from kernel.SiLU_exp2.silu_params import silu_layer_params
from kernel.dq.dq_params import dequant_layer_params
from kernel.q.q_params import quant_layer_params
from kernel.layer_norm_fp16x16.layer_norm_params import layernorm_layer_params, LayernormDims
from kernel.linear_approx_bf16.linear_approx_params import linear_approx_layer_params
from buildscripts.common import ScheduleInputs

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, "..", "..", "dmacompiler"))
sys.path.append(os.path.join(CURRDIR, "..", "..", "dataflow"))
sys.path.append(os.path.join(CURRDIR, "..", "..", "kerneltest"))


set_dev_gen(DevGen.Aie4)


def compile_uniop_3x4_dataflow(schedule_input: ScheduleInputs):
    """Compile Dataflow for Uniop"""
    dims: UnaryShape = schedule_input.shape
    mapping: UnaryMapping = schedule_input.mapping
    kernel_names: list[str] = schedule_input.kernel_names
    kernel_includes: list[str] = schedule_input.kernel_includes
    L3_alloc: dict[str, tuple[int, int]] | None = schedule_input.L3_alloc
    backend = schedule_input.backend

    function = dims.function
    true_C = mapping.TensorDim[2]
    true_NYX = mapping.TensorDim[0] * mapping.TensorDim[1]
    TensorDim = mapping.PaddedTDim
    SubvolumeDim = mapping.ifm_subv
    SpatialSplitMode = mapping.spatial_split
    signA = dims.ifmSign  # 0: Unsigned , 1: Signed (Effective if input is fixed point)
    signO = dims.ofmSign  # 0: Unsigned , 1: Signed (Effective if output is fixed point)

    assert mapping.ifm_subv == mapping.ofm_subv
    assert function in {"l2norm", "silu", "softmax", "copy", "dequant", "quant", "layernorm", "groupnorm", "gelu", "swish", "sigmoid", "tanh", "elu"}
    assert backend in [BackEnd.Adf, BackEnd.CertAsm]
    assert isinstance(TensorDim, Tuple) and isinstance(SubvolumeDim, Tuple)
    assert len(TensorDim) == 3
    assert len(SubvolumeDim) == 3

    AieCols = 3
    AieRows = 4

    uniop_shim = L3Alloc_to_Shim(L3_alloc)
    log(f"function: {function}")
    log(f"Shim Allocator: {uniop_shim}")
    QdqNodes = 2
    QdqPrm = 3
    QdqPrmBytes = 128
    GammaBetaBytes = 2
    bytes_per_word = 4
    LayerPrmSize = 1024
    Npass = mapping.Npass

    # Group norm specific params
    Ngroups = mapping.Ngroups  # Used only for groupnorm
    if SpatialSplitMode == "N1X4C3":
        NgroupsPerCol = [11, 11, 10]  # TO-do Generelize this, Currently fixed to 32 groups
        NgroupsPerCore = [11]*4 + [11]*4 + [10]*4
    else:
        NgroupsPerCol = [12, 12, 8]  # To-do Generelize this, Currently fixed to 32 groups
        NgroupsPerCore = [3]*10 + [2] + [0]
    MaxGroupsPerCol = max(NgroupsPerCol)
    GroupSize = true_C // Ngroups
    wgt_rep = SubvolumeDim[1] if function in {"groupnorm"} else 1
    MaxGammaBetaDim = MaxGroupsPerCol * GroupSize * AieCols
    MaxMemGammaBetaSize = MaxGammaBetaDim * wgt_rep * 2 * GammaBetaBytes
    granC = GroupSize if function in {"groupnorm"} else 64

    ActBytes = dims.ifmbytes
    OutBytes = dims.ofmbytes

    MemtilePrmPingAddr = 0

    log("SubvolumeDim:", SubvolumeDim)
    log("SpatialSplitMode:", SpatialSplitMode)
    CoreAlignSize = 128
    CoreInputSize = SubvolumeDim[0]*SubvolumeDim[1]*SubvolumeDim[2]*ActBytes
    CoreOutputSize = SubvolumeDim[0]*SubvolumeDim[1]*SubvolumeDim[2]*OutBytes
    CoreWeightSize = (CoreInputSize // ActBytes) // 8 if function == "softmax" else 0
    CoreQdqPrmSize = 4 * 32
    CoreQdqBuffersize = QdqNodes * QdqPrm * QdqPrmBytes

    CoreQdqTotalSize = CoreQdqPrmSize + CoreQdqBuffersize
    CoreGammaBetaSize = 2 * wgt_rep * SubvolumeDim[2] * GammaBetaBytes if function in {"layernorm", "groupnorm"} else 0
    CoreScratchSize = 4096 if function in {"groupnorm"} else 0
    # CoreSpillBuffersize = 0 if function in {"softmax", "l2norm", "copy"} else 1024
    Core2LUTSize = 8192 if function in {"swish", "tanh", "sigmoid", "silu", "gelu", "elu"} else 0

    CoreInputPingAddr = 0
    CoreOutputPingAddr = iceil(CoreInputPingAddr + CoreInputSize, CoreAlignSize)
    CoreWeightPingAddr = iceil(CoreOutputPingAddr + CoreOutputSize, CoreAlignSize)
    CoreScratchAddr = iceil(CoreWeightPingAddr + CoreWeightSize, CoreAlignSize)  # used only for groupnorm
    CoreQdqParamAddr = iceil(CoreScratchAddr + CoreScratchSize, CoreAlignSize)
    CoreQdqBufferAddr = iceil(CoreQdqParamAddr + CoreQdqTotalSize, CoreAlignSize)
    CoreGammaBetaAddr = iceil(CoreQdqParamAddr + CoreQdqTotalSize, CoreAlignSize)
    CoreLUT_AB_Addr = iceil(CoreGammaBetaAddr + CoreGammaBetaSize, CoreAlignSize)
    CoreLUT_CD_Addr = iceil(CoreLUT_AB_Addr + Core2LUTSize//2, CoreAlignSize)

    log("CoreQdqParamAddr:", CoreQdqParamAddr)
    log("CoreGammaBetaAddr:", CoreGammaBetaAddr)

    CoreSpillBuffA_Addr = iceil(CoreLUT_CD_Addr + Core2LUTSize//2, CoreAlignSize)
    # CoreSpillBuffB_Addr = iceil(CoreSpillBuffA_Addr + CoreSpillBuffersize, CoreAlignSize)

    log("CoreLUT_AB_Addr:", CoreLUT_AB_Addr)
    log("CoreLUT_CD_Addr:", CoreLUT_CD_Addr)
    log("last_addr:", CoreLUT_CD_Addr + (Core2LUTSize//2))
    log("Core2LUTSize", Core2LUTSize)
    # exit(1)
    assert CoreLUT_CD_Addr + (Core2LUTSize//2) < overlay_3x4_core_stack_addr()

    MemtileQdqTotalSize = CoreQdqTotalSize

    MemtileGammaBetaSize = CoreGammaBetaSize
    Memtile2LUTSize = Core2LUTSize
    # ShimConstantBoTotalSize = MemtileQdqTotalSize + MemtileGammaBetaSize + Memtile2LUTSize

    if function == "groupnorm":
        MemtileGammaBetaSize = MaxMemGammaBetaSize
        ShimGammaBetaSize = mapping.TensorDim[2] * 2 * GammaBetaBytes
        ShimConstantBoQdqTotalSize = MemtileQdqTotalSize
        ShimConstantBoTotalSize = MemtileQdqTotalSize + ShimGammaBetaSize
    else:
        MemtileGammaBetaSize = CoreGammaBetaSize
        ShimConstantBoTotalSize = MemtileQdqTotalSize + MemtileGammaBetaSize + Memtile2LUTSize

    log("ShimConstantBoTotalSize", ShimConstantBoTotalSize)

    MemtileIfmPingAddr = 0
    MemtileQdqPingAddr = MemtileIfmPingAddr + AieRows * CoreInputSize
    MemtileGammaBetaPingAddr = MemtileQdqPingAddr + MemtileQdqTotalSize

    MemtileOfmPingAddr = MemtileGammaBetaPingAddr + MemtileGammaBetaSize
    MemtileLUTsPingAddr = MemtileOfmPingAddr + CoreInputSize

    if function == "groupnorm":
        MemtileOfmPingAddr = MemtileGammaBetaPingAddr + MaxMemGammaBetaSize
    else:
        MemtileOfmPingAddr = MemtileGammaBetaPingAddr + MemtileGammaBetaSize

    log("Last L1 Addr :", CoreQdqBufferAddr + CoreQdqBuffersize)
    # assert CoreWeightPingAddr + CoreWeightSize < overlay_3x4_core_stack_addr()
    log("ActBytes:", ActBytes)
    log("OutBytes:", OutBytes)
    assert Npass in {1, 2}

    Act = UniOpTensor(TensorDim, SubvolumeDim, SpatialSplitMode, L2BufferAddr=MemtileIfmPingAddr, bytes_per_elem=ActBytes, input_npass=Npass, granC=granC)
    Out = UniOpTensor(TensorDim, SubvolumeDim, SpatialSplitMode, L2BufferAddr=MemtileOfmPingAddr, bytes_per_elem=OutBytes, input_npass=Npass, granC=granC, inputBuf=False)

    Act.set_bo_id(uniop_shim.ifm_xrt_idx)
    Act.set_bo_offset(uniop_shim.ifm_xrt_offset)
    Out.set_bo_id(uniop_shim.ofm_xrt_idx)
    Out.set_bo_offset(uniop_shim.ofm_xrt_offset)

    log("Act.X.shm_spatial_split_factor", Act.X.shm_spatial_split_factor)
    log("Act.X.l2_dim", (Act.X.l2_dim))
    log("Act.X.l3_dim", (Act.X.l3_dim))

    mem_layer_param = [
        DataTransfer(
            # [1] if (Act.X.l3_dim % (AieCols * Act.X.l2_dim)) == 0 else [1, 0] + [0,0] * (Npass-1),
            [1] if (Act.X.l3_dim % (Act.X.shm_spatial_split_factor * Act.X.l2_dim)) == 0 else [1, 0] + [0, 0] * (Npass-1),
            AieTile(TileType.Memtile, col, 0),
            [MemtilePrmPingAddr],
            LayerPrmSize,
            [generate_transfer_params(memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]), prm_memtile_memory(), prm_memtile_s2mm())],
            [generate_transfer_params(memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]), prm_memtile_memory(), prm_memtile_mm2s(row)) for row in range(AieRows)],
            sync_strategy=SyncStrategy.Parallel_1_to_N,
        )
        for col in range(AieCols)
    ]

    mem_qdq_t = [
        DataTransfer(
            # [1] if (Act.X.l3_dim % (AieCols * Act.X.l2_dim)) == 0 else [1, 0] + [0,0] * (Npass-1),
            [1] if (Act.X.l3_dim % (Act.X.shm_spatial_split_factor * Act.X.l2_dim)) == 0 else [1, 0] + [0, 0] * (Npass-1),
            AieTile(TileType.Memtile, col, 0),
            [MemtileQdqPingAddr],
            MemtileQdqTotalSize,
            [TransferParams(memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]), MemtileQdqTotalSize // bytes_per_word)],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]), MemtileQdqTotalSize // bytes_per_word) for row in range(AieRows)],
        )
        for col in range(AieCols)
    ]

    def get_mem_gamma_beta():
        if function == "layernorm":
            GammaBetaDim = SubvolumeDim[2]
            mem_gamma_beta_t = [
                DataTransfer(
                    # [1] if (Act.X.l3_dim % (AieCols * Act.X.l2_dim)) == 0 else [1, 0] * Npass,
                    [1] if (Act.X.l3_dim % (Act.X.shm_spatial_split_factor * Act.X.l2_dim)) == 0 else [1, 0] * Npass,
                    AieTile(TileType.Memtile, col, 0), [MemtileGammaBetaPingAddr], MemtileGammaBetaSize,
                    [TransferParams(memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),  MemtileGammaBetaSize//bytes_per_word)],
                    [
                        generate_transfer_params(memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                                                 f"Ngb:2 gbdim:{GammaBetaDim}", f"gbdim:0:{GammaBetaDim}:32 Ngb:0:2 gbdim:0:32",
                                                 bits_per_block=GammaBetaBytes * 8)
                        for row in range(AieRows)
                    ]
                ) for col in range(AieCols)
            ]
        elif function == "groupnorm":
            GammaBetaDim = SubvolumeDim[2]
            GammaBetaDimPerCol = [NgroupsPerCol[col]*GroupSize for col in range(AieCols)]
            GammaBetaDimPerCore = [NgroupsPerCore[x]*GroupSize for x in range(AieCols*AieRows)]

            gpn_wgt_mem_mem_format = f"Ngb:2 rep:{wgt_rep} gbdim:{MaxGammaBetaDim}"

            def get_mem_wgt_mm2s_tiling_format(row: int, col: int) -> str:

                if SpatialSplitMode == "N1X4C3":
                    wgt_stt = sum(GammaBetaDimPerCol[0:col])
                    wgt_end = wgt_stt + GammaBetaDimPerCol[col]

                else:
                    wgt_stt = sum(GammaBetaDimPerCore[0:col*AieRows + row])
                    wgt_end = wgt_stt + GammaBetaDimPerCore[col*AieRows + row]

                wgt_tiling_format = f"Ngb:0:2 gbdim:{wgt_stt}:{wgt_end}:{GroupSize} rep:0:{wgt_rep} gbdim:0:{GroupSize}"
                return wgt_tiling_format

            mem_gamma_beta_t = [
                DataTransfer(
                    # [1] if (Act.X.l3_dim % (AieCols * Act.X.l2_dim)) == 0 else [1, 0] * Npass,
                    [1] if (Act.X.l3_dim % (Act.X.shm_spatial_split_factor * Act.X.l2_dim)) == 0 else [1, 0] * Npass,
                    AieTile(TileType.Memtile, col, 0), [MemtileGammaBetaPingAddr], MemtileGammaBetaSize,
                    [generate_transfer_params(memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                                              gpn_wgt_mem_mem_format, f"rep:0:{wgt_rep} Ngb:0:2 gbdim:0:{true_C}",
                                              bits_per_block=GammaBetaBytes * 8)],
                    [
                        generate_transfer_params(memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                                                 gpn_wgt_mem_mem_format, get_mem_wgt_mm2s_tiling_format(row, col),
                                                 bits_per_block=GammaBetaBytes * 8)
                        for row in range(AieRows)
                    ]
                ) for col in range(AieCols)
            ]
        else:
            mem_gamma_beta_t = []

        return mem_gamma_beta_t

    mem_Lutab_Lutcd_t = [
        DataTransfer(
            [1] if (Act.X.l3_dim % (Act.X.shm_spatial_split_factor * Act.X.l2_dim)) == 0 else [1, 0] * Npass,
            AieTile(TileType.Memtile, col, 0),
            [MemtileLUTsPingAddr],
            Memtile2LUTSize,
            [TransferParams(memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]), Memtile2LUTSize // bytes_per_word)],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]), Memtile2LUTSize // bytes_per_word) for row in range(AieRows)],
        )
        for col in range(AieCols)
    ]

    shm_layer_param = [
        generate_shim_data_transfer(
            # [1] if (Act.X.l3_dim % (AieCols * Act.X.l2_dim)) == 0 else [1, 0] + [0,0] * (Npass-1),
            [1] if (Act.X.l3_dim % (Act.X.shm_spatial_split_factor * Act.X.l2_dim)) == 0 else [1, 0] + [0, 0] * (Npass-1),
            shim_dma(col, DmaDir.MM2S, 0),
            uniop_shim.prm_xrt_idx,
            prm_shim_memory(),
            prm_shim_mm2s(col),
            buffer_offset=uniop_shim.prm_xrt_offset,
        )
        for col in range(AieCols)
    ]

    if function == "groupnorm":
        shim_const_bo = [
            generate_shim_data_transfer(
                [1] if (Act.X.l3_dim % (Act.X.shm_spatial_split_factor * Act.X.l2_dim)) == 0 else [1, 0] + [0, 0] * (Npass-1),
                shim_dma(col, DmaDir.MM2S, 0), uniop_shim.wgt_xrt_idx,
                f"Byte:{ShimConstantBoQdqTotalSize}", f"Byte:0:{ShimConstantBoQdqTotalSize}",
                buffer_offset=uniop_shim.wgt_xrt_offset
            ) for col in range(AieCols)
        ] + [
            generate_shim_data_transfer(
                [wgt_rep] if (Act.X.l3_dim % (Act.X.shm_spatial_split_factor * Act.X.l2_dim)) == 0 else [1, 0] + [0, 0] * (Npass-1),
                shim_dma(col, DmaDir.MM2S, 0), uniop_shim.wgt_xrt_idx,
                f"Byte:{ShimConstantBoTotalSize}", f"Byte:{ShimConstantBoQdqTotalSize}:{ShimConstantBoTotalSize}",
                buffer_offset=uniop_shim.wgt_xrt_offset
            ) for col in range(AieCols)
        ]
    else:
        shim_const_bo = [
            generate_shim_data_transfer(
                [1] if (Act.X.l3_dim % (Act.X.shm_spatial_split_factor * Act.X.l2_dim)) == 0 else [1, 0] + [0, 0] * (Npass-1),
                shim_dma(col, DmaDir.MM2S, 0), uniop_shim.wgt_xrt_idx,
                f"Byte:{ShimConstantBoTotalSize}", f"Byte:0:{ShimConstantBoTotalSize}",
                buffer_offset=uniop_shim.wgt_xrt_offset
            ) for col in range(AieCols)
        ]

    kernel_name_mapping = {
        "l2norm": "run_l2norm_fp16x16",
        "silu": "run_lut_fp16x16",  # "run_silu",
        "softmax": "run_softmax_fp16x16",
        "copy": "run_copy_fp16x16",
        "dequant": "run_dequant",
        "quant": "run_quant",
        "layernorm": "run_layernorm_fp16x16",
        "groupnorm": "run_group_norm_qdq",
        "gelu": "run_lut_fp16x16",  # "run_silu",
        "swish": "run_lut_fp16x16",
        "tanh": "run_lut_fp16x16",
        "sigmoid": "run_lut_fp16x16",
        "elu": "run_lut_fp16x16"
    }

    layer_params = None
    if function == "l2norm":
        layer_params = l2norm_layer_params(
            CoreInputPingAddr, CoreWeightPingAddr, CoreOutputPingAddr,
            CoreQdqParamAddr, CoreQdqBufferAddr, CoreQdqBufferAddr + 384,
            true_C,
            SubvolumeDim[1], SubvolumeDim[2], SubvolumeDim[1] * SubvolumeDim[2] // 32,
            signA, signO,
            L2normDims(
                1, SubvolumeDim[0], SubvolumeDim[1], SubvolumeDim[2], order_select=1
            ),
        )

    elif function in {"softmax"}:
        layer_params = softmax_layer_params(
            CoreInputPingAddr, CoreWeightPingAddr, CoreOutputPingAddr,
            CoreQdqParamAddr, CoreQdqBufferAddr, CoreQdqBufferAddr+384,
            true_C, SubvolumeDim[1], SubvolumeDim[2], SubvolumeDim[1]*SubvolumeDim[2]//32, SubvolumeDim[1]*SubvolumeDim[2]//8,
            signA, signO,
            SoftmaxDims(SubvolumeDim[0], SubvolumeDim[1], SubvolumeDim[2])
        )

    elif function in {"copy"}:
        layer_params = copy_layer_params(
            CoreInputPingAddr, CoreOutputPingAddr, SubvolumeDim[1]*SubvolumeDim[2], ActBytes, OutBytes
        )

    elif function == "dequant":
        layer_params = dequant_layer_params(CoreInputPingAddr, CoreOutputPingAddr, CoreQdqParamAddr, SubvolumeDim[1] * SubvolumeDim[2] // 32, signA)

    elif function == "quant":
        layer_params = quant_layer_params(CoreInputPingAddr, CoreOutputPingAddr, CoreQdqParamAddr, SubvolumeDim[1] * SubvolumeDim[2] // 32, signO)

    elif function == "layernorm":
        layer_params = layernorm_layer_params(
            CoreInputPingAddr, CoreGammaBetaAddr, CoreOutputPingAddr,
            CoreQdqParamAddr, CoreQdqBufferAddr, CoreQdqBufferAddr+384,
            true_C, SubvolumeDim[1], SubvolumeDim[2], SubvolumeDim[1]*SubvolumeDim[2]//32,
            signA, signO,
            LayernormDims(1, SubvolumeDim[0], SubvolumeDim[1], SubvolumeDim[2], order_select=1)
        )

    elif function in {"swish", "tanh", "sigmoid", "silu", "gelu", "elu"}:
        layer_params = linear_approx_layer_params(
            CoreInputPingAddr, CoreLUT_AB_Addr, CoreLUT_CD_Addr,
            CoreSpillBuffA_Addr, CoreOutputPingAddr, CoreQdqParamAddr,
            CoreQdqParamAddr, CoreQdqParamAddr,
            signA, signO,
            idx_bias=2048.0,
            num_iters=SubvolumeDim[0]*SubvolumeDim[1]*SubvolumeDim[2]//32,
            idx_max=4.96875,
            idx_min=-4.96875,
            idx_mul=412.0)

    log("Act.get_iters(X)", Act.get_iters("X"))
    log("Act.get_iters(X) * Npass:", Act.get_iters("X") * Npass)

    def CallKernelFunc(params):
        return [CallKernel(kernel_name_mapping[function], params)]

    def get_gpn_core_instrs(core_col_id: int, core_row_id: int):
        mean_var_iters = Act.get_iters("X")

        def gpn_params(op_type: int, is_last_iter: int):
            enable_global_reduce = 1 if SpatialSplitMode == "N1X4C3" else 0
            return gpn_layer_params(
                CoreInputPingAddr,
                CoreOutputPingAddr,
                CoreScratchAddr,
                CoreGammaBetaAddr,
                CoreQdqParamAddr, CoreQdqBufferAddr, CoreQdqBufferAddr+384,
                SubvolumeDim[1],
                SubvolumeDim[2],
                GroupSize,
                NgroupsPerCore[core_col_id*AieRows + core_row_id],
                op_type,
                is_last_iter,
                true_NYX,
                enable_global_reduce,
                signA,
                signO
            )

        def core_inst_mean_var_loop(iters: int, op_type: int, is_last_iter: int):
            return [
                Loop(
                    iters,
                    [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]
                    + CallKernelFunc(gpn_params(op_type=op_type, is_last_iter=is_last_iter))
                    + [
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ],
                ),
            ]

        def core_inst_norm(iters: int, op_type: int, is_last_iter: int):
            return [
                Loop(
                    iters,
                    [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),  # To-do Remove when norm loop is added
                    ]
                    + CallKernelFunc(gpn_params(op_type=op_type, is_last_iter=is_last_iter))
                    + [
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),  # To-do Remove when norm loop is added
                    ],
                ),
            ]

        CoreGammaBetaSizePerCore = NgroupsPerCore[core_col_id*AieRows + core_row_id] * GroupSize * wgt_rep * 2 * GammaBetaBytes
        core_inst_preamble = [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreQdqParamAddr, None, CoreQdqTotalSize + CoreGammaBetaSizePerCore),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreInputPingAddr, None, Act.core_amount()),
            ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutputPingAddr, None, Out.core_amount()),
        ]

        core_inst = core_inst_preamble + core_inst_mean_var_loop(mean_var_iters-1, op_type=0, is_last_iter=0) +\
            core_inst_mean_var_loop(iters=1, op_type=0, is_last_iter=1) + core_inst_norm(mean_var_iters, op_type=1, is_last_iter=0)
        return core_inst

    # NOTE : The following layer_params are used if silu/gelu goes with poly-based implementation

    # elif function == "silu":
    #    layer_params = silu_layer_params(
    #        CoreInputPingAddr, CoreOutputPingAddr, CoreSpillBuffA_Addr,
    #        CoreSpillBuffB_Addr, CoreQdqParamAddr, CoreQdqBufferAddr, CoreQdqBufferAddr+384,
    #        SubvolumeDim[1]*SubvolumeDim[2]//32,
    #        k0s=-1+895.001,
    #        k1s=0.8083,
    #        k2s=-0.1084,
    #        k3s=0.0,
    #        signA=signA, signO=signO
    #    )

    # elif function == "gelu":
    #    layer_params = silu_layer_params(
    #        CoreInputPingAddr, CoreOutputPingAddr, CoreSpillBuffA_Addr,
    #        CoreSpillBuffB_Addr, CoreQdqParamAddr, CoreQdqBufferAddr, CoreQdqBufferAddr+384,
    #        SubvolumeDim[1]*SubvolumeDim[2]//32,
    #        k0s=-1+895.001,
    #        k1s=1.0955,
    #        k2s=-0.5631,
    #        k3s=0.0,
    #        signA=signA, signO=signO
    #    )
    # CallKernelFunc = [CallKernel(kernel_name_mapping[function], layer_params)]

    def get_core_instrs(core_col_id: int, core_row_id: int):
        if function in {"groupnorm"}:
            return get_gpn_core_instrs(core_col_id, core_row_id)

        return [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreQdqParamAddr, None, CoreQdqTotalSize + CoreGammaBetaSize + Core2LUTSize),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            Loop(
                Act.get_iters("X"),
                [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreInputPingAddr, None, Act.core_amount()),
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutputPingAddr, None, Out.core_amount()),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                ]
                + CallKernelFunc((layer_params))
                + [
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                ],
            ),
        ]

    memtile_transfers = mem_layer_param + mem_qdq_t + \
        get_mem_gamma_beta() + \
        (mem_Lutab_Lutcd_t if function in {"swish", "tanh", "sigmoid", "silu", "gelu", "elu"} else []) + \
        Act.L2_DataTransfer() + Out.L2_DataTransfer()
    shmtile_transfers = shm_layer_param + shim_const_bo + Act.L3_DataTransfer() + Out.L3_DataTransfer()

    run_layer_compilation(
        overlay_shape=OverlayShape(AieCols, AieRows),
        kernel_names=kernel_names,
        kernel_includes=kernel_includes,
        core_instrs={AieTile(TileType.Core, col, row): get_core_instrs(col, row) for row in range(AieRows) for col in range(AieCols)},
        memtile_transfers=memtile_transfers,
        shim_transfers=shmtile_transfers,
        dma_connections=overlay_3x4_dma_connections(),
        back_end=backend,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        param_channel_id=0,
        layer_file="dma.hpp" if backend == BackEnd.Adf else "aie4_dma.cpp",
        core_connections=overlay_3x4_col_core_stream_bdcast(),
        dma_padding_map=schedule_input.dma_pad,
    )

    prm_shim_size = compute_buffer_size(prm_shim_memory())
    wgt_shim_size = ShimConstantBoTotalSize
    shim_prm_offset_next_layer = uniop_shim.prm_xrt_offset + prm_shim_size
    shim_wgt_offset_next_layer = uniop_shim.wgt_xrt_offset + wgt_shim_size
    log("shim_prm_offset_next_layer", shim_prm_offset_next_layer)
    log("shim_wgt_offset_next_layer", shim_wgt_offset_next_layer)

    return shim_prm_offset_next_layer, shim_wgt_offset_next_layer
