import sys
sys.path.append('.')
sys.path.append('../../../dmacompiler')
sys.path.append('../../../kernels')

from mha.mini_mha.overlay import sfmx_params, mha_qkt_sfmx_qdq_params, smxv_qdq_params, bcast_add_params
from typing import List
from functools import partial

from dataflow.dataflow_common import overlay_stack_addr, overlay_8x4_dma_connections, overlay_8x4_core_stream_bdcast, overlay_4x4_core_stream_bdcast, overlay_4x4_dma_connections
from dmacompiler import OverlayShape, DataTransfer, TransferParams, BackEnd, \
    DmaChannel, DmaDir, AieTile, TileType, CascDir, \
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    generate_transfer_params,\
    memtile_dma, shim_dma, \
    TransferParams,\
    run_layer_compilation, \
    pack_reconfig_transfers

from dmacompiler import set_dev_gen, DevGen, config
config.ENABLE_BUSY_POLL = True
set_dev_gen(DevGen.Aie2p)

def ceildiv(x: int, d: int) -> int:
    return -(x // -d)

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d

create_transfer_params = partial(generate_transfer_params, bits_per_block=2*8)
bd_reconfig = partial(pack_reconfig_transfers, bits_per_elem=2*8)

def generate_dataflow_2p1_mha(
    H: int,
    Sin_q: int,
    Sin_kv: int,
    Sin_dh: int,
    Sq_subv: int,
    aie_overlay_cols: int,  
    CodeBackend: BackEnd, 
    kernel_names: List[str],
    kernel_includes: List[str],     
    disable_fast_pm : bool,  
    enable_mask_vector = False,
    enable_bias_matrix = False,
    debug_print = False  
):
    config.ENABLE_FAST_PM = not disable_fast_pm
    enable_bias_matrix = True 
    AieCols = aie_overlay_cols
    AieRows = 4
    Num4x4  = 2 if AieCols == 8 else 1
    NumAieCompCols = AieCols//Num4x4

    QdqNodes    = 6
    QdqPrm      = 16
    QdqPrmBytes = 4
    CoreQdqPrmSize = (QdqNodes * QdqPrm * QdqPrmBytes)
    bytes_per_word = 4
    CoreAlignSize  = 64
    CorePrmSize    = 1024

    IfmBytes = 2   
    OutBytes = 2  
    TdmBytes = 4
    C0Bytes  = 8

    assert Sq_subv >= 16 and (Sq_subv % 16 == 0)    
    Sq = Sq_subv 
    
    #################################################################################################
    ################### Below for Deriving General N Spatial split  #################################
    #################################################################################################
    
    N    = Sin_kv  ## Out : M x N
    Nout = iceil(Sin_kv, 8)  ## DD-Padded for Output M x N is actually : M x Nout, Nout is multiple of 8
    
    Ngran  = 8
    N_next = iceil(N, AieRows * NumAieCompCols * Ngran)
    Skv =  N_next // (AieRows * NumAieCompCols)   # K is split across 4x4 
    
    N_per_col = N_next // NumAieCompCols 
    
    Nsubv  = N_next //  (AieRows * NumAieCompCols)
    Nlayer_core = [0] * (AieRows * NumAieCompCols)
    
    remain = N
    for core_id in range(16):
        Nlayer_core[core_id] = Nsubv if (remain-Nsubv >= 0) else remain % Nsubv
        remain -= Nlayer_core[core_id]
    
    assert(N == sum(Nlayer_core))

    #Nlayer_col  = [0] * NumAieCompCols
    #for col_id in range(4):
    #    Nlayer_col[col_id] = sum(Nlayer_core[col_id*4:(col_id+1)*4])
    #assert(N == sum(Nlayer_col))
    
    #################################################################################################
    #################################################################################################
    #################################################################################################
    
    Dh = Sin_dh
    assert Skv * AieRows * NumAieCompCols == N_next #Sin_kv
    assert Sq == 16
    Tm = (iceil(Sin_q, Num4x4 * Sq)// Sq) // Num4x4
    To = Tm 
    
    Mask = 1 if(enable_mask_vector) else 0
    Bias = 1 if(enable_bias_matrix) else 0

    if debug_print:
        print(f"Sin_q: {Sin_q} Sin_kv: {Sin_kv}, Skv: {Skv}, Sq: {Sq}, Dh: {Dh}, N_next: {N_next},")
        print("Sq:",Sq)
        print("Skv * AieRows * NumAieCompCols:",Skv * AieRows * NumAieCompCols)
        print("N_next:",N_next)
        print(f"Tm: {Tm}, To: {To}")

    CoreQrySize     = (Sq  * Dh   * IfmBytes) * Num4x4
    CoreKeySize     = (Skv * Dh   * IfmBytes)
    CoreMaskSize    = (1   * Skv  * IfmBytes)
    CoreOutSize     = (Sq  * Skv  * OutBytes)
    CoreTdm1Size    = (Sq  * Skv  * OutBytes) 
    CoreTdm2Size    = (Sq  * Skv  * OutBytes) 
    CoreBiasSize    = Bias * Sq * Skv * IfmBytes
    CoreAct1SumSize = iceil(TdmBytes*Sq, 512)
    CoreAct2SumSize = iceil(Skv * TdmBytes, 512)
    CoreC0Size = Skv * C0Bytes
    
    CoreQryPingAddr    = 0                                                   
    CoreKeyPingAddr    = iceil(CoreQryPingAddr + CoreQrySize, CoreAlignSize)
    CoreTdm1Addr       = iceil(CoreKeyPingAddr + CoreKeySize, CoreAlignSize)                        
    CoreTdm2Addr       = iceil(CoreTdm1Addr + CoreTdm1Size, CoreAlignSize)
    CoreOutPingAddr    = CoreTdm2Addr                                                                
    
    CoreQdqPingAddr    = iceil(CoreTdm2Addr + CoreTdm2Size, CoreAlignSize)           
    assert CoreOutPingAddr + CoreOutSize <= CoreQdqPingAddr               
    CoreAct1SumAddr    = iceil(CoreQdqPingAddr + CoreQdqPrmSize, CoreAlignSize)           
    CoreAct2SumAddr    = iceil(CoreAct1SumAddr + CoreAct1SumSize, CoreAlignSize)                     
    CoreC0Addr         = iceil(CoreAct2SumAddr + CoreAct2SumSize, CoreAlignSize)
    CoreMaskPingAddr   = iceil(CoreC0Addr + CoreC0Size, CoreAlignSize)      
    CoreScratchAddr    = iceil(CoreMaskPingAddr + CoreMaskSize,CoreAlignSize)    
    
    MemtilePrmSize    = CorePrmSize * (AieRows)
    MemtileQdqPrmSize = CoreQdqPrmSize
    MemtileQrySize    = CoreQrySize
    MemtileKeySize    = CoreKeySize * (AieRows)
    MemtileMaskSize   = CoreMaskSize* (AieRows)
    MemtileBiasSize   = CoreBiasSize* (AieRows)
    MemtileOutSize    = CoreOutSize * (AieRows) 
    MemtileQkvSize    = MemtileQrySize + MemtileKeySize + MemtileMaskSize + MemtileBiasSize
    
    if debug_print:
        print(f"CoreQryPingAddr: {CoreQryPingAddr}, CoreKeyPingAddr: {CoreKeyPingAddr}, CoreTdm1Addr: {CoreTdm1Addr}, CoreTdm2Addr: {CoreTdm2Addr}, CoreOutPingAddr: {CoreOutPingAddr}, CoreQdqPingAddr: {CoreQdqPingAddr}, CoreAct1SumAddr: {CoreAct1SumAddr}, CoreAct2SumAddr: {CoreAct2SumAddr}, CoreC0Addr: {CoreC0Addr}, CoreScratchAddr: {CoreScratchAddr}")
        print(f"MemtileQrySize: {MemtileQrySize}, MemtileKeySize: {MemtileKeySize}, MemtileOutSize: {MemtileOutSize}")

    MemtilePrmPingAddr   = 0   
    MemtileQkvPingAddr   = MemtilePrmPingAddr + MemtilePrmSize
    MemtileQkvPongAddr   = MemtileQkvPingAddr + MemtileQkvSize
    MemtileOutPingAddr   = MemtileQkvPongAddr + MemtileQkvSize
    MemtileOutPongAddr   = MemtileOutPingAddr + MemtileOutSize
    MemtileQdqPingAddr   = MemtileOutPongAddr + MemtileOutSize

    ShimQdqPrmSize = MemtileQdqPrmSize
    
    CorePrmWords       = CorePrmSize // bytes_per_word
    MemtilePrmWords    = MemtilePrmSize // bytes_per_word
    MemtileQdqPrmWords = MemtileQdqPrmSize // bytes_per_word
    ShimQdqPrmWords = ShimQdqPrmSize // bytes_per_word

    K_is_tranposed_on_DDR = False
    transposeK = 0 if(K_is_tranposed_on_DDR) else 1   ## traversal pattern. 0 is default
    sum_mode = 1 if(K_is_tranposed_on_DDR) else 0     ## use mha_mode field now. 0 --> OP_SUM , 1 --> OP_SUM_T
    perform_8x8_block_transpose = 0 if(K_is_tranposed_on_DDR) else 1  

    def Memtile(col: int):
        return AieTile(TileType.Memtile, col, 0)
    
    def Shimtile(col: int):
        return AieTile(TileType.Shim, col, 0)

    ##########################################################################################
    ###   Start of Core-Instruction dDefinition :
    ##########################################################################################
    dummy_var = 0
    def get_core_instrs(core_col_id:int, core_row_id:int):
        assert 0 <= core_row_id < 4
        assert 0 <= core_col_id < 8

        assert(Mask==0 or Mask==1)
        assert(Bias==0 or Bias==1)
        match (Bias<<1)+Mask:
            case 0: # Non of Mask nor Bias existed
                KernelFunc = [ CallKernel('run_presoftmax_dequant', mha_qkt_sfmx_qdq_params( 
                        (Sq, Dh), (Dh, Skv), 1, 0, core_col_id, core_row_id, dummy_var, CoreTdm1Addr, CoreTdm2Addr, CoreQdqPingAddr, 
                        CoreAct1SumAddr, CoreAct2SumAddr, CoreC0Addr, CoreScratchAddr, CoreQryPingAddr, CoreKeyPingAddr, 0, 0)), ]
            case 1: # Mask Exist, Bias non existed
                KernelFunc = [ CallKernel('run_bcast_add_mini', bcast_add_params(Sq, Skv, dummy_var, CoreQdqPingAddr, CoreTdm1Addr, 
                                    CoreMaskPingAddr, CoreTdm2Addr, CoreTdm1Addr, Mask, Bias, 1)) ]
            case 2: # Bias Exist, Mask non existed
                KernelFunc = [ ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreTdm2Addr, None, CoreBiasSize),
                               AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                               CallKernel('run_bcast_add_mini', bcast_add_params(Sq, Skv, dummy_var, CoreQdqPingAddr, CoreTdm1Addr, 
                                    CoreMaskPingAddr, CoreTdm2Addr, CoreTdm1Addr,    0, Bias, 1)),
                               RelBuffer(DmaChannel(DmaDir.S2MM, 0)) ]
            case 3: # Both Mask and Bias existed
                KernelFunc = [ CallKernel('run_bcast_add_mini', bcast_add_params(Sq, Skv, dummy_var, CoreQdqPingAddr, CoreTdm1Addr, 
                                    CoreMaskPingAddr, CoreTdm2Addr, CoreTdm1Addr,    1,   0, 1)),
                               ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreTdm2Addr, None, CoreBiasSize),
                               AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                               CallKernel('run_bcast_add_mini', bcast_add_params(Sq, Skv, dummy_var, CoreQdqPingAddr, CoreTdm1Addr, 
                                    CoreMaskPingAddr, CoreTdm2Addr, CoreTdm1Addr,    0,   1, 0)),
                               RelBuffer(DmaChannel(DmaDir.S2MM, 0)) ]
                  
        true_num_cols = Nlayer_core[(core_col_id % NumAieCompCols) * NumAieCompCols + core_row_id]

        return  [
            
        Loop(H, [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreQdqPingAddr, None, CoreQdqPrmSize),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
            
            ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutPingAddr, None, CoreOutSize),  # output is Pinned to TDM2 buffer
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreKeyPingAddr, None, CoreKeySize),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),      # key  unicast
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ]
            +
            ([
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreMaskPingAddr, None, CoreMaskSize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),  # mask unicast
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ] if(Mask==1) else [])
            +
            [
            CallKernel('run_mini_mha_preprocess', mha_qkt_sfmx_qdq_params(
                    (Sq, Dh), (Dh, Skv), sum_mode, perform_8x8_block_transpose, core_col_id, core_row_id, dummy_var,
                    CoreTdm1Addr, CoreTdm2Addr, CoreQdqPingAddr, CoreAct1SumAddr, CoreAct2SumAddr,
                    CoreC0Addr, CoreScratchAddr, CoreQryPingAddr, CoreKeyPingAddr, 0, 0)),
            Loop(Tm , [    # loop over M, M//Msubv times
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),  # output
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreQryPingAddr, None, CoreQrySize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),  # query broadcast
                CallKernel('run_gemm_qdq_mini', smxv_qdq_params(
                            (Sq, Dh), (Dh, Skv), 1, 0, core_col_id, core_row_id, 1, 1, dummy_var,
                            CoreTdm1Addr, CoreTdm2Addr, CoreQdqPingAddr, CoreAct1SumAddr, CoreAct2SumAddr,
                            CoreC0Addr, CoreScratchAddr, CoreQryPingAddr, CoreKeyPingAddr, CoreMaskPingAddr, CoreTdm2Addr, 128, transposeK)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                ] + 
                KernelFunc
                + [
                CallKernel('run_softmax_qdq', sfmx_params(
                            true_num_cols, Sq, Skv, 1, (core_col_id % 4), core_row_id, (CoreQdqPingAddr+320), (CoreQdqPingAddr+256), \
                            CoreTdm1Addr, CoreTdm2Addr, 1, 1.00
                        )),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0))
            ]),
        ]),
        
        ]
    
    shape = OverlayShape(AieCols, AieRows)

    core_instrs_array = []
    for col in range(AieCols):
        for row in range(AieRows):
            core_instrs_array.append(get_core_instrs(col, row))
    
    instr_dict = {}
    for col in range(AieCols):
        for row in range(AieRows):
            instr_dict[AieTile(TileType.Core, col, row)] = core_instrs_array[col*AieRows+row]
    
    ##########################################################################################
    ###   Start of Mem-tile transfer :
    ##########################################################################################
    
    M    = Sin_q   ## Qry : M x P
    P    = Dh      ## Key : N x P
    Pgran= 8 
    
    Md = M % (Num4x4 * Sq)
    Mp = (M-Md) if(Md != 0) else (M-(Num4x4*Sq))    ## if M <= Num4x4 * Sq , Mp == 0
    Mr = M - Mp
    H_phases = 1 if(Md == 0) else 2   ## Number of phases per head
    Mp = M if(Md == 0) else Mp

    ## In case M % (Msubv*Num4x4) != 0, there is remainder of M left to be streamed. del_M calculate the remainder in this case
    def del_M(col : int):
        out = min(Mr, Sq) if (col <= 3) else max(Mr - Sq, 0)
        return out

    def N_off_min(col : int, input = True):
        col_mod_4 = col % NumAieCompCols
        N_cap = N if(input) else Nout
        return min(col_mod_4 * N_per_col, N_cap) 
        
    def N_off_max(col : int, input = True):
        col_mod_4 = col % NumAieCompCols
        N_cap = N if(input) else Nout
        return min((col_mod_4+1) * N_per_col, N_cap)
    
    Q_array_split = 2 
    NumPhases = H * H_phases  ## Total number of phases = Number of heads * Number of phases per head
    Qry_memtile_buffer  = f"Xq:{Dh} Yq:{Q_array_split*Sq} Xq:{Pgran}"
    Bias_memtile_buffer = f"Yb:{Sq} Xb:{AieRows*Skv}"
    Key_memtile_buffer  = f"Yk:{AieRows*Skv} Xk:{Dh}"
    Mask_memtile_buffer = f"Xm:{Nout}"
    Out_memtile_buffer  = f"Yo:{Sq} Xo:{AieRows*Skv}"
    
    def mem_Qry_s2mm(phase : int, col : int) -> str : 
        if((phase%H_phases)==0):
            return f"Yq:0:{2*Sq} Xq:0:{Dh}"
        else:
            return f"Yq:0:{Mr} Xq:0:{Dh}"
            #return f"Yq:0:{del_M(col)} Xq:0:{Dh}"

    def mem_Qry_mm2s() -> str:
        return f"Yq:{0}:{2*Sq}:{Sq} Xq:0:{Dh}:{Pgran} Yq:0:{Sq} Xq:0:{Pgran}"
    
    def mem_Qry_repeatCount() -> List:
        return [Tm-1,1]*H if (H_phases == 2) else [Tm]*H

    def mem_Mask_s2mm(phase : int, col : int) -> str : 
        if((phase%H_phases)==0):
            return f"Xm:0:{N_off_max(col, False)-N_off_min(col, False)}"  ## Mask need to use False as it is 1 x N, N inner
        else:
            return f"Xm:0:0"
        
    def mem_Mask_mm2s(row : int) -> str:
        return f"Xm:{row*Skv}:{(row+1)*Skv}"

    def mem_Key_s2mm(phase : int, col : int) -> str : 
        if((phase%H_phases)==0):
            return f"Yk:0:{N_off_max(col)-N_off_min(col)} Xk:0:{Dh}"
        else:
            return f"Yk:0:0 Xk:0:0"
        
    def mem_Key_mm2s(row : int) -> str:
        return f"Xk:0:{Dh}:{Pgran} Yk:{row*Skv}:{(row+1)*Skv} Xk:0:{Pgran}"
    
    def mem_KeyMask_repeatCount() -> List:
        return [1,0]*H if (H_phases == 2) else [1]*H

    def mem_Bias_s2mm(phase : int, col : int) -> str : 
        if((phase%H_phases)==0):
            return f"Yb:0:{    Sq    } Xb:0:{N_off_max(col, False)-N_off_min(col, False)}"
        else:
            return f"Yb:0:{del_M(col)} Xb:0:{N_off_max(col, False)-N_off_min(col, False)}"
        
    def mem_Bias_mm2s(row : int) -> str:
        return f"Xb:{row*Skv}:{(row+1)*Skv}:{Ngran} Yb:{0}:{Sq} Xb:0:{Ngran}"
    
    def mem_Bias_repeatCount() -> List:
        return [Tm-1,1]*H if (H_phases == 2) else [Tm]*H
        
    def mem_Out_mm2s(phase : int, col : int) -> str : 
        del_Nout = N_off_max(col, False) - N_off_min(col, False)
        if((phase%H_phases)==0):
            return f"Yo:0:{    Sq    } Xo:0:{del_Nout}"
        else:
            return f"Yo:0:{del_M(col)} Xo:0:{del_Nout}"
        
    def mem_Out_s2mm(row : int) -> str:
        return f"Xo:{row*Skv}:{(row+1)*Skv}:{Ngran} Yo:0:{Sq} Xo:0:{Ngran}"

    def mem_Out_repeatCount() -> List:
        return [To-1,1]*H if (H_phases == 2) else [To]*H
    
    def mem_Qdq_repeatCount() -> List:
        return [1,0]*H  if (H_phases == 2) else [1]*H
    
    def mem_LayerPrm_repeatCount() -> List:
        return [1]+[0]*(H*H_phases-1)

            
    mem_bt = [ 
            DataTransfer(
                mem_Bias_repeatCount(),
                AieTile(TileType.Memtile, col), [MemtileQkvPingAddr], MemtileQkvSize,
                ([
                    bd_reconfig(memtile_dma(col, DmaDir.S2MM,  1 ),  
                        [Bias_memtile_buffer]*NumPhases,
                        [mem_Bias_s2mm(p, col) for p in range(NumPhases)],    
                        buffer_offset=[MemtileQrySize+MemtileKeySize+MemtileMaskSize]*NumPhases
                    )
                ])
                ,
                ([
                    bd_reconfig(memtile_dma(col, DmaDir.MM2S, row),  
                        [Bias_memtile_buffer]*NumPhases,
                        [mem_Bias_mm2s(row) for p in range(NumPhases)],   
                        buffer_offset=[MemtileQrySize+MemtileKeySize+MemtileMaskSize]*NumPhases,
                    ) for row in range(AieRows)   
                ])
            ) for col in range(AieCols)
    ]

    mem_qt = [ 
            DataTransfer(
                mem_Qry_repeatCount(),
                AieTile(TileType.Memtile, col), [MemtileQkvPingAddr], MemtileQkvSize,
                [
                    bd_reconfig(memtile_dma(col, DmaDir.S2MM,  0 ),  
                        [Qry_memtile_buffer]*NumPhases,
                        [mem_Qry_s2mm(p, col) for p in range(NumPhases)],
                        buffer_offset=[0]*NumPhases,
                    )
                ]
                ,
                [
                    bd_reconfig(memtile_dma(col, DmaDir.MM2S, 4 ),  
                        [Qry_memtile_buffer]*NumPhases,
                        [mem_Qry_mm2s() for p in range(NumPhases)],
                        buffer_offset=[0]*NumPhases,
                    ) 
                ]
                
            ) for col in range(0,AieCols, Num4x4)
    ]

    mem_kmt = [     
            DataTransfer(    
                mem_KeyMask_repeatCount(),
                AieTile(TileType.Memtile, col), [MemtileQkvPingAddr], MemtileQkvSize,
                [
                    bd_reconfig(memtile_dma(col, DmaDir.S2MM,  1 ),  
                        [Key_memtile_buffer]* NumPhases,       
                        [mem_Key_s2mm(p, col) for p in range(NumPhases)],          
                        buffer_offset=[MemtileQrySize]*NumPhases,
                    ),
                ]
                +
                ([
                    bd_reconfig(memtile_dma(col, DmaDir.S2MM,  1 ),  
                        [Mask_memtile_buffer]* NumPhases,
                        [mem_Mask_s2mm(p, col) for p in range(NumPhases)],
                        buffer_offset=[MemtileQrySize+MemtileKeySize]*NumPhases,
                    )
                ] if (Mask==1) else [])
                ,  
                [
                    bd_reconfig(memtile_dma(col, DmaDir.MM2S, row),  
                        [Key_memtile_buffer]* NumPhases,
                        [mem_Key_mm2s(row) for p in range(NumPhases)],
                        buffer_offset=[MemtileQrySize]*NumPhases,
                    ) for row in range(AieRows) 
                ]
                +
                ([
                    bd_reconfig(memtile_dma(col, DmaDir.MM2S, row),  
                        [Mask_memtile_buffer]* NumPhases,
                        [mem_Mask_mm2s(row) for p in range(NumPhases)],
                        buffer_offset=[MemtileQrySize+MemtileKeySize]*NumPhases,
                    ) for row in range(AieRows)   
                ] if (Mask==1) else [])
            ) for col in range(AieCols)
    ]

    mem_ot = [ 
            DataTransfer(   
                mem_Out_repeatCount(),
                AieTile(TileType.Memtile, col), [MemtileOutPingAddr], MemtileOutSize,
                [
                    bd_reconfig(memtile_dma(col, DmaDir.S2MM, row+2),  
                        [Out_memtile_buffer]*NumPhases, 
                        [mem_Out_s2mm(row) for p in range(NumPhases)],
                        buffer_offset=[0]*NumPhases
                    ) for row in range(AieRows)
                ],
                [
                    bd_reconfig(memtile_dma(col, DmaDir.MM2S,   5  ),  
                        [Out_memtile_buffer]*NumPhases, 
                        [mem_Out_mm2s(p, col) for p in range(NumPhases)],
                        buffer_offset=[0]*NumPhases
                    )
                ]
            ) for col in range(AieCols)
    ]
    
    mem_qdq_t = [
        DataTransfer(
            mem_Qdq_repeatCount(), Memtile(col), [MemtileQdqPingAddr], MemtileQdqPrmSize,
            [TransferParams(memtile_dma(col, DmaDir.S2MM, 0),  MemtileQdqPrmWords)],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, 4),  MemtileQdqPrmWords)]
            ) for col in range(0,AieCols,Num4x4)
    ]
    
    memtile_transfers = [
        DataTransfer(
            mem_LayerPrm_repeatCount(), Memtile(col),  [MemtilePrmPingAddr], MemtilePrmSize,
            [TransferParams(memtile_dma(col, DmaDir.S2MM,  1 ), MemtilePrmWords)],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, row), CorePrmWords, offset=(row * CorePrmWords)) for row in range(AieRows)]  
        ) for col in range(AieCols)
    ] + mem_qdq_t + mem_kmt + mem_qt + (mem_bt if (Bias==1) else []) + mem_ot
    
    ##########################################################################################
    ###   Start of Shim-tile transfer :
    ##########################################################################################
    
    ## offset on axis M, function of column:col and63 phase:h
    def M_off( head_idx: int, col: int):
        return head_idx*M + (0 if col < NumAieCompCols else Sq)
        
    shim_bo_Qry  = f"Yq:{H*(M+N)} Xq:{P}"
    shim_bo_Key  = f"Yk:{H*(M+N)} Xk:{P}"   
    shim_bo_Mask = f"Ym:{H} Xm:{Nout}"         ## Mask : Hx1xN
    shim_bo_Bias = f"Yb:{H*M} Xb:{Nout}"   
    shim_bo_Out  = f"Yo:{H*M} Xo:{Nout}"
        
    def shim_Qry_mm2s(phase : int, col : int) -> str:
        p = phase
        h = phase // H_phases
        #if((p%H_phases)==0):
        #    return f"Yq:{M_off(h,col)}:{M_off(h,col)+Mp}:{Num4x4*Sq}    Yq:0:{Sq} Xq:0:{P}"
        #else:
        #    return f"Yq:{M_off(h,col)+Mp}:{M_off(h,col)+Mp+del_M(col)} Xq:0:{P}"
        if((p%H_phases)==0):
            return f"Yq:{h*M}:{h*M+Mp} Xq:0:{P}"
        else:
            return f"Yq:{h*M+Mp}:{h*M+Mp+Mr} Xq:0:{P}"
        
    def shim_Qry_repeatCount() -> List:
        return [1,1]*(H) if (H_phases == 2) else [1]*H

    def shim_Key_mm2s(phase : int, col : int) -> str:
        p = phase
        h = p // H_phases
        if((p%H_phases)==0):
            return f"Yk:{N*h+N_off_min(col)}:{N*h+N_off_max(col)} Xk:0:{P}"
        else:
            return f"Yk:0:0 Xk:0:0"
        
    def shim_Mask_mm2s(phase : int, col : int) -> str:
        p = phase
        if((p%H_phases)==0):
            return f"Ym:0:1 Xm:{N_off_min(col, False)}:{N_off_max(col, False)}"
        else:
            return f"Ym:0:0 Xm:0:0"
        
    def shim_KeyMask_repeatCount() -> List:
        return [1,0]*(H) if (H_phases == 2) else [1]*H
        
    def shim_Bias_mm2s(phase : int, col : int) -> str:
        p = phase
        h = p // H_phases
        if((p%H_phases)==0):
            return f"Yb:{M_off(h,col)}:{M_off(h,col)+Mp}:{Num4x4 * Sq} Yb:0:{Sq} Xb:{N_off_min(col, False)}:{N_off_max(col, False)}"
        else:
            return f"Yb:{M_off(h,col)+Mp}:{M_off(h,col)+Mp+del_M(col)} Xb:{N_off_min(col, False)}:{N_off_max(col, False)}"
      
    def shim_Bias_repeatCount() -> List:
        return [1,1]*(H) if (H_phases == 2) else [1]*H  

    def shim_Out_s2mm(phase : int, col : int) -> str:
        p = phase
        h = p // H_phases
        if((p%H_phases)==0):
            return f"Yo:{M_off(h,col)}:{M_off(h,col)+Mp}:{Num4x4*Sq} Yo:0:{Sq} Xo:{N_off_min(col, False)}:{N_off_max(col, False)}"
        else:
            return f"Yo:{M_off(h,col)+Mp}:{M_off(h,col)+Mp+del_M(col)} Xo:{N_off_min(col, False)}:{N_off_max(col, False)}"

    def shim_Out_repeatCount() -> List:
        return [1,1]*(H) if (H_phases == 2) else [1]*H
    
    def shim_Qdq_repeatCount() -> List:
        return [1,0]*H if (H_phases == 2) else [1]*H

    def shim_LayerPrm_repeatCount() -> List:
        return [1]+[0]*(H*H_phases-1)
    
    shim_km_t = [
        DataTransfer(
            shim_KeyMask_repeatCount(),
            AieTile(TileType.Shim, col, 0), [1], H*(M+N)*P*IfmBytes,
            [],
            [
                bd_reconfig(shim_dma(col, DmaDir.MM2S, 1), 
                        [shim_bo_Key]* NumPhases,
                        [shim_Key_mm2s( p, col) for p in range(NumPhases)],
                        buffer_offset=[H*M*P*IfmBytes]*NumPhases
                ),
            ]
            +
            ([
                bd_reconfig(shim_dma(col, DmaDir.MM2S, 1), 
                        [shim_bo_Mask]* NumPhases,
                        [shim_Mask_mm2s(p, col) for p in range(NumPhases)],
                        buffer_offset=[H*(M+N)*P*IfmBytes]* NumPhases
                ),
            ] if(Mask==1) else [])
        ) for col in range(AieCols)
    ]
    
    shim_q_t = [
        DataTransfer(
            shim_Qry_repeatCount(),
            AieTile(TileType.Shim, col, 0), [1], H*(M+N)*P*IfmBytes,
            [],
            [
                bd_reconfig(shim_dma(col, DmaDir.MM2S, 0), 
                        [shim_bo_Qry]*NumPhases,
                        [shim_Qry_mm2s( p, col) for p in range(NumPhases)],
                        buffer_offset=[0]*NumPhases
                ) 
            ]
        ) for col in range(0, AieCols, Num4x4)
    ]
    
    shim_b_t = [
        DataTransfer(
            shim_Bias_repeatCount(),
            AieTile(TileType.Shim, col, 0), [1], H*(M+N)*P*IfmBytes,
            [],
            [
                bd_reconfig(shim_dma(col, DmaDir.MM2S, 1), 
                        [shim_bo_Bias]*NumPhases,
                        [shim_Bias_mm2s(p, col) for p in range(NumPhases)],
                        buffer_offset=[(H*(M+N)*P+H*1*Nout)*IfmBytes]*NumPhases  
                ),
            ]
        ) for col in range(AieCols)
    ]
    
    shim_ot_te = [ 
        DataTransfer(
            shim_Out_repeatCount(),
            AieTile(TileType.Shim, col, 0), [0], H*M*N*OutBytes,
            [
                bd_reconfig(shim_dma(col, DmaDir.S2MM, 0), 
                        [shim_bo_Out]*NumPhases,
                        [shim_Out_s2mm( p, col) for p in range(NumPhases)],
                        buffer_offset=[0]*NumPhases
                ) 
            ],
            []
        ) for col in range(AieCols)
    ]

    shim_qdq = [
        DataTransfer(
            shim_Qdq_repeatCount(), Shimtile(col), [2], ShimQdqPrmSize,
            [],
            [TransferParams(shim_dma(col, DmaDir.MM2S, 0), ShimQdqPrmWords)]
        ) for col in range(0,AieCols, Num4x4)
    ]

    shim_transfers = [
        DataTransfer(
            shim_LayerPrm_repeatCount(), Shimtile(col), [3], MemtilePrmSize,
            [],
            [TransferParams(shim_dma(col, DmaDir.MM2S, 1), MemtilePrmWords, offset=((col * MemtilePrmWords)))]
        ) for col in range(AieCols)
    ] + shim_qdq + shim_km_t + shim_q_t + (shim_b_t if(Bias) else []) + shim_ot_te 

    dma_connections = overlay_8x4_dma_connections() if Num4x4 == 2 else overlay_4x4_dma_connections()
    
    if debug_print:
        print("Start MHA-2p1 Layer compilation:")
        print("CodeBackend = ", CodeBackend)
        print("kernel_names:", kernel_names)
        print("kernel_includes:", kernel_includes)

    run_layer_compilation(
        shape,
        kernel_names,
        kernel_includes,
        instr_dict, 
        memtile_transfers,
        shim_transfers,
        dma_connections=dma_connections,
        back_end=CodeBackend,
        core_stack_addr=overlay_stack_addr(), 
        param_channel_id=0,
        layer_file='dma.hpp',
        casc_dir = CascDir.Vertical,
        core_connections = overlay_8x4_core_stream_bdcast() if Num4x4 == 2 else overlay_4x4_core_stream_bdcast()
    )
    if debug_print:
        print("Exits MHA-2p1 Layer compilation")

