"""2D Broadcast operator scheduler for two inputs and one output.
Does not support the first input being broadcasted over C.
Assumes N=1.
If Y>1 and X>1 and C>1, then one input must be (1, 1, C).
"""

import numpy as np
import os
import math
import copy
import itertools
import pytest
from dataclasses import dataclass
from typing import Callable, no_type_check

import dmacompiler
from kernel.broadcast.pack import generate_broadcast_params
from buildscripts.common import ScheduleInputs
from tiler.broadcast_tiler import BroadcastShape
from scheduler.broadcast.folding import pad_to_32_bits
from utils.utils_common import (
    L2Alloc,
    iceil,
    log,
    BaseShape,
    BaseDims,
    BaseMapping,
    ceildiv,
    _core_to_split,
    overlay_3x4_core_stack_addr,
)
from scheduler.common import (
    overlay_3x4_F_ids,
    overlay_3x4_A_ids,
    overlay_3x4_B_ids,
    overlay_3x4_O_ids,
    overlay_3x4_S_ids,
    overlay_3x4_dma_connections,
    overlay_3x4_param_channel_id,
    prm_shim_memory,
    prm_shim_mm2s,
    prm_memtile_memory,
    prm_memtile_s2mm,
    prm_memtile_mm2s,
    ShimAllocator,
    L3Alloc_to_Shim,
)

from dmacompiler import (
    DevGen, set_dev_gen, OverlayShape,
    DataTransfer, SyncStrategy, AieTile,
    TileType, DmaChannel, DmaDir,
    memtile_dma, shim_dma, ConfigBuffer,
    AcqBuffer, RelBuffer, CallKernel,
    compute_buffer_size, core_dma, generate_transfer_params,
    pack_reconfig_transfers, generate_shim_data_transfer,
    run_layer_compilation, Loop, memory_tile,
    generate_core_buffer_config, TransferParams,
    shim_tile
)

set_dev_gen(DevGen.Aie4)

CURRDIR = os.path.dirname(os.path.abspath(__file__))

NUM_COLS = dmacompiler.config.NUM_AIE_COLS
NUM_ROWS = dmacompiler.config.NUM_AIE_ROWS


@dataclass
class Tiling1D:
    """Represents a single dimension of a tiling expression.

    Attributes:
        start (int): Starting index of the tile.
        end (int): Exclusive ending index of the tile.
        stride (int | None): Step size between elements; ``None`` means no loop stride.
        core_offset (int): Offset applied to the core index.
        num_elements (int): Number of elements covered within the tile.
        is_broadcast (bool): Whether this dimension is broadcast.
        num_repeats (int): Number of repeats for this dimension; must remain 1 when not broadcast.
    """

    start: int
    end: int
    stride: int | None
    core_offset: int
    num_elements: int
    is_broadcast: bool
    num_repeats: int

    def __init__(
        self,
        start: int,
        end: int,
        stride: int | None,
        core_offset: int,
        num_elements: int,
        max_dim,
        is_broadcast: bool = False,
        num_repeats: int = 1,
        is_C: bool = False,
    ):
        # Make a dummy repeat if starting out of bounds, these don't get transfered
        if not is_broadcast and start + core_offset >= max_dim or end < 0:
            start, end, stride, core_offset, num_elements = 0, 0, None, 0, 0
        # if is_broadcast and num_repeats == 0:
        #     start, end, stride, core_offset, num_elements = 0, 0, None, 0, 0
        self.start = start
        self.end = 4 if is_broadcast and is_C else end
        self.stride = stride
        self.core_offset = core_offset
        self.num_elements = num_elements
        self.is_broadcast = is_broadcast
        self.num_repeats = num_repeats if is_broadcast else 1
        self.is_C = is_C

    def is_dummy(self):
        return self.num_elements == 0

    def num_loops(self):
        if self.num_elements == 0:
            return 0
        elif self.stride:
            return ceildiv(self.end - self.start, self.stride)
        else:
            return 1

    def strloop(self):
        if self.stride and not (self.stride == self.end - self.start):
            return f"{self.start}:{self.end}:{self.stride}"
        else:
            return ""

    def straccess(self):
        if self.is_C and self.is_broadcast:
            return "0:4"
        elif self.strloop():
            return f"{self.core_offset}:{self.core_offset + self.num_elements}"
        else:
            return f"{self.start + self.core_offset}:{self.start + self.core_offset + self.num_elements}"


@dataclass
class TilingND:
    tilings: tuple[Tiling1D]

    def is_dummy(self):
        return any([tiling.is_dummy() for tiling in self.tilings])

    def num_loops(self):
        return math.prod([tiling.num_loops() for tiling in self.tilings])

    def num_repeats(self):
        return math.prod([tiling.num_repeats for tiling in self.tilings])

    def is_broadcast(self):
        return any([tiling.is_broadcast for tiling in self.tilings])

    def strloop(self, suffix: str):
        assert len(self.tilings) == 3
        return " ".join(
            [f"{axis}{suffix}:{tiling.strloop()}" for axis, tiling in zip(
                "YXC", self.tilings) if tiling.strloop()]
        )

    def straccess(self, suffix: str):
        assert len(self.tilings) == 3
        return " ".join([f"{axis}{suffix}:{tiling.straccess()}" for axis, tiling in zip("YXC", self.tilings)])

    def str_with_suffix(self, suffix: str):
        # The suffix to append to the axis names. For example, if suffix is 'i', then the output will be 'Yi:Xi:Ci' instead of 'Y:X:C'
        if not self.strloop(suffix):
            return self.straccess(suffix)
        return self.strloop(suffix) + " " + self.straccess(suffix)

    def __str__(self):
        return self.str_with_suffix("")

    def subv_size(self):  # used for testing
        size = 1
        for tiling in self.tilings:
            size *= tiling.num_elements
        return size


class BroadcastTilingConfig:
    loops: list[int]
    pads: list[bool]
    broadcasts: list[bool]
    ofm: list[int]
    ofm_repeats: list[int]
    split: list[int]
    """OFM information used to generate all tiling expressions"""
    def __init__(
        self,
        loops: list[int],
        split: list[int],
        ofm: list[int],
        ofms: list[int],
        ifm: list[list[int]],
    ):
        self.loops = loops  # mapping.iters[0:3]
        self.pads = [
            bool(dim % (loop * subv * split))
            for (dim, loop, subv, split) in zip(
                ofm,
                loops,
                ofms,
                split,  # split = mapping.spatial_split[1:4]
            )
        ]
        self.broadcasts = [not (dim == ifm[0][i] and dim == ifm[1][i])
                           for i, dim in enumerate(ofm)]
        # check
        for i, (dim, loop, subv, _split) in enumerate(zip(ofm, loops, ofms, split)):
            assert dim <= (loop * subv * _split), "Mapping does not cover ofm"
        self.split = split
        self.ofm_repeats = dict()
        for col in range(NUM_COLS):
            for row in range(NUM_ROWS):
                ofm_tilings = tilings(
                    self,
                    ofm,
                    ofms,
                    col=col,
                    row=row,
                )
                self.ofm_repeats[(col, row)] = [tile.num_loops()
                                                for tile in ofm_tilings]

    @staticmethod
    def from_shape_and_mapping(shape: BroadcastShape, mapping: BaseMapping):
        return BroadcastTilingConfig(
            loops=list(mapping.iters[0:3]),
            split=mapping.spatial_split[1:4],
            ofm=shape.ofm,
            ofms=mapping.ofm_subv,
            ifm=shape.ifm,
        )


def single_axis_memory(size) -> str:
    return f"Bytes:{size}"


def ifm_dummy_transfer() -> str:
    """Dummy transfer: IFM Mem MM2S"""
    return "Yi:0:0 Xi:0:0 Ci:0:0"


def ofm_dummy_transfer() -> str:
    return "Yo:0:0 Xo:0:0 Co:0:0"


def ifm_memory(shape: tuple[int, int, int], n_bytes: int) -> str:
    shape = pad_to_32_bits(shape, n_bytes)
    return f"Yi:{shape[0]} Xi:{shape[1]} Ci:{shape[2]}"


def ifm_s2mm(shape: tuple[int, int, int], n_bytes: int) -> str:
    shape = pad_to_32_bits(shape, n_bytes)
    return f"Yi:0:{shape[0]} Xi:0:{shape[1]} Ci:0:{shape[2]}"


ifm_mm2s = ifm_s2mm


def ofm_memory(shape: tuple[int, int, int]) -> str:
    return f"Yo:{shape[0]} Xo:{shape[1]} Co:{shape[2]}"


def ofm_shim_s2mm(shape: tuple[int, int, int]) -> str:
    return f"Yo:0:{shape[0]} Xo:0:{shape[1]} Co:0:{shape[2]}"


ofm_s2mm = ofm_shim_s2mm


def null_tile_fmt(suffix: str):
    return [f"Y{suffix}:0:0 X{suffix}:0:0 C{suffix}:0:0"] * 2


def single_axis_transfer(size) -> str:
    return f"Bytes:0:{size}"


null_ifm_tile_fmt = "Yi:0:0 Xi:0:0 Ci:0:0"
null_ofm_tile_fmt = "Yo:0:0 Xo:0:0 Co:0:0"

null_ifm_fill_fmt = null_ifm_spill_fmt = [null_ifm_tile_fmt] * 2
null_ofm_fill_fmt = null_ofm_spill_fmt = [null_ofm_tile_fmt] * 2


wgt_shim_memory = single_axis_memory
wgt_shim_mm2s = single_axis_transfer
wgt_memtile_memory = single_axis_memory
wgt_memtile_mm2s = single_axis_transfer
wgt_memtile_s2mm = single_axis_transfer

LOOP_NAME = ["C", "X", "Y"]


def tilings(
    cfg: BroadcastTilingConfig,
    fm: list[int],
    fms: list[int],
    col: int = -1,
    row: int = -1,
) -> list[TilingND]:
    """Compute tilings for a multi-dimensional transfer with optional padding.

    Generates one ``TilingND`` instance per tiling "phase" so the caller can
    describe DMA behaviour for each subvolume. The total amount of work is the
    same across inputs, but the number of phases varies with loop ordering.

    Args:
        cfg (BroadcastTilingConfig): Global tiling configuration derived from
            the mapping.
        fm (list[int]): Full-dimension extents ordered Y, X, C.
        fms (list[int]): Subvolume extents matching ``fm`` ordering.
        col (int, optional): Column index of the target core. Defaults to ``-1``.
        row (int, optional): Row index of the target core. Defaults to ``-1``.

    Returns:
        list[TilingND]: Multi-dimensional transfer descriptors for each tiling
        phase.

    Raises:
        ValueError: If ``col`` or ``row`` is left unspecified.
    """
    loops, pads, broadcasts, split = (
        cfg.loops,
        cfg.pads,
        cfg.broadcasts,
        cfg.split,
    )
    assert len(loops) == len(pads)
    if col < 0 or row < 0:
        raise ValueError("Specify col and row for transfers()")
    # we reverse from [y x c] to [c x y] for recursion, since we pop at each layer. fm / fms might be tuples by accident
    loops, pads, broadcasts, fm, fms, split = (
        loops[::-1].copy(),
        pads[::-1].copy(),
        list(broadcasts)[::-1].copy(),
        list(fm)[::-1].copy(),
        list(fms)[::-1].copy(),
        split[::-1],
    )
    returned_tilings, _ = _tilings(
        loops, pads, broadcasts, fm, fms, split, col=col, row=row)
    return returned_tilings


def _tilings(
    loops: list[int],
    pads: list[bool],
    broadcasts: list[bool],
    fm: list[int],
    fms: list[int],
    split: list[int],
    col: int = -1,
    row: int = -1,
) -> tuple[list[TilingND], bool]:
    if not loops:  # Base case
        # remove this line and uncomment line with False below once https://gitenterprise.xilinx.com/IPSP/dmacompiler/pull/273 is merged
        return [], True
        # return [], False
    loop_level = len(loops)
    loop, pad, broadcast, dim, dim_sub = (
        loops.pop(),
        pads.pop(),
        broadcasts.pop(),
        fm.pop(),
        fms.pop(),
    )
    # split is reversed from [y x c] to [c x y (n)]
    Y_split, X_split, C_split = split[:3][::-1]
    _, Y_idx, X_idx, C_idx = _core_to_split(
        NUM_ROWS, Y_split, X_split, C_split, col, row)
    core_split = (C_idx, X_idx, Y_idx)[loop_level - 1]
    inner_tilings, are_we_multi_phase = _tilings(
        loops, pads, broadcasts, fm, fms, split, col=col, row=row)
    step = dim_sub * split[loop_level - 1]
    core_offset = dim_sub * core_split
    is_broadcast = broadcast and dim == 1
    is_C = (loop_level - 1) == 0

    # If there's padding over any subloop then we can't re-use this BD,
    # so we create a new BD for each sub-phase
    if are_we_multi_phase:
        new_tilings = [
            Tiling1D(
                start=(i * step),
                end=(i * step) + dim_sub,
                stride=None,
                core_offset=core_offset,
                num_elements=dim_sub,
                max_dim=dim,
                is_broadcast=is_broadcast,
                is_C=is_C,
            )
            for i in range(loop)
        ]
    # If there's no padding over any subloop, then we can re-use this BD
    # If we pad this loop, then we use one phase for loop and another phase for the padding
    elif pad:
        loop_slice = (
            [
                Tiling1D(
                    start=0,
                    end=step * (loop - 1),
                    stride=step,
                    core_offset=core_offset,
                    num_elements=dim_sub,
                    max_dim=dim,
                    is_broadcast=is_broadcast,
                    num_repeats=loop - 1,
                    is_C=is_C,
                )
            ]
            if loop > 1
            else []
        )  # loop phase
        padding_slice = [
            Tiling1D(
                start=step * (loop - 1),
                end=step * loop,
                stride=None,
                core_offset=core_offset,
                num_elements=dim_sub,
                max_dim=dim,
                is_broadcast=is_broadcast,
                is_C=is_C,
            )
        ]  # padding phase
        new_tilings = loop_slice + padding_slice
    elif dim == 1:
        new_tilings = [Tiling1D(start=0, end=1, stride=None, core_offset=0,
                                num_elements=1, max_dim=dim, is_broadcast=is_broadcast, is_C=is_C)]
    # Otherwise, we can use the same phase for this entire loop!
    else:
        new_tilings = [
            Tiling1D(
                start=0,
                end=step * (loop),
                stride=step,
                core_offset=core_offset,
                num_elements=dim_sub,
                max_dim=dim,
                is_broadcast=is_broadcast,
                num_repeats=loop,
                is_C=is_C,
            )
        ]
    # we do this after to determine whether a transfer is a dummy or not using the
    # non-broadcasted info
    if is_broadcast:
        # This will not create any dummy tilings. Instead, we rely on tiling_cfg.ofm_repeats
        # to set the num_repeats to 0 where appropriate.
        new_tilings = [
            Tiling1D(
                start=0,
                end=1,
                stride=None,
                core_offset=0,
                num_elements=1,
                max_dim=dim,
                is_broadcast=is_broadcast,
                is_C=is_C,
                num_repeats=loop * (not are_we_multi_phase),
            )
            for _ in new_tilings
        ]

    # Create next level if tilings
    if not inner_tilings:  # create innermost dim after hitting end of recursion
        expanded_tilings = [TilingND((tiling,)) for tiling in new_tilings]
    else:  # otherwise, tilings at this dimension are a cross product of tilings for all subdimensions
        expanded_tilings: list[TilingND] = [
            TilingND(tuple([new_tiling, *_inner_tilings.tilings]))
            for (new_tiling, _inner_tilings) in itertools.product(new_tilings, inner_tilings)
        ]

    # use_iter_step can only work on a single loop. We always apply it to the innermost loop,
    # so if we use_iter_step over the inner dimensions then we must use a separate
    # phase for the outer dimensions, thus the loop > 1.
    are_we_multi_phase = are_we_multi_phase or pad or loop > 1
    return expanded_tilings, are_we_multi_phase


def depad_tiling(
    cfg: BroadcastTilingConfig,
    fm: list[int],
    fms: list[int],
    tiling: TilingND,
    normalize: bool = False,
) -> TilingND:
    """TODO add comment"""
    # if tiling.is_broadcast():
    #     raise ValueError(f"You probably should not be depadding a broadcasted tiling: {tiling}")
    depadded_tiling = []
    for dim_idx, (tiling_1d, dim, dim_sub) in enumerate(zip(tiling.tilings, fm, fms)):
        new_tile_1d = copy.deepcopy(tiling_1d)
        # compute true num elements without padding
        true_start = tiling_1d.start + tiling_1d.core_offset
        true_num_elements = min(max(0, dim-true_start), tiling_1d.num_elements)

        # normalization should work the same for looping and non-looping dims
        if normalize:
            new_tile_1d.start = 0
            new_tile_1d.end = true_num_elements
            new_tile_1d.num_elements = true_num_elements
            new_tile_1d.stride = None
            new_tile_1d.core_offset = 0
        else:
            # don't remove padding if we are looping on this dimension
            if tiling_1d.num_loops() > 1:
                new_tile_1d.start = true_start
                assert new_tile_1d.num_elements == dim_sub
            elif tiling_1d.num_loops() == 1:
                new_tile_1d.start = true_start
                new_tile_1d.end = true_start + true_num_elements
                new_tile_1d.stride = None
                new_tile_1d.core_offset = 0
                new_tile_1d.num_elements = true_num_elements
        depadded_tiling.append(new_tile_1d)

    return TilingND(tuple(depadded_tiling))


def generate_tiling_exprs(cfg, fm, fms, suffix, enable_fill=False, enable_spill=False, depad=False, normalize=False):
    """Wrapper around tilings() to generate the tiling expressions."""
    # per core ofm tilings and repeats
    fm_tilings = {}
    for col in range(NUM_COLS):
        for row in range(NUM_ROWS):
            _tilings = tilings(cfg, fm, fms, col=col, row=row)
            if depad:
                _tilings = [depad_tiling(cfg, fm, fms, t, normalize=normalize) for t in _tilings]
            fm_tilings[(col, row)] = [t.str_with_suffix(suffix) for t in _tilings]

            if enable_fill:
                fm_tilings[(col, row)] = null_tile_fmt(suffix) + fm_tilings[(col, row)]
            if enable_spill:
                fm_tilings[(col, row)] = fm_tilings[(col, row)] + null_tile_fmt(suffix)
    return fm_tilings


def generate_iter_steps(cfg, fm, fms, enable_fill=False, enable_spill=False):
    fm_itersteps = {}
    for col in range(NUM_COLS):
        for row in range(NUM_ROWS):
            _tilings = tilings(cfg, fm, fms, col=col, row=row)
            fm_itersteps[(col, row)] = [t.num_loops() > 1 for t in _tilings]

            if enable_fill:
                fm_itersteps[(col, row)] = [False, False] + fm_itersteps[(col, row)]
            if enable_spill:
                fm_itersteps[(col, row)] = fm_itersteps[(col, row)] + [False, False]
    return fm_itersteps


def generate_ifm_transfers_l2(
    mapping: BaseMapping,
    shape: BaseShape,
    dims: BaseDims,
    tile_cfg: BroadcastTilingConfig,
    l2_alloc: L2Alloc,
    fill_sync_repeats: list[int],
    fill_repeats: list[int],
    shim_alloc: ShimAllocator,
):
    larger_ifm_subv_size = 0
    # we need to compute chained ifm b addresses relative to ifm a, which we assume is on first tile
    memtile_transfers, shimtile_transfers = [], []
    ab_tilings = [
        generate_tiling_exprs(tile_cfg, shape.ifm[0], mapping.ifm_subv[0], "i", enable_fill=True, enable_spill=True),
        generate_tiling_exprs(tile_cfg, shape.ifm[1], mapping.ifm_subv[1], "i", enable_fill=True, enable_spill=True),
    ]
    use_iter_steps = [
        generate_iter_steps(tile_cfg, shape.ifm[0], mapping.ifm_subv[0], enable_fill=True, enable_spill=True),
        generate_iter_steps(tile_cfg, shape.ifm[1], mapping.ifm_subv[1], enable_fill=True, enable_spill=True),
    ]

    for col in range(NUM_COLS):
        mem2core_transfer_params: list[list[TransferParams]] = []
        for i, (ifm, ifms, ifm_tilings, (tile, addr,), shim_offset, shim_idx,) in enumerate(
            zip(
                shape.ifm,
                mapping.ifm_subv,
                ab_tilings,
                l2_alloc.ifm_L2_loc,
                shim_alloc.ifm_xrt_offset,
                shim_alloc.ifm_xrt_idx,
            )
        ):
            mem2core_transfer_params.append([])
            ifm_size = math.prod(ifm) * shape.ifm_bytes
            larger_ifm_subv_size = max(
                larger_ifm_subv_size, math.prod(ifms) * shape.ifm_bytes)

            # We assume L3 is always padded, but we don't always pad channels to L2
            # if C=1, we need to pad to 4
            padded_ifm = (ifm[0], ifm[1], iceil(ifm[2], 64))
            ifm = (ifm[0], ifm[1], max(ifm[2], 4))

            # we don't generate DataTransfer yet because we need to chain bds for both ifms in a single transfer
            # remove fill/spill phases
            n_phases = len(tile_cfg.ofm_repeats[(0, 0)]) + 4
            for row in range(NUM_ROWS):
                transfer_params = pack_reconfig_transfers(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                    [ifm_memory(ifm, shape.ifm_bytes)] * n_phases,
                    ifm_tilings[(col, row)],
                    bits_per_elem=shape.ifm_bits,  # type: ignore
                    use_iter_step=use_iter_steps[i][(col, row)],
                    buffer_offset=[
                        # all addrs need to be w.r.t ifm a, which is on col 0
                        addr
                        + ((dmacompiler.config.MAX_MEMTILE_ADDR + 1) * tile.col)
                    ],
                )
                mem2core_transfer_params[-1].append(
                    transfer_params[1] if isinstance(
                        transfer_params, tuple) else transfer_params
                )

            # only configure fill transfers once
            if col > 0:
                continue

            memtile_transfers.append(
                DataTransfer(  # IFM FILL
                    fill_repeats,
                    tile,
                    [addr],
                    ifm_size,
                    [
                        generate_transfer_params(
                            memtile_dma(tile.col, DmaDir.S2MM,
                                        overlay_3x4_F_ids()[0]),
                            ifm_memory(ifm, shape.ifm_bytes),
                            ifm_s2mm(ifm, shape.ifm_bytes),
                            bits_per_block=shape.ifm_bits,
                        )
                    ],
                    [],
                )
            )
            memtile_transfers.append(  # IFM SYNC DUMMY TRANSFER
                DataTransfer(
                    fill_sync_repeats,
                    tile,
                    [addr],
                    0,
                    [
                        generate_transfer_params(
                            memtile_dma(tile.col, DmaDir.S2MM,
                                        overlay_3x4_F_ids()[0]),
                            ifm_memory(ifm, shape.ifm_bytes),
                            ifm_dummy_transfer(),
                            bits_per_block=shape.ifm_bits,
                        )
                    ],
                    [],
                    sync_strategy=SyncStrategy.Remote_Barrier,
                )
            )

            shimtile_transfers.append(
                generate_shim_data_transfer(
                    fill_repeats,
                    shim_dma(tile.col, DmaDir.MM2S, 0),
                    shim_idx,
                    ifm_memory(padded_ifm, shape.ifm_bytes),
                    ifm_s2mm(ifm, shape.ifm_bytes),
                    bits_per_block=shape.ifm_bits,
                    buffer_offset=shim_offset,
                )
            )

        # Chain BD after we generating tiling exprs across both subvs
        # use ifm a tile as reference, it's on col 0
        tile = l2_alloc.ifm_L2_loc[0][0]
        for row, (bd_a, bd_b) in enumerate(zip(*mem2core_transfer_params)):
            # we have to chain ifm a and b BDs here in one transfer
            if sum(tile_cfg.ofm_repeats[(col, row)]) == 0:
                continue
            memtile_transfers.append(  # MEM to CORE
                DataTransfer(
                    [0, 0] + tile_cfg.ofm_repeats[(col, row)] + [0, 0],
                    tile,
                    [0],
                    larger_ifm_subv_size,
                    [],
                    [bd_a, bd_b],
                )
            )

    return memtile_transfers, shimtile_transfers


def generate_ofm_transfers_l2(
    mapping: BaseMapping,
    shape: BaseShape,
    dims: BaseDims,
    tile_cfg: BroadcastTilingConfig,
    l2_alloc: L2Alloc,
    spill_sync_repeats: list[int],
    spill_repeats: list[int],
    shim_alloc: ShimAllocator,
):
    memtile_transfers, shimtile_transfers = [], []
    ofm, ofms = shape.ofm, mapping.ofm_subv
    ofm_size = math.prod(ofm) * shape.ofm_bytes
    tile, addr = l2_alloc.ofm_L2_loc

    ofm_itersteps = generate_iter_steps(
        tile_cfg, ofm, ofms, enable_fill=True, enable_spill=True)
    ofm_tilings = generate_tiling_exprs(
        tile_cfg, ofm, ofms, "o", enable_fill=True, enable_spill=True)

    # remove fill/spill phases
    n_phases = len(tile_cfg.ofm_repeats[(0, 0)]) + 4
    for col in range(NUM_COLS):
        for row in range(NUM_ROWS):
            if sum(tile_cfg.ofm_repeats[(col, row)]) == 0:
                continue
            transfer_param = pack_reconfig_transfers(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_O_ids()[row]),
                [ofm_memory(mapping.ofm_pad)] * n_phases,
                ofm_tilings[(col, row)],
                bits_per_elem=shape.ofm_bits,
                use_iter_step=ofm_itersteps[(col, row)],
            )

            memtile_transfers.append(
                DataTransfer(  # core 2 mem
                    [0, 0] + tile_cfg.ofm_repeats[(col, row)] + [0, 0],
                    tile,
                    [addr],
                    ofm_size,
                    [  # MEMTILE S2MM
                        transfer_param[1] if isinstance(
                            transfer_param, tuple) else transfer_param
                    ],
                    [],  # type: ignore
                )
            )

        memtile_transfers.append(
            DataTransfer(  # sync before spill
                spill_sync_repeats,
                tile,
                [addr],
                ofm_size,
                [
                    generate_transfer_params(
                        memtile_dma(col, DmaDir.S2MM,
                                    overlay_3x4_O_ids()[row]),
                        ofm_memory(mapping.ofm_pad),
                        ofm_dummy_transfer(),
                        bits_per_block=shape.ofm_bits,  # type: ignore
                    )
                    for row in range(NUM_ROWS)
                ],
                [],
                sync_strategy=SyncStrategy.Remote_Barrier if col == 0 else SyncStrategy.Default,
            )
        )
    memtile_transfers.append(
        DataTransfer(  # spill mem 2 shim
            spill_repeats,
            tile,
            [addr],
            ofm_size,
            [],
            [
                generate_transfer_params(
                    memtile_dma(tile.col, DmaDir.MM2S, overlay_3x4_S_ids(1)[0]),
                    ofm_memory(mapping.ofm_pad),
                    ofm_shim_s2mm(ofm),
                    bits_per_block=shape.ofm_bits,
                )  # type: ignore
            ],
        )
    )
    shimtile_transfers.append(
        DataTransfer(
            spill_repeats,
            AieTile(TileType.Shim, tile.col),
            [shim_alloc.ofm_xrt_idx],
            ofm_size,
            [
                generate_transfer_params(
                    shim_dma(tile.col, DmaDir.S2MM, 0),
                    ofm_memory(ofm),
                    ofm_shim_s2mm(ofm),
                    shape.ofm_bits,
                    buffer_offset=shim_alloc.ofm_xrt_offset,
                )
            ],
            [],
        )
    )

    return memtile_transfers, shimtile_transfers


def generate_ifm_transfers_l3(
    mapping: BaseMapping,
    shape: BaseShape,
    dims: BaseDims,
    tile_cfg: BroadcastTilingConfig,
    l2_alloc: L2Alloc,
    fill_sync_repeats: list[int],
    fill_repeats: list[int],
    shim_alloc: ShimAllocator,
):
    larger_ifm_subv_bytes = 0
    # we need to compute chained ifm b addresses relative to ifm a, which we assume is on first tile
    n_phases = len(tile_cfg.ofm_repeats[(0, 0)])
    memtile_transfers, shimtile_transfers = [], []
    ab_tilings = [
        generate_tiling_exprs(tile_cfg, shape.ifm[0], mapping.ifm_subv[0], "i", enable_fill=False, enable_spill=False, depad=True),
        generate_tiling_exprs(tile_cfg, shape.ifm[1], mapping.ifm_subv[1], "i", enable_fill=False, enable_spill=False, depad=True),
    ]
    normalized_ab_tilings = [
        generate_tiling_exprs(tile_cfg, shape.ifm[0], mapping.ifm_subv[0], "i", enable_fill=False, enable_spill=False, depad=True, normalize=True),
        generate_tiling_exprs(tile_cfg, shape.ifm[1], mapping.ifm_subv[1], "i", enable_fill=False, enable_spill=False, depad=True, normalize=True),
    ]
    use_iter_steps = [
        generate_iter_steps(tile_cfg, shape.ifm[0], mapping.ifm_subv[0], enable_fill=False, enable_spill=False),
        generate_iter_steps(tile_cfg, shape.ifm[1], mapping.ifm_subv[1], enable_fill=False, enable_spill=False),
    ]

    for col in range(NUM_COLS):
        for row in range(NUM_ROWS):

            if sum(tile_cfg.ofm_repeats[(col, row)]) == 0:
                continue
            for i, (ifm, ifms, (_, addr,), shim_offset, shim_idx,) in enumerate(
                zip(
                    shape.ifm,
                    mapping.ifm_subv,
                    l2_alloc.ifm_L2_loc,
                    shim_alloc.ifm_xrt_offset,
                    shim_alloc.ifm_xrt_idx,
                )
            ):
                ifm_tilings = ab_tilings[i]
                normalized_ifm_tilings = normalized_ab_tilings[i]
                subv_size = math.prod(ifms)
                subv_bytes = subv_size * shape.ifm_bytes
                larger_ifm_subv_bytes = max(larger_ifm_subv_bytes, subv_bytes)
                # We assume L3 is always padded, but we don't always pad channels to L2
                # if C=1, we need to pad to 4
                padded_ifm = (ifm[0], ifm[1], iceil(ifm[2], 64))
                ifm = (ifm[0], ifm[1], max(ifm[2], 4))

                # For L3 dataflow, tiling exprs only matter for shim <-> mem transfers.
                shim_transfer_params = pack_reconfig_transfers(
                    shim_dma(col, DmaDir.MM2S, i),  # i is 0 or 1 for ifm a/b
                    [ifm_memory(padded_ifm, shape.ifm_bytes)] * n_phases,
                    ifm_tilings[(col, row)],
                    bits_per_elem=shape.ifm_bits,  # type: ignore
                    use_iter_step=use_iter_steps[i][(col, row)],
                    buffer_offset=[shim_offset],
                )
                if isinstance(shim_transfer_params, tuple):
                    shim_transfer_params = shim_transfer_params[1]

                shimtile_transfers.append(
                    DataTransfer(
                        tile_cfg.ofm_repeats[(col, row)],
                        shim_tile(col),
                        [shim_idx],
                        larger_ifm_subv_bytes,
                        [],
                        [shim_transfer_params],
                    )
                )

                memtile_write_params = pack_reconfig_transfers(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[i]),
                    [ifm_memory(ifms, shape.ifm_bytes)] * n_phases,
                    # [ifm_mm2s(ifms, shape.ifm_bytes)] * n_phases,
                    normalized_ifm_tilings[(col, row)],
                    bits_per_elem=shape.ifm_bits,  # type: ignore
                    use_iter_step=use_iter_steps[i][(col, row)],
                    buffer_offset=[subv_bytes * row],
                )
                memtile_read_params = pack_reconfig_transfers(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                    [ifm_memory(ifms, shape.ifm_bytes)] * n_phases,
                    [ifm_mm2s(ifms, shape.ifm_bytes)] * n_phases,
                    bits_per_elem=shape.ifm_bits,  # type: ignore
                    use_iter_step=use_iter_steps[i][(col, row)],
                    buffer_offset=[subv_bytes * row],
                )

                memtile_transfers.append(
                    DataTransfer(
                        tile_cfg.ofm_repeats[(col, row)],
                        memory_tile(col),
                        [addr],
                        subv_bytes,
                        [memtile_write_params],
                        [memtile_read_params],
                    )
                )

    return memtile_transfers, shimtile_transfers


def generate_ofm_transfers_l3(
    mapping: BaseMapping,
    shape: BaseShape,
    dims: BaseDims,
    tile_cfg: BroadcastTilingConfig,
    l2_alloc: L2Alloc,
    spill_sync_repeats: list[int],
    spill_repeats: list[int],
    shim_alloc: ShimAllocator,
):
    memtile_transfers, shimtile_transfers = [], []
    ofm, ofms = shape.ofm, mapping.ofm_subv
    ofm_size = math.prod(ofm) * shape.ofm_bytes
    subv_size = math.prod(ofms)
    subv_bytes = subv_size * shape.ofm_bytes
    tile, addr = l2_alloc.ofm_L2_loc

    use_iter_steps = generate_iter_steps(
        tile_cfg, ofm, ofms, enable_fill=False, enable_spill=False)
    ofm_tilings_normalized = generate_tiling_exprs(
        tile_cfg, ofm, ofms, "o", enable_fill=False, enable_spill=False, depad=True, normalize=True)
    ofm_tilings = generate_tiling_exprs(
        tile_cfg, ofm, ofms, "o", enable_fill=False, enable_spill=False, depad=True, normalize=False)

    n_phases = len(tile_cfg.ofm_repeats[(0, 0)])
    for col in range(NUM_COLS):
        for row in range(NUM_ROWS):
            if sum(tile_cfg.ofm_repeats[(col, row)]) == 0:
                continue

            memtile_write_params = pack_reconfig_transfers(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_O_ids()[row]),
                [ofm_memory(ofms)] * n_phases,
                [ofm_s2mm(ofms)] * n_phases,
                bits_per_elem=shape.ofm_bits,  # type: ignore
                use_iter_step=use_iter_steps[(col, row)],
                buffer_offset=[subv_bytes*row],
            )
            memtile_read_params = pack_reconfig_transfers(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[row//2]),
                [ofm_memory(ofms)] * n_phases,
                ofm_tilings_normalized[(col, row)],
                bits_per_elem=shape.ofm_bits,  # type: ignore
                use_iter_step=use_iter_steps[(col, row)],
                buffer_offset=[subv_bytes*row],
                name="bleep"
            )

            memtile_transfers.append(
                DataTransfer(
                    tile_cfg.ofm_repeats[(col, row)],
                    memory_tile(col),
                    [addr],
                    subv_bytes,
                    [memtile_write_params],
                    [memtile_read_params],
                )
            )

            shimtile_transfers.append(
                DataTransfer(
                    tile_cfg.ofm_repeats[(col, row)],
                    shim_tile(col),
                    [shim_alloc.ofm_xrt_idx],
                    ofm_size,
                    [
                        pack_reconfig_transfers(
                            shim_dma(col, DmaDir.S2MM, row//2),
                            [ofm_memory(ofm)] * n_phases,
                            ofm_tilings[(col, row)],
                            bits_per_elem=shape.ofm_bits,
                            buffer_offset=[shim_alloc.ofm_xrt_offset],
                            use_iter_step=use_iter_steps[(col, row)],
                        )
                    ],
                    [],
                )
            )

    return memtile_transfers, shimtile_transfers


@no_type_check
def compile_dataflow(
    schedule_input: ScheduleInputs,
    generate_ifm_transfers: Callable,
    generate_ofm_transfers: Callable,
) -> tuple:
    shape, mapping = schedule_input.shape, schedule_input.mapping
    shim_alloc = L3Alloc_to_Shim(schedule_input.L3_alloc)
    l2_alloc = schedule_input.L2_alloc
    is_sub = "sub" in shape.op_name.lower()

    # hard-coded l2 values for onnx, see https://gitenterprise.xilinx.com/IPSP/aie4_models/issues/444
    # these should match test_broadcast.py
    if not l2_alloc or not l2_alloc.enable_L2_fusion:
        if not l2_alloc:
            # set l2 addresses in the case of L3 dataflow
            ifm_a_l2_addr = 20_000
            ifm_subv_sizes = [math.prod(subv) * shape.ifm_bytes for subv in mapping.ifm_subv]
            ifm_a_l2_size = ifm_subv_sizes[0] * NUM_ROWS
            ifm_L2_loc = [(memory_tile(0), ifm_a_l2_addr), (memory_tile(1), 20000 + ifm_a_l2_size)]
            assert 20_000 + sum(ifm_subv_sizes) * NUM_ROWS < dmacompiler.config.MAX_MEMTILE_ADDR
        else:
            # set l2 addresses for reading model data with l2 fusion
            ifm_L2_loc = [(memory_tile(0), 20000), (memory_tile(1), 20000)]

        l2_alloc = L2Alloc(
            ifm_L2_loc=ifm_L2_loc,
            ofm_L2_loc=(memory_tile(2), 20000),
            wgt_l2_loc=[[memory_tile(0), 4096, 5120], [memory_tile(1), 4096, 5120], [memory_tile(2), 4096, 5120]],
            prm_l2_loc=[[memory_tile(0), 0], [memory_tile(1), 0], [memory_tile(2), 0]],
            enable_ifm_fill=l2_alloc.enable_ifm_fill if l2_alloc else False,
            enable_ofm_spill=l2_alloc.enable_ofm_spill if l2_alloc else False,
            enable_L2_fusion=l2_alloc.enable_L2_fusion if l2_alloc else False,
        )
        schedule_input.L2_alloc = l2_alloc

    log("=== Shim Allocations ===")
    log(f"IFM Shim XRT Idx: {shim_alloc.ifm_xrt_idx}, Offset: {shim_alloc.ifm_xrt_offset}")
    log(f"WGT Shim XRT Idx: {shim_alloc.wgt_xrt_idx}, Offset: {shim_alloc.wgt_xrt_offset}")
    log(f"OFM Shim XRT Idx: {shim_alloc.ofm_xrt_idx}, Offset: {shim_alloc.ofm_xrt_offset}")
    log(f"PRM Shim XRT Idx: {shim_alloc.prm_xrt_idx}, Offset: {shim_alloc.prm_xrt_offset}")

    # log l2 addresses
    log("=== L2 Addresses ===")
    for i, (tile, addr) in enumerate(l2_alloc.ifm_L2_loc):
        log(f"IFM {i} L2 Tile: {tile}, Addr: {addr}")
    log("WGT L2 Tiles and Addrs:")
    for entry in l2_alloc.wgt_l2_loc:
        log(f"  WGT: Tile: {entry[0]}, Ping Addr: {entry[1]}, Pong Addr: {entry[2]}")
    log(f"OFM L2 Tile: {l2_alloc.ofm_L2_loc[0]}, Addr: {l2_alloc.ofm_L2_loc[1]}")
    log("PRM L2 Tiles and Addrs:")
    for entry in l2_alloc.prm_l2_loc:
        log(f"  PRM Tile: {entry[0]}, Addr: {entry[1]}")
    log("=== L1 Addresses ===")
    for name, l1_alloc in schedule_input.mapping.l1_alloc.items():
        log(f"{name} L1 Alloc: {l1_alloc}")

    QDQ_PARAM_SIZE = 128  # initial qdq parameters
    DQ_BUF_SIZE = Q_BUF_SIZE = 512
    q_buf_offset = (
        QDQ_PARAM_SIZE  # qdq params pad the first 128 bytes of the wgt buffer before the 1024 of scratch space
    )
    dq_buf_offset = q_buf_offset + Q_BUF_SIZE
    WGT_SIZE = QDQ_PARAM_SIZE + Q_BUF_SIZE + DQ_BUF_SIZE
    kernel_gran = 64
    # C is broadcasted on second ifm
    ifm_a_size = math.prod(mapping.ifm_subv[0]) * shape.ifm_bytes
    ifm_b_size = math.prod(mapping.ifm_subv[1]) * shape.ifm_bytes

    assert isinstance(shape.ifm, list)
    assert isinstance(shape.ifm[0], tuple)
    assert isinstance(
        shape.ofm[0], int), f"shape.ofm[0] must be int, got {shape.ofm}"

    dims = BaseDims.from_shape_and_mapping(shape, mapping)
    tiling_cfg = BroadcastTilingConfig.from_shape_and_mapping(shape, mapping)

    num_phases = len(
        tilings(
            tiling_cfg,
            shape.ofm,
            mapping.ofm_subv,
            col=0,
            row=0,
        )
    )
    # first start with LayerParameters
    layer_param_repeat_count = [1] + [0] * (num_phases - 1)
    ifm_fill_repeats = [0] * num_phases
    fill_sync_repeat_count = [0] * num_phases
    spill_sync_repeat_count = [0] * num_phases
    ofm_spill_repeats = [0] * num_phases
    if l2_alloc.enable_ifm_fill:
        layer_param_repeat_count = [0, 0] + layer_param_repeat_count
        ifm_fill_repeats = [1, 0] + ifm_fill_repeats
        fill_sync_repeat_count = [0, 1] + fill_sync_repeat_count
        spill_sync_repeat_count = [0, 0] + spill_sync_repeat_count
        ofm_spill_repeats = [0, 0] + ofm_spill_repeats
    if l2_alloc.enable_ofm_spill:
        layer_param_repeat_count = layer_param_repeat_count + [0, 0]
        ifm_fill_repeats = ifm_fill_repeats + [0, 0]
        fill_sync_repeat_count = fill_sync_repeat_count + [0, 0]
        spill_sync_repeat_count = spill_sync_repeat_count + [1, 0]
        ofm_spill_repeats = ofm_spill_repeats + [0, 1]

    log("=== Scheduler Information ===")
    log(f"IFM A Shape: {shape.ifm[0]}, IFM B Shape: {shape.ifm[1]}, OFM Shape: {shape.ofm}")
    log(f"Num Phases: {num_phases}")
    log(f"subvolume IFM A: {mapping.ifm_subv[0]}")
    log(f"subvolume IFM B: {mapping.ifm_subv[1]}")
    log(f"subvolume OFM: {mapping.ofm_subv}")
    log(f"Spatial Split: {mapping.spatial_split}")

    for col in range(NUM_COLS):
        for row in range(NUM_ROWS):
            log(
                f"Core ({col}, {row}) OFM repeats per phase: {tiling_cfg.ofm_repeats[(col, row)]}")
    assert mapping.wgt_L1_size == WGT_SIZE

    # rename L2 info for readability
    prm_memtile_size = compute_buffer_size(prm_memtile_memory())
    wgt_memtile_size = dims.wgt_size if not mapping.wgt_L1_size else mapping.wgt_L1_size
    # Use the keys as the tile index
    wgt_L2_alloc_tiles = [entry[0] for entry in l2_alloc.wgt_l2_loc]
    wgt_L2_ping_addrs = [entry[1] for entry in l2_alloc.wgt_l2_loc]
    wgt_L2_pong_addrs = [entry[2] for entry in l2_alloc.wgt_l2_loc]
    # Use the keys as the tile index
    prm_L2_alloc_tiles = [entry[0] for entry in l2_alloc.prm_l2_loc]
    prm_L2_addrs = [entry[1] for entry in l2_alloc.prm_l2_loc]

    def generate_core_instructions(col: int, row: int):
        iterations = sum(tiling_cfg.ofm_repeats[(col, row)])
        log(f"# iterations for core ({col}, {row}): {iterations}")

        # add temp buffer if going from 8->16 bit
        ifm_a_tmp_size = ifm_a_size * 2 if shape.ifm_bytes == 1 else 0

        ifm_a_config = generate_core_buffer_config(
            core_dma(col, row, DmaDir.S2MM, 0),
            mapping.ifm_L1_ping_addr,
            mapping.ifm_L1_pong_addr,
            ifm_memory(mapping.ifm_subv[0], shape.ifm_bytes),
            ifm_s2mm(mapping.ifm_subv[0], shape.ifm_bytes),
            bits_per_block=shape.ifm_bits,
        )
        dq_padding_offset = 12*32*2
        ifm_b_config = generate_core_buffer_config(
            core_dma(col, row, DmaDir.S2MM, 0),
            mapping.ifm_L1_ping_addr + iceil(ifm_a_size, 128) + iceil(ifm_a_tmp_size, 128)+dq_padding_offset,
            mapping.ifm_L1_pong_addr + iceil(ifm_a_size, 128) + iceil(ifm_a_tmp_size, 128)+dq_padding_offset,
            ifm_memory(mapping.ifm_subv[1], shape.ifm_bytes),
            ifm_s2mm(mapping.ifm_subv[1], shape.ifm_bytes),
            bits_per_block=shape.ifm_bits,
        )

        log(f"IFM A CORE offsets: ping {mapping.ifm_L1_ping_addr}, pong {mapping.ifm_L1_pong_addr}, size {ifm_a_size}")
        log(
            f"IFM B CORE offsets: ping {mapping.ifm_L1_ping_addr + ifm_b_size}, pong {mapping.ifm_L1_pong_addr + ifm_b_size}, size {ifm_b_size}"
        )

        log(f"mapping.ofm_L1_ping_addr: {mapping.ofm_L1_ping_addr}")
        ofm_config = generate_core_buffer_config(
            core_dma(col, row, DmaDir.MM2S, 0),
            mapping.ofm_L1_ping_addr,
            mapping.ofm_L1_pong_addr,
            ofm_memory(mapping.ofm_subv),
            ofm_s2mm(mapping.ofm_subv),
            bits_per_block=shape.ofm_bits,
        )
        wgt_instrs = [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 1),
                         mapping.wgt_L1_ping_addr, None, WGT_SIZE),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        ]

        if not iterations:
            return wgt_instrs

        return wgt_instrs + [
            ofm_config,
            Loop(
                iterations,
                [
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    ifm_a_config,
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ifm_b_config,
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    CallKernel(
                        shape.call_kernel,
                        generate_broadcast_params(
                            shape.ifm_bytes,
                            shape.ofm_bytes,
                            mapping.ifm_subv,
                            mapping.ofm_subv,
                            kernel_gran,
                            q_buf_offset,
                            dq_buf_offset,
                            shape.has_scalar_broadcast,
                            is_sub,
                            shape.sign_A,
                            shape.sign_W,
                            shape.sign_O,
                        ),
                    ),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                ],
            ),  # type: ignore
        ]

    instr_dict = {
        AieTile(TileType.Core, col, row): generate_core_instructions(col, row)
        for col in range(NUM_COLS)
        for row in range(NUM_ROWS)
    }

    prm_shimtile_transfers = [
        generate_shim_data_transfer(
            layer_param_repeat_count,
            shim_dma(col, DmaDir.MM2S, 0),
            shim_alloc.prm_xrt_idx,
            prm_shim_memory(),
            prm_shim_mm2s(col),
            buffer_offset=shim_alloc.prm_xrt_offset,
        )
        for col in range(NUM_COLS)
    ]

    wgt_shimtile_transfers = [
        generate_shim_data_transfer(
            layer_param_repeat_count,
            shim_dma(col, DmaDir.MM2S, 1),
            shim_alloc.wgt_xrt_idx,
            wgt_shim_memory(WGT_SIZE),
            wgt_shim_mm2s(WGT_SIZE),
            buffer_offset=shim_alloc.wgt_xrt_offset,
        )
        for col in range(NUM_COLS)
    ]
    prm_memtile_transfers = [
        DataTransfer(
            layer_param_repeat_count,
            prm_L2_alloc_tiles[col],
            [prm_L2_addrs[col]],
            prm_memtile_size,
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                    prm_memtile_memory(),
                    prm_memtile_s2mm(),
                )
            ],
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                    prm_memtile_memory(),
                    prm_memtile_mm2s(row),
                )
                for row in range(NUM_ROWS)
            ],
            sync_strategy=SyncStrategy.Serial_M_to_N,
        )
        for col in range(NUM_COLS)
    ]

    wgt_memtile_transfers = [
        DataTransfer(
            layer_param_repeat_count,
            wgt_L2_alloc_tiles[col],
            [wgt_L2_ping_addrs[col], wgt_L2_pong_addrs[col]],
            wgt_memtile_size,
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]),
                    wgt_memtile_memory(wgt_memtile_size),
                    wgt_memtile_s2mm(wgt_memtile_size),
                )
            ],
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_B_ids(col)[0]),
                    wgt_memtile_memory(wgt_memtile_size),
                    wgt_memtile_s2mm(wgt_memtile_size),
                ),
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_B_ids(col)[1]),
                    wgt_memtile_memory(wgt_memtile_size),
                    wgt_memtile_s2mm(wgt_memtile_size),
                ),
            ],
        )
        for col in range(NUM_COLS)
    ]

    ifm_mem_transfers, ifm_shim_transfers = generate_ifm_transfers(
        mapping,
        shape,
        dims,
        tiling_cfg,
        l2_alloc,
        fill_sync_repeat_count,
        ifm_fill_repeats,
        shim_alloc,
    )
    ofm_mem_transfers, ofm_shim_transfers = generate_ofm_transfers(
        mapping,
        shape,
        dims,
        tiling_cfg,
        l2_alloc,
        spill_sync_repeat_count,
        ofm_spill_repeats,
        shim_alloc,
    )

    memtile_transfers = prm_memtile_transfers + \
        wgt_memtile_transfers + ifm_mem_transfers + ofm_mem_transfers
    shimtile_transfers = prm_shimtile_transfers + \
        wgt_shimtile_transfers + ifm_shim_transfers + ofm_shim_transfers

    run_layer_compilation(
        OverlayShape(NUM_COLS, NUM_ROWS),
        schedule_input.kernel_names,
        schedule_input.kernel_includes,
        instr_dict,
        memtile_transfers,
        shimtile_transfers,
        overlay_3x4_dma_connections(),
        core_stack_addr=overlay_3x4_core_stack_addr(),
        param_channel_id=overlay_3x4_param_channel_id(),
        back_end=schedule_input.backend,
        layer_file=schedule_input.layer_file_name,
        enable_task_queue_optimization=False,  # these are issues with task queue optimizations for l3
    )

    shim_wgt_size = compute_buffer_size(
        wgt_shim_memory(WGT_SIZE),
    )
    shim_prm_size = compute_buffer_size(prm_shim_memory())

    shim_prm_offset_next_layer = shim_alloc.prm_xrt_offset + shim_prm_size
    shim_wgt_offset_next_layer = shim_alloc.wgt_xrt_offset + shim_wgt_size
    return shim_prm_offset_next_layer, shim_wgt_offset_next_layer


def compile_L2_dataflow(
    schedule_input: ScheduleInputs,
) -> tuple:
    return compile_dataflow(
        schedule_input,
        generate_ifm_transfers_l2,
        generate_ofm_transfers_l2,
    )


def compile_L3_dataflow(
    schedule_input: ScheduleInputs,
) -> tuple:
    schedule_input.L2_alloc = None
    return compile_dataflow(
        schedule_input,
        generate_ifm_transfers_l3,
        generate_ofm_transfers_l3,
    )


YDIM, XDIM, CDIM = 0, 1, 2


@pytest.mark.scheduler
class TestBroadcastTilings:
    def fill_tensor_with_tiling(self, tensor: np.ndarray, tiling: TilingND):
        y, x, c = tiling.tilings  # type: ignore
        ystride = y.end - y.start if not y.stride else y.stride
        xstride = x.end - x.start if not x.stride else x.stride
        cstride = c.end - c.start if not c.stride else c.stride
        if not ystride or not xstride or not cstride:
            return
        for yi in range(y.start + y.core_offset, y.end + y.core_offset, ystride):
            for xi in range(x.start + x.core_offset, x.end + x.core_offset, xstride):
                for ci in range(c.start + c.core_offset, c.end + c.core_offset, cstride):
                    tensor[
                        yi: yi + y.num_elements,
                        xi: xi + x.num_elements,
                        ci: ci + c.num_elements,
                    ] = 1

    def _assert_broadcast_fill(
        self,
        ofm,
        ofms,
        ifm_a,
        ifma_s,
        ifm_b,
        ifmb_s,
        split,
    ):
        loops = [ceildiv(dim, subdim * _split)
                 for dim, subdim, _split in zip(ofm, ofms, split)]

        cfg = BroadcastTilingConfig(
            loops=list(loops),
            split=list(split),
            ofm=list(ofm),
            ofms=list(ofms),
            ifm=[list(ifm_a), list(ifm_b)],
        )

        ofm_output = np.zeros(tuple(ofm))
        ifm_a_output = np.zeros(tuple(ifm_a))
        ifm_b_output = np.zeros(tuple(ifm_b))

        num_rows = NUM_ROWS
        num_cols = NUM_COLS

        for row in range(num_rows):  # type: ignore
            for col in range(num_cols):  # type: ignore
                ofm_phases = tilings(cfg, ofm, ofms, col=col, row=row)
                ifm_a_phases = tilings(cfg, ifm_a, ifma_s, col=col, row=row)
                ifm_b_phases = tilings(cfg, ifm_b, ifmb_s, col=col, row=row)

                assert len(ofm_phases) == len(
                    ifm_a_phases) == len(ifm_b_phases)

                for o_phase, ia_phase, ib_phase in zip(ofm_phases, ifm_a_phases, ifm_b_phases):
                    assert ib_phase.subv_size() == math.prod(ifmb_s) or ib_phase.is_dummy()
                    assert ia_phase.subv_size() == math.prod(ifma_s) or ia_phase.is_dummy()
                    assert o_phase.subv_size() == math.prod(ofms) or o_phase.is_dummy()
                    self.fill_tensor_with_tiling(ofm_output, o_phase)
                    self.fill_tensor_with_tiling(ifm_a_output, ia_phase)
                    self.fill_tensor_with_tiling(ifm_b_output, ib_phase)

        assert np.array_equal(ofm_output, np.ones(tuple(ofm)))
        assert np.array_equal(ifm_a_output, np.ones(tuple(ifm_a)))
        assert np.array_equal(ifm_b_output, np.ones(tuple(ifm_b)))

    def test_binary_fill(self):
        nonpadded_ifm = [1, 32, 18]
        padded_ifm = [1, 33, 19]
        split = [1, 4, 3]
        for ifm in [nonpadded_ifm, padded_ifm]:
            ifms = [1, 4, 3]
            ofm = ifm
            ofms = ifms
            self._assert_broadcast_fill(ofm, ofms, ifm, ifms, ifm, ifms, split)

    def test_broadcast_fill(self):
        ifm_a = [1, 64, 1]
        ifm_b = [1, 1, 3072]
        ofm = [1, 64, 3072]
        ifma_s = [1, 64, 1]
        ifmb_s = [1, 1, 256]
        ofms = [1, 64, 256]
        split = [1, 1, 12]
        self._assert_broadcast_fill(
            ofm, ofms, ifm_a, ifma_s, ifm_b, ifmb_s, split,)

    def test_broadcast_pattern_no_broadcast(self):
        """Pattern: (1, X, C) inputs with no broadcasting."""
        ifm_a = [1, 64, 3072]
        ifm_b = [1, 64, 3072]
        ofm = [1, 64, 3072]
        ifma_s = [1, 8, 512]
        ifmb_s = [1, 8, 512]
        ofms = [1, 8, 512]
        split = [1, 2, 6]
        self._assert_broadcast_fill(
            ofm,
            ofms,
            ifm_a,
            ifma_s,
            ifm_b,
            ifmb_s,
            split,
        )

    def test_broadcast_pattern_ifm_a_spatial(self):
        """Pattern: IFM A broadcast over the spatial X dimension."""
        ifm_a = [1, 1, 3072]
        ifm_b = [1, 64, 3072]
        ofm = [1, 64, 3072]
        ifma_s = [1, 1, 512]
        ifmb_s = [1, 8, 512]
        ofms = [1, 8, 512]
        split = [1, 2, 6]
        self._assert_broadcast_fill(
            ofm,
            ofms,
            ifm_a,
            ifma_s,
            ifm_b,
            ifmb_s,
            split,
        )

    def test_broadcast_pattern_ifm_b_spatial(self):
        """Pattern: IFM B broadcast over the spatial X dimension."""
        ifm_a = [1, 64, 3072]
        ifm_b = [1, 1, 3072]
        ofm = [1, 64, 3072]
        ifma_s = [1, 8, 512]
        ifmb_s = [1, 1, 512]
        ofms = [1, 8, 512]
        split = [1, 2, 6]
        self._assert_broadcast_fill(
            ofm,
            ofms,
            ifm_a,
            ifma_s,
            ifm_b,
            ifmb_s,
            split,
        )

    def test_broadcast_pattern_channel_b(self):
        """Pattern: IFM B broadcast over the channel dimension."""
        ifm_a = [1, 64, 3072]
        ifm_b = [1, 64, 1]
        ofm = [1, 64, 3072]
        ifma_s = [1, 8, 512]
        ifmb_s = [1, 8, 1]
        ofms = [1, 8, 512]
        split = [1, 2, 6]
        self._assert_broadcast_fill(
            ofm,
            ofms,
            ifm_a,
            ifma_s,
            ifm_b,
            ifmb_s,
            split,
        )

    def test_broadcast_pattern_spatial_a_channel_b(self):
        """Pattern: IFM A broadcast over X, IFM B broadcast over channel."""
        ifm_a = [1, 1, 3072]
        ifm_b = [1, 64, 1]
        ofm = [1, 64, 3072]
        ifma_s = [1, 1, 512]
        ifmb_s = [1, 8, 1]
        ofms = [1, 8, 512]
        split = [1, 2, 6]
        self._assert_broadcast_fill(
            ofm,
            ofms,
            ifm_a,
            ifma_s,
            ifm_b,
            ifmb_s,
            split,
        )

    def test_broadcast_pattern_fully_broadcast(self):
        """Pattern: IFM A matches OFM, IFM B broadcast over X and channel."""
        ifm_a = [1, 1, 3072]
        ifm_b = [1, 1, 1]
        ofm = [1, 1, 3072]
        ifma_s = [1, 1, 512]
        ifmb_s = [1, 1, 1]
        ofms = [1, 1, 512]
        split = [1, 2, 6]
        self._assert_broadcast_fill(
            ofm,
            ofms,
            ifm_a,
            ifma_s,
            ifm_b,
            ifmb_s,
            split,
        )


@pytest.mark.scheduler
class TestTilingDepadNormalize:

    def test_loop_depad_1d_normalize(self):
        """Confirm a non-padded looping expression depads into a single subvolume"""
        fm, fms = (65,), (8,)
        loop_tile_core_0 = TilingND((Tiling1D(
            start=0,
            end=64,
            num_elements=8,
            stride=24,
            core_offset=0,
            max_dim=fm[0],
            num_repeats=2
        ),))
        depad_tile_core_0 = depad_tiling(None, fm, fms, loop_tile_core_0, normalize=True).tilings[0]
        assert depad_tile_core_0.start == 0, f"got {depad_tile_core_0.start}"
        assert depad_tile_core_0.end == 8, f"got {depad_tile_core_0.end}"
        assert depad_tile_core_0.num_elements == 8, f"got {depad_tile_core_0.num_elements}"
        assert not depad_tile_core_0.stride, f"got {depad_tile_core_0.stride}"

    def test_pad_depad_1d_normalize(self):
        """Confirm a padded tiling expression depads into a non-dummy, unpadded subvolume"""
        fm, fms = (65,), (8,)
        pad_tile_core_0 = TilingND((Tiling1D(
            start=64,
            end=128,
            num_elements=8,
            stride=None,
            core_offset=0,
            max_dim=fm[0],
            num_repeats=1
        ),))
        depad_tile_core_0 = depad_tiling(None, fm, fms, pad_tile_core_0, normalize=True).tilings[0]
        assert not depad_tile_core_0.is_dummy()
        assert depad_tile_core_0.start == 0, f"got {depad_tile_core_0.start}"
        assert depad_tile_core_0.end == 1, f"got {depad_tile_core_0.end}"
        assert depad_tile_core_0.num_elements == 1, f"got {depad_tile_core_0.num_elements}"
        assert not depad_tile_core_0.stride, f"got {depad_tile_core_0.stride}"

    def test_loop_depad_1d_no_normalize(self):
        """Confirm a non-padded looping expression depads into multiple subvolumes without normalization"""
        fm, fms = (65,), (8,)
        loop_tile_core_0 = TilingND((Tiling1D(
            start=0,
            end=64,
            num_elements=8,
            stride=24,
            core_offset=0,
            max_dim=fm[0],
            num_repeats=2
        ),))
        depad_tile_core_0 = depad_tiling(None, fm, fms, loop_tile_core_0, normalize=False).tilings[0]
        assert depad_tile_core_0.start == 0, f"got {depad_tile_core_0.start}"
        assert depad_tile_core_0.end == 64, f"got {depad_tile_core_0.end}"
        assert depad_tile_core_0.num_elements == 8, f"got {depad_tile_core_0.num_elements}"
        assert depad_tile_core_0.stride == 24, f"got {depad_tile_core_0.stride}"

    def test_pad_depad_1d_no_normalize(self):
        """Confirm a padded tiling expression depads into a padded subvolume without normalization"""
        fm, fms = (65,), (8,)
        pad_tile_core_0 = TilingND((Tiling1D(
            start=64,
            end=128,
            num_elements=8,
            stride=None,
            core_offset=0,
            max_dim=fm[0],
            num_repeats=1
        ),))
        depad_tile_core_0 = depad_tiling(None, fm, fms, pad_tile_core_0, normalize=False).tilings[0]
        assert not depad_tile_core_0.is_dummy()
        assert depad_tile_core_0.start == 64, f"got {depad_tile_core_0.start}"
        assert depad_tile_core_0.end == 65, f"got {depad_tile_core_0.end}"
        assert depad_tile_core_0.num_elements == 1, f"got {depad_tile_core_0.num_elements}"
        assert depad_tile_core_0.stride is None, f"got {depad_tile_core_0.stride}"

    def test_coreoffset_1d_no_normalize(self):
        """Confirm a tiling expression with core offset depads correctly without normalization"""
        fm, fms = (65,), (8,)
        tile_core_1 = TilingND((Tiling1D(
            start=0,
            end=64,
            num_elements=8,
            stride=24,
            core_offset=8,
            max_dim=fm[0],
            num_repeats=2
        ),))
        depad_tile_core_1 = depad_tiling(None, fm, fms, tile_core_1, normalize=False).tilings[0]
        assert not depad_tile_core_1.is_dummy()
        assert depad_tile_core_1.start == 8, f"got {depad_tile_core_1.start}"
        assert depad_tile_core_1.end == 64, f"got {depad_tile_core_1.end}"
        assert depad_tile_core_1.num_elements == 8, f"got {depad_tile_core_1.num_elements}"
        assert depad_tile_core_1.stride == 24, f"got {depad_tile_core_1.stride}"

    def test_2d_depad_normalize_integrated(self):
        """Confirm a 2D tiling expression depads and normalizes correctly:

        Valid Data enclosed in the box, phases separated by periods

          phase 5                  phase 6
          30 31 32 30 31 32 . 30   31 32
          20 21 22 20 21 22 . 20   21 22
          10 11 12 10 11 12 . 10   11 12
        +------------------------+
        | 00 01 02 00 01 02 . 00 | 01 02
        | .......................|.....
        | phase 3           .    | phase 4
        | 30 31 32 30 31 32 . 30 | 31 32
        | 20 21 22 20 21 22 . 20 | 21 22
        | 10 11 12 10 11 12 . 10 | 11 12
        | 00 01 02 00 01 02 . 00 | 01 02
        | .......................| .....
        |   phase 1         .    | phase 2
        | 30 31 32 30 31 32 . 30 | 31 32
        | 20 21 22 20 21 22 . 20 | 21 22
        | 10 11 12 10 11 12 . 10 | 11 12
        | 00 01 02 00 01 02 . 00 | 01 02
        +----------------------+
        """
        fm, fms = (1, 49, 65), (1, 8, 8)
        cfg = BroadcastTilingConfig(
            loops=[1, 3, 3],
            ofm=list(fm),
            ofms=list(fms),
            split=[1, 3, 4],
            ifm=[list(fm), list(fm)],  # dummy values to initialize broadcasts to false
        )

        tilings_0_0 = tilings(cfg, fm, fms, col=0, row=0)
        # We still don't have https://gitenterprise.xilinx.com/IPSP/dmacompiler/pull/273
        # so we can't check the actual looped phases.
        # Let's just check the last tiling:
        padded_tile = tilings_0_0[-1]
        depad_tile = depad_tiling(None, fm, fms, padded_tile, normalize=True)
        assert not depad_tile.is_dummy()
        y_tiling, x_tiling, c_tiling = depad_tile.tilings

        assert y_tiling.start == 0, f"got {y_tiling.start}"
        assert y_tiling.end == 1, f"got {y_tiling.end}"
        assert y_tiling.num_elements == 1, f"got {y_tiling.num_elements}"
        assert not y_tiling.stride, f"got {y_tiling.stride}"

        assert x_tiling.start == 0, f"got {x_tiling.start}"
        assert x_tiling.end == 1, f"got {x_tiling.end}"
        assert x_tiling.num_elements == 1, f"got {x_tiling.num_elements}"
        assert not x_tiling.stride, f"got {x_tiling.stride}"

        assert c_tiling.start == 0, f"got {c_tiling.start}"
        assert c_tiling.end == 1, f"got {c_tiling.end}"
        assert c_tiling.num_elements == 1, f"got {c_tiling.num_elements}"
        assert not c_tiling.stride, f"got {c_tiling.stride}"

        # Also check a dummy tiling
        tilings_2_3 = tilings(cfg, fm, fms, col=2, row=3)
        dummy_tile = tilings_2_3[-1]
        assert dummy_tile.is_dummy(), f"got {dummy_tile}"
