import sys
sys.path.append('.')
sys.path.append('../../../dmacompiler')
sys.path.append('../../../kernels')

from math import sqrt, floor
from typing import Tuple, List
from typing import Optional, Any, Union, Dict
import struct
from os import path
from mha.mini_mha.mha_params import generate_layer_kernel_params1, MhaSubvDims, OPMode
from mha.mini_mha.mha_params_smxv import generate_layer_kernel_params, MhaSubvDims, OPMode

from dataflow.dataflow_common import overlay_8x4_dma_connections, overlay_4x4_dma_connections
from dataflow_common import disable_fast_pm_backend
from dmacompiler import OverlayShape, DataTransfer, TransferParams, SyncStrategy, BackEnd, \
    DmaChannel, DmaDir, AieDma, AieTile, TileType, DmaConnection, \
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    memtile_dma, shim_dma, core_dma, \
    generate_shim_data_transfer, \
    run_layer_compilation, BackEnd, \
    generate_transfer_params

from dmacompiler import set_dev_gen, DevGen, config
config.ENABLE_BUSY_POLL = True
config.ENABLE_FAST_PM = False


set_dev_gen(DevGen.Aie2p)

def phase_length(Sin_q, Sq_subv, H_aie, Tqkv, Bias):
    AieRows = 4
    if not Bias:
        if  Sin_q // (Sq_subv*AieRows) >= 1 and Sin_q % (Sq_subv*AieRows) != 0:
            return 2*H_aie
        else:
            return H_aie
    else:
        pad = 1 if Sin_q % (Sq_subv*AieRows) != 0 else 0
        return 1 + (Tqkv + pad) * H_aie 

def shim_phase_pattern(H_aie, H_phase, Tqkv, Bias):
    if not Bias:
        if H_aie != H_phase:
            return  [1, 0]*(H_aie) 
        else:
            return  [1]*(H_aie)
    else:
        pad = H_phase - (1 + Tqkv * H_aie) 
        return [0] + [item for _ in range(H_aie) for item in ([1] + [0] * (Tqkv-1 + pad))]
    
def memtile_phase_pattern(H_aie, H_phase, Tqkv, Bias):
    if not Bias:
        if H_aie != H_phase:
            return [Tqkv, 1]*(H_aie) 
        else:
            return [Tqkv]*(H_aie)
    else:
        return [0] + [1]*(H_phase-1)
    
def exact_mapping(H_aie, H_phase):
    return H_aie != H_phase

def pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
    if H_aie == H_phase:
        return False
    if M_pad:
        if not Bias:
            return h_phase % 2 == 1
        else:
            if (h_phase) == 0: 
                return False
            return (h_phase) % (Tqkv + 1) == 0
    else:
        False

def H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
    if not Bias:
        h_eff = h_phase // 2 if (H_aie != H_phase) else h_phase
        return h_eff 
    else:
        if M_pad:
            return (h_phase-1) // (Tqkv + 1)
        else:
            return (h_phase-1) // (Tqkv)


def pack_reconfig(
    dma: AieDma,
    num_phases: int,
    fmts: list[tuple[int, str, str]],
    bits_per_block: int,
    buffer_offset: int = 0
) -> TransferParams:
    '''Pack tuples with (phase_index, memory_format, tiling_format) into a single BD allocation'''
    length: Any = [0 for _ in range(num_phases)]
    offset: Any = [0 for _ in range(num_phases)]
    step: Any = [[1] for _ in range(num_phases)]
    wrap: Any = [[] for _ in range(num_phases)]
    padding: Any = [[] for _ in range(num_phases)]
    iter_step: Any = [None for _ in range(num_phases)]
    iter_wrap: Any = [None for _ in range(num_phases)]
    for idx, memory_fmt, tiling_fmt in fmts:
        params: Any = generate_transfer_params(
            dma, memory_fmt, tiling_fmt, bits_per_block, enable_padding=True, buffer_offset=buffer_offset)
        
        length[idx] = params.length_i(0)
        offset[idx] = params.offset_i(0)
        step[idx] = params.step_i(0)
        wrap[idx] = params.wrap_i(0)
        padding[idx] = params.padding_i(0)
        iter_step[idx] = params.iter_step_i(0)
        iter_wrap[idx] = params.iter_wrap_i(0)
        
    return TransferParams(
        dma, length,
        offset=offset,
        step=step,
        wrap=wrap,
        padding=padding,
        iter_step=iter_step,
        iter_wrap=iter_wrap,
    )

def ceildiv(x: int, d: int) -> int:
    return -(x // -d)

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d
    
def mha_qkt_sfmx_qdq_params(
    input: Tuple[int, int],
    output: Tuple[int, int],
    mha_mode: int,
    multi_core : int,
    col_id : int,
    row_id : int,
    Sin_kv : int,
    core_tdm1_addr : int,
    core_tdm2_addr : int,
    core_qdq_addr : int,
    core_act1_addr : int,
    core_act2_addr : int,
    core_C0_addr : int,
    core_scratch_addr : int,
    core_query_addr : int,
    core_key_addr : int,
    core_val_addr : int,
    core_msk_addr : int,
) -> bytes:
    Y_gran = 1
    X_gran = 8
    Co_gran = 8
    Ci_gran = 8
    size_bytes = 2
    stride_efficiency = 0.5
    mem_align = 64
    M, K = input
    K, N = output
    Ky, Kx = 1, 1
    Sy, Sx = 1, 1
    op_mode = OPMode.OP_SUM
    params_blob = generate_layer_kernel_params1(
        mha_mode,
        multi_core,
        col_id,
        row_id,
        Sin_kv,
        core_tdm1_addr,
        core_tdm2_addr,
        core_qdq_addr,
        core_act1_addr,
        core_act2_addr,
        core_C0_addr,
        core_scratch_addr,
        core_query_addr,
        core_key_addr,
        core_val_addr,
        core_msk_addr,
        MhaSubvDims(
            1,
            1, Y_gran,
            M, X_gran,
            K, Ci_gran,
            N, Co_gran,
            Ky, Kx,
            Sy, Sx,
            op_mode,
            size_bytes,
            stride_efficiency,
            mem_align,
        )
    )
    return params_blob

def smxv_qdq_params(
    input: Tuple[int, int],
    output: Tuple[int, int],
    mha_mode: int,
    multi_core : int,
    col_id : int, 
    row_id : int,
    first_tdm_iter : int, 
    final_tdm_iter : int,
    Sin_kv : int,
    core_tdm1_addr : int,
    core_tdm2_addr : int,
    core_qdq_addr : int,
    core_act1_addr : int,
    core_act2_addr : int,
    core_C0_addr : int,
    core_scratch_addr : int,
    core_key_addr : int,
    core_val_addr : int,
    core_msk_addr : int,
    core_out_addr : int,
    qdq_node_offset : int, 
    transpose_B : int,
) -> bytes:
    Y_gran = 1
    X_gran = 8
    Co_gran = 8
    Ci_gran = 8
    size_bytes = 2
    stride_efficiency = 0.5
    mem_align = 64
    M, K = input
    K, N = output
    Ky, Kx = 1, 1
    Sy, Sx = 1, 1
    op_mode = OPMode.OP_SUM
    params_blob = generate_layer_kernel_params(
        mha_mode,
        multi_core,
        col_id,
        row_id,
        first_tdm_iter,
        final_tdm_iter,
        Sin_kv,
        core_tdm1_addr,
        core_tdm2_addr,
        core_qdq_addr,
        core_act1_addr,
        core_act2_addr,
        core_C0_addr,
        core_scratch_addr,
        core_key_addr,
        core_val_addr,
        core_msk_addr,
        core_out_addr,
        qdq_node_offset, 
        transpose_B,
        MhaSubvDims(
            1,
            1, Y_gran,
            M, X_gran,
            K, Ci_gran,
            N, Co_gran,
            Ky, Kx,
            Sy, Sx,
            op_mode,
            size_bytes,
            stride_efficiency,
            mem_align,
        )
    )
    return params_blob

def sfmx_params(
    Nlayer: int,
    Msubv: int,
    Nsubv: int,
    SplitType : int,
    col_id : int, 
    row_id : int,
    q_node_addr : int, 
    dq_node_addr : int,
    in_addr : int,
    out_addr : int,
    fuse_mode : int,
    scalefactor : float
):
    float_bytes = struct.pack('f', scalefactor)
    uint32_scalefactor = struct.unpack('I', float_bytes)[0]

    dummy = 0  ## for 4 bytes alignment

    return (
           Nlayer.to_bytes(length=2, byteorder='little', signed=False) + \
           Msubv.to_bytes(length=2, byteorder='little', signed=False) + \
           Nsubv.to_bytes(length=2, byteorder='little', signed=False) + \
           SplitType.to_bytes(length=2, byteorder='little', signed=False) + \
           col_id.to_bytes(length=2, byteorder='little', signed=False) + \
           row_id.to_bytes(length=2, byteorder='little', signed=False) + \
           q_node_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           dq_node_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           in_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           out_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           fuse_mode.to_bytes(length=2, byteorder='little', signed=False) + \
           dummy.to_bytes(length=2, byteorder='little', signed=False) + \
           uint32_scalefactor.to_bytes(length=4, byteorder='little', signed=False)
    )

def bcast_add_params(
    Msubv: int,
    Nsubv: int,
    Nlayer: int,
    dq_node_addr : int,
    act_addr : int,
    mask_vector_addr : int,
    bias_matrix_addr : int,
    out_addr : int,
    mask_vector_exist : int,
    bias_matrix_exist : int,
    perform_dq : int
):
    #float_bytes = struct.pack('f', scalefactor)
    #uint32_scalefactor = struct.unpack('I', float_bytes)[0]
    #print("uint32_scalefactor:", uint32_scalefactor)
    #dummy = 0  ## for 4 bytes alignment
    #int Msubv = layer_params->Msubv;
    #int Nsubv = layer_params->Nsubv;
    dummy = 0
    return (
           Msubv.to_bytes(length=2, byteorder='little', signed=False) + \
           Nsubv.to_bytes(length=2, byteorder='little', signed=False) + \
           Nlayer.to_bytes(length=2, byteorder='little', signed=False) + \
           dq_node_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           act_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           mask_vector_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           bias_matrix_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           out_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           mask_vector_exist.to_bytes(length=2, byteorder='little', signed=False) + \
           bias_matrix_exist.to_bytes(length=2, byteorder='little', signed=False) + \
           perform_dq.to_bytes(length=2, byteorder='little', signed=False) + \
           dummy.to_bytes(length=2, byteorder='little', signed=False) 
    )

def generate_dataflow_mini_mha(
    H: int,
    Sin_q: int,
    Sin_kv: int,
    Sin_dh: int,
    Sq_subv: int,
    aie_overlay_cols: int,  
    CodeBackend: BackEnd, 
    kernel_names: Union[List[str], Dict[str, int]],
    kernel_includes: List[str],
    disable_fast_pm = False,
    enable_attn_mask = False,
    K_is_tranposed_on_DDR = False,
    G: Optional[int] = None,
    G_seq: Optional[int] = None,
    B: Optional[int] = None,
    B_seq: Optional[int] = None,
    HYX: bool = True,
    mini_3p0:int = 1
):
    """
    Generate dataflow compilation for Mini Multi-Head Attention (MHA) on AIE overlay.
    
    Args:
        H: Number of attention heads
        Sin_q: Sequence length for query tensors
        Sin_kv: Sequence length for key and value tensors  
        Sin_dh: Head dimension size
        Sq_subv: Query subvolume size for tiling
        aie_overlay_cols: Number of AIE columns in the overlay 
        CodeBackend: Backend configuration for code generation
        kernel_names: List of kernel names or dictionary mapping kernels to IDs
        kernel_includes: List of header files to include for kernels
        disable_fast_pm: If True, disables fast parameter memory optimization
        enable_attn_mask: If True, enables attention masking (vector per head) support
        K_is_tranposed_on_DDR: If True, key tensor is stored transposed in DDR
        G: Optional group size parameter for grouped query attention
        B: Optional bias size parameter for bias handling
        
    Returns:
        None. Generates dma.hpp.
    """
    if disable_fast_pm:
            disable_fast_pm_backend()
            print(f"Fast PM disabled")

    # Initialize constants and parameters
    AieCols = aie_overlay_cols
    AieRows = 4
    Mask = 1 if(enable_attn_mask) else 0
    
    # MHA Heads distribution
    H_aie =  ((H-1) // AieCols) + 1
    H_col = H%AieCols

    # Group size
    if G is None:
        G = H
    if G_seq is None:
        G_seq = [i for i in range(H)]

    # Bias size
    if B is None:
        B = H
        Bias = False
    else:
        Bias = True
    
    if B_seq is None:
        B_seq = [i for i in range(H)]
    
    # Odd shape padding
    Skv = (((Sin_kv - 1) // 8) + 1) * 8  
    Dh = Sin_dh
    Sq = Sq_subv

    # Check if the shape is valid based on tiling parameters
    assert Sq >= 16

    # M padding:
    if Sin_q % (Sq * AieRows) != 0:
        M_pad = True
    else:
        M_pad = False

    # Iterations on the core
    Tqkv = max(1, Sin_q // (Sq * AieRows))
    To = max(1, Sin_q // (Sq * AieRows))
    
    # Phase length and pattern
    H_phase = phase_length(Sin_q, Sq_subv, H_aie, Tqkv, Bias)

    # Qdq and prm sizes
    QdqNodes = 6
    QdqPrm = 16
    QdqPrmBytes = 4
    CoreQdqPrmSize = QdqNodes * QdqPrm * QdqPrmBytes
    CoreAlignSize = 64
    CorePrmSize = 1024
    C0Bytes = 8
    
    IfmBytes = 2
    OutBytes = 2
    TdmBytes = 4
    bits_per_byte = 8
    bytes_per_word = 4
    QKt_postGemm_qdq_Offset = 128
    SMxV_postGemm_qdq_offset = 192
    
    # 2p1 or 3p0
    out_dim = Dh if mini_3p0 else Skv

    # Core sizes and padded size based on kernel requirements
    CoreQuerySize = Sq * Dh * IfmBytes
    CoreKeySize = Skv * Dh * IfmBytes
    CoreValSize = Skv * Dh * IfmBytes if mini_3p0 else 0
    CoreMaskSize = Mask * Skv * IfmBytes
    CoreBiasSize = (Sq * Skv * IfmBytes if Bias else 0)  # use Tdm2 buffer 
    CoreOutSize = Sq * out_dim * OutBytes
    CoreAct1SumSize = iceil(Skv * TdmBytes + 1, 256)
    CoreAct2SumSize = iceil(Skv * TdmBytes + 1, 256)
    CoreC0Size = Skv * C0Bytes
    CoreTdmBufSize = 2 * Skv * Sq * TdmBytes

    # Conditional address allocation
    if mini_3p0:
        # Mini 3P0: Query -> Key -> Value -> Mask -> QDQ -> TDM1 -> TDM2
        CoreQueryPingAddr = 0
        CoreKeyPingAddr = iceil(CoreQueryPingAddr + CoreQuerySize, CoreAlignSize)
        CoreValPingAddr = iceil(CoreKeyPingAddr + CoreKeySize, CoreAlignSize)
        CoreMaskPingAddr = iceil(CoreValPingAddr + CoreValSize, CoreAlignSize)
        CoreQdqPingAddr = iceil(CoreMaskPingAddr + CoreMaskSize, CoreAlignSize)
        CoreTdm1Addr = iceil(CoreQdqPingAddr + CoreQdqPrmSize, CoreAlignSize)
        CoreTdm2Addr = iceil(CoreTdm1Addr + CoreTdmBufSize // 2, CoreAlignSize)
        CoreOutPingAddr = CoreTdm1Addr  # Output uses TDM1
        CoreAct1SumAddr = iceil(CoreTdm2Addr + CoreTdmBufSize // 2, CoreAlignSize)
        CoreAct2SumAddr = iceil(CoreAct1SumAddr + CoreAct1SumSize, CoreAlignSize)
        CoreC0Addr = iceil(CoreAct2SumAddr + CoreAct2SumSize, CoreAlignSize)
        CoreScratchAddr = iceil(CoreC0Addr + CoreC0Size, CoreAlignSize)
    else:
        # Mini 2P1: Key -> Mask -> QDQ -> TDM1 -> TDM2 (Query/Output share TDM2)
        CoreValPingAddr = 0
        CoreKeyPingAddr = 0
        CoreMaskPingAddr = iceil(CoreKeyPingAddr + CoreKeySize, CoreAlignSize)
        CoreQdqPingAddr = iceil(CoreMaskPingAddr + CoreMaskSize, CoreAlignSize)
        CoreTdm1Addr = iceil(CoreQdqPingAddr + CoreQdqPrmSize, CoreAlignSize)
        CoreTdm2Addr = iceil(CoreTdm1Addr + CoreTdmBufSize // 2, CoreAlignSize)
        CoreOutPingAddr = CoreTdm2Addr 
        CoreAct2SumAddr = iceil(CoreTdm2Addr + CoreTdmBufSize // 2, CoreAlignSize)
        CoreC0Addr = iceil(CoreAct2SumAddr + CoreAct2SumSize, CoreAlignSize)
        CoreScratchAddr = iceil(CoreC0Addr + CoreC0Size, CoreAlignSize)
        CoreQueryPingAddr = CoreTdm2Addr
        CoreC0Addr = iceil(CoreTdm2Addr + CoreQuerySize, CoreAlignSize)
        CoreAct1SumAddr = iceil(CoreC0Addr + CoreC0Size, CoreAlignSize)
        assert CoreAct1SumAddr + CoreAct1SumSize <= CoreAct2SumAddr, "CoreAct1SumAddr + CoreAct1SumSize should be less than or equal to CoreAct2SumAddr"

    print("CoreQueryPingAddr", CoreQueryPingAddr)
    print("CoreKeyPingAddr", CoreKeyPingAddr)
    print("CoreValPingAddr", CoreValPingAddr)
    print("CoreMaskPingAddr", CoreMaskPingAddr)
    print("CoreTdm1Addr", CoreTdm1Addr)
    print("CoreTdm2Addr", CoreTdm2Addr)
    print("CoreOutPingAddr", CoreOutPingAddr)
    print("CoreQdqPingAddr", CoreQdqPingAddr)
    print("CoreAct1SumAddr", CoreAct1SumAddr)
    print("CoreAct2SumAddr", CoreAct2SumAddr)
    print("CoreC0Addr", CoreC0Addr)
    print("CoreScratchAddr", CoreScratchAddr)
    
    # Memtile sizes
    MemtileQuerySize = Sq * Dh * IfmBytes * AieRows
    MemtileKeySize = Sin_kv * Dh * IfmBytes
    MemtileValSize = Sin_kv * Dh * IfmBytes if mini_3p0 else 0
    MemtileMaskSize = Mask * Skv * IfmBytes
    MemtileBiasSize = CoreBiasSize*AieRows
    MemtileQkvSize = MemtileQuerySize + MemtileKeySize + MemtileValSize     
    MemtileOutSize = Sq * out_dim * OutBytes * AieRows
    MemtilePrmSize = CorePrmSize
    MemtileQdqPrmSize = CoreQdqPrmSize

    # Memtile addresses
    MemtilePrmPingAddr = 0   
    MemtileQkvPingAddr = MemtilePrmPingAddr + MemtilePrmSize
    MemtileOutPingAddr = MemtileQkvPingAddr + MemtileQkvSize + MemtileBiasSize + MemtileMaskSize
    MemtileQdqPingAddr = MemtileOutPingAddr + MemtileOutSize
    
    print("MemtilePrmPingAddr", MemtilePrmPingAddr)
    print("MemtileQkvPingAddr", MemtileQkvPingAddr)
    print("MemtileOutPingAddr", MemtileOutPingAddr)
    print("MemtileQdqPingAddr", MemtileQdqPingAddr)

    
    # Shim sizes
    ShimQuerySize = MemtileQuerySize
    ShimKeySize = MemtileKeySize
    ShimValSize = MemtileValSize
    ShimMaskSize = MemtileMaskSize
    ShimQkvSize = ShimQuerySize + ShimKeySize + ShimValSize + ShimMaskSize
    ShimOutSize = MemtileOutSize
    ShimPrmSize = MemtilePrmSize
    ShimQdqPrmSize = MemtileQdqPrmSize * H
    
    # Size in Words used for linear Data Transfer which does not use Tiling and other related functions
    CorePrmWords = CorePrmSize // bytes_per_word
    CoreQdqPrmWords = CoreQdqPrmSize // bytes_per_word
    MemtilePrmWords = MemtilePrmSize // bytes_per_word
    MemtileQdqPrmWords = MemtileQdqPrmSize // bytes_per_word
    ShimPrmWords = ShimPrmSize // bytes_per_word
    ShimQdqPrmWords = ShimQdqPrmSize // bytes_per_word

    transposeK = 0 if(K_is_tranposed_on_DDR) else 1   ## traversal pattern. 0 is default
    sum_mode = 1 if(K_is_tranposed_on_DDR) else 0  ## use mha_mode field now. 0 --> OP_SUM , 1 --> OP_SUM_T
    perform_8x8_block_transpose = 0 if(K_is_tranposed_on_DDR) else 1  
        
    # Helper functions
    def Memtile(col: int):
        return AieTile(TileType.Memtile, col, 0)
    
    def Shimtile(col: int):
        return AieTile(TileType.Shim, col, 0)
    
    def get_core_instrs(core_col_id: int, core_row_id: int):
        assert 2 <= core_row_id < 6
        assert 0 <= core_col_id < 8
        
        return [
            Loop(H_aie, [
                # Configure and process QDQ parameters
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreQdqPingAddr, None, CoreQdqPrmSize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
    
                # Configure and process Key data
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreKeyPingAddr, None, CoreKeySize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                
                # Run MHA preprocessing kernel
                CallKernel('run_mini_mha_preprocess', mha_qkt_sfmx_qdq_params(
                    (Sq, Dh), (Dh, Skv), sum_mode, perform_8x8_block_transpose, 
                    core_col_id, core_row_id, Sin_kv,
                    CoreTdm1Addr, CoreTdm2Addr, CoreQdqPingAddr, CoreAct1SumAddr, CoreAct2SumAddr,
                    CoreC0Addr, CoreScratchAddr, CoreQueryPingAddr, CoreKeyPingAddr, CoreValPingAddr,
                    CoreMaskPingAddr)),
                
                # Configure and process Value data
                *(
                    (ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreValPingAddr, None, CoreValSize),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),)
                    if mini_3p0 else []
                ),

                # Configure Mask data
                *(
                   (ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreMaskPingAddr, None, CoreMaskSize),
                   AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                   RelBuffer(DmaChannel(DmaDir.S2MM, 0)),)
                   if enable_attn_mask else []
                ),
                
                # Main processing loop
                Loop(Tqkv + 1 if (M_pad and Sin_q > Sq*AieRows) else Tqkv, [

                    # Configure Output buffer
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutPingAddr, None, CoreOutSize),
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    
                    # Configure Query data
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreQueryPingAddr, None, CoreQuerySize),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    
                    # Run Q*K^T GEMM operation
                    CallKernel('run_gemm_qdq_mini', smxv_qdq_params(
                        (Sq, Dh), (Dh, Skv), 0, 0, 0, 0, 1, 1, 0,
                        CoreTdm1Addr, CoreTdm2Addr, CoreQdqPingAddr, CoreAct1SumAddr, CoreAct2SumAddr,
                        CoreC0Addr, CoreScratchAddr, CoreQueryPingAddr, CoreKeyPingAddr, CoreMaskPingAddr, 
                        CoreTdm2Addr, QKt_postGemm_qdq_Offset, transposeK)),
                    
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                ]
                
                # Conditional attention mask and bias processing
                + (
                    # Case 1: Attention mask enabled, no bias
                    [CallKernel('run_bcast_add_mini', bcast_add_params(
                        Sq, Skv, Sin_kv, CoreQdqPingAddr, CoreTdm1Addr, CoreMaskPingAddr, 
                        CoreTdm2Addr, CoreTdm1Addr, 1 if(enable_attn_mask) else 0, 0, 1))]
                    if (enable_attn_mask and not Bias) 
                    
                    # Case 2: Bias enabled, no attention mask
                    else [   
                        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreTdm2Addr, None, CoreBiasSize),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        CallKernel('run_bcast_add_mini', bcast_add_params(
                            Sq, Skv, Sin_kv, CoreQdqPingAddr, CoreTdm1Addr, CoreMaskPingAddr, 
                            CoreTdm2Addr, CoreTdm1Addr, 1 if(enable_attn_mask) else 0, 0, 1)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0))
                    ]
                    if (Bias and not enable_attn_mask) 
                    
                    # Case 3: Both attention mask and bias enabled
                    else [   
                        CallKernel('run_bcast_add_mini', bcast_add_params(
                            Sq, Skv, Sin_kv, CoreQdqPingAddr, CoreTdm1Addr, CoreMaskPingAddr, 
                            CoreTdm2Addr, CoreTdm1Addr, 1 if(enable_attn_mask) else 0, 0, 1)),
    
                        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreTdm2Addr, None, CoreBiasSize),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        CallKernel('run_bcast_add_mini', bcast_add_params(
                            Sq, Skv, Sin_kv, CoreQdqPingAddr, CoreTdm1Addr, CoreMaskPingAddr, 
                            CoreTdm2Addr, CoreTdm1Addr, 1 if(enable_attn_mask) else 0, 0, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0))
                    ]
                    if (enable_attn_mask and Bias) 
                    
                    # Case 4: No attention mask, no bias - run presoftmax dequantization
                    else [CallKernel('run_presoftmax_dequant', mha_qkt_sfmx_qdq_params(
                        (Sq, Dh), (Dh, Skv), 1, 0, core_col_id, core_row_id, Sin_kv,
                        CoreTdm1Addr, CoreTdm2Addr, CoreQdqPingAddr, CoreAct1SumAddr, CoreAct2SumAddr,
                        CoreC0Addr, CoreScratchAddr, CoreQueryPingAddr, CoreKeyPingAddr, CoreValPingAddr, 
                        CoreMaskPingAddr))]
                )
                
                # Softmax and final GEMM operations
                + [
                    # Apply softmax with quantization
                    CallKernel('run_softmax_qdq', sfmx_params(
                        Sin_kv, Sq, Skv, 0, (core_col_id % 4), core_row_id, 
                        (CoreQdqPingAddr+320), (CoreQdqPingAddr+256),
                        CoreTdm1Addr, CoreTdm2Addr, 1, 1.00
                    )),
                    
                    # Run attention scores * Value GEMM operation
                 *(
                     [CallKernel('run_gemm_qdq_mini', smxv_qdq_params(
                         (Sq, Skv), (Skv, Dh), 0, 0, 0, 0, 1, 1, 0,
                         CoreTdm1Addr, CoreTdm2Addr, CoreQdqPingAddr, CoreAct1SumAddr, CoreAct2SumAddr,
                         CoreC0Addr, CoreScratchAddr, CoreTdm2Addr, CoreValPingAddr, CoreMaskPingAddr, 
                         CoreTdm1Addr, SMxV_postGemm_qdq_offset, 0))]
                     if mini_3p0 else []
                 ),
                    
                    RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                ]) # end of Tqkv loop
            ]) # end of H_aie loop
        ]
    # Generate core instructions
    core_instrs_array = []
    for col in range(AieCols):
        for row in range(2, AieRows + 2):
            core_instrs_array.append(get_core_instrs(col, row))
    
    instr_dict = {}
    for col in range(AieCols):
        for row in range(AieRows):
            instr_dict[AieTile(TileType.Core, col, row)] = core_instrs_array[col * AieRows + row]
    
    assert Tqkv == To

    def q_memtile_memory():
        return f"Yq:{Sq * AieRows} Xq:{Dh}"
    def q_mem_s2mm(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
    
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yq:0:{Sin_q % (AieRows * Sq)} Xq:0:{Dh}" if H > idx_h else f"Yq:0:0 Xq:0:0"
        else:
            return f"Yq:0:0 Xq:0:0" if H <= idx_h else (f"Yq:0:{Sq * AieRows} Xq:0:{Dh}" if Sin_q >= Sq * AieRows else f"Yq:0:{Sin_q} Xq:0:{Dh}")
    def q_mem_mm2s(row: int):
        return f"Xq:0:{Dh}:8 Yq:{Sq * row}:{Sq * (row + 1)} Xq:0:8"

    def q_shim_memory():
        return f"Yq:{ H*Sin_q + G*(Sin_kv + Sin_kv)} Xq:{Dh}"
    def q_shim_mm2s(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yq:0:0 Xq:0:0"
        qkv_size = Sin_q + Sin_kv + Sin_kv
        if HYX:
            return f"Yq:{idx_h*Sin_q}:{(idx_h+1)*Sin_q}:{Sin_q} Yq:0:{Sin_q} Xq:0:{Dh}" if H > idx_h else f"Yq:0:0 Xq:0:0"
        else: #YHX
            return f"Yq:{idx_h*qkv_size}:{idx_h*qkv_size+Sin_q}:{Sin_q} Yq:0:{Sin_q} Xq:0:{Dh}" if H > idx_h else f"Yq:0:0 Xq:0:0"
    
    def b_memtile_memory():
        return f"Yb:{Sq * AieRows} Xb:{Skv}"
    def b_mem_s2mm(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yb:0:{Sin_q%(AieRows*Sq)} Xb:0:{Skv}" if H > idx_h else f"Yb:0:0 Xb:0:0"
        return f"Yb:0:0 Xb:0:0" if H <= idx_h else f"Yb:0:{Sq * AieRows} Xb:0:{Skv}" if (Sin_q >= Sq * AieRows)   else f"Yb:0:{Sin_q} Xb:0:{Skv}"
    def b_mem_mm2s(row: int):
        return f"Xb:0:{Skv}:8 Yb:{Sq * row}:{Sq * (row + 1)} Xb:0:8"

    def b_shim_memory():
        return f"Yb:{B*Sin_q} Xb:{Skv}"
    def b_shim_mm2s(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yb:0:0 Xb:0:0"

        idx_b = B_seq[idx_h] if H > idx_h  else None
        return f"Yb:{idx_b*Sin_q}:{(idx_b+1)*Sin_q}:{Sin_q} Yb:0:{Sin_q} Xb:0:{Skv}" if H > idx_h else f"Yb:0:0 Xb:0:0"



    def o_memtile_memory():
        return f"Yo:{Sq * AieRows} Xo:{out_dim}"
    def o_mem_mm2s(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yo:0:{Sin_q % (AieRows * Sq)} Xo:0:{out_dim}" if H > idx_h else f"Yo:0:0 Xo:0:0"
        return f"Yo:0:0 Xo:0:0" if H <= idx_h else f"Yo:0:{Sq * AieRows} Xo:0:{out_dim}" if (Sin_q >= Sq * AieRows)   else f"Yo:0:{Sin_q} Xo:0:{out_dim}"
    def o_mem_s2mm(row: int):
        return f"Xo:0:{out_dim}:8 Yo:{Sq * row}:{Sq * (row + 1)} Xo:0:8"
    
    def o_shim_memory():
        return f"Yo:{H * Sin_q} Xo:{out_dim}"
    def o_shim_s2mm(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yo:0:0 Xo:0:0"
        return f"Yo:{(idx_h)*Sin_q}:{(idx_h+1)*Sin_q}:{Sin_q} Yo:0:{Sin_q} Xo:0:{out_dim}" if H > (idx_h) else f"Yo:0:0 Xo:0:0"

    def k_memtile_memory():
        return f"Xk:{Dh} Yk:{Sin_kv} Xk:8"
    def k_mem_s2mm(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yk:0:0 Xk:0:0"
        return f"Yk:0:{Sin_kv} Xk:0:{Dh}" if H > idx_h else f"Yk:0:0 Xk:0:0"
    def k_mem_mm2s():
        return f"Xk:0:{Dh}:8 Yk:0:{Skv} Xk:0:8"
    
    def k_shim_memory():
        return f"Yk:{ H*Sin_q + G*(Sin_kv + Sin_kv)} Xk:{Dh}"
    def k_shim_mm2s(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        qkv_size = Sin_q + Sin_kv + Sin_kv

        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yk:0:0 Xk:0:0"

        
        idx_g = G_seq[idx_h] if H > idx_h else None
        if HYX:
            return f"Yk:{Sin_q*H + idx_g*Sin_kv}:{Sin_q*H + (idx_g + 1) * Sin_kv} Xk:0:{Dh}" if H > idx_h else f"Yk:0:0 Xk:0:0"
        else: # TODO: YHX # G1 G2 is not clear here needs  tests data
            return f"Yk:{idx_g*qkv_size + Sin_q}:{idx_g*qkv_size + (Sin_q+Sin_kv)}:{Sin_kv} Yk:0:{Sin_kv} Xk:0:{Dh}" if H > idx_h else f"Yk:0:0 Xk:0:0"
    
    
    def v_memtile_memory():
        return f"Xv:{Dh} Yv:{Sin_kv} Xv:8"
    def v_mem_s2mm(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yv:0:0 Xv:0:0"
        return f"Yv:0:{Sin_kv} Xv:0:{Dh}" if H > idx_h else f"Yv:0:0 Xv:0:0"
    def v_mem_mm2s():
        return f"Xv:0:{Dh}:8 Yv:0:{Skv} Xv:0:8"
    
    def v_shim_memory():
        return f"Yv:{H * Sin_q + G * (Sin_kv + Sin_kv)} Xv:{Dh}"
    def v_shim_mm2s(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        qkv_size = Sin_q + Sin_kv + Sin_kv

        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yv:0:0 Xv:0:0"

        idx_g = G_seq[idx_h] if H > idx_h else None
        if HYX:
            return f"Yv:{(Sin_q*H+Sin_kv*G) + idx_g*Sin_kv}:{(Sin_q*H+Sin_kv*G) + (idx_g + 1)*Sin_kv} Xv:0:{Dh}" if H > idx_h else f"Yv:0:0 Xv:0:0"
        else: #YHX
            return f"Yv:{idx_g*qkv_size + (Sin_q+Sin_kv)}:{idx_g*qkv_size + (Sin_q+Sin_kv+Sin_kv)}:{Sin_kv} Yv:0:{Sin_kv} Xv:0:{Dh}" if H > idx_h else f"Yv:0:0 Xv:0:0"
    
    
    def m_memtile_memory():
        return f"Ym:{1} Xm:{Skv}"
    def m_mem_s2mm(h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Ym:0:0 Xm:0:0"
        return f"Ym:0:1 Xm:0:{Skv}" if Mask else f"Ym:0:0 Xm:0:0"
    def m_mem_mm2s():
        return f"Ym:0:1 Xm:0:{Skv}" if Mask else f"Ym:0:0 Xm:0:0"
    
    def m_shim_memory():
        return f"Ym:{1} Xm:{Skv}"
    def m_shim_mm2s():
        return f"Ym:0:1 Xm:0:{Skv}" if Mask else f"Ym:0:0 Xm:0:0"
    
    def qdq_memtile_memory():
        return f"Yqdq:{1} Xqdq:{QdqNodes*QdqPrm}"
    
    def qdq_mem_s2mm(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yqdq:0:0 Xqdq:0:0"
        return f"Yqdq:0:1 Xqdq:0:{QdqNodes*QdqPrm}" 
    
    def qdq_mem_mm2s():
        return f"Yqdq:0:1 Xqdq:0:{QdqNodes*QdqPrm}"
    
    def qdq_shim_memory():
        return f"Yqdq:{1} Xqdq:{H*QdqNodes*QdqPrm}"
    
    def qdq_shim_mm2s(col: int, h_phase: int):
        if h_phase == 0 and Bias:
            return f""
        h_eff = H_index(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad)
        idx_h = col + h_eff * AieCols 
        if pad_phase(h_phase, H_aie, H_phase, Bias, Tqkv, M_pad):
            return f"Yqdq:0:0 Xqdq:0:0"
        return f"Yqdq:0:1 Xqdq:{idx_h*QdqNodes*QdqPrm}:{(idx_h + 1)*QdqNodes*QdqPrm}" if H > (idx_h) else f"Yqdq:0:1 Xqdq:{(H-1)*QdqNodes*QdqPrm}:{(H)*QdqNodes*QdqPrm}"
    
    
    
    # Memtile data transfers
    mem_qt = [
        DataTransfer(
            memtile_phase_pattern(H_aie, H_phase, Tqkv, Bias),
            AieTile(TileType.Memtile, col), [MemtileQkvPingAddr], MemtileQkvSize,
            [pack_reconfig(
                memtile_dma(col, DmaDir.S2MM, 1),
                H_phase,
                [(h_phase,
                q_memtile_memory(),
                q_mem_s2mm(col, h_phase)) for h_phase in range(H_phase)],
                bits_per_block=IfmBytes * bits_per_byte)],
            [pack_reconfig(
                memtile_dma(col, DmaDir.MM2S, row),
                H_phase,
                [(h_phase,
                q_memtile_memory(),
                q_mem_mm2s(row)) for h_phase in range(H_phase)],
                bits_per_block=IfmBytes * bits_per_byte)
                for row in range(AieRows)],
                sync_strategy= SyncStrategy.Parallel_1_to_N
        )
        for col in range(AieCols)
    ]
    
    mem_qb_t = [
        DataTransfer(
            memtile_phase_pattern(H_aie, H_phase, Tqkv, Bias),
            AieTile(TileType.Memtile, col), [MemtileQkvPingAddr], MemtileQkvSize,
            [
            pack_reconfig(
                memtile_dma(col, DmaDir.S2MM, 0),
                H_phase,
                [(h_phase,
                b_memtile_memory(),
                b_mem_s2mm(col, h_phase)) for h_phase in range(H_phase)],
                bits_per_block=IfmBytes * bits_per_byte,
                buffer_offset=MemtileQuerySize + MemtileKeySize + MemtileValSize)
                ],
            [
                
                    pack_reconfig(
                    memtile_dma(col, DmaDir.MM2S, row),
                    H_phase,
                    [(h_phase,
                    b_memtile_memory(),
                    b_mem_mm2s(row)) for h_phase in range(H_phase)],
                    bits_per_block=IfmBytes * bits_per_byte,
                    buffer_offset=MemtileQuerySize + MemtileKeySize + MemtileValSize)
                    for row in range(AieRows)
            ],
            sync_strategy= SyncStrategy.Parallel_1_to_N
        )
        for col in range(AieCols)
    ]
    
    
    mem_ot = [
        DataTransfer(
            memtile_phase_pattern(H_aie, H_phase, To, Bias),
            AieTile(TileType.Memtile, col), [MemtileOutPingAddr], MemtileOutSize,
            [pack_reconfig(
                memtile_dma(col, DmaDir.S2MM, row + 2),
                H_phase,
                [(h_phase,
                o_memtile_memory(),
                o_mem_s2mm(row)) for h_phase in range(H_phase)],
                bits_per_block=OutBytes * bits_per_byte)
                for row in range(AieRows)],
            [pack_reconfig(
                memtile_dma(col, DmaDir.MM2S, 5),
                H_phase,
                [(h_phase,
                o_memtile_memory(),
                o_mem_mm2s(col, h_phase)) for h_phase in range(H_phase)],
                bits_per_block=OutBytes * bits_per_byte)],
                sync_strategy= SyncStrategy.Parallel_N_to_1
        )
        for col in range(AieCols)
    ]

    mem_kvm_t = [
        DataTransfer(
            shim_phase_pattern(H_aie, H_phase, Tqkv, Bias),
            AieTile(TileType.Memtile, col), [MemtileQkvPingAddr], MemtileQkvSize,
            # Input transfers (S2MM)
            [
                pack_reconfig(
                    memtile_dma(col, DmaDir.S2MM, 0),
                    H_phase,
                    [(h_phase,
                    k_memtile_memory(),
                    k_mem_s2mm(col, h_phase)) for h_phase in range(H_phase)],
                    bits_per_block=IfmBytes * bits_per_byte,
                    buffer_offset=MemtileQuerySize
                ),
            ] + 
            ([
                pack_reconfig(
                    memtile_dma(col, DmaDir.S2MM, 0),
                    H_phase,
                    [(h_phase,
                    v_memtile_memory(),
                    v_mem_s2mm(col, h_phase)) for h_phase in range(H_phase)],
                    bits_per_block=IfmBytes * bits_per_byte,
                    buffer_offset=MemtileQuerySize + MemtileKeySize
                ),
            ] if mini_3p0 else []) +
            ([
                pack_reconfig(
                    memtile_dma(col, DmaDir.S2MM, 0),
                    H_phase,
                    [(h_phase,
                    m_memtile_memory(),
                    m_mem_s2mm(h_phase)) for h_phase in range(H_phase)],
                    bits_per_block=IfmBytes * bits_per_byte,
                    buffer_offset=MemtileQuerySize + MemtileKeySize + MemtileValSize + MemtileBiasSize
                ),
            ] if enable_attn_mask else []),
            # Output transfers (MM2S)
            [
                func for row in range(AieRows) for func in (
                    [
                        pack_reconfig(
                            memtile_dma(col, DmaDir.MM2S, row),
                            H_phase,
                            [(h_phase,
                            k_memtile_memory(),
                            k_mem_mm2s()) for h_phase in range(H_phase)],
                            bits_per_block=IfmBytes * bits_per_byte,
                            buffer_offset=MemtileQuerySize
                        ),
                    ] +
                    ([
                        pack_reconfig(
                            memtile_dma(col, DmaDir.MM2S, row),
                            H_phase,
                            [(h_phase,
                            v_memtile_memory(),
                            v_mem_mm2s()) for h_phase in range(H_phase)],
                            bits_per_block=IfmBytes * bits_per_byte,    
                            buffer_offset=MemtileQuerySize + MemtileKeySize
                        ),
                    ] if mini_3p0 else []) +
                    ([
                        pack_reconfig(
                            memtile_dma(col, DmaDir.MM2S, row),
                            H_phase,
                            [(h_phase,
                            m_memtile_memory(),
                            m_mem_mm2s()) for h_phase in range(H_phase)],
                            bits_per_block=IfmBytes * bits_per_byte,
                            buffer_offset=MemtileQuerySize + MemtileKeySize + MemtileValSize + MemtileBiasSize
                        ),
                    ] if enable_attn_mask else [])
                )
            ],
            sync_strategy=SyncStrategy.Parallel_1_to_N
        )
        for col in range(AieCols)
    ]
    
    mem_qdq_t = [
        DataTransfer(
            shim_phase_pattern(H_aie, H_phase, Tqkv, Bias),
            AieTile(TileType.Memtile, col), [MemtileQdqPingAddr], MemtileQdqPrmSize,
            [
                pack_reconfig(
                    memtile_dma(col, DmaDir.S2MM, 0),
                    H_phase,
                    [(h_phase,
                    qdq_memtile_memory(),
                    qdq_mem_s2mm(col, h_phase)) for h_phase in range(H_phase)],
                    bits_per_block= QdqPrmBytes * bits_per_byte),
            ],
            [
                pack_reconfig(
                    memtile_dma(col, DmaDir.MM2S, row),
                    H_phase,
                    [(h_phase,
                    qdq_memtile_memory(),
                    qdq_mem_mm2s()) for h_phase in range(H_phase)],
                    bits_per_block= QdqPrmBytes * bits_per_byte)
            for row in range(AieRows)],
            sync_strategy= SyncStrategy.Parallel_1_to_N if not Bias else SyncStrategy.Default
        )
        for col in range(AieCols)
    ]

    memtile_transfers = [
        DataTransfer(
            [1] + [0] * (H_phase-1),
            AieTile(TileType.Memtile, col, 0), [MemtilePrmPingAddr], MemtilePrmSize,
            [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.S2MM, 1)), MemtilePrmWords)],
            [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, row)), CorePrmWords)
             for row in range(AieRows)]
        ) for col in range(AieCols)
    ] + mem_qdq_t + mem_kvm_t + mem_qt + (mem_qb_t if Bias else [])  + mem_ot
    
    # Shim data transfers
    shim_kvm_t = [
        DataTransfer(
            shim_phase_pattern(H_aie, H_phase, Tqkv, Bias),
            AieTile(TileType.Shim, col, 0), [1], ShimQkvSize,
            [],
            [
                pack_reconfig(
                    AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 0)),
                    H_phase,
                    [(h_phase,
                    k_shim_memory(),
                    k_shim_mm2s(col, h_phase)) for h_phase in range(H_phase)],
                    bits_per_block=IfmBytes * bits_per_byte),
            ] +
            ([
                pack_reconfig(
                    AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 0)),
                    H_phase,
                    [(h_phase,
                    v_shim_memory(),
                    v_shim_mm2s(col, h_phase)) for h_phase in range(H_phase)],
                    bits_per_block=IfmBytes * bits_per_byte),
            ] if mini_3p0 else []) +
            ([
                pack_reconfig(
                    AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 0)),
                    H_phase,
                    [(h_phase,
                    m_shim_memory(),
                    m_shim_mm2s()) for h_phase in range(H_phase)],
                    bits_per_block=IfmBytes * bits_per_byte,
                    buffer_offset=(Sin_q*H + (Sin_kv + Sin_kv*mini_3p0)*G) * Dh * IfmBytes
                )
            ] if enable_attn_mask else [])
        )
        for col in range(AieCols)
    ]


    shim_q_t = [
        DataTransfer(
            shim_phase_pattern(H_aie, H_phase, Tqkv, Bias),
            AieTile(TileType.Shim, col, 0), [1], ShimQkvSize,  
            [],
            [pack_reconfig(
                AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 1)),
                H_phase,
                [(h_phase,
                q_shim_memory(),
                q_shim_mm2s(col, h_phase)) for h_phase in range(H_phase)],
                bits_per_block=IfmBytes * bits_per_byte)
            ]
        )    
        for col in range(AieCols)
    ]

    shim_qb_t = [
        DataTransfer(
            shim_phase_pattern(H_aie, H_phase, Tqkv, Bias),
            AieTile(TileType.Shim, col, 0), [1], ShimQkvSize,  
            [],
            [
            pack_reconfig(
                AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 0)),
                H_phase,
                [(h_phase,
                b_shim_memory(),
                b_shim_mm2s(col, h_phase)) for h_phase in range(H_phase)],
                bits_per_block=IfmBytes * bits_per_byte,
                buffer_offset=(Sin_q*H + (Sin_kv + Sin_kv)*G) * Dh * IfmBytes)
            ]
        )    
        for col in range(AieCols)
    ]

    shim_ot_te = [
        DataTransfer(
            shim_phase_pattern(H_aie, H_phase, Tqkv, Bias),
            AieTile(TileType.Shim, col, 0), [0], ShimQkvSize,  
            [pack_reconfig(
                AieDma(Shimtile(col), DmaChannel(DmaDir.S2MM, 0)),
                H_phase,
                [(h_phase,
                o_shim_memory(),
                o_shim_s2mm(col, h_phase)) for h_phase in range(H_phase)],
                bits_per_block=IfmBytes * bits_per_byte)
            ],
            []
        )
        for col in range(AieCols)
    ]
    
    shim_transfers = [
        
        DataTransfer(
            [1] + [0]*(H_phase-1),
            AieTile(TileType.Shim, col, 0), [3], ShimPrmSize,
            [],
            [TransferParams(AieDma(AieTile(TileType.Shim, col, 0), DmaChannel(DmaDir.MM2S, 1)), ShimPrmWords)]
        ) for col in range(AieCols)
    
    ] + [
        
        DataTransfer(
            shim_phase_pattern(H_aie, H_phase, Tqkv, Bias),
            AieTile(TileType.Shim, col, 0), [2], ShimQdqPrmSize,
            [],
            [pack_reconfig(
                AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 0)),
                H_phase,
                [(h_phase,
                qdq_shim_memory(),
                qdq_shim_mm2s(col, h_phase)) for h_phase in range(H_phase)],
                bits_per_block=QdqPrmBytes * bits_per_byte)
            ]
        ) for col in range(AieCols)

    ] + shim_kvm_t  + shim_q_t +  (shim_qb_t if Bias else [])  + shim_ot_te
    
    shape = OverlayShape(AieCols, AieRows)

    dma_connections = overlay_8x4_dma_connections() if AieCols == 8 else   overlay_4x4_dma_connections()
    
    CoreStackAddr = 60352
    run_layer_compilation(
        shape,
        kernel_names,
        kernel_includes,
        instr_dict, 
        memtile_transfers,
        shim_transfers,
        dma_connections=dma_connections,
        back_end=CodeBackend,
        core_stack_addr=CoreStackAddr,
        param_channel_id=0,
        layer_file='dma.hpp'
    )
    
