import sys
sys.path.append('.')
sys.path.append('../../dataflow/')
sys.path.append('../../dmacompiler/')


from dataflow_utils_shim import access_shim_rm_vert_shard, gen_Out_shim_data_transfer, gen_Qry_shim_data_transfer #, access_shim_rm_hori_shard
from dataflow_utils_mem import write_L2_rm_to_w8_subvolumes, read_L2_w8_to_rm_subvolumes
from mha_3p0_params import mha_3p0_qdq_params
from dataflow_common import overlay_8x4_dma_connections
from dataflow_common import overlay_4x4_dma_connections

from dmacompiler import OverlayShape, DataTransfer, TransferParams, SyncStrategy, BackEnd, \
    DmaChannel, DmaDir, AieDma, AieTile, TileType, DmaConnection, \
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    run_overlay_deadlock_check, \
    shim_dma, memtile_dma, \
    run_layer_compilation, \
    generate_shim_data_transfer

def align_to_64bytes(size):
    return (size + 63) & ~63

def ceildiv(x: int, d: int) -> int:
    return -(x // -d)

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d

Num4x4 = 1
NumAieCompCols = 4
AieRows = 4
AieCols = Num4x4 * NumAieCompCols  
BytesPerWord = 4

H = 2
backend_type = sys.argv[1] # Adf or Txn
shapeId = int(sys.argv[2])  
if (backend_type == "Adf") :
    CodeBackEnd = BackEnd.Adf
else:
    CodeBackEnd = BackEnd.TxnHostPatch
    H = 20 if(shapeId == 1) else 10


min_dim = 64
sfmx_mask = False
V_DDR_Tranposed = False

(St, S) = (0, 0)
if(shapeId == 1):
    (St, S) = (256, 256) 
elif(shapeId == 2):
    (St, S) = (1024, 1024)
elif(shapeId == 3):
    (St, S) = (4096, 4096)
elif(shapeId == 4):
    (St, S) = (1024, 128)
elif(shapeId == 5):
    (St, S) = (256, 128)
else:
    raise NotImplementedError(f"ShapeID ({shapeId})not supported.")
    
D = 64 * H
Di = 64 * H

Sc = 32
Sic = S // (AieRows * NumAieCompCols)
Dh = D // H
Dih = Di // H
if Sic < 16:
    sfmx_mask = True

QryBytes    = 2
KeyBytes    = 2
ValBytes    = 2
OutBytes    = 2
QdqNodes    = 6                                       
QdqPrm      = 16                                      
QdqPrmBytes = 4
TdmBytes    = 4
C0Bytes     = 8

CoreQrySize     = (Sc  * Dih * QryBytes) * Num4x4
CoreKeySize     = (Dih * Sic * KeyBytes)
CoreValSize     = (Sic * Dh  * ValBytes)
CoreOutSize     = (Sc  * Dh  * OutBytes)
CoreTdmSize     = Sc * max(min_dim, Sic) * 4 #16384
CorePrmSize     = 1024  
CoreKeyValSize  = CoreKeySize + CoreValSize
CoreQdqPrmSize  = (QdqNodes * QdqPrm * QdqPrmBytes) # 384   
CoreAct1SumSize = iceil(align_to_64bytes(TdmBytes*Sc),256) #align_to_64bytes(TdmBytes*Sc)
CoreAct2SumSize = 1024 #max(Sic, Dh) * TdmBytes *4  ## NOTE: we need this for test to pass
CoreC0_K_Size   = 1024 #max(Sic, Dh) * C0Bytes #*4
CoreC0_V_Size   = 1024  

#CoreKeyPadSize = (Dih * max(min_dim, Sic) * KeyBytes)
#CoreValPadSize = (max(min_dim, Sic) * Dh * ValBytes)

MemtileKeySize    = CoreKeySize * AieRows
MemtileValSize    = CoreValSize * AieRows
MemtileKvbSize    = MemtileKeySize + MemtileValSize
MemtilePrmSize    = CorePrmSize * AieRows 
MemtileQdqPrmSize = CoreQdqPrmSize                 # 384
#MemtileKeyPadSize = CoreKeyPadSize * AieRows
#MemtileValPadSize = CoreValPadSize * AieRows

ShimQrySize    = (St * Di * QryBytes)
ShimKeySize    = (Di * S  * KeyBytes) // NumAieCompCols
ShimValSize    = (S  * D  * ValBytes) // NumAieCompCols
ShimOutSize    = (St * D  * OutBytes)
ShimKvbSize    = ShimKeySize + ShimValSize
ShimQdqPrmSize = MemtileQdqPrmSize                 # 384


CoreQryPingAddr  = 0
CoreKeyPingAddr  = CoreQryPingAddr + CoreQrySize
CoreValPingAddr  = CoreKeyPingAddr + CoreKeySize
CoreOutAddr      = CoreValPingAddr + CoreValSize
CoreTdm1Addr     = CoreOutAddr     + CoreOutSize
CoreTdm2Addr     = CoreTdm1Addr    + CoreTdmSize
CoreQdqPingAddr  = CoreTdm2Addr    + CoreTdmSize                         
CoreAct1SumAddr  = CoreQdqPingAddr + CoreQdqPrmSize            ## NOTE: hardcoded max(512, CoreQdqPrmSize) ## 64 bit aligned looks into utils 
CoreAct2SumAddr  = CoreAct1SumAddr + CoreAct1SumSize           ## NOTE: 4*m | act1sum for qdq| 4*Sq =?    max(512,TdmBytes*Sq)
CoreC0_K_Addr    = CoreAct2SumAddr + CoreAct2SumSize           ## NOTE: ask Sam where this 4 is coming from? || Sam tells to remove the 4 in  ||  Skv * TdmBytes * 4    
CoreC0_V_Addr    = CoreC0_K_Addr   + CoreC0_K_Size
CoreScratchAddr  = CoreC0_V_Addr   + CoreC0_V_Size  
CoreStackAddr    = 56 * 1024 

print("CoreQrySize:", CoreQrySize)
print("CoreScratchAddr:", CoreScratchAddr)
assert(CoreScratchAddr < 48 * 1024) 

print("CoreQryPingAddr", CoreQryPingAddr)
print("CoreKeyPingAddr", CoreKeyPingAddr)
print("CoreValPingAddr", CoreValPingAddr)
print("CoreOutAddr"    , CoreOutAddr    )
print("CoreTdm1Addr"   , CoreTdm1Addr   )
print("CoreTdm2Addr"   , CoreTdm2Addr   )


MemtileQryPingAddr  = 0
MemtileQryPongAddr  = MemtileQryPingAddr + CoreQrySize
MemtilePrmPingAddr  = MemtileQryPongAddr + CoreQrySize
MemtileKvbPingAddr  = MemtilePrmPingAddr + MemtilePrmSize
MemtileKvbPongAddr  = MemtileKvbPingAddr + MemtileKvbSize
MemtileOutPingAddr  = MemtileKvbPongAddr + MemtileKvbSize
MemtileQdqPingAddr  = MemtileOutPingAddr + CoreOutSize      
#MemtileMaskPingAddr = MemtileOutPingAddr + CoreOutSize

def get_Words_from_Sizes(v_sizes : list):
    outlist = []
    for size in v_sizes:
        assert(size % BytesPerWord == 0)
        outlist.append(size // BytesPerWord)
    return tuple(outlist)

(CoreQryWords, CoreKeyWords, CoreValWords, CoreOutWords, CorePrmWords, CoreQdqPrmWords) =\
    get_Words_from_Sizes([CoreQrySize, CoreKeySize, CoreValSize, CoreOutSize, CorePrmSize, CoreQdqPrmSize])

(MemtileKeyWords, MemtileValWords, MemtileKvbWords, MemtilePrmWords, MemtileQdqPrmWords) =\
    get_Words_from_Sizes([MemtileKeySize, MemtileValSize, MemtileKvbSize, MemtilePrmSize, MemtileQdqPrmSize])

(ShimQryWords, ShimKeyWords, ShimValWords, ShimKvbWords, ShimOutWords, ShimQdqPrmWords) =\
    get_Words_from_Sizes([ShimQrySize, ShimKeySize, ShimValSize, ShimKvbSize, ShimOutSize, ShimQdqPrmSize])


def Memtile(col: int):
    return AieTile(TileType.Memtile, col, 0)

def shim_tile(col: int):
    return AieTile(TileType.Shim, col, 0)

MemtileKeyOffset = 0
MemtileValOffset = CoreKeyWords * AieRows  


ShimQryShardWords = (St * Dih * QryBytes) // 4
ShimQryRepeatNum = ShimQryWords // ShimQryShardWords
assert ShimQryWords % ShimQryShardWords == 0

ShimKeyShardWords = CoreKeyWords * AieRows 
assert ShimKeyShardWords * NumAieCompCols == (Di * S  * KeyBytes) // H // 4
ShimKeyOffset = ((S * Di * KeyBytes) // NumAieCompCols) // 4
ShimKeyRepeatNum = ShimKeyWords // ShimKeyShardWords
assert ShimKeyWords % ShimKeyShardWords == 0

ShimValShardWords = CoreValWords * AieRows
assert ShimValShardWords * NumAieCompCols == (S  * D  * ValBytes) // H // 4
ShimValOffset = ((S * D * ValBytes) // NumAieCompCols) // 4  
ShimValRepeatNum = ShimValWords // ShimValShardWords
assert ShimValWords % ShimValShardWords == 0


ShimOutShardWords = (St * Dh * OutBytes) // 4
ShimOutRepeatNum = ShimOutWords // ShimOutShardWords
assert ShimOutWords % ShimOutShardWords == 0


Tq  = ShimQryWords // CoreQryWords
Tkv = ShimKvbWords // MemtileKvbWords
To  = ShimOutWords // CoreOutWords // Num4x4
Li  = St // Sc // Num4x4

assert ShimQryWords % CoreQryWords == 0
assert ShimKvbWords % MemtileKvbWords == 0
assert ShimOutWords % CoreOutWords == 0
assert Tq == To
assert St % Sc == 0

MemtileQryRepeat  = [Tq]
MemtileKeyValRepeat = [Tkv]
MemtileOutRepeat    = [To]

ShimQryRepeat = [ShimQryRepeatNum]
ShimKvbRepeat = [ShimKeyRepeatNum]
ShimOutRepeat = [ShimOutRepeatNum]

assert ShimKeyRepeatNum == ShimValRepeatNum
assert ShimKeyRepeatNum == H
assert ShimQryRepeatNum == H
assert ShimOutRepeatNum == H
assert MemtileQryRepeat == [H * (St // Sc) // Num4x4]
assert MemtileKeyValRepeat == [H]
assert MemtileOutRepeat == [H * (St // Sc) // Num4x4]

print("MemtileQryRepeat=", MemtileQryRepeat)
print("CoreOutWords=", CoreOutWords)
print("ShimQryRepeatNum=", ShimQryRepeatNum)
print("ShimQryWords=",ShimQryWords)

shape = OverlayShape(AieCols, AieRows)

kernel_includes = ['super.hh', 'mha_qdq/wrapper_mha_3p0_4x4_i16i16.cc'] 
kernel_names = ['run_act_K_preprocess','run_act_V_preprocess','run_qkt_gemm_qdq','run_sfmx_i16_to_i16','run_smxv_gemm_qdq']#['run_mha_sdxl', 'run_kv_to_bfp16', 'run_set_gemm_params']

L2_MM2S_BCAST_CHAN      = 4 if(Num4x4==1) else 4
L2_MM2S_UCAST_CHAN_BASE = 0 if(Num4x4==1) else 0
CORE_UCAST_CHANNEL_ID   = 0 if(Num4x4==1) else 0
CORE_BCAST_CHANNEL_ID   = 1 if(Num4x4==1) else 1
SET_BCAST_COLUMNS  = [col for col in range(0, AieCols,Num4x4)] 
SET_OUTPUT_COLUMNS = [col for col in range(3, AieCols, NumAieCompCols)]
dma_connections = overlay_4x4_dma_connections() if(Num4x4==1) else overlay_8x4_dma_connections()
ucast_to_core_s2mm1 = 1 if(CORE_UCAST_CHANNEL_ID == 1) else 0

def gen_mha_params(
    ucast_to_s2mm1,
    multi_core,
    aie_col_id,
    aie_row_id
):
    return mha_3p0_qdq_params(
    (Sc, Dh), (Dh, Sic), 
    ucast_to_s2mm1, multi_core, #mha_mode, multi_core, 
    aie_col_id, aie_row_id, 
    Dh, S, #Sic, S, 
    CoreTdm1Addr, CoreTdm2Addr, CoreQdqPingAddr, 
    CoreAct1SumAddr, CoreAct2SumAddr, 
    CoreC0_K_Addr, CoreC0_V_Addr, CoreScratchAddr)

def Conditional_AcqBuffer(dma_channel: DmaChannel, aie_col_id:int, aie_row_id:int):
    return [AcqBuffer(dma_channel)] if (aie_col_id in SET_OUTPUT_COLUMNS and aie_row_id == 0) else []

def Conditional_RelBuffer(dma_channel: DmaChannel, aie_col_id:int, aie_row_id:int):
    return [RelBuffer(dma_channel)] if (aie_col_id in SET_OUTPUT_COLUMNS and aie_row_id == 0) else []     

def get_core_instrs(core_col_id:int, core_row_id:int):
    mha_params = gen_mha_params(ucast_to_core_s2mm1, 1, core_col_id, core_row_id)
    return  [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID), CoreQdqPingAddr, None, CoreQdqPrmSize),
        AcqBuffer(   DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID)),
        RelBuffer(   DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID)),

        ConfigBuffer(DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID), CoreQryPingAddr, None, CoreQrySize   ),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutAddr,     None, CoreOutSize   ),

        Loop(H, [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, CORE_UCAST_CHANNEL_ID), CoreKeyPingAddr, None, CoreKeyValSize),
            AcqBuffer(   DmaChannel(DmaDir.S2MM, CORE_UCAST_CHANNEL_ID)),                           # acq KV
            CallKernel('run_act_K_preprocess', mha_params),
            CallKernel('run_act_V_preprocess', mha_params),

            Loop(Li, Conditional_AcqBuffer(DmaChannel(DmaDir.MM2S, 0),core_col_id, core_row_id) +   # acq O
                [   
                    AcqBuffer(DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID)),                      # acq Q
                    
                    CallKernel('run_qkt_gemm_qdq',    mha_params),
                    CallKernel('run_sfmx_i16_to_i16', mha_params),
                    CallKernel('run_smxv_gemm_qdq',   mha_params),

                    RelBuffer(DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID))                       # rel Q
                ] +  Conditional_RelBuffer(DmaChannel(DmaDir.MM2S, 0),core_col_id, core_row_id)     # rel O),
            ),
            RelBuffer(DmaChannel(DmaDir.S2MM, CORE_UCAST_CHANNEL_ID)),                              # rel KV
        ])
    ]

core_instrs_array = []
for col in range(AieCols):
    for row in range(AieRows):
        core_instrs_array.append(get_core_instrs(col, row))


def access_linear_buffer(
    dma: AieDma,
    buffer_words: int,
    offset_words: int = 0,
) -> TransferParams:
    return TransferParams(dma, buffer_words, offset=offset_words)

memtile_transfers = [
    DataTransfer( [1], Memtile(col), [MemtilePrmPingAddr], MemtilePrmSize,
        [TransferParams(memtile_dma(col, DmaDir.S2MM, 0), MemtilePrmWords)],
        [TransferParams(memtile_dma(col, DmaDir.MM2S, row+L2_MM2S_UCAST_CHAN_BASE), CorePrmWords, offset=(row * CorePrmWords)) for row in range(AieRows)]
    ) for col in range(AieCols)
] + [
    DataTransfer( [1], Memtile(col), [MemtileQdqPingAddr], MemtileQdqPrmSize,
        [TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.S2MM, 0)), MemtileQdqPrmWords)],
        [TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.MM2S, L2_MM2S_BCAST_CHAN)), MemtileQdqPrmWords)]
    ) for col in SET_BCAST_COLUMNS
] + [
    DataTransfer( MemtileQryRepeat, Memtile(col), [MemtileQryPingAddr, MemtileQryPongAddr], CoreQrySize,
        [write_L2_rm_to_w8_subvolumes(      Sc,    Dih, QryBytes, memtile_dma(col, DmaDir.S2MM, 0), memtile_buffer_offset=0, id="Q", Num4x4=Num4x4)],
        [access_linear_buffer(memtile_dma(col, DmaDir.MM2S, L2_MM2S_BCAST_CHAN), CoreQryWords)] #[TransferParams(memtile_dma(col, DmaDir.MM2S, 0), CoreQryWords)]
    ) for col in SET_BCAST_COLUMNS
] + [
    DataTransfer( MemtileKeyValRepeat, Memtile(col), [MemtileKvbPingAddr, MemtileKvbPongAddr], MemtileKvbSize,
        [write_L2_rm_to_w8_subvolumes(AieRows*Sic, Dih, KeyBytes, memtile_dma(col, DmaDir.S2MM, 1), memtile_buffer_offset=MemtileKeyOffset, id="K"),
         write_L2_rm_to_w8_subvolumes(AieRows*Sic,  Dh, ValBytes, memtile_dma(col, DmaDir.S2MM, 1), memtile_buffer_offset=MemtileValOffset, id="V")],
        
        [access_linear_buffer(memtile_dma(col, DmaDir.MM2S, row+L2_MM2S_UCAST_CHAN_BASE), CoreKeyWords, MemtileKeyOffset + (row * CoreKeyWords)) for row in range(AieRows)]
        +
        [access_linear_buffer(memtile_dma(col, DmaDir.MM2S, row+L2_MM2S_UCAST_CHAN_BASE), CoreValWords, MemtileValOffset + (row * CoreValWords)) for row in range(AieRows)],
        sync_strategy=SyncStrategy.Parallel_1_to_N
    ) for col in range(AieCols)
] + [
    DataTransfer(  MemtileOutRepeat, Memtile(col), [MemtileOutPingAddr], CoreOutSize,
        [read_L2_w8_to_rm_subvolumes(Sc, Dh, OutBytes, memtile_dma(col, DmaDir.S2MM, 2), memtile_buffer_offset=0, id="O")],
        [access_linear_buffer(memtile_dma(col, DmaDir.MM2S, 5), CoreOutWords)],
        sync_strategy=SyncStrategy.Parallel_N_to_1
    ) for col in SET_OUTPUT_COLUMNS
]


SHIM_MM2S_CHAN_ID_PRM      = 0
SHIM_MM2S_CHAN_ID_QDQ_PRM  = 0
SHIM_MM2S_CHAN_ID_Q_TENSOR = 0
SHIM_MM2S_CHAN_ID_K_TENSOR = 1
SHIM_MM2S_CHAN_ID_V_TENSOR = 1
SHIM_S2MM_CHAN_ID_O_TENSOR = 0
DONTCARE = None
shim_transfers = [
    DataTransfer( [1], shim_tile(col), [3], MemtilePrmSize,
        [],
        [TransferParams(shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_PRM), MemtilePrmWords, offset=((col * MemtilePrmWords)))]
    ) for col in range(AieCols)
] + [
    DataTransfer( [1], shim_tile(col), [2], ShimQdqPrmSize,
        [],
        [TransferParams(shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_QDQ_PRM), ShimQdqPrmWords)]
    ) for col in SET_BCAST_COLUMNS
] + [
    #DataTransfer( [H], shim_tile(col), [1], ShimQrySize,
    #    [],
    #    [access_shim_rm_vert_shard(  St            , Di, DONTCARE, Dih, QryBytes, H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_Q_TENSOR), bo_offset=0, id="Q")] #[TransferParams(shim_dma(col, DmaDir.MM2S, 0), ShimQryWords//H, '''ShimQryShardWords,''' step=[ShimQryStep0, ShimQryStep1], wrap=[ShimQryWrap0], iter_step=ShimQryIterStep)]
    #) for col in SET_BCAST_COLUMNS
    gen_Qry_shim_data_transfer(   St                , Di, DONTCARE, Dih, QryBytes, H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_Q_TENSOR), bo_offset_words=0, id="Q", aie_col=col, Num4x4=Num4x4
    )  for col in SET_BCAST_COLUMNS
] + [
    DataTransfer( [H], shim_tile(col), [1], ShimKvbSize,
        [],
        [access_shim_rm_vert_shard(S, Di, DONTCARE, Dih, KeyBytes, H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_K_TENSOR), bo_offset=(ShimQryWords), id="K", aie_col=col), #[TransferParams(shim_dma(col, DmaDir.MM2S, 1), ShimKeyShardWords, offset=(ShimQryWords + (col * ShimKeyOffset)), step=[ShimKeyStep0, ShimKeyStep1], wrap=[ShimKeyWrap0], iter_step=ShimKeyIterStep),
         access_shim_rm_vert_shard(S,  D, DONTCARE,  Dh, ValBytes, H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_V_TENSOR), bo_offset=(ShimQryWords + (ShimKeyWords * NumAieCompCols)), id="V", aie_col=col)]# TransferParams(shim_dma(col, DmaDir.MM2S, 1), ShimValShardWords, offset=(ShimQryWords + (ShimKeyWords * AieCols) + (col * ShimValOffset)), step=ShimValBD['steps'], wrap=ShimValBD['wraps'], iter_step=ShimValBD['iter_step'])]  
        #[access_shim_rm_vert_shard(S, Di, DONTCARE, Dih, KeyBytes, H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_K_TENSOR), bo_offset=(ShimQryWords + ((col%NumAieCompCols) * ShimKeyOffset)), id="K", aie_col=col), #[TransferParams(shim_dma(col, DmaDir.MM2S, 1), ShimKeyShardWords, offset=(ShimQryWords + (col * ShimKeyOffset)), step=[ShimKeyStep0, ShimKeyStep1], wrap=[ShimKeyWrap0], iter_step=ShimKeyIterStep),
        # access_shim_rm_vert_shard(S,  D, DONTCARE,  Dh, ValBytes, H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_V_TENSOR), bo_offset=(ShimQryWords + (ShimKeyWords * NumAieCompCols) + ((col%NumAieCompCols) * ShimValOffset)), id="V", aie_col=col)]# TransferParams(shim_dma(col, DmaDir.MM2S, 1), ShimValShardWords, offset=(ShimQryWords + (ShimKeyWords * AieCols) + (col * ShimValOffset)), step=ShimValBD['steps'], wrap=ShimValBD['wraps'], iter_step=ShimValBD['iter_step'])]  
    ) for col in range(AieCols)
] + [
    #DataTransfer( [H], shim_tile(col), [0], ShimOutSize,
    #    [access_shim_rm_vert_shard(  St            ,  D,       Sc,  Dh, OutBytes, H, shim_dma(col, DmaDir.S2MM, SHIM_S2MM_CHAN_ID_O_TENSOR), bo_offset=(0 if(col==3) else CoreOutWords), id="O", Num4x4=Num4x4)], #[TransferParams(shim_dma(col, DmaDir.S2MM, 0), ShimOutShardWords, step=[ShimOutStep0, ShimOutStep1], wrap=[ShimOutWrap0], iter_step=ShimOutIterStep)],
    #    []
    #) for col in SET_OUTPUT_COLUMNS
    gen_Out_shim_data_transfer(   St                ,  D,       Sc,  Dh, OutBytes, H, shim_dma(col, DmaDir.S2MM, SHIM_S2MM_CHAN_ID_O_TENSOR), bo_offset_words=(0 if(col==3) else CoreOutWords), id="O", aie_col=col, Num4x4=Num4x4
    )  for col in SET_OUTPUT_COLUMNS
]

print("H:", H)
print("MemtileOutRepeat:", MemtileOutRepeat)
print("CoreOutWords:", CoreOutWords)
print("(St//Sc)//Num4x4* CoreOutWords:", (St//Sc)//Num4x4* CoreOutWords)


instr_dict = {}
for col in range(AieCols):
    for row in range(AieRows):
        #instr_dict[AieTile(TileType.Core, col, row)] = core_instrs_array[col*AieRows+row]
        instr_dict[AieTile(TileType.Core, col, row)] = get_core_instrs(col, row)

run_overlay_deadlock_check(
    shape,
    instr_dict,
    memtile_transfers,
    shim_transfers,
    dma_connections,
    param_channel_id=CORE_UCAST_CHANNEL_ID
)

run_layer_compilation(
    shape,
    kernel_names=kernel_names,
    kernel_includes=kernel_includes,
    core_instrs=instr_dict,
    memtile_transfers=memtile_transfers,
    shim_transfers=shim_transfers,
    dma_connections=dma_connections,
    back_end=CodeBackEnd,
    core_stack_addr=CoreStackAddr,
    param_channel_id=CORE_UCAST_CHANNEL_ID,
    layer_name='run_dma_layer_config',
    layer_file='dma.hpp')

