import os

from dataclasses import dataclass
from typing import Tuple, List, ClassVar, Dict

from dataflow.mha.mha_general_dflow.src.dataflow_utils_shim import (
    access_shim_rm_vert_shard,
    gen_Out_shim_data_transfer,
    gen_Qry_shim_data_transfer,
    # access_shim_rm_hori_shard,
)
from dataflow.mha.mha_general_dflow.src.dataflow_utils_mem import (
    write_L2_rm_to_w8_subvolumes,
    read_L2_w8_to_rm_subvolumes,
)
from dataflow.mha.mha_general_dflow.src.mha_3p0_params import mha_3p0_qdq_params

from dataflow.dataflow_common import (
    overlay_8x4_dma_connections,
    overlay_8x4_core_stream_bdcast,
)
from dataflow.dataflow_common import (
    overlay_4x4_dma_connections,
    overlay_4x4_core_stream_bdcast,
    overlay_stack_addr,
)

from OGOAT.src.Scheduling_Engine.schedules.BufferAllocatorResult import (
    BufferAllocations,
)

from dmacompiler import (
    OverlayShape,
    DataTransfer,
    TransferParams,
    SyncStrategy,
    BackEnd,
    DmaChannel,
    DmaDir,
    AieDma,
    AieTile,
    TileType,
    DmaConnection,
    CascDir,
    ConfigBuffer,
    AcqBuffer,
    RelBuffer,
    CallKernel,
    Loop,
    shim_dma,
    memtile_dma,
    run_layer_compilation,
    set_dev_gen,
    DevGen,
    config,
    generate_shim_data_transfer,
)
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

CURRDIR = os.path.dirname(os.path.abspath(__file__))


def align_to_64bytes(size):
    return (size + 63) & ~63


def ceildiv(x: int, d: int) -> int:
    return -(x // -d)


def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d


# H = 2
# backend_type = sys.argv[1] # Adf or Txn
# shapeId = int(sys.argv[2])
# if (backend_type == "Adf") :
#    CodeBackEnd = BackEnd.Adf
# else:
#    CodeBackEnd = BackEnd.TxnHostPatch
#    #H = 20 if(shapeId == 1) else 10


# V_DDR_Tranposed = False


@dataclass(frozen=True)
class Mha3p0Parameters:
    AieCols: int
    AieRows: int

    Num4x4: int
    NumAieCompCols: int

    QryBytes: int
    KeyBytes: int
    ValBytes: int
    OutBytes: int

    St: int  # M (Per Head)
    Di: int  # K (Per Head)
    S: int  # N (Per Head)
    D: int  # L (Per Head)

    Sc: int
    Sic: int
    Dh: int
    Dih: int
    H: int

    sfmx_mask: int = False

    # Constant values

    QdqPrm: ClassVar[int] = 16
    QdqPrmBytes: ClassVar[int] = 4
    QdqNodes: ClassVar[int] = 6
    TdmBytes: ClassVar[int] = 4
    C0Bytes: ClassVar[int] = 8

    min_dim: ClassVar[int] = 64
    BytesPerWord: ClassVar[int] = 4

    Msubv_gran: ClassVar[int] = 16  # Contraint coming from softmax
    Ksubv_gran: ClassVar[int] = 8  # Contraint coming from act2act gemm and softmax kernel
    Nsubv_gran: ClassVar[int] = 16  # Contraint coming from act2act gemm
    Lsubv_gran: ClassVar[int] = 16  # Contraint coming from act2act gemm

    def __post_init__(self):
        assert self.Sc >= self.Msubv_gran and self.Sc % self.Msubv_gran == 0
        assert self.Dih >= self.Ksubv_gran and self.Dih % self.Ksubv_gran == 0
        assert self.Sic >= self.Nsubv_gran and self.Sic % self.Nsubv_gran == 0
        assert self.Dh >= self.Lsubv_gran and self.Dh % self.Lsubv_gran == 0

        assert self.St % self.Sc == 0

    def get_Words_from_Sizes(self, v_sizes: list):
        outlist = []
        for size in v_sizes:
            assert(size % self.BytesPerWord == 0)
            outlist.append(size // self.BytesPerWord)
        return tuple(outlist)

    @staticmethod
    def compute_internal_parameters(parameters: Dict):
        AieCols = parameters['overlay_info']['shape']['col']
        AieRows = parameters['overlay_info']['shape']['row']

        NumAieCompCols = 4
        assert(AieRows == 4)
        Num4x4 = AieCols // 4

        QryBytes = parameters["layer_info"]["in_q_bytes"]
        KeyBytes = parameters["layer_info"]["in_k_bytes"]
        ValBytes = parameters["layer_info"]["in_v_bytes"]
        OutBytes = parameters["layer_info"]["out_ofm_bytes"]

        St = parameters['layer_info']['in_q_shape'][0]  # M (Per Head)
        Di = parameters['layer_info']['in_q_shape'][1]  # K (Per Head)
        S  = parameters['layer_info']['in_v_shape'][0]  # N (Per Head)
        D  = parameters['layer_info']['in_v_shape'][1]  # L (Per Head)

        # assert(parameters['layer_info']['in_k_shape'][1] == Di)
        # assert(parameters['layer_info']['in_v_shape'][0] == S )
        assert(parameters['layer_info']['out_ofm_shape'][0] == St)
        assert(parameters['layer_info']['out_ofm_shape'][1] == D)

        H = 1  # if(backend_type == BackEnd.Adf) else 20
        Sc = parameters['core_tile_params']['subvols']['q'][0]
        Sic = S // (AieRows * NumAieCompCols)
        Dh = D // H
        Dih = Di // H

        sfmx_mask = False
        # if Sic < 16:
        #    sfmx_mask = True

        return Mha3p0Parameters(
            AieCols=AieCols,
            AieRows=AieRows,
            Num4x4=Num4x4,
            NumAieCompCols=NumAieCompCols,
            QryBytes=QryBytes,
            KeyBytes=KeyBytes,
            ValBytes=ValBytes,
            OutBytes=OutBytes,
            St=St,
            Di=Di,
            S=S,
            D=D,
            Sc=Sc,
            Sic=Sic,
            H=H,
            Dh=Dh,
            Dih=Dih,
            sfmx_mask=sfmx_mask
        )


# FIXME: as a first step this function is used in the product flow
# but once the dataflow can works with the buffer allocator values
# we should only use this function for standalone development, testing and debugging
def generate_standalone_buffer_allocations(p: Mha3p0Parameters) -> BufferAllocations:
    res = BufferAllocations()

    CoreQrySize = (p.Sc * p.Dih * p.QryBytes) * p.Num4x4
    CoreKeySize = p.Dih * p.Sic * p.KeyBytes
    CoreValSize = p.Sic * p.Dh * p.ValBytes
    CoreOutSize = p.Sc * p.Dh * p.OutBytes
    CoreTdmSize = p.Sc * max(p.Sic, p.Dh) * p.TdmBytes
    CorePrmSize = config.MAX_CORE_LAYER_PARAM_SIZE
    CoreQdqPrmSize = p.QdqNodes * p.QdqPrm * p.QdqPrmBytes  # 384
    CoreAct1SumSize = iceil(align_to_64bytes(p.TdmBytes * p.Sc), 256)
    CoreAct2SumSize = 1024  # NOTE: we need this for test to pass
    CoreC0_K_Size = p.Dih * p.C0Bytes  # 1024 # max(Sic, Dh) * C0Bytes #*4
    CoreC0_V_Size = p.Dh * p.C0Bytes  # 1024

    CoreQryPingAddr = 0
    CoreKeyPingAddr = CoreQryPingAddr + CoreQrySize
    CoreValPingAddr = CoreKeyPingAddr + CoreKeySize
    CoreOutAddr = CoreValPingAddr + CoreValSize
    CoreTdm1Addr = CoreOutAddr + CoreOutSize
    CoreTdm2Addr = CoreTdm1Addr + CoreTdmSize
    CoreQdqPingAddr = CoreTdm2Addr + CoreTdmSize
    CoreAct1SumAddr = CoreQdqPingAddr + CoreQdqPrmSize
    CoreAct2SumAddr = CoreAct1SumAddr + CoreAct1SumSize
    CoreC0_K_Addr = CoreAct2SumAddr + CoreAct2SumSize
    CoreC0_V_Addr = CoreC0_K_Addr + CoreC0_K_Size
    CoreScratchAddr = CoreC0_V_Addr + CoreC0_V_Size
    CoreStackAddr = overlay_stack_addr()  # 56 * 1024

    res.add_core_allocation("q", CoreQrySize, ping_addresses=[CoreQryPingAddr])
    res.add_core_allocation("k", CoreKeySize, ping_addresses=[CoreKeyPingAddr])
    res.add_core_allocation("v", CoreValSize, ping_addresses=[CoreValPingAddr])
    res.add_core_allocation("ofm", CoreOutSize, ping_addresses=[CoreOutAddr])
    res.add_core_allocation("tdm", CoreTdmSize, ping_addresses=[CoreTdm1Addr, CoreTdm2Addr])
    res.add_core_allocation("qdq", CoreQdqPrmSize, ping_addresses=[CoreQdqPingAddr])
    res.add_core_allocation("act1_sum", CoreAct1SumSize, ping_addresses=[CoreAct1SumAddr])
    res.add_core_allocation("act2_sum", CoreAct2SumSize, ping_addresses=[CoreAct2SumAddr])
    res.add_core_allocation("c0_k", CoreC0_K_Size, ping_addresses=[CoreC0_K_Addr])
    res.add_core_allocation("c0_v", CoreC0_V_Size, ping_addresses=[CoreC0_V_Addr])

    # scratch and stack sizes are not needed, provide a fake values
    res.add_core_allocation("scratch", 0, ping_addresses=[CoreScratchAddr])
    res.add_core_allocation("stack", 0, ping_addresses=[CoreStackAddr])

    assert CoreScratchAddr < 48 * 1024

    MemtileQrySize = CoreQrySize
    MemtileOutSize = CoreOutSize
    MemtileKeySize = CoreKeySize * p.AieRows
    MemtileValSize = CoreValSize * p.AieRows
    MemtilePrmSize = CorePrmSize * p.AieRows
    MemtileQdqPrmSize = CoreQdqPrmSize

    MemtileQryPingAddr  = 0
    MemtileQryPongAddr  = MemtileQryPingAddr + MemtileQrySize
    MemtilePrmPingAddr  = MemtileQryPongAddr + MemtileQrySize
    MemtileKeyPingAddr  = MemtilePrmPingAddr + MemtilePrmSize
    MemtileValPingAddr  = MemtileKeyPingAddr + MemtileKeySize
    MemtileKeyPongAddr  = MemtileValPingAddr + MemtileValSize
    MemtileValPongAddr  = MemtileKeyPongAddr + MemtileKeySize
    MemtileOutPingAddr  = MemtileValPongAddr + MemtileValSize
    MemtileQdqPingAddr  = MemtileOutPingAddr + MemtileOutSize

    res.add_mem_allocation("q", MemtileQrySize, ping_addresses=[MemtileQryPingAddr], pong_addresses=[MemtileQryPongAddr])
    res.add_mem_allocation("prm", MemtilePrmSize, ping_addresses=[MemtilePrmPingAddr])
    res.add_mem_allocation("k", MemtileKeySize, ping_addresses=[MemtileKeyPingAddr], pong_addresses=[MemtileKeyPongAddr])
    res.add_mem_allocation("v", MemtileValSize, ping_addresses=[MemtileValPingAddr], pong_addresses=[MemtileValPongAddr])
    res.add_mem_allocation("ofm", MemtileOutSize, ping_addresses=[MemtileOutPingAddr])
    res.add_mem_allocation("qdq", MemtileQdqPrmSize, ping_addresses=[MemtileQdqPingAddr])

    return res


def generate_dataflow_mha_3p0(
    parameters: Mha3p0Parameters,
    buffer_alloc: BufferAllocations,
    backend_type: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str]
) -> Tuple[int, int, int, int]:

    p = parameters

    CoreQrySize = buffer_alloc.get_core_alloc("q").ping.size
    CoreKeySize = buffer_alloc.get_core_alloc("k").ping.size
    CoreValSize = buffer_alloc.get_core_alloc("v").ping.size
    CoreOutSize = buffer_alloc.get_core_alloc("ofm").ping.size

    CorePrmSize = config.MAX_CORE_LAYER_PARAM_SIZE
    CoreKeyValSize = CoreKeySize + CoreValSize
    CoreQdqPrmSize = buffer_alloc.get_core_alloc("qdq").ping.size

    MemtileKeySize = buffer_alloc.get_mem_alloc("k").ping.size
    MemtileValSize = buffer_alloc.get_mem_alloc("v").ping.size
    MemtileKvbSize = MemtileKeySize + MemtileValSize
    MemtilePrmSize = buffer_alloc.get_mem_alloc("prm").ping.size
    MemtileQdqPrmSize = buffer_alloc.get_mem_alloc("qdq").ping.size

    ShimQrySize = p.St * p.Di * p.QryBytes
    ShimKeySize = (p.Di * p.S * p.KeyBytes) // p.NumAieCompCols
    ShimValSize = (p.S * p.D * p.ValBytes) // p.NumAieCompCols
    ShimOutSize = p.St * p.D * p.OutBytes
    ShimKvbSize = ShimKeySize + ShimValSize
    ShimQdqPrmSize = MemtileQdqPrmSize

    CoreQryPingAddr = buffer_alloc.get_core_alloc("q").ping.addresses[0]
    CoreKeyPingAddr = buffer_alloc.get_core_alloc("k").ping.addresses[0]
    CoreOutAddr = buffer_alloc.get_core_alloc("ofm").ping.addresses[0]
    CoreTdm1Addr = buffer_alloc.get_core_alloc("tdm").ping.addresses[0]
    CoreTdm2Addr = buffer_alloc.get_core_alloc("tdm").ping.addresses[1]
    CoreQdqPingAddr = buffer_alloc.get_core_alloc("qdq").ping.addresses[0]
    CoreAct1SumAddr = buffer_alloc.get_core_alloc("act1_sum").ping.addresses[0]
    CoreAct2SumAddr = buffer_alloc.get_core_alloc("act2_sum").ping.addresses[0]
    CoreC0_K_Addr = buffer_alloc.get_core_alloc("c0_k").ping.addresses[0]
    CoreC0_V_Addr = buffer_alloc.get_core_alloc("c0_v").ping.addresses[0]
    CoreScratchAddr = buffer_alloc.get_core_alloc("scratch").ping.addresses[0]
    CoreStackAddr = buffer_alloc.get_core_alloc("stack").ping.addresses[0]

    assert CoreScratchAddr < 48 * 1024

    MemtileQryPingAddr  = buffer_alloc.get_mem_alloc("q").ping.addresses[0]
    MemtileQryPongAddr  = buffer_alloc.get_mem_alloc("q").pong.addresses[0]
    MemtilePrmPingAddr  = buffer_alloc.get_mem_alloc("prm").ping.addresses[0]
    MemtileKvbPingAddr  = buffer_alloc.get_mem_alloc("k").ping.addresses[0]
    MemtileKvbPongAddr  = buffer_alloc.get_mem_alloc("k").pong.addresses[0]
    MemtileOutPingAddr  = buffer_alloc.get_mem_alloc("ofm").ping.addresses[0]
    MemtileQdqPingAddr  = buffer_alloc.get_mem_alloc("qdq").ping.addresses[0]

    CoreSizes = [CoreQrySize, CoreKeySize, CoreValSize, CoreOutSize, CorePrmSize, CoreQdqPrmSize]
    (CoreQryWords, CoreKeyWords, CoreValWords, CoreOutWords, CorePrmWords, CoreQdqPrmWords) = \
        p.get_Words_from_Sizes(CoreSizes)

    MemtileSizes = [MemtileKeySize, MemtileValSize, MemtileKvbSize, MemtilePrmSize, MemtileQdqPrmSize]
    (MemtileKeyWords, MemtileValWords, MemtileKvbWords, MemtilePrmWords, MemtileQdqPrmWords) = \
        p.get_Words_from_Sizes(MemtileSizes)

    ShimSizes = [ShimQrySize, ShimKeySize, ShimValSize, ShimKvbSize, ShimOutSize, ShimQdqPrmSize]
    (ShimQryWords, ShimKeyWords, ShimValWords, ShimKvbWords, ShimOutWords, ShimQdqPrmWords) = \
        p.get_Words_from_Sizes(ShimSizes)

    def Memtile(col: int):
        return AieTile(TileType.Memtile, col, 0)

    def shim_tile(col: int):
        return AieTile(TileType.Shim, col, 0)

    MemtileKeyOffset = 0
    MemtileValOffset = CoreKeyWords * p.AieRows

    ShimQryShardWords = (p.St * p.Dih * p.QryBytes) // 4
    ShimQryRepeatNum = ShimQryWords // ShimQryShardWords
    assert ShimQryWords % ShimQryShardWords == 0

    ShimKeyShardWords = CoreKeyWords * p.AieRows
    assert ShimKeyShardWords * p.NumAieCompCols == (p.Di * p.S * p.KeyBytes) // p.H // 4
    ShimKeyOffset = ((p.S * p.Di * p.KeyBytes) // p.NumAieCompCols) // 4
    ShimKeyRepeatNum = ShimKeyWords // ShimKeyShardWords
    assert ShimKeyWords % ShimKeyShardWords == 0

    ShimValShardWords = CoreValWords * p.AieRows
    assert ShimValShardWords * p.NumAieCompCols == (p.S * p.D * p.ValBytes) // p.H // 4
    ShimValOffset = ((p.S * p.D * p.ValBytes) // p.NumAieCompCols) // 4
    ShimValRepeatNum = ShimValWords // ShimValShardWords
    assert ShimValWords % ShimValShardWords == 0

    ShimOutShardWords = (p.St * p.Dh * p.OutBytes) // 4
    ShimOutRepeatNum = ShimOutWords // ShimOutShardWords
    assert ShimOutWords % ShimOutShardWords == 0

    Tq = ShimQryWords // CoreQryWords
    Tkv = ShimKvbWords // MemtileKvbWords
    To = ShimOutWords // CoreOutWords // p.Num4x4
    Li = p.St // p.Sc // p.Num4x4

    assert ShimQryWords % CoreQryWords == 0
    assert ShimKvbWords % MemtileKvbWords == 0
    assert ShimOutWords % CoreOutWords == 0
    assert Tq == To

    MemtileQryRepeat = [Tq]
    MemtileKeyValRepeat = [Tkv]
    MemtileOutRepeat = [To]

    assert ShimKeyRepeatNum == ShimValRepeatNum
    assert ShimKeyRepeatNum == p.H
    assert ShimQryRepeatNum == p.H
    assert ShimOutRepeatNum == p.H
    assert MemtileQryRepeat == [p.H * (p.St // p.Sc) // p.Num4x4]
    assert MemtileKeyValRepeat == [p.H]
    assert MemtileOutRepeat == [p.H * (p.St // p.Sc) // p.Num4x4]

    print("MemtileQryRepeat=", MemtileQryRepeat)
    print("CoreOutWords=", CoreOutWords)
    print("ShimQryRepeatNum=", ShimQryRepeatNum)
    print("ShimQryWords=", ShimQryWords)

    shape = OverlayShape(p.AieCols, p.AieRows)

    L2_MM2S_BCAST_CHAN = 4 if p.Num4x4 == 1 else 4
    L2_MM2S_UCAST_CHAN_BASE = 0 if p.Num4x4 == 1 else 0
    CORE_UCAST_CHANNEL_ID = 0 if p.Num4x4 == 1 else 0
    CORE_BCAST_CHANNEL_ID = 1 if p.Num4x4 == 1 else 1
    SET_BCAST_COLUMNS = [col for col in range(0, p.AieCols, p.Num4x4)]
    SET_OUTPUT_COLUMNS = [col for col in range(3, p.AieCols, p.NumAieCompCols)]
    dma_connections = (
        overlay_4x4_dma_connections() if p.Num4x4 == 1 else overlay_8x4_dma_connections()
    )
    ucast_to_core_s2mm1 = 1 if CORE_UCAST_CHANNEL_ID == 1 else 0

    def gen_mha_params(ucast_to_s2mm1, multi_core, aie_col_id, aie_row_id):
        return mha_3p0_qdq_params(
            (p.Sc, p.Dh),
            (p.Dh, p.Sic),
            ucast_to_s2mm1,
            multi_core,  # mha_mode, multi_core,
            aie_col_id,
            aie_row_id,
            p.Dh,
            p.S,  # Sic, S,
            CoreTdm1Addr,
            CoreTdm2Addr,
            CoreQdqPingAddr,
            CoreAct1SumAddr,
            CoreAct2SumAddr,
            CoreC0_K_Addr,
            CoreC0_V_Addr,
            CoreScratchAddr,
        )

    def Conditional_AcqBuffer(dma_channel: DmaChannel, aie_col_id: int, aie_row_id: int):
        return [AcqBuffer(dma_channel)] if aie_col_id in SET_OUTPUT_COLUMNS and aie_row_id == 0 else []

    def Conditional_RelBuffer(dma_channel: DmaChannel, aie_col_id: int, aie_row_id: int):
        return [RelBuffer(dma_channel)] if (aie_col_id in SET_OUTPUT_COLUMNS and aie_row_id == 0) else []

    def get_core_instrs(core_col_id: int, core_row_id: int):
        mha_params = gen_mha_params(ucast_to_core_s2mm1, 1, core_col_id, core_row_id)
        return [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID), CoreQdqPingAddr, None, CoreQdqPrmSize),
            AcqBuffer(   DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID)),
            RelBuffer(   DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID)),

            ConfigBuffer(DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID), CoreQryPingAddr, None, CoreQrySize   ),
            ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutAddr,     None, CoreOutSize   ),

            Loop(p.H, [
                ConfigBuffer(DmaChannel(DmaDir.S2MM, CORE_UCAST_CHANNEL_ID), CoreKeyPingAddr, None, CoreKeyValSize),
                AcqBuffer(   DmaChannel(DmaDir.S2MM, CORE_UCAST_CHANNEL_ID)),                           # acq KV
                CallKernel('run_act_K_preprocess', mha_params),
                CallKernel('run_act_V_preprocess', mha_params),

                Loop(Li, Conditional_AcqBuffer(DmaChannel(DmaDir.MM2S, 0), core_col_id, core_row_id) +   # acq O
                    [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID)),                      # acq Q
                        CallKernel('run_qkt_gemm_qdq',    mha_params),
                        CallKernel('run_sfmx_i16_to_i16', mha_params),
                        CallKernel('run_smxv_gemm_qdq',   mha_params),

                        RelBuffer(DmaChannel(DmaDir.S2MM, CORE_BCAST_CHANNEL_ID))                       # rel Q
                    ] +  Conditional_RelBuffer(DmaChannel(DmaDir.MM2S, 0), core_col_id, core_row_id)     # rel O),
                ),
                RelBuffer(DmaChannel(DmaDir.S2MM, CORE_UCAST_CHANNEL_ID)),                              # rel KV
            ])
        ]

    core_instrs_array = []
    for col in range(p.AieCols):
        for row in range(p.AieRows):
            core_instrs_array.append(get_core_instrs(col, row))

    def access_linear_buffer(
        dma: AieDma,
        buffer_words: int,
        offset_words: int = 0,
    ) -> TransferParams:
        return TransferParams(dma, buffer_words, offset=offset_words)

    memtile_transfers = [
        DataTransfer( [1], Memtile(col), [MemtilePrmPingAddr], MemtilePrmSize,
            [TransferParams(memtile_dma(col, DmaDir.S2MM, 0), MemtilePrmWords)],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, row+L2_MM2S_UCAST_CHAN_BASE), CorePrmWords, offset=(row * CorePrmWords)) for row in range(p.AieRows)]
        ) for col in range(p.AieCols)
    ] + [
        DataTransfer( [1], Memtile(col), [MemtileQdqPingAddr], MemtileQdqPrmSize,
            [TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.S2MM, 0)), MemtileQdqPrmWords)],
            [TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.MM2S, L2_MM2S_BCAST_CHAN)), MemtileQdqPrmWords)]
        ) for col in SET_BCAST_COLUMNS
    ] + [
        DataTransfer( MemtileQryRepeat, Memtile(col), [MemtileQryPingAddr, MemtileQryPongAddr], CoreQrySize,
            [write_L2_rm_to_w8_subvolumes(      p.Sc,    p.Dih, p.QryBytes, memtile_dma(col, DmaDir.S2MM, 0), memtile_buffer_offset=0, id="Q", Num4x4=p.Num4x4)],
            [access_linear_buffer(memtile_dma(col, DmaDir.MM2S, L2_MM2S_BCAST_CHAN), CoreQryWords)] #[TransferParams(memtile_dma(col, DmaDir.MM2S, 0), CoreQryWords)]
        ) for col in SET_BCAST_COLUMNS
    ] + [
        DataTransfer( MemtileKeyValRepeat, Memtile(col), [MemtileKvbPingAddr, MemtileKvbPongAddr], MemtileKvbSize,
            [write_L2_rm_to_w8_subvolumes(p.AieRows * p.Sic, p.Dih, p.KeyBytes, memtile_dma(col, DmaDir.S2MM, 1), memtile_buffer_offset=MemtileKeyOffset, id="K"),
             write_L2_rm_to_w8_subvolumes(p.AieRows * p.Sic, p.Dh,  p.ValBytes, memtile_dma(col, DmaDir.S2MM, 1), memtile_buffer_offset=MemtileValOffset, id="V")],

            [access_linear_buffer(memtile_dma(col, DmaDir.MM2S, row+L2_MM2S_UCAST_CHAN_BASE), CoreKeyWords, MemtileKeyOffset + (row * CoreKeyWords)) for row in range(p.AieRows)]
            +
            [access_linear_buffer(memtile_dma(col, DmaDir.MM2S, row+L2_MM2S_UCAST_CHAN_BASE), CoreValWords, MemtileValOffset + (row * CoreValWords)) for row in range(p.AieRows)],
            sync_strategy=SyncStrategy.Parallel_1_to_N
        ) for col in range(p.AieCols)
    ] + [
        DataTransfer(  MemtileOutRepeat, Memtile(col), [MemtileOutPingAddr], CoreOutSize,
            [read_L2_w8_to_rm_subvolumes(p.Sc, p.Dh, p.OutBytes, memtile_dma(col, DmaDir.S2MM, 2), memtile_buffer_offset=0, id="O")],
            [access_linear_buffer(memtile_dma(col, DmaDir.MM2S, 5), CoreOutWords)],
            sync_strategy=SyncStrategy.Parallel_N_to_1
        ) for col in SET_OUTPUT_COLUMNS
    ]


    SHIM_MM2S_CHAN_ID_PRM      = 0
    SHIM_MM2S_CHAN_ID_QDQ_PRM  = 0
    SHIM_MM2S_CHAN_ID_Q_TENSOR = 0
    SHIM_MM2S_CHAN_ID_K_TENSOR = 1
    SHIM_MM2S_CHAN_ID_V_TENSOR = 1
    SHIM_S2MM_CHAN_ID_O_TENSOR = 0
    DONTCARE = None
    shim_transfers = [
        DataTransfer( [1], shim_tile(col), [3], MemtilePrmSize,
            [],
            [TransferParams(shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_PRM), MemtilePrmWords, offset=((col * MemtilePrmWords)))]
        ) for col in range(p.AieCols)
    ] + [
        DataTransfer( [1], shim_tile(col), [2], ShimQdqPrmSize,
            [],
            [TransferParams(shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_QDQ_PRM), ShimQdqPrmWords)]
        ) for col in SET_BCAST_COLUMNS
    ] + [
        #DataTransfer( [H], shim_tile(col), [1], ShimQrySize,
        #    [],
        #    [access_shim_rm_vert_shard(  St            , Di, DONTCARE, Dih, QryBytes, H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_Q_TENSOR), bo_offset=0, id="Q")] #[TransferParams(shim_dma(col, DmaDir.MM2S, 0), ShimQryWords//H, '''ShimQryShardWords,''' step=[ShimQryStep0, ShimQryStep1], wrap=[ShimQryWrap0], iter_step=ShimQryIterStep)]
        #) for col in SET_BCAST_COLUMNS
        gen_Qry_shim_data_transfer(   p.St                , p.Di, DONTCARE, p.Dih, p.QryBytes, p.H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_Q_TENSOR), bo_offset_words=0, id="Q", aie_col=col, Num4x4=p.Num4x4
        )  for col in SET_BCAST_COLUMNS
    ] + [
        DataTransfer( [p.H], shim_tile(col), [1], ShimKvbSize,
            [],
            [access_shim_rm_vert_shard(p.S, p.Di, DONTCARE, p.Dih, p.KeyBytes, p.H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_K_TENSOR), bo_offset=(ShimQryWords), id="K", aie_col=col%4), #[TransferParams(shim_dma(col, DmaDir.MM2S, 1), ShimKeyShardWords, offset=(ShimQryWords + (col * ShimKeyOffset)), step=[ShimKeyStep0, ShimKeyStep1], wrap=[ShimKeyWrap0], iter_step=ShimKeyIterStep),
            access_shim_rm_vert_shard(p.S,  p.D, DONTCARE,  p.Dh, p.ValBytes, p.H, shim_dma(col, DmaDir.MM2S, SHIM_MM2S_CHAN_ID_V_TENSOR), bo_offset=(ShimQryWords + (ShimKeyWords * p.NumAieCompCols)), id="V", aie_col=col%4)]# TransferParams(shim_dma(col, DmaDir.MM2S, 1), ShimValShardWords, offset=(ShimQryWords + (ShimKeyWords * AieCols) + (col * ShimValOffset)), step=ShimValBD['steps'], wrap=ShimValBD['wraps'], iter_step=ShimValBD['iter_step'])]
        ) for col in range(p.AieCols)
    ] + [
        #DataTransfer( [H], shim_tile(col), [0], ShimOutSize,
        #    [access_shim_rm_vert_shard(  St            ,  D,       Sc,  Dh, OutBytes, H, shim_dma(col, DmaDir.S2MM, SHIM_S2MM_CHAN_ID_O_TENSOR), bo_offset=(0 if(col==3) else CoreOutWords), id="O", Num4x4=Num4x4)], #[TransferParams(shim_dma(col, DmaDir.S2MM, 0), ShimOutShardWords, step=[ShimOutStep0, ShimOutStep1], wrap=[ShimOutWrap0], iter_step=ShimOutIterStep)],
        #    []
        #) for col in SET_OUTPUT_COLUMNS
        gen_Out_shim_data_transfer(   p.St                ,  p.D,       p.Sc,  p.Dh, p.OutBytes, p.H, shim_dma(col, DmaDir.S2MM, SHIM_S2MM_CHAN_ID_O_TENSOR), bo_offset_words=(0 if(col==3) else CoreOutWords), id="O", aie_col=col, Num4x4=p.Num4x4
        )  for col in SET_OUTPUT_COLUMNS
    ]

    print("H:", p.H)
    print("MemtileOutRepeat:", MemtileOutRepeat)
    print("CoreOutWords:", CoreOutWords)
    print("(St//Sc)//Num4x4* CoreOutWords:", (p.St // p.Sc) // p.Num4x4 * CoreOutWords)

    instr_dict = {}
    for col in range(p.AieCols):
        for row in range(p.AieRows):
            instr_dict[AieTile(TileType.Core, col, row)] = get_core_instrs(col, row)

    run_layer_compilation(
        shape,
        kernel_names=kernel_names,
        kernel_includes=kernel_includes,
        core_instrs=instr_dict,
        memtile_transfers=memtile_transfers,
        shim_transfers=shim_transfers,
        dma_connections=dma_connections,
        back_end=backend_type,
        core_stack_addr=CoreStackAddr,
        param_channel_id=CORE_UCAST_CHANNEL_ID,
        layer_name='run_dma_layer_config',
        layer_file='dma.hpp',
        casc_dir=CascDir.Vertical,
        core_connections=overlay_8x4_core_stream_bdcast() if p.Num4x4 == 2 else overlay_4x4_core_stream_bdcast())

    return (p.St, p.Di, p.S, p.D, p.Sc, p.Dih, p.Sic, p.Dh)


# super.cc/super.hh/dma.hpp/graph.hpp



## Memtile transfer MM2S:
#[TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.MM2S, row+1)), (MemtileKeyPadWords // 4),
#                offset=MemtileKeyOffset + (row * (MemtileKeyWords // 4))) for row in range(AieRows)
#                #step=[MemtileKeyOutStep0, MemtileKeyOutStep1, MemtileKeyOutStep2],
#                #wrap=[MemtileKeyOutWrap0, MemtileKeyOutWrap1, MemtileKeyOutWrap2],
#                #padding=[(0, 0), (MemtileKeyPad1before, MemtileKeyPad1after)]) for row in range(AieRows)
#] + 
#[TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.MM2S, row+1)), (MemtileValPadWords // 4),
#                offset=MemtileValOffset + (row * (MemtileValWords // 4))) for row in range(AieRows)],
#                #step=[MemtileValOutStep0, MemtileValOutStep1, MemtileValOutStep2],
#                #wrap=[MemtileValOutWrap0, MemtileValOutWrap1, MemtileValOutWrap2],
#                #padding=[(0, 0), (0, 0), (MemtileValPad2before, MemtileValPad2after)]) for row in range(AieRows)],


## Backup the memtile transfer
#[
    #DataTransfer(
    #    MemtileQryRepeat, Memtile(col), [MemtileQryPingAddr, MemtileQryPongAddr], CoreQrySize,
        #[TransferParams(memtile_dma(col, DmaDir.S2MM, 0), CoreQryWords,
        #                step=[MemtileQryStep0, MemtileQryStep1, MemtileQryStep2],
        #                wrap=[MemtileQryWrap0, MemtileQryWrap1])],
        #[conv_rm_to_w8_subvolumes(      Sc,    Dih, QryBytes, H, memtile_dma(col, DmaDir.S2MM, 0), memtile_buffer_offset=0)],
        #[TransferParams(memtile_dma(col, DmaDir.MM2S, 0), CoreQryWords)]
    #) for col in range(AieCols)
#] + [
    #DataTransfer(
        #MemtileKeyValRepeat, Memtile(col), [MemtileKvbPingAddr, MemtileKvbPongAddr], MemtileKvbSize,
        #[TransferParams(memtile_dma(col, DmaDir.S2MM, 1), MemtileKeyWords,
        #                offset=MemtileKeyOffset,
        #                step=[MemtileKeyStep0, MemtileKeyStep1, MemtileKeyStep2, MemtileKeyStep3],
        #                wrap=[MemtileKeyWrap0, MemtileKeyWrap1, MemtileKeyWrap2]),
        # TransferParams(memtile_dma(col, DmaDir.S2MM, 1), MemtileValWords,
        #                offset=MemtileValOffset,
        #                step=[MemtileValStep0, MemtileValStep1, MemtileValStep2, MemtileValStep3],
        #                wrap=[MemtileValWrap0, MemtileValWrap1, MemtileValWrap2])],
        #[TransferParams(memtile_dma(col, DmaDir.MM2S, row+1), CoreKeyWords, #memtile_dma(col, DmaDir.MM2S, row+1), CoreKeyWords, #
        #                offset=MemtileKeyOffset + (row * CoreKeyWords)) for row in range(AieRows) ] +
        #[TransferParams(memtile_dma(col, DmaDir.MM2S, row+1), CoreValWords, #memtile_dma(col, DmaDir.MM2S, row+1), CoreValWords, #
        #                offset=MemtileValOffset + (row * CoreValWords)) for row in range(AieRows)],
        #sync_strategy=SyncStrategy.Parallel_1_to_N
    #) for col in range(AieCols)

'''
(MemtileQryStep0, MemtileQryStep1, MemtileQryStep2, _,
 MemtileQryWrap0, MemtileQryWrap1, _,
 ) = conv_rm_to_w8_subvolumes(Sc, Dih, QryBytes, None, H, 0, 1)


(MemtileKeyStep0, MemtileKeyStep1, MemtileKeyStep2, MemtileKeyStep3,
 MemtileKeyWrap0, MemtileKeyWrap1, MemtileKeyWrap2,
 ) = conv_rm_to_w8_subvolumes(AieRows*Sic, Dih, KeyBytes, None, H, 0, 1) #, subv_step=(MemtileKvbWords // AieRows))

if(V_DDR_Tranposed):
    MemtileValOffset = CoreKeyWords * AieRows
    (MemtileValStep0, MemtileValStep1, MemtileValStep2, MemtileValStep3,
    MemtileValWrap0, MemtileValWrap1, MemtileValWrap2,
    ) = conv_rm_to_w8_subvolumes(Dh, Sic, ValBytes, None, H, 0, 1) #, subv_step=(MemtileKvbWords // AieRows))
else:
    MemtileValOffset = CoreKeyWords * AieRows  # or * 1
    (MemtileValStep0, MemtileValStep1, MemtileValStep2, MemtileValStep3,
    MemtileValWrap0, MemtileValWrap1, MemtileValWrap2,
    ) = conv_rm_to_w8_subvolumes(AieRows*Sic, Dh, ValBytes, None, H, 0, 1) #, subv_step=(MemtileKvbWords // AieRows))
'''


#(ShimQryStep0, ShimQryStep1, ShimQryWrap0, ShimQryIterStep) = access_shim_rm_vert_shard(St, Di, Dih, QryBytes) #access_shim_rm_vert_shard(S, Di, Sc, Dih, QryBytes)
#(ShimKeyStep0, ShimKeyStep1, ShimKeyWrap0, ShimKeyIterStep) = access_shim_rm_vert_shard(S, Di, Dih, KeyBytes)
#(ShimValStep0, ShimValStep1, ShimValWrap0, ShimValIterStep) = access_shim_rm_vert_shard(S, D, Sic, Dh, ValBytes)
#(ShimValStep0, ShimValStep1, ShimValStep2, ShimValWrap0, ShimValWrap1, ShimValIterStep) = access_shim_rm_hori_shard((D, S, Dh, Sic, ValBytes)

#ShimValBD = {}
#if(V_DDR_Tranposed):    
    #(ShimValStep0, ShimValStep1, ShimValStep2, ShimValWrap0, ShimValWrap1, ShimValIterStep) = \
    #    access_shim_rm_hori_shard(D, S, Dh, Sic, ValBytes)
    #ShimValBD = {'steps':[ShimValStep0, ShimValStep1, ShimValStep2], 'wraps':[ShimValWrap0, ShimValWrap1], 'iter_step':ShimValIterStep}
#else:
    #(ShimValStep0, ShimValStep1, ShimValWrap0, ShimValIterStep) = \
    #    access_shim_rm_vert_shard(S, D, Dh, ValBytes)
    #ShimValBD = {'steps':[ShimValStep0, ShimValStep1], 'wraps':[ShimValWrap0], 'iter_step':ShimValIterStep}
#(ShimOutStep0, ShimOutStep1, ShimOutWrap0, ShimOutIterStep) = access_shim_rm_vert_shard(S, D, Dh, OutBytes)#access_shim_rm_vert_shard(S, D, Sc, Dh, OutBytes)


#(MemtileKeyOutStep0,MemtileKeyOutStep1, MemtileKeyOutStep2, MemtileKeyOutWrap0, MemtileKeyOutWrap1,
# MemtileKeyOutWrap2, MemtileKeyPad1before, MemtileKeyPad1after) = zero_pad_rows_w8_subvolume(Sic, Dh, min_dim-Sic, KeyBytes)

#(MemtileOutStep0, MemtileOutStep1, MemtileOutStep2, _,
# MemtileOutWrap0, MemtileOutWrap1, _,
# ) = conv_w8_to_rm_subvolumes(Sc, Dh, OutBytes, None, 0, id="O", refmode=1)

#[TransferParams(memtile_dma(col, DmaDir.S2MM, 2), CoreOutWords, step=[MemtileOutStep0, MemtileOutStep1, MemtileOutStep2], wrap=[MemtileOutWrap0, MemtileOutWrap1])],
#[TransferParams(memtile_dma(col, DmaDir.MM2S, 5), CoreOutWords)],    

'''
                CallKernel('run_act_K_preprocess', mha_qkt_sfmx_qdq_params(
                                                           (Sc, Dh), 
                                                           (Dh, Sic), 
                                                           2, 0, 
                                                           core_col_id, core_row_id, 
                                                           S, 
                                                           CoreTdm1Addr, CoreTdm2Addr,
                                                           CoreQdqPingAddr, 
                                                           CoreAct1SumAddr, CoreAct2SumAddr, 
                                                           CoreC0Addr, CoreScratchAddr)
                                                          ),   
'''
'''
CallKernel('run_act_K_preprocess', mha_qkt_sfmx_qdq_params(
                                            (Sc, Dh), 
                                            (Dh, Sic), 
                                            2, 0, 
                                            core_col_id, core_row_id, 
                                            S, 
                                            CoreTdm1Addr, CoreTdm2Addr,
                                            CoreQdqPingAddr, 
                                            CoreAct1SumAddr, CoreAct2SumAddr, 
                                            CoreC0Addr, CoreScratchAddr)
                                            ),
'''
'''
def get_core_instrs(core_col_id:int, core_row_id:int):
    if core_col_id == AieCols - 1 and core_row_id == 0:
        return  [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreQdqPingAddr, None, CoreQdqPrmSize),
            AcqBuffer(   DmaChannel(DmaDir.S2MM, 0)),
            RelBuffer(   DmaChannel(DmaDir.S2MM, 0)),
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreQryPingAddr, None, CoreQrySize),
            ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutAddr, None, CoreOutSize),

            Loop(H, [
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreKeyPingAddr, None, CoreKeyValSize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),       # acq KV
                CallKernel('run_act_K_preprocess', gen_mha_params( 2, 0, core_col_id, core_row_id)),
                CallKernel('run_act_V_preprocess', gen_mha_params( 2, 0, core_col_id, core_row_id)),

                Loop(Li, Conditional_AcqBuffer(DmaChannel(DmaDir.MM2S, 0)) + 
                    [   
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),   # acq Q
                    
                    CallKernel('run_qkt_gemm_qdq', gen_mha_params( 0, 1, core_col_id, core_row_id)),
                    CallKernel('run_softmax_int16_to_int16', gen_mha_params( 0, 1, core_col_id, core_row_id)),
                    CallKernel('run_smxv_gemm_qdq', gen_mha_params( 0, 1, core_col_id, core_row_id)),

                    RelBuffer(DmaChannel(DmaDir.S2MM, 0))   # rel Q
                ] + 
                Conditional_RelBuffer(DmaChannel(DmaDir.MM2S, 0))),   # rel O),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),       # rel KV
            ])
        ]
    else:
        return [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreQdqPingAddr, None, CoreQdqPrmSize),
            AcqBuffer(   DmaChannel(DmaDir.S2MM, 0)),
            RelBuffer(   DmaChannel(DmaDir.S2MM, 0)),
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreQryPingAddr, None, CoreQrySize),

            Loop(H, [
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreKeyPingAddr, None, CoreKeyValSize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),       # acq KV
                CallKernel('run_act_K_preprocess', gen_mha_params( 2, 0, core_col_id, core_row_id)),
                CallKernel('run_act_V_preprocess', gen_mha_params( 2, 0, core_col_id, core_row_id)),

                Loop(Li, []+[
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),   # acq Q

                    CallKernel('run_qkt_gemm_qdq', gen_mha_params( 0, 1, core_col_id, core_row_id)),
                    CallKernel('run_softmax_int16_to_int16', gen_mha_params( 0, 1, core_col_id, core_row_id)),
                    CallKernel('run_smxv_gemm_qdq', gen_mha_params( 0, 1, core_col_id, core_row_id)),

                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),   # rel Q
                ]+[]),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),       # rel KV
            ])
        ]
def gemm_params(zero_acc: int):
    assert zero_acc in (0, 1)
    return zero_acc.to_bytes(length=1, byteorder='little', signed=False)

def mha_params(m_subv:int, k_subv:int, n_subv: int, l_subv: int, spatialsplit: int, sfmx_mask: bool):

    qry_addr = 0                 # args_params[5];  #0;
    key_addr = qry_addr + 4*1024 # args_params[6];  #qry_addr  +  4*1024;
    val_addr = key_addr + 8*1024 # args_params[7];  #key_addr  +  8*1024;
    out_addr = val_addr + 8*1024 # args_params[8];  #val_addr  +  8*1024;
    tdm1_addr= out_addr + 4*1024 # args_params[9];  #out_addr  +  4*1024;
    tdm2_addr= tdm1_addr+ 8*1024 # args_params[10]  #tdm1_addr +  8*1024;
    qkt_addr = tdm2_addr+ 8*1024 # args_params[11]; #tdm2_addr +  8*1024;
    sfm_addr = qkt_addr + 0 #4*1024       # args_params[12]; #qkt_addr;

    return m_subv.to_bytes(length=2, byteorder='little', signed=False) + \
           k_subv.to_bytes(length=2, byteorder='little', signed=False) + \
           n_subv.to_bytes(length=2, byteorder='little', signed=False) + \
           l_subv.to_bytes(length=2, byteorder='little', signed=False) + \
           spatialsplit.to_bytes(length=2, byteorder='little', signed=False) + \
           qry_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           key_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           val_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           out_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           tdm1_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           tdm2_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           qkt_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           sfm_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           sfmx_mask.to_bytes(length=2, byteorder='little', signed=False)

def mknl_params(m_subv:int, k_subv:int, n_subv: int, l_subv: int):
    return m_subv.to_bytes(length=1, byteorder='little', signed=False) + \
           k_subv.to_bytes(length=1, byteorder='little', signed=False) + \
           n_subv.to_bytes(length=1, byteorder='little', signed=False) + \
           l_subv.to_bytes(length=1, byteorder='little', signed=False)

def tobfp16_params(m_subv:int, k_subv:int, n_subv: int, l_subv: int):
    qry_addr = 0                 # args_params[5];  #0;
    key_addr = qry_addr + 4*1024 # args_params[6];  #qry_addr  +  4*1024;
    val_addr = key_addr + 8*1024 # args_params[7];  #key_addr  +  8*1024;

    return key_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           val_addr.to_bytes(length=2, byteorder='little', signed=False)

'''
'''
dma_connections = [
    DmaConnection(AieDma(AieTile(TileType.Shim,    col, 0), DmaChannel(DmaDir.MM2S, 0)),
                  AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.S2MM, 0)))
    for col in range(AieCols)
] + [
    DmaConnection(AieDma(AieTile(TileType.Shim,    col, 0), DmaChannel(DmaDir.MM2S, 1)),
                  AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.S2MM, 1)))
    for col in range(AieCols)
] + [
    DmaConnection(AieDma(AieTile(TileType.Memtile, col, 0),   DmaChannel(DmaDir.MM2S, 0)),
                  AieDma(AieTile(TileType.Core,    row, col), DmaChannel(DmaDir.S2MM, 0)))
    for col in range(AieCols)
    for row in range(AieRows)
] + [
    DmaConnection(AieDma(AieTile(TileType.Memtile, col, 0),   DmaChannel(DmaDir.MM2S, 1 + row)),
                  AieDma(AieTile(TileType.Core,    col, row), DmaChannel(DmaDir.S2MM, 1)))
    for col in range(AieCols)
    for row in range(AieRows)
] + [
    DmaConnection(AieDma(AieTile(TileType.Core,    col, row), DmaChannel(DmaDir.MM2S, 0)),
                  AieDma(AieTile(TileType.Memtile, col, 0),   DmaChannel(DmaDir.S2MM, 2 + row)))
    for col in range(AieCols)
    for row in range(AieRows)
] + [
    DmaConnection(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 5)),
                  AieDma(AieTile(TileType.Shim,    col, 0), DmaChannel(DmaDir.S2MM, 0)))
    for col in range(AieCols)
]
'''
