import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels','mha_qdq'))

from dataclasses import dataclass, field
from mha_params import generate_layer_kernel_params, MhaSubvDims, OPMode
from dataflow.mha.parameters import deserialize_parameters, Parameters
from OGOAT.src.Scheduling_Engine.schedules.BufferAllocatorResult import BufferAllocations
from typing import Tuple, List, ClassVar, Dict
from dataflow.dataflow_common import overlay_8x4_dma_connections, overlay_8x4_core_stream_bdcast, overlay_4x4_core_stream_bdcast, overlay_4x4_dma_connections, overlay_stack_addr, overlay_stack_size

from dmacompiler import OverlayShape, DataTransfer, TransferParams, SyncStrategy, BackEnd, \
    DmaChannel, DmaDir, AieDma, AieTile, TileType, DmaConnection, CascDir, \
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    generate_transfer_params,\
    shim_dma, memtile_dma,\
    TransferParams,\
    run_overlay_deadlock_check, \
    compile_dma_layer_config, \
    shim_dma, memtile_dma, core_dma, \
    run_layer_compilation, \
    generate_shim_data_transfer


def ceildiv(x: int, d: int) -> int:
    return -(x // -d)


def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d


def mha_qkt_sfmx_qdq_params(
    input: Tuple[int, int],
    output: Tuple[int, int],
    mha_mode: int,
    multi_core: int,
    col_id: int,
    row_id: int,
    Sin_kv: int,
    core_tdm1_addr: int,
    core_tdm2_addr: int,
    core_qdq_addr: int,
    core_act1_addr: int,
    core_act2_addr: int,
    core_C0_addr: int,
    core_scratch_addr: int,
) -> bytes:
    Y_gran = 1
    X_gran = 8
    Co_gran = 8
    Ci_gran = 8
    size_bytes = 2
    stride_efficiency = 0.5
    mem_align = 64
    M, K = input
    K, N = output
    Ky, Kx = 1, 1
    Sy, Sx = 1, 1
    op_mode = OPMode.OP_SUM
    params_blob = generate_layer_kernel_params(
        mha_mode,
        multi_core,
        col_id,
        row_id,
        Sin_kv,
        core_tdm1_addr,
        core_tdm2_addr,
        core_qdq_addr,
        core_act1_addr,
        core_act2_addr,
        core_C0_addr,
        core_scratch_addr,
        MhaSubvDims(
            1,
            1, Y_gran,
            M, X_gran,
            K, Ci_gran,
            N, Co_gran,
            Ky, Kx,
            Sy, Sx,
            op_mode,
            size_bytes,
            stride_efficiency,
            mem_align,
        )
    )
    return params_blob


@dataclass(frozen=True)
class Mha2p1Parameters:
    """
    Definition and initialization of the internal parameters needed for the
    creation of the data transfers for the MHA operator.
    The dataclass is frozen and the object cannot be modified after creation
    as the parameters values are not supposed to change once computed.
    """
    AieCols: int
    AieRows: int

    Num4x4: int
    NumAieCompCols: int

    IfmBytes: int
    OutBytes: int

    Sin_q: int
    Sin_kv: int

    Sq: int
    Skv: int

    Tm: int
    Tn: int
    To: int

    # Constant parameters
    QdqNodes: ClassVar[int] = 6
    QdqPrm: ClassVar[int] = 16
    QdqPrmBytes: ClassVar[int] = 4
    bits_per_byte: ClassVar[int] = 8
    bytes_per_word: ClassVar[int] = 4
    Dh: ClassVar[int] = 64
    H: ClassVar[int] = 1
    TdmBytes: ClassVar[int] = 4
    C0Bytes: ClassVar[int] = 8

    @staticmethod
    def compute_internal_parameters(tiler_output: Dict):
        AieCols = tiler_output["overlay_info"]["shape"]["col"]
        AieRows = tiler_output["overlay_info"]["shape"]["row"]

        assert AieRows == 4, f"AieRows={AieRows} not supported"
        if AieCols == 8:
            Num4x4 = 2
        elif AieCols == 4:
            Num4x4 = 1
        else:
            assert False, f"Unsupported overlay: {AieCols}x{AieRows}"

        NumAieCompCols = AieCols // Num4x4

        IfmBytes    = tiler_output["layer_info"]["in_q_bytes"]
        OutBytes    = tiler_output["layer_info"]["out_ofm_bytes"]

        Sin_q = tiler_output["layer_info"]["in_q_shape"][0]
        Sin_kv = tiler_output["layer_info"]["in_k_shape"][1]

        Sq = tiler_output["core_tile_params"]["subvols"]['q'][0] // Num4x4
        # (((Sin_kv - 1) // 8) + 1) * 8
        # sub-volume size - padded to multiple of 8
        Skv = tiler_output["core_tile_params"]["subvols"]['k'][1]

        assert Skv * AieRows * NumAieCompCols == Sin_kv

        Tm = (tiler_output["core_tile_params"]["iters"]['q'][0] * tiler_output["mem_tile_params"]["iters"]['q'][0]) * Num4x4
        Tn = tiler_output["core_tile_params"]["iters"]['k'][0] * tiler_output["mem_tile_params"]["iters"]['k'][0]
        To = (tiler_output["core_tile_params"]["iters"]['ofm'][0] * tiler_output["mem_tile_params"]["iters"]['ofm'][0]) * Num4x4

        return Mha2p1Parameters(AieCols=AieCols,
                                AieRows=AieRows,
                                Num4x4=Num4x4,
                                NumAieCompCols=NumAieCompCols,
                                IfmBytes=IfmBytes,
                                OutBytes=OutBytes,
                                Sin_q=Sin_q,
                                Sin_kv=Sin_kv,
                                Sq=Sq,
                                Skv=Skv,
                                Tm=Tm,
                                Tn=Tn,
                                To=To)


def generate_standalone_buffer_allocation(tiler_output: Dict) -> BufferAllocations:
    """
    WARNING: The purpose of this function is to provide a way to hardcode the buffer allocation
    and test the dataflow of the mha operator without going through the scheduler and the buffer
    allocator. This is should only be used for testing and not for the product part.
    """
    p = Mha2p1Parameters.compute_internal_parameters(tiler_output)

    CoreAlignSize = 64

    # Compute the sizes of the cores buffer allocations
    CoreQdqPrmSize = (p.QdqNodes * p.QdqPrm * p.QdqPrmBytes)
    CoreQrySize    = (p.Sq  * p.Dh   * p.IfmBytes) * p.Num4x4
    CoreKeySize    = (p.Skv * p.Dh   * p.IfmBytes)
    CoreOutSize    = (p.Sq  * p.Skv  * p.OutBytes)
    CoreTdmSize    = (p.Sq  * p.Skv  * p.TdmBytes)
    print(CoreTdmSize)
    CoreAct1SumSize = iceil(p.TdmBytes * p.Sq, 256)
    CoreAct2SumSize = iceil(p.Skv * p.TdmBytes, 256)
    CoreC0Size = p.Skv * p.C0Bytes

    # Compute the addresses of the cores buffer allocations
    CoreQryPingAddr    = 0
    CoreKeyPingAddr    = iceil(CoreQryPingAddr + CoreQrySize, CoreAlignSize)
    CoreTdm1Addr       = iceil(CoreKeyPingAddr + CoreKeySize, CoreAlignSize)
    CoreTdm2Addr       = iceil(CoreTdm1Addr + CoreTdmSize // 2, CoreAlignSize)
    CoreOutPingAddr = CoreTdm1Addr
    assert CoreOutPingAddr + CoreOutSize <= CoreTdm2Addr
    CoreQdqPingAddr    = iceil(CoreTdm2Addr + CoreTdmSize // 2, CoreAlignSize)
    CoreAct1SumAddr    = iceil(CoreQdqPingAddr + CoreQdqPrmSize, CoreAlignSize)
    CoreAct2SumAddr    = iceil(CoreAct1SumAddr + CoreAct1SumSize, CoreAlignSize)
    CoreC0Addr         = iceil(CoreAct2SumAddr + CoreAct2SumSize, CoreAlignSize)
    CoreScratchAddr    = iceil(CoreC0Addr + CoreC0Size,CoreAlignSize)

    res = BufferAllocations()
    res.add_core_allocation(buffer_name="q", size=CoreQrySize, ping_addresses=[CoreQryPingAddr])
    res.add_core_allocation(buffer_name="k", size=CoreKeySize, ping_addresses=[CoreKeyPingAddr])
    res.add_core_allocation(buffer_name="tdm", size=CoreTdmSize // 2, ping_addresses=[CoreTdm1Addr, CoreTdm2Addr])
    res.add_core_allocation(buffer_name="ofm", size=CoreOutSize, ping_addresses=[CoreTdm1Addr])
    res.add_core_allocation(buffer_name="act1_sum", size=CoreAct1SumSize, ping_addresses=[CoreAct1SumAddr])
    res.add_core_allocation(buffer_name="act2_sum", size=CoreAct2SumSize, ping_addresses=[CoreAct2SumAddr])
    res.add_core_allocation(buffer_name="c0", size=CoreC0Size, ping_addresses=[CoreC0Addr])
    res.add_core_allocation(buffer_name="qdq", size=CoreQdqPrmSize, ping_addresses=[CoreQdqPingAddr])

    # FIXME: scratch and stack sizes are not needed for the computation so we provide a fake data for testing
    res.add_core_allocation(buffer_name="scratch", size=0, ping_addresses=[CoreScratchAddr])
    res.add_core_allocation(buffer_name="stack", size=overlay_stack_size(), ping_addresses=[overlay_stack_addr()])

    # Compute the sizes of the memtile buffers allocations
    MemtilePrmSize    = 1024 * (p.AieRows)
    MemtileQdqPrmSize = CoreQdqPrmSize
    MemtileQrySize   = tiler_output["mem_tile_params"]["sizes"]['q']
    MemtileKeySize   = tiler_output["mem_tile_params"]["sizes"]['k']
    MemtileOutSize   = tiler_output["mem_tile_params"]["sizes"]['ofm']
    MemtileQkvSize    = MemtileQrySize + MemtileKeySize

    # Compute the addresses of the memtile buffers allocations
    MemtilePrmPingAddr   = 0
    MemtileQPingAddr   = MemtilePrmPingAddr + MemtilePrmSize
    MemtileKPingAddr   = MemtileQPingAddr + MemtileQrySize
    MemtileQPongAddr   = MemtileKPingAddr + MemtileKeySize
    MemtileKPongAddr   = MemtileQPongAddr + MemtileQrySize
    MemtileOutPingAddr   = MemtileKPongAddr + MemtileKeySize
    MemtileOutPongAddr   = MemtileOutPingAddr + MemtileOutSize
    MemtileQdqPingAddr   = MemtileOutPongAddr + MemtileOutSize

    res.add_mem_allocation(buffer_name="q", size=MemtileQrySize, ping_addresses=[
                           MemtileQPingAddr], pong_addresses=[MemtileQPongAddr])
    res.add_mem_allocation(buffer_name="k", size=MemtileKeySize, ping_addresses=[
                           MemtileKPingAddr], pong_addresses=[MemtileKPongAddr])
    res.add_mem_allocation(buffer_name="ofm", size=MemtileOutSize, ping_addresses=[
                           MemtileOutPingAddr], pong_addresses=[MemtileOutPongAddr])
    res.add_mem_allocation(
        buffer_name="qdq", size=MemtileQdqPrmSize, ping_addresses=[MemtileQdqPingAddr])
    res.add_mem_allocation(
        buffer_name="prm", size=MemtilePrmSize, ping_addresses=[MemtilePrmPingAddr])

    return res


def generate_dataflow_mha(
    p: Mha2p1Parameters,
    buffer_alloc: BufferAllocations,
    CodeBackend: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
) -> Tuple[int, int, int, int]:

    # Extract Core Buffers Allocations:
    CorePrmSize    = 1024
    CoreQdqPrmSize = (p.QdqNodes * p.QdqPrm * p.QdqPrmBytes)
    CoreQrySize    = buffer_alloc.get_core_alloc("q").ping.size
    CoreKeySize    = buffer_alloc.get_core_alloc("k").ping.size
    CoreOutSize    = buffer_alloc.get_core_alloc("ofm").ping.size
    CoreTdmSize    = buffer_alloc.get_core_alloc("tdm").ping.size

    CoreQryPingAddr    = buffer_alloc.get_core_alloc("q").ping.addresses[0]
    CoreKeyPingAddr    = buffer_alloc.get_core_alloc("k").ping.addresses[0]
    CoreTdm1Addr       = buffer_alloc.get_core_alloc("tdm").ping.addresses[0]
    CoreTdm2Addr       = buffer_alloc.get_core_alloc("tdm").ping.addresses[1]

    # output buffer is shared with the first half of tdm in the implementation
    # of the qkt kernel, make sure that it is can fit and will not overflow
    CoreOutPingAddr    = CoreTdm1Addr
    assert CoreOutSize <= CoreTdmSize

    CoreQdqPingAddr    = buffer_alloc.get_core_alloc("qdq").ping.addresses[0]
    CoreAct1SumAddr    = buffer_alloc.get_core_alloc("act1_sum").ping.addresses[0]
    CoreAct2SumAddr    = buffer_alloc.get_core_alloc("act2_sum").ping.addresses[0]
    CoreC0Addr         = buffer_alloc.get_core_alloc("c0").ping.addresses[0]
    CoreScratchAddr    = buffer_alloc.get_core_alloc("scratch").ping.addresses[0]

    # Extract Memtile Buffers Allocations:
    MemtilePrmSize = buffer_alloc.get_mem_alloc("prm").ping.size
    MemtileQdqPrmSize = buffer_alloc.get_mem_alloc("qdq").ping.size
    MemtileQrySize   = buffer_alloc.get_mem_alloc('q').ping.size
    MemtileKeySize   = buffer_alloc.get_mem_alloc("k").ping.size
    MemtileOutSize   = buffer_alloc.get_mem_alloc("ofm").ping.size
    MemtileQkvSize   = MemtileQrySize + MemtileKeySize

    MemtilePrmPingAddr   = buffer_alloc.get_mem_alloc("prm").ping.addresses[0]
    MemtileQkvPingAddr   = buffer_alloc.get_mem_alloc("q").ping.addresses[0]
    MemtileQkvPongAddr   = buffer_alloc.get_mem_alloc("q").pong.addresses[0]

    # Currently the memory layout expected for the data transfer is:
    # MemtileQPing -> MemtileKPing -> MemtileQPong -> MemtileKPong
    # assert if that's not the case as it could invalidate the data transfer
    # below.
    MemtileKeyPingAddr   = buffer_alloc.get_mem_alloc("k").ping.addresses[0]
    assert MemtileKeyPingAddr == MemtileQkvPingAddr + MemtileQrySize

    MemtileKeyPongAddr   = buffer_alloc.get_mem_alloc("k").pong.addresses[0]
    assert MemtileKeyPongAddr == MemtileQkvPongAddr + MemtileQrySize

    MemtileOutPingAddr   = buffer_alloc.get_mem_alloc("ofm").ping.addresses[0]
    MemtileOutPongAddr   = buffer_alloc.get_mem_alloc("ofm").pong.addresses[0]
    MemtileQdqPingAddr   = buffer_alloc.get_mem_alloc("qdq").ping.addresses[0]

    ShimQdqPrmSize = MemtileQdqPrmSize

    CorePrmWords    = CorePrmSize // p.bytes_per_word
    MemtilePrmWords    = MemtilePrmSize // p.bytes_per_word
    MemtileQdqPrmWords = MemtileQdqPrmSize // p.bytes_per_word
    ShimQdqPrmWords = ShimQdqPrmSize // p.bytes_per_word

    def Memtile(col: int):
        return AieTile(TileType.Memtile, col, 0)

    def Shimtile(col: int):
        return AieTile(TileType.Shim, col, 0)

    def get_core_instrs(core_col_id: int, core_row_id: int):
        assert 2 <= core_row_id < 6
        assert 0 <= core_col_id < 8
        return [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreQdqPingAddr, None, CoreQdqPrmSize),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
            Loop(p.H, [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreQryPingAddr, None, CoreQrySize),
            ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutPingAddr, None, CoreOutSize),  # output is shared with TDM buffer
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreKeyPingAddr, None, CoreKeySize),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),  # key unicast
            CallKernel('run_qtxk', mha_qkt_sfmx_qdq_params(
                                                            (p.Sq, p.Dh),
                                                            (p.Dh, p.Skv),
                                                            2, 0,
                                                            core_col_id, core_row_id,
                                                            p.Sin_kv,
                                                            CoreTdm1Addr, CoreTdm2Addr,
                                                            CoreQdqPingAddr,
                                                            CoreAct1SumAddr, CoreAct2SumAddr,
                                                            CoreC0Addr, CoreScratchAddr)
                                                            ),     # 2/3 - K pre-processing

            Loop(p.Tm * p.Tn, [    # loop over M split and N split
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),  # output
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),  # query broadcast
                CallKernel('run_qtxk', mha_qkt_sfmx_qdq_params(
                                                            (p.Sq, p.Dh),
                                                            (p.Dh, p.Skv),
                                                            0, 0,
                                                            core_col_id, core_row_id,
                                                            p.Sin_kv,
                                                            CoreTdm1Addr, CoreTdm2Addr,
                                                            CoreQdqPingAddr,
                                                            CoreAct1SumAddr, CoreAct2SumAddr,
                                                            CoreC0Addr, CoreScratchAddr)),     # 0 - standalone QKt, not fused

                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                CallKernel('run_sfmx', mha_qkt_sfmx_qdq_params(
                                                            (p.Sq, p.Dh),
                                                            (p.Dh, p.Skv),
                                                            0, 1,
                                                            (core_col_id % p.NumAieCompCols), core_row_id,
                                                            p.Sin_kv,
                                                            CoreTdm1Addr, CoreTdm2Addr,
                                                            CoreQdqPingAddr,
                                                            CoreAct1SumAddr, CoreAct2SumAddr,
                                                            CoreC0Addr, CoreScratchAddr)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            ]),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ]),
        ]

    shape = OverlayShape(p.AieCols, p.AieRows)

    core_instrs_array = []
    for col in range(p.AieCols):
        for row in range(2, p.AieRows + 2):
            core_instrs_array.append(get_core_instrs(col, row))

    instr_dict = {}
    for col in range(p.AieCols):
        for row in range(p.AieRows):
            instr_dict[AieTile(TileType.Core, col, row)] = core_instrs_array[col * p.AieRows + row]

    mem_qt = [

            DataTransfer(

                [p.Tm * p.H],

                AieTile(TileType.Memtile, col), [MemtileQkvPingAddr, MemtileQkvPongAddr], MemtileQkvSize,

                [generate_transfer_params( 
                    memtile_dma(col, DmaDir.S2MM, 0),
                    memory_format=f"Yq:{p.Num4x4*p.Sq} Xq:{p.Dh} Yq:{p.Sq} Xq:8",
                    tiling_format=f"Yq:0:{p.Num4x4*p.Sq} Xq:0:{p.Dh}",
                    bits_per_block=p.IfmBytes * p.bits_per_byte)],

                [generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, 4),
                    memory_format=f"Yq:{p.Num4x4*p.Sq} Xq:{p.Dh} Yq:{p.Sq} Xq:8",
                    ## TODO: add logic to keep half of the data in 4x4 from mem to core
                    tiling_format=f"Yq:0:{p.Num4x4*p.Sq}:{p.Sq} Xq:0:{p.Dh}:8 Yq:0:{p.Sq}  Xq:0:8",
                    bits_per_block=p.IfmBytes * p.bits_per_byte)],
                )

                for col in range(0, p.AieCols, p.Num4x4)]

    ## TODO: This version of DMACompiler does not support buffer offset in the transfer params. 
    mem_kt = [

            DataTransfer(

                [1 * p.H],

                AieTile(TileType.Memtile, col), [MemtileQkvPingAddr], MemtileQkvSize,

                [generate_transfer_params( 
                    memtile_dma(col, DmaDir.S2MM, 1),
                    memory_format=f"Yk:{p.AieRows*p.Skv} Xk:{p.Dh} Yk:{p.Skv} Xk:8",
                    tiling_format=f"Yk:0:{p.Skv*p.AieRows} Xk:0:{p.Dh}",
                    bits_per_block=p.IfmBytes * p.bits_per_byte,
                    buffer_offset= MemtileQrySize)], ## Bytes or Words? ## TODO: This offset might need fixing

                [generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, row),
                    memory_format=f"Yk:{p.AieRows*p.Skv} Xk:{p.Dh} Yk:{p.Skv} Xk:8",
                    tiling_format=f"Xk:0:{p.Dh}:8 Yk:{row*p.Skv}:{(row+1)*p.Skv} Xk:0:8",
                    bits_per_block=p.IfmBytes * p.bits_per_byte,
                    buffer_offset= MemtileQrySize)

                    for row in range(p.AieRows)],

                )

                for col in range(p.AieCols)]

    mem_ot = [

            DataTransfer(

                [p.To * p.H],

                AieTile(TileType.Memtile, col), [MemtileOutPingAddr, MemtileOutPongAddr], MemtileOutSize,

                [generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, row+2),
                    memory_format   = f"Xo:{p.AieRows*p.Skv} Yo:{p.Sq} Xo:{p.Skv}",
                    tiling_format = f"Xo:{row*p.Skv}:{(row+1)*p.Skv}:8 Yo:0:{p.Sq} Xo:0:8",
                    bits_per_block = p.OutBytes * p.bits_per_byte)

                    for row in range(p.AieRows)],

                [generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, 5),
                    memory_format= f"Xo:{p.AieRows*p.Skv} Yo:{p.Sq} Xo:{p.Skv}",
                    tiling_format= f"Yo:0:{p.Sq} Xo:0:{p.AieRows*p.Skv}",
                    bits_per_block=p.OutBytes*p.bits_per_byte)],

                )

                for col in range(p.AieCols)]

    memtile_transfers = [
        DataTransfer(
            [1], Memtile(col), [MemtilePrmPingAddr], MemtilePrmSize,

            [TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.S2MM, 0)), MemtilePrmWords)],

            [TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.MM2S, row)), CorePrmWords, offset=(row * CorePrmWords))
             for row in range(p.AieRows)]

        ) for col in range(p.AieCols)
    ] + [
        DataTransfer(
            [1], Memtile(col), [MemtileQdqPingAddr], MemtileQdqPrmSize,

            [TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.S2MM, 0)), MemtileQdqPrmWords)],

            [TransferParams(AieDma(Memtile(col), DmaChannel(DmaDir.MM2S, 4)), MemtileQdqPrmWords)]
        )
        for col in range(0, p.AieCols, p.Num4x4)

    ] + mem_kt + mem_qt + mem_ot

    ## TODO: May be arithmatic calculation done in better way
    shim_kt = [ generate_shim_data_transfer([1], AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 1)), 1,
                                memory_format=f"Yk:{p.H*(p.Sin_q+p.Sin_kv)} Xk:{p.Dh}",
                                tiling_format=f"Yk:{p.H* p.Sin_q}:{p.H*(p.Sin_q+p.Sin_kv)}:{p.Sin_kv} Yk:{((col%p.NumAieCompCols) * (p.Sin_kv // p.NumAieCompCols) )}:{(((col%p.NumAieCompCols)+1) * (p.Sin_kv // p.NumAieCompCols))} Xk:0:{p.Dh}",
                                bits_per_block= p.IfmBytes* p.bits_per_byte, elements_per_block = 1) for col in range(p.AieCols) ]

    shim_qt = [ generate_shim_data_transfer([1], AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 0)), 1,
                                memory_format=f"Yq:{p.H*(p.Sin_q+p.Sin_kv)} Xq:{p.Dh}",
                                tiling_format=f"Yq:0:{p.H*p.Sin_q}:{p.Sin_q} Yq:0:{p.Sin_q}:{p.Sin_q//p.AieCols} Yq:0:{p.Sin_q//p.AieCols} Xq:0:{p.Dh}",
                                bits_per_block=p.IfmBytes * p.bits_per_byte, elements_per_block = 1) for col in range(0, p.AieCols, p.Num4x4) ]
    
    shim_ot_te = [ generate_shim_data_transfer([1], AieDma(Shimtile(col), DmaChannel(DmaDir.S2MM, 0)), 0,
                                memory_format=f"Yo:{p.H*p.Sin_q} Xo:{p.Sin_kv}",
                                tiling_format=f"Yo:0:{p.H*p.Sin_q}:{p.Sin_q} Yo:0:{p.Sin_q}:{p.Num4x4*p.Sq} Yo:{p.Sq*(col//p.NumAieCompCols)}:{p.Sq*(col//p.NumAieCompCols+1)} Xo:{(p.Sin_kv // p.NumAieCompCols)*(col%p.NumAieCompCols)}:{(p.Sin_kv // p.NumAieCompCols)*((col%p.NumAieCompCols)+1)}", # Xo_span = Sin_kv // AieCols
                                bits_per_block= p.OutBytes * p.bits_per_byte, elements_per_block = 1) for col in range(p.AieCols) ]
    
    shim_transfers = [
        DataTransfer(
            [1], Shimtile(col), [3], MemtilePrmSize,
            [],
            [TransferParams(AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 0)), MemtilePrmWords, offset=((col * MemtilePrmWords)))]
        ) for col in range(p.AieCols)
    ] + [
        DataTransfer(
            [1], Shimtile(col), [2], ShimQdqPrmSize,
            [],
            [TransferParams(AieDma(Shimtile(col), DmaChannel(DmaDir.MM2S, 0)), ShimQdqPrmWords)]
        ) for col in range(0, p.AieCols, p.Num4x4)

    ] + shim_kt + shim_qt + shim_ot_te

    dma_connections = overlay_8x4_dma_connections() if p.Num4x4 == 2 else overlay_4x4_dma_connections()

    run_layer_compilation(
        shape,
        kernel_names,
        kernel_includes,
        instr_dict,
        memtile_transfers,
        shim_transfers,
        dma_connections=dma_connections,
        back_end=CodeBackend,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
        layer_file='dma.hpp',
        casc_dir = CascDir.Vertical,
        core_connections = overlay_8x4_core_stream_bdcast() if p.Num4x4 == 2 else overlay_4x4_core_stream_bdcast()
    )

    return(p.Sin_q, p.Dh, p.Sin_kv, p.Sq)
