"""
    Dataflow for Matrix Addition (Binary) at L2.

    This module implements a tiler-based approach to efficiently execute
    matrix addition within a fused dataflow pipeline.

    The primary entry point is the `run_op` function, which orchestrates
    the execution of the Binary operation, managing data transfers
    and computation flow.
"""

import os
from typing import List, no_type_check, Optional

from dataclasses import dataclass
from tiler.binary_tiler import BinaryL2Dims
from kernel.binary.pack import generate_binary_params

from scheduler.common import (
    L3Alloc, L3Alloc_to_Shim,
    overlay_3x4_F_ids, overlay_3x4_A_ids,
    overlay_3x4_B_ids, overlay_3x4_O_ids,
    overlay_3x4_S_ids, overlay_3x4_dma_connections,
    prm_shim_memory,
    prm_shim_mm2s, prm_memtile_memory,
    prm_memtile_s2mm, prm_memtile_mm2s
)

from utils.utils_common import (
    L2Alloc, log, iceil,
    overlay_3x4_core_stack_addr
)
from buildscripts.common import ScheduleInputs

from dmacompiler import (
    DevGen, set_dev_gen,
    OverlayShape, TransferParams,
    DataTransfer, SyncStrategy,
    AieTile, TileType, DmaChannel,
    DmaDir, memtile_dma, shim_dma,
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel,
    compute_buffer_size, core_dma,
    generate_transfer_params,
    generate_shim_data_transfer,
    run_layer_compilation, Loop,
    generate_core_buffer_config,
    shim_tile
)

set_dev_gen(DevGen.Aie4)

CURRDIR = os.path.dirname(os.path.abspath(__file__))


class BinaryMapping:
    '''Define subvolume size, spatial split, and L1 buffer placement'''

    def __init__(self,
                 dims: BinaryL2Dims,
                 kernel_gran: int = 64,
                 kernel_loop_range: int = 8
                 ):
        self.minimum_core_subv = kernel_gran * kernel_loop_range
        self.core_ofm_transfer_size = dims.max_subvolume * dims.ofm_bytes
        self.core_ifm_size = iceil(
            dims.max_subvolume * dims.shape.ifm_bytes, self.minimum_core_subv) * 2
        self.core_wgt_size = dims.wgt_size
        self.core_ofm_size = iceil(
            dims.max_subvolume * dims.ofm_bytes, self.minimum_core_subv)

        core_alignment = 128
        core_bank_size = 32768
        core_stack_addr = overlay_3x4_core_stack_addr()
        self.core_ifm_ping_addr = 0
        self.core_ofm_ping_addr = max(
            1*core_bank_size, iceil(self.core_ifm_ping_addr + self.core_ifm_size, core_alignment))
        self.core_ifm_pong_addr = max(
            2*core_bank_size, iceil(self.core_ofm_ping_addr + self.core_ofm_size, core_alignment))
        self.core_ofm_pong_addr = max(
            3*core_bank_size, iceil(self.core_ifm_pong_addr + self.core_ifm_size, core_alignment))
        self.core_wgt_ping_addr = iceil(
            self.core_ofm_pong_addr + self.core_ofm_size, core_alignment) + 1024  # add 1kb for extra padding,
        # so we don't corrupt qdq_params when writing
        self.core_qbuf_offset = dims.wgt_size
        # to output in vectors
        self.core_dqbuf_offset = dims.wgt_size + dims.q_buf_size
        if self.core_wgt_ping_addr + self.core_dqbuf_offset + dims.dq_buf_size > core_stack_addr:
            log("L1 allocation if loop")
            self.core_ofm_ping_addr = iceil(
                self.core_ifm_ping_addr + self.core_ifm_size, core_alignment)
            self.core_ifm_pong_addr = iceil(
                self.core_ofm_ping_addr + self.core_ofm_size, core_alignment)
            self.core_ofm_pong_addr = iceil(
                self.core_ifm_pong_addr + self.core_ifm_size, core_alignment)
            self.core_wgt_ping_addr = iceil(
                self.core_ofm_pong_addr + self.core_ofm_size, core_alignment) + 1024  # add 1kb for extra padding,
            # so we don't corrupt qdq_params when writing
            self.core_qbuf_offset = dims.wgt_size
            # to output in vectors
            self.core_dqbuf_offset = dims.wgt_size + dims.q_buf_size

        assert self.core_wgt_ping_addr + self.core_dqbuf_offset + \
            dims.dq_buf_size <= core_stack_addr

        log("core_ifm_size:", self.core_ifm_size)
        log("core_wgt_size:", self.core_wgt_size)
        log("core_ofm_size:", self.core_ofm_size)
        log(f"core_ifm_ping_addr: {self.core_ifm_ping_addr}")
        log(f"core_ofm_ping_addr: {self.core_ofm_ping_addr}")
        log(f"core_ifm_pong_addr: {self.core_ifm_pong_addr}")
        log(f"core_ofm_pong_addr: {self.core_ofm_pong_addr}")
        log(f"core_wgt_ping_addr: {self.core_wgt_ping_addr}")
        log(f"core_qbuf_offset: {self.core_qbuf_offset}")
        log(f"core_dqbuf_offset: {self.core_dqbuf_offset}")


@dataclass
class BinaryL2MemoryAllocator:
    """ Allocate L2 memory tiles and addresses for Binary operation fusion from dims and L2Alloc object."""
    def __init__(
        self,
        dims: BinaryL2Dims,
        fusion_params: L2Alloc,
        wgt_L1_size: Optional[int] = None
    ):
        self.prm_memtile_size = compute_buffer_size(prm_memtile_memory())
        self.wgt_memtile_size = dims.wgt_size if not wgt_L1_size else wgt_L1_size
        self.ofm_memtile_size = dims.ofm_size

        self.ifm_a_l2_tile, self.ifm_a_l2_addr = fusion_params.ifm_L2_loc[0]
        self.ifm_b_l2_tile, self.ifm_b_l2_addr = fusion_params.ifm_L2_loc[1]

        memtile_size = (2**20)*3
        stride = abs((self.ifm_b_l2_tile.col -
                     self.ifm_a_l2_tile.col)*memtile_size)
        if self.ifm_a_l2_tile.col < self.ifm_b_l2_tile.col:
            self.ifm_b_l2_addr = self.ifm_b_l2_addr + stride
        elif self.ifm_a_l2_tile.col > self.ifm_b_l2_tile.col:
            temp_ifm_a_l2_tile, temp_ifm_a_l2_addr = self.ifm_a_l2_tile.col, self.ifm_a_l2_addr
            self.ifm_a_l2_tile.col, self.ifm_a_l2_addr = self.ifm_b_l2_tile.col, self.ifm_b_l2_addr
            self.ifm_b_l2_tile.col, self.ifm_b_l2_addr = temp_ifm_a_l2_tile, temp_ifm_a_l2_addr + stride
        log("After adjustment:")
        log(f"ifm_a_l2_tile: {self.ifm_a_l2_tile.col}, ifm_a_l2_addr: {self.ifm_a_l2_addr}")
        log(f"ifm_b_l2_tile: {self.ifm_b_l2_tile.col}, ifm_b_l2_addr: {self.ifm_b_l2_addr}")

        self.ofm_l2_tile, self.ofm_l2_addr = fusion_params.ofm_L2_loc
        # Use the keys as the tile index
        self.wgt_L2_alloc_tiles = [entry[0]
                                   for entry in fusion_params.wgt_l2_loc]
        self.wgt_L2_ping_addrs = [entry[1]
                                  for entry in fusion_params.wgt_l2_loc]
        self.wgt_L2_pong_addrs = [entry[2]
                                  for entry in fusion_params.wgt_l2_loc]
        # Use the keys as the tile index
        self.prm_L2_alloc_tiles = [entry[0]
                                   for entry in fusion_params.prm_l2_loc]
        self.prm_L2_addrs = [entry[1] for entry in fusion_params.prm_l2_loc]


def coreid(dims: BinaryL2Dims, c: int, r: int) -> int:
    '''Flatten core to 1d index'''
    return (c * dims.aie_rows) + r


def write_tiling_expressions(dims):
    """ Returns Tiling expression based on the subv """
    S = dims.max_subvolume
    T = dims.total_iterations
    partial = dims.partial_last_iter
    A = dims.active_cores_last_iter
    C = dims.num_cores

    if S == 0 or T == 0:
        return {}, []

    out = {}
    flat = []

    step = C * S  # 12 * S by default

    if not partial:
        n = T
        for i in range(C):
            start = i * S
            # end   = (i + 1) * S
            end = step*n
            expr = f"ELEMS:{start}:{end}:{step} ELEMS:0:{S}"
            out[i] = [expr]
            flat.append(expr)
        return out, flat

    # Partial last iteration
    n_minus_1 = T - 1
    offset = step * n_minus_1

    # Precompute the "last active core" interval for reuse by inactive cores.
    if A > 0:
        last_active_start = (A - 1) * S + offset
        last_active_end = A * S + offset
        last_active_expr = f"ELEMS:{last_active_start}:{last_active_end}:{S} ELEMS:0:{S}"
    else:
        # Fallback (shouldn't happen with valid tiling)
        log("Fallback happened something wrong")
        last_active_expr = f"ELEMS:0:{S} ELEMS:0:{S}"

    for i in range(C):
        # Part 1: first (T-1) full iterations
        start1 = i * S
        # end1   = (i + 1) * S
        end1 = step*n_minus_1
        part1 = f"ELEMS:{start1}:{end1}:{step} ELEMS:0:{S}"

        # Part 2: last iteration
        if i < A:
            start2 = i * S + offset
            end2 = (i + 1) * S + offset
            part2 = f"ELEMS:{start2}:{end2}:{S} ELEMS:0:{S}"
        else:
            part2 = last_active_expr  # repeat last active core's expr

        out[i] = [part1, part2]
        flat.extend([part1, part2])

    return out, flat


# Shim tiling expressions
def ifm_shim_memory(dims: BinaryL2Dims) -> str:
    '''Memory: IFM Shim'''
    return f'ELEMS:{dims.shape.Xi * dims.shape.Yi * dims.shape.Ci}'


def ifm_shim_mm2s(dims: BinaryL2Dims) -> str:
    '''Transfer: IFM Shim MM2S'''
    return f'ELEMS:0:{dims.shape.Xi * dims.shape.Yi * dims.shape.Ci}'


def wgt_shim_memory(dims: BinaryL2Dims) -> str:
    '''Memory: WGT Shim'''
    return f'Bytes:{dims.wgt_size}'


def wgt_shim_mm2s(dims: BinaryL2Dims) -> str:
    '''Transfer: WGT Shim MM2S'''
    return f'Bytes:0:{dims.wgt_size}'


def ofm_shim_memory(dims: BinaryL2Dims) -> str:
    '''Memory: OFM Shim'''
    return f'ELEMS:{dims.shape.Xi * dims.shape.Yi * dims.shape.Ci}'


def ofm_shim_s2mm(dims: BinaryL2Dims) -> str:
    '''Transfer: OFM Shim S2MM'''
    return f'ELEMS:0:{dims.shape.Xi * dims.shape.Yi * dims.shape.Ci}'


# Memtile tiling expressions
def ifm_memtile_memory(dims: BinaryL2Dims) -> str:
    '''Memory: IFM Mem'''
    return f'ELEMS:{dims.shape.Xi * dims.shape.Yi * dims.shape.Ci}'


def ifm_memtile_s2mm(dims: BinaryL2Dims) -> str:
    '''Transfer: IFM Mem S2MM'''
    return f'ELEMS:0:{dims.shape.Xi * dims.shape.Yi * dims.shape.Ci}'


def ifm_memtile_mm2s_all_cores(dims: BinaryL2Dims, ifm_a_l2_addr, ifm_b_l2_addr, col: int) -> str:
    '''Transfer: IFM Mem MM2S without depadding'''
    transfers = []
    per_core, _ = write_tiling_expressions(dims)
    for row in range(dims.aie_rows):
        core_id = coreid(dims, col, row)
        transfers.append(
            generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                ifm_memtile_memory(dims),
                per_core[core_id][0],
                bits_per_block=dims.ifm_bits,
                buffer_offset=ifm_a_l2_addr,
                enable_padding=True,
                use_iter_step=True)[1][0]
        )
        transfers.append(
            generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                ifm_memtile_memory(dims),
                per_core[core_id][0],
                bits_per_block=dims.ifm_bits,
                buffer_offset=ifm_b_l2_addr,
                enable_padding=True,
                use_iter_step=True)[1][0]
        )
    return transfers


def ifm_memtile_mm2s_partial_cores(dims: BinaryL2Dims, ifm_a_l2_addr, ifm_b_l2_addr, col: int) -> str:
    '''Transfer: IFM Mem MM2S without depadding'''
    transfers = []
    per_core, _ = write_tiling_expressions(dims)
    for row in range(dims.aie_rows):
        core_id = coreid(dims, col, row)
        transfers.append(
            generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                ifm_memtile_memory(dims),
                per_core[core_id][1],
                bits_per_block=dims.ifm_bits,
                buffer_offset=ifm_a_l2_addr,
                enable_padding=True,
                use_iter_step=True)[1][0]
        )
        transfers.append(
            generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                ifm_memtile_memory(dims),
                per_core[core_id][1],
                bits_per_block=dims.ifm_bits,
                buffer_offset=ifm_b_l2_addr,
                enable_padding=True,
                use_iter_step=True)[1][0]
        )
    return transfers


def ifm_memtile_mm2s_dummy() -> str:
    '''Dummy transfer: IFM Mem MM2S'''
    return 'ELEMS:0:0'


def ifm_memtile_s2mm_dummy() -> str:
    '''Dummy transfer: IFM Mem S2MM'''
    return 'ELEMS:0:0'


def wgt_memtile_memory(dims: BinaryL2Dims) -> str:
    '''Memory: Wgt Mem'''
    return f'Bytes:{dims.wgt_size}'


def wgt_memtile_mm2s(dims: BinaryL2Dims) -> str:
    '''Transfer: Wgt Mem MM2S'''
    return f'Bytes:0:{dims.wgt_size}'


def wgt_memtile_s2mm(dims: BinaryL2Dims) -> str:
    '''Transfer: Wgt Mem S2MM'''
    return f'Bytes:0:{dims.wgt_size}'


def ofm_memtile_memory(dims: BinaryL2Dims) -> str:
    '''Memory: OFM Mem'''
    return f'ELEMS:{dims.shape.Xi * dims.shape.Yi * dims.shape.Ci}'


def ofm_memtile_mm2s(dims: BinaryL2Dims) -> str:
    '''Transfer: OFM Mem'''
    return f'ELEMS:0:{dims.shape.Xi * dims.shape.Yi * dims.shape.Ci}'


def ofm_memtile_s2mm(dims: BinaryL2Dims, col: int) -> List:
    '''Exact match: OFM Mem S2MM (with & without depadding), merged version.'''
    per_core, _ = write_tiling_expressions(dims)
    transfers = []
    for row in range(dims.aie_rows):
        core_id = coreid(dims, col, row)
        if dims.full_subvol_iterations > 0:
            transfers.append(
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_O_ids()[row]),
                    ofm_memtile_memory(dims),
                    per_core[core_id][0],
                    bits_per_block=dims.ofm_bits,
                )
            )
    for row in range(dims.aie_rows):
        core_id = coreid(dims, col, row)
        if dims.partial_last_iter:
            transfers.append(
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_O_ids()[row]),
                    ofm_memtile_memory(dims),
                    per_core[core_id][1],
                    bits_per_block=dims.ofm_bits,
                )
            )
    return transfers


def ifm_core_memory(dims: BinaryL2Dims):
    '''Memory: IFM Core S2MM 0'''
    return f"TENSOR:2 ELEMS:{dims.core_bank_mem_size_software // dims.shape.ifm_bytes}"


def ifm_core_s2mm(dims: BinaryL2Dims):
    '''Transfer: IFM Core S2MM 0'''
    return f"TENSOR:0:2:1 TENSOR:0:1 ELEMS:0:{dims.max_subvolume}"


def ofm_memtile_channels() -> list[tuple[int, int]]:
    '''Generate OFM channels allocations (row, id)'''
    return list(enumerate(overlay_3x4_O_ids()))


def generate_L2_alloc(
    num_cols=3,
    col_size=3 * 1024 * 1024,   # 3 MB per tile
    prm_size=4096,              # 4 KB per column
    wgt_size=1024,              # ping/pong = 1 KB each
    align=4,
    ofm_overwrites_ifm=1,       # 0 or 1 → which IFM OFM aliases
):
    """
    Generate L2 allocation over num_cols*col_size bytes:
      • 3 prm buffers (4 KB each) placed sequentially
      • 3 (ping,pong) weight pairs (1 KB each) placed sequentially
      • Remaining space split into IFM0 and IFM1
      • OFM aliases IFM0 or IFM1
      • All starts are aligned; buffers may live in any column.
    Returns an L2Alloc instance.
    """
    total_size = num_cols * col_size
    cur = 0

    def to_tile_off(g_off):
        col = g_off // col_size
        off = g_off % col_size
        return AieTile(TileType.Memtile, col), off

    # params
    prm_locs = []
    for _ in range(num_cols):
        cur = iceil(cur, align)
        prm_start = cur
        cur += prm_size
        tile, off = to_tile_off(prm_start)
        prm_locs.append([tile, off])

    # weights
    wgt_locs = []
    for _ in range(num_cols):
        cur = iceil(cur, align)
        ping_start = cur
        cur += wgt_size

        cur = iceil(cur, align)
        pong_start = cur
        cur += wgt_size

        tile_ping, off_ping = to_tile_off(ping_start)
        _, off_pong = to_tile_off(pong_start)
        # both may land in any tile; we just record their coords
        wgt_locs.append([tile_ping, off_ping, off_pong])

    # IFM0 + IFM1
    cur = iceil(cur, align)
    data_start = cur
    data_end = total_size
    remaining = data_end - data_start
    if remaining <= 0:
        raise RuntimeError("No L2 space left for IFM buffers")

    # make IFM0 size a multiple of align so IFM1 start is aligned
    half = (remaining // 2) // align * align
    if half == 0:
        raise RuntimeError("Insufficient space for two aligned IFM buffers")

    ifm0_start = data_start
    ifm1_start = ifm0_start + half

    tile0, off0 = to_tile_off(ifm0_start)
    tile1, off1 = to_tile_off(ifm1_start)

    ifm_L2_loc = [(tile0, off0), (tile1, off1)]
    ofm_tile, ofm_off = ifm_L2_loc[ofm_overwrites_ifm]

    return L2Alloc(
        ifm_L2_loc=ifm_L2_loc,
        ofm_L2_loc=(ofm_tile, ofm_off),
        wgt_l2_loc=wgt_locs,
        prm_l2_loc=prm_locs,
        enable_ifm_fill=True,
        enable_ofm_spill=True,
        enable_L2_fusion=False,
    )


# NOTE: Mypy doesn't correctly infer types with the abstract base construct
# used by the core instruction list, so we disable type checking locally here
@no_type_check
def compile_dataflow(schedule_input: ScheduleInputs):
    '''Compile dataflow with given mapping information'''
    dims: BinaryL2Dims = schedule_input.shape
    mapping: BinaryMapping = schedule_input.mapping
    L2_alloc: L2Alloc | None = schedule_input.L2_alloc
    L3_alloc: L3Alloc | None = schedule_input.L3_alloc
    shim_alloc = L3Alloc_to_Shim(L3_alloc)
    log(f"Shim Allocator: {shim_alloc}")

    if not L2_alloc.enable_L2_fusion:
        L2_alloc = generate_L2_alloc()

    memtile_alloc = BinaryL2MemoryAllocator(dims, L2_alloc)

    def generate_core_instructions(col: int, row: int):
        ifm_config = generate_core_buffer_config(
            core_dma(col, row, DmaDir.S2MM, 0),
            mapping.core_ifm_ping_addr, mapping.core_ifm_pong_addr,
            ifm_core_memory(dims),
            ifm_core_s2mm(dims),
            bits_per_block=dims.ifm_bits,
        )

        return [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 1),
                         mapping.core_wgt_ping_addr, None, mapping.core_wgt_size),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
            ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), mapping.core_ofm_ping_addr,
                         mapping.core_ofm_pong_addr, mapping.core_ofm_transfer_size),
            ifm_config,
            Loop(dims.total_iterations, [
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                CallKernel(dims.call_kernel, generate_binary_params(dims, dims.core_bank_mem_size_software,
                                                                    mapping.minimum_core_subv, mapping.core_qbuf_offset,
                                                                    mapping.core_dqbuf_offset)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            ])
        ]

    instr_dict = {
        AieTile(TileType.Core, col, row): generate_core_instructions(col, row)
        for col in range(dims.aie_cols)
        for row in range(dims.aie_rows)
    }

    repeat_count = [1]
    fill_repeat_count = [1] + [0] + [0] * len(repeat_count) + [0] + [0]
    fill_sync_repeat_count = [0] + [1] + [0] * len(repeat_count) + [0] + [0]
    spill_sync_repeat_count = [0] + [0] + [0] * len(repeat_count) + [1] + [0]
    spill_repeat_count = [0] + [0] + [0] * len(repeat_count) + [0] + [1]
    layer_param_repeat_count = [0] + [0] + repeat_count + [0] + [0]
    ifm_all_cores_repeat = [0] + [0] + \
        [dims.full_subvol_iterations] + [0] + [0]
    ifm_partial_cores_repeat = [0] + [0] + repeat_count + [0] + [0]
    ofm_repeat_count = [0] + [0] + repeat_count + [0] + [0]

    shim_prm_size = compute_buffer_size(prm_shim_memory())
    shim_wgt_size = compute_buffer_size(wgt_shim_memory(dims))

    memtile_transfers = []
    prm_memtile_transfers = [
        DataTransfer(
            layer_param_repeat_count, memtile_alloc.prm_L2_alloc_tiles[col],
            [memtile_alloc.prm_L2_addrs[col]], memtile_alloc.prm_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                prm_memtile_memory(),
                prm_memtile_s2mm(),)],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                prm_memtile_memory(),
                prm_memtile_mm2s(row)) for row in range(dims.aie_rows)],
            sync_strategy=SyncStrategy.Serial_M_to_N,
        ) for col in range(dims.aie_cols)
    ]
    memtile_transfers += prm_memtile_transfers

    if L2_alloc.enable_ifm_fill:
        log("[INFO] IFM FILL ENABLED")
        ifm_fill_transfers_1 = [
            DataTransfer(
                fill_repeat_count,
                memtile_alloc.ifm_a_l2_tile, [0], dims.ifm_size,
                [generate_transfer_params(
                    memtile_dma(memtile_alloc.ifm_a_l2_tile.col,
                                DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                    ifm_memtile_memory(dims),
                    ifm_memtile_s2mm(dims),
                    bits_per_block=dims.ifm_bits,
                    buffer_offset=memtile_alloc.ifm_a_l2_addr,
                    enable_padding=True)],
                [])
        ]
        ifm_fill_transfers_2 = [
            DataTransfer(
                fill_repeat_count,
                memtile_alloc.ifm_a_l2_tile, [0], dims.ifm_size,
                [generate_transfer_params(
                    memtile_dma(memtile_alloc.ifm_a_l2_tile.col,
                                DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                    ifm_memtile_memory(dims),
                    ifm_memtile_s2mm(dims),
                    bits_per_block=dims.ifm_bits,
                    buffer_offset=memtile_alloc.ifm_b_l2_addr)],
                [])
        ]
        ifm_fill_sync_transfer = [
            DataTransfer(
                fill_sync_repeat_count,
                memtile_alloc.ifm_a_l2_tile, [0], dims.ifm_size,
                [generate_transfer_params(
                    memtile_dma(memtile_alloc.ifm_a_l2_tile.col,
                                DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                    ifm_memtile_memory(dims),
                    ifm_memtile_s2mm_dummy()),],
                [],
                sync_strategy=SyncStrategy.Remote_Barrier)
            ]
        memtile_transfers += ifm_fill_transfers_1
        memtile_transfers += ifm_fill_transfers_2
        memtile_transfers += ifm_fill_sync_transfer

    if dims.full_subvol_iterations > 0:
        log(f"[INFO] ALL CORE ITERATION - {dims.full_subvol_iterations}")
        ifm_memtile_transfers_all_cores = [
            DataTransfer(
                ifm_all_cores_repeat,
                memtile_alloc.ifm_a_l2_tile, [0], dims.ifm_size,
                [],
                ifm_memtile_mm2s_all_cores(dims, memtile_alloc.ifm_a_l2_addr, memtile_alloc.ifm_b_l2_addr, col)) for col in range(dims.aie_cols)
        ]
        memtile_transfers += ifm_memtile_transfers_all_cores

    if dims.partial_last_iter:
        log("[INFO] PARTIAL CORE ITERATION - 1")
        ifm_memtile_transfers_2_partial_cores = [
            DataTransfer(
                ifm_partial_cores_repeat,
                memtile_alloc.ifm_a_l2_tile, [0], dims.ifm_size,
                [],
                ifm_memtile_mm2s_partial_cores(dims, memtile_alloc.ifm_a_l2_addr, memtile_alloc.ifm_b_l2_addr, col)) for col in range(dims.aie_cols)
        ]
        memtile_transfers += ifm_memtile_transfers_2_partial_cores

    wgt_memtile_transfers = [
        DataTransfer(
            layer_param_repeat_count, memtile_alloc.wgt_L2_alloc_tiles[col],
            [memtile_alloc.wgt_L2_ping_addrs[col],
                memtile_alloc.wgt_L2_pong_addrs[col]], memtile_alloc.wgt_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(dims),
                bits_per_block=dims.wgt_bits
            )
            ],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_B_ids(col)[0]),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims),
                bits_per_block=dims.wgt_bits
            ),
                generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_B_ids(col)[1]),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims),
                bits_per_block=dims.wgt_bits
            )
            ]) for col in range(dims.aie_cols)
    ]
    memtile_transfers += wgt_memtile_transfers

    ofm_memtile_transfers = [
        DataTransfer(
            ofm_repeat_count, memtile_alloc.ofm_l2_tile,
            [memtile_alloc.ofm_l2_addr], memtile_alloc.ofm_memtile_size,
            ofm_memtile_s2mm(dims, col),
            []) for col in range(dims.aie_cols)
    ]
    memtile_transfers += ofm_memtile_transfers

    ofm_col = memtile_alloc.ofm_l2_tile.col
    if L2_alloc.enable_ofm_spill:
        log("[INFO] OFM SPILL ENABLED")
        # NOTE: THis is a dummy transfer to ensure the
        # spill transfer is executed as the last phase
        ofm_spill_sync_transfer = [
            DataTransfer(
                spill_sync_repeat_count, AieTile(TileType.Memtile, ofm_col),
                [memtile_alloc.ofm_l2_addr], memtile_alloc.ofm_memtile_size,
                [TransferParams(memtile_dma(col, DmaDir.S2MM, channel_id), 0, name=f"ofm_spill_sync_{col}")
                 for _, channel_id in ofm_memtile_channels()],
                [],
                sync_strategy=SyncStrategy.Remote_Barrier if col == 0 else SyncStrategy.Default
            ) for col in range(dims.aie_cols)
        ]
        memtile_transfers += ofm_spill_sync_transfer

        ofm_spill_transfer = [
            DataTransfer(
                spill_repeat_count, memtile_alloc.ofm_l2_tile,
                [memtile_alloc.ofm_l2_addr], memtile_alloc.ofm_memtile_size,
                [],
                [generate_transfer_params(
                    memtile_dma(ofm_col, DmaDir.MM2S, overlay_3x4_S_ids(ofm_col)[0]),
                    ofm_memtile_memory(dims),
                    ofm_memtile_mm2s(dims),
                    bits_per_block=dims.ofm_bits)])
            ]
        memtile_transfers += ofm_spill_transfer

    shim_transfers = []
    prm_shim_transfers = [
        generate_shim_data_transfer(
            layer_param_repeat_count, shim_dma(col, DmaDir.MM2S, 0),
            shim_alloc.prm_xrt_idx,
            prm_shim_memory(),
            prm_shim_mm2s(col),
            buffer_offset=shim_alloc.prm_xrt_offset,
        ) for col in range(dims.aie_cols)
    ]
    shim_transfers += prm_shim_transfers

    if L2_alloc.enable_ifm_fill:
        ifm_shim_transfers_1 = [
            DataTransfer(
                fill_repeat_count,
                shim_tile(memtile_alloc.ifm_a_l2_tile.col),
                [shim_alloc.ifm_xrt_idx[0]], dims.ifm_size,
                [],
                [generate_transfer_params(
                    shim_dma(memtile_alloc.ifm_a_l2_tile.col, DmaDir.MM2S, 0),
                    ifm_shim_memory(dims),
                    ifm_shim_mm2s(dims),
                    dims.ifm_bits,
                    buffer_offset=shim_alloc.ifm_xrt_offset[0],
                    enable_padding=True)],
            )
        ]
        shim_transfers += ifm_shim_transfers_1

        ifm_shim_transfers_2 = [
            DataTransfer(
                fill_repeat_count,
                shim_tile(memtile_alloc.ifm_a_l2_tile.col),
                [shim_alloc.ifm_xrt_idx[1]], dims.ifm_size,
                [],
                [generate_transfer_params(
                    shim_dma(memtile_alloc.ifm_a_l2_tile.col, DmaDir.MM2S, 0),
                    ifm_shim_memory(dims),
                    ifm_shim_mm2s(dims),
                    dims.ifm_bits,
                    buffer_offset=shim_alloc.ifm_xrt_offset[1],
                    enable_padding=True)],
            )
        ]
        shim_transfers += ifm_shim_transfers_2

    wgt_shim_transfers = [
        DataTransfer(
            layer_param_repeat_count, AieTile(TileType.Shim, col),
            [shim_alloc.wgt_xrt_idx], dims.wgt_size,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, 1),
                wgt_shim_memory(dims),
                wgt_shim_mm2s(dims),
                dims.wgt_bits,
                buffer_offset=shim_alloc.wgt_xrt_offset)],
        ) for col in range(dims.aie_cols)
    ]
    shim_transfers += wgt_shim_transfers

    if L2_alloc.enable_ofm_spill:
        ofm_shim_transfers = [
            DataTransfer(
                spill_repeat_count, AieTile(TileType.Shim, ofm_col),
                [shim_alloc.ofm_xrt_idx], dims.ofm_size,
                [generate_transfer_params(
                    shim_dma(ofm_col, DmaDir.S2MM, 0),
                    ofm_shim_memory(dims),
                    ofm_shim_s2mm(dims),
                    dims.ofm_bits,
                    enable_padding=True,
                    buffer_offset=shim_alloc.ofm_xrt_offset)],
                [])
        ]
        shim_transfers += ofm_shim_transfers

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        schedule_input.kernel_names,
        schedule_input.kernel_includes,
        instr_dict,
        memtile_transfers,
        shim_transfers,
        overlay_3x4_dma_connections(),
        schedule_input.backend,
        param_channel_id=0,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        layer_file=schedule_input.layer_file_name,
        dma_padding_map=schedule_input.dma_pad,
    )

    shim_prm_offset_next_layer = shim_alloc.prm_xrt_offset + shim_prm_size
    shim_wgt_offset_next_layer = shim_alloc.wgt_xrt_offset + shim_wgt_size

    return shim_prm_offset_next_layer, shim_wgt_offset_next_layer
