'''
Map conv shapes to the AIE-4 dataflow architecture.
External facing functions are documented below.

    compile_L2_dataflow - given a shape and mapping, compile the data movement
    code to orchestrate the DMA heirarchy
'''
from typing import no_type_check
from dmacompiler import (
    DevGen, set_dev_gen, SyncStrategy,
    OverlayShape, DataTransfer, TransferParams,
    AieDma, DmaDir, DmaChannel,
    core_dma, memtile_dma, shim_dma, core_tile,
    AcqBuffer, RelBuffer,
    run_layer_compilation,
    generate_transfer_params,
    generate_shim_data_transfer,
    generate_core_buffer_config,
    compute_buffer_size,
    CallKernel,
    pack_reconfig_transfers,
    generate_memtile_data_transfer,
    Loop,
    memory_tile,
)

from utils.utils_common import (
    overlay_3x4_core_stack_addr,
    BaseDims, log,
    split_to_mode,
    core_to_split,
    L2Alloc,
)
from scheduler.common import (
    overlay_3x4_dma_connections,
    overlay_3x4_param_channel_id,
    overlay_3x4_A_ids,
    overlay_3x4_F_ids,
    overlay_3x4_O_ids,
    overlay_3x4_S_ids,
    prm_memtile_memory,
    prm_shim_memory,
    prm_memtile_mm2s,
    prm_memtile_s2mm,
    prm_shim_mm2s,
    broadcast_channels,
    TensorSlicer,
    L3Alloc,
    unicast_channels,
    L3Alloc_to_Shim,
)

from scheduler.conv.conv_config_builders import (
    ConvShape,
    ConvMapping,
    ConvDims,
    conv_input
)

from buildscripts.common import ScheduleInputs

from kernel.conv.conv_noqdq_a8w8_params import generate_conv_noqdq_a8w8_params

set_dev_gen(DevGen.Aie4)


def ifm_memtile_channels(dims: ConvDims, col: int) -> list[tuple[int, int]]:
    '''Map spatial split and column to IFM channel allocations (row, id)'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: unicast_channels(),
        1: broadcast_channels(col),
    }
    channels = channel_lookup[mode]
    return channels


def wgt_memtile_channels(dims: ConvDims, col: int, enable_wgt_channel_sharing: bool, co_split_idx: int = 0) -> list[tuple[int, int]]:
    '''Map spatial split and column to WGT channel allocations (row, id) for specific Co split'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: broadcast_channels(col),
        1: unicast_channels(),
    }
    all_channels = channel_lookup[mode]
    if not enable_wgt_channel_sharing:
        return all_channels
    # Get the Co split indices for this column
    Co_idxs = Co_split_idxs(dims, col)
    unique_co_idxs = sorted(set(Co_idxs))
    # Filter channels based on which cores handle this co_split_idx
    if co_split_idx < len(unique_co_idxs):
        target_co_idx = unique_co_idxs[co_split_idx]
        # Find which rows (cores) handle this Co split
        filtered_channels: list[tuple[int, int]] = []
        for row in range(dims.aie_rows):
            _, _, _, core_co_idx = core_to_split(dims, col, row)
            if core_co_idx == target_co_idx:
                # Find the corresponding channel for this row
                for channel_row, channel_id in all_channels:
                    if channel_row == row:
                        filtered_channels.append((channel_row, channel_id))
                        break
        return filtered_channels
    raise ValueError(f"co_split_idx {co_split_idx} out of range for unique Co indices {unique_co_idxs}")


def ofm_memtile_channels() -> list[tuple[int, int]]:
    '''Generate OFM channels allocations (row, id)'''
    return list(enumerate(overlay_3x4_O_ids()))


def ifm_core_channel(dims: ConvDims) -> int:
    '''Map spatial split to IFM core channel ID allocation'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: 0,
        1: 1,
    }
    channel = channel_lookup[mode]
    return channel


def wgt_core_channel(dims: ConvDims) -> int:
    '''Map spatial split to WGT core channel ID allocation'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: 1,
        1: 0,
    }
    channel = channel_lookup[mode]
    return channel


# Backward compatibility functions - these use the TensorSlicer class
def Yi_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Yi at core (col, row) during iteration i of Y_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Yi_slice(col, row, i)


def Xi_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Xi at core (col, row) during iteration i of X_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Xi_slice(col, row, i)


def Yo_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Yo at core (col, row) during iteration i of Y_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Yo_slice(col, row, i)


def Xo_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Xo at core (col, row) during iteration i of the X_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Xo_slice(col, row, i)


def Co_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Co at core (col, row) during iteration i of the Co_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Co_slice(col, row, i)


def Co_split_idxs(dims: ConvDims, col: int) -> list[int]:
    '''Generate Cout split indicies for a given column'''
    idxs = [Co_idx
            for row in range(dims.aie_rows)
            for _, _, _, Co_idx in (core_to_split(dims, col, row),)]
    return idxs


def Co_split_size(dims: ConvDims) -> int:
    '''Calculate the number of Cout split blocks for a single column'''
    # NOTE: The Cout split size will be regular across all columns,
    # so it's safe to just use column zero.
    idxs = Co_split_idxs(dims, 0)
    size = max(idxs) - min(idxs) + 1
    return size


def Co_split_offset(dims: ConvDims, col: int, row: int) -> int:
    '''Calculate the relative Cout split offset within a column for the core at (col, row)'''
    idxs = Co_split_idxs(dims, col)
    _, _, _, Co_idx = core_to_split(dims, col, row)
    offset = Co_idx - min(idxs)
    return offset


#
# IFM Memory Formats
#


def ifm_core_memory(dims: ConvDims) -> str:
    '''Define IFM L1 data order and shape'''
    return f'Ci:{dims.Cis} Yi:{dims.Yis} Xi:{dims.Xis} Ci:{dims.Ci_gran}'


def ifm_memtile_memory(dims: ConvDims) -> str:
    '''Define IFM L2 data order and shape'''
    return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'


def ifm_shim_memory(shape: ConvShape) -> str:
    '''Define IFM DDR data order and shape'''
    Yi, Xi, Ci = shape.ifm
    return f'Yi:{Yi} Xi:{Xi} Ci:{Ci}'


#
# IFM Access Patterns
#


def ifm_core_s2mm(dims: ConvDims, col: int, row: int) -> str:
    '''IFM core s2mm access pattern for tile at (col, row)'''
    # NOTE: We require that the core access pattern doesn't change
    # in time, so it's safe to just use the slice in iteration 0
    # for size calculations.
    Yi_start, Yi_stop, _ = Yi_slice(dims, col, row, 0)
    Xi_start, Xi_stop, _ = Xi_slice(dims, col, row, 0)
    Yis = Yi_stop - Yi_start
    Xis = Xi_stop - Xi_start
    return f'Ci:0:{dims.Cis}:{dims.Ci_gran} Yi:0:{Yis} Xi:0:{Xis} Ci:0:{dims.Ci_gran}'


def ifm_memtile_mm2s(dims: ConvDims, col: int, row: int, i: int, j: int) -> list[str]:
    '''IFM memtile mm2s access pattern chain for core (col, row) at iteration (i, j)'''
    def Ci_split_iters() -> list[tuple[int, int, int, int]]:
        Ci = dims.Ci
        Cis = dims.Cis
        Ci_gran = dims.Ci_gran
        Ci_loop = dims.Ci_loop
        has_padding = (Ci_loop * Cis) > Ci
        if not has_padding:
            splits = [(0, Ci_gran, Cis, Ci_loop)]
        else:
            Ci_final_start = Cis * (Ci_loop - 1)
            Ci_final_stop_no_pad = ((Ci - 1) // Ci_gran) * Ci_gran
            Ci_final_stop_padded = Ci_final_start + Cis
            splits = []
            if Ci_loop > 1:
                # traverse unpadded subvolumes
                splits.append((0, Ci_gran, Cis, Ci_loop - 1))
            if Ci_final_stop_no_pad > Ci_final_start:
                # traverse unpadded data in final subvolume
                start = Ci_final_start
                stop = Ci_final_start + Ci_gran
                size = Ci_final_stop_no_pad - Ci_final_start
                splits.append((start, stop, size, 1))
            if Ci_final_stop_padded > Ci_final_stop_no_pad:
                # traverse padded data in final subvolume
                start = Ci_final_stop_no_pad
                stop = Ci_final_stop_padded
                size = Ci_gran
                splits.append((start, stop, size, 1))
        return splits
    Yi_start, Yi_stop, _ = Yi_slice(dims, col, row, i)
    Xi_start, Xi_stop, _ = Xi_slice(dims, col, row, j)
    return [(
        f'Ci:0:{n * Cis}:{Cis} '
        f'Ci:0:{Cis}:{dims.Ci_gran} '
        f'Yi:{Yi_start}:{Yi_stop} '
        f'Xi:{Xi_start}:{Xi_stop} '
        f'Ci:{Ci_start}:{Ci_stop}'
    ) for Ci_start, Ci_stop, Cis, n in Ci_split_iters()]


def ifm_memtile_s2mm(shape: ConvShape) -> str:
    '''IFM memtile s2mm access pattern'''
    # NOTE: this only runs on one column to fill the full IFM buffer
    Yi, Xi, Ci = shape.ifm
    return f'Yi:0:{Yi} Xi:0:{Xi} Ci:0:{Ci}'


def ifm_shim_mm2s(shape: ConvShape) -> str:
    '''IFM shim mm2s access pattern'''
    # NOTE: this only runs on one column to fill the full IFM buffer
    Yi, Xi, Ci = shape.ifm
    return f'Yi:0:{Yi} Xi:0:{Xi} Ci:0:{Ci}'


#
# WGT Memory Formats
#


def wgt_core_memory(dims: ConvDims) -> str:
    '''Define WGT L1 data order'''
    return f'Subv:{dims.wgt_L1_size}'


def wgt_memtile_memory(dims: ConvDims) -> str:
    '''Define WGT L2 data order'''
    return f'Cob:{Co_split_size(dims)} Subv:{dims.wgt_L1_size}'


def wgt_shim_memory(dims: ConvDims) -> str:
    '''Define WGT DDR data order'''
    # NOTE: The weights are pre-formatted in blocks according to the subvolume size.
    # We traverse first in Cout, then in Cin. This is a very intentional decision,
    # since the traversal along Cout for Co_split can be folded into a linear access.
    # This is required to keep the shim access pattern within five dimensions.
    return f'Cib:{dims.Ci_loop} Cob:{dims.Co_loop * dims.Co_split} Subv:{dims.wgt_L1_size}'


#
# WGT Access Patterns
#

def wgt_core_s2mm(dims: ConvDims) -> str:
    '''WGT core s2mm access pattern'''
    return f'Subv:0:{dims.wgt_L1_size}'


def wgt_memtile_s2mm(dims: ConvDims, enable_wgt_channel_sharing: bool, co_split_idx: int = 0) -> str:
    '''WGT memtile s2mm access pattern for a specific Co split index'''
    co_split_size = Co_split_size(dims)
    # Calculate the specific Cob range for this co_split_idx
    if not enable_wgt_channel_sharing:
        Cob_size = Co_split_size(dims)
        return f'Cob:0:{Cob_size} Subv:0:{dims.wgt_L1_size}'
    if co_split_idx < co_split_size:
        Cob_start = co_split_idx
        Cob_stop = co_split_idx + 1
        return f'Cob:{Cob_start}:{Cob_stop} Subv:0:{dims.wgt_L1_size}'
    raise ValueError(f"co_split_idx {co_split_idx} out of range for Co_split_size {co_split_size}")


def wgt_memtile_mm2s(dims: ConvDims, col: int, row: int, enable_wgt_channel_sharing: bool, co_split_idx: int = 0) -> str:
    '''WGT memtile mm2s access pattern routing to core at (col, row) for specific Co split'''
    if not enable_wgt_channel_sharing:
        Cob_offset = Co_split_offset(dims, col, row)
        return f'Cob:{Cob_offset}:{Cob_offset + 1} Subv:0:{dims.wgt_L1_size}'
    # Get the base Co offset for this core
    # Cob_offset_base = Co_split_offset(dims, col, row)
    # Get the Co split indices for this column to understand the mapping
    Co_idxs = Co_split_idxs(dims, col)
    unique_co_idxs = sorted(set(Co_idxs))
    log(f"col {col}, Co_idxs: {Co_idxs}, unique_co_idxs: {unique_co_idxs}, co_split_idx: {co_split_idx}")
    # Calculate the actual Cob offset based on co_split_idx
    if co_split_idx < len(unique_co_idxs):
        # Map co_split_idx to the actual Co index for this column
        # target_co_idx = co_split_idx
        # Cob_offset = target_co_idx
        return f'Cob:{co_split_idx}:{co_split_idx + 1} Subv:0:{dims.wgt_L1_size}'
    raise ValueError(f"co_split_idx {co_split_idx} out of range for unique Co indices {unique_co_idxs}")


def wgt_shim_mm2s(dims: ConvDims, col: int, enable_wgt_channel_sharing: bool, co_split_idx: int = 0) -> str:
    '''WGT shim mm2s access pattern for a specific Co split index within a column'''
    # Get the Co split indices for this column
    Co_idxs = Co_split_idxs(dims, col)
    # Calculate the specific Cob range for this co_split_idx
    if not enable_wgt_channel_sharing:
        Cob_start = min(Co_idxs)
        Cob_stop = max(Co_idxs) + 1
        return (
            f'Cob:0:{dims.Co_loop * dims.Co_split}:{dims.Co_split} '
            f'Cib:0:{dims.Ci_loop} '
            f'Cob:{Cob_start}:{Cob_stop} '
            f'Subv:0:{dims.wgt_L1_size}'
        )
    if co_split_idx < len(Co_idxs):
        # Use the actual Co index for this split
        target_co_idx = sorted(set(Co_idxs))[co_split_idx]
        Cob_start = target_co_idx
        Cob_stop = target_co_idx + 1
        return (
            f'Cob:0:{dims.Co_loop * dims.Co_split}:{dims.Co_split} '
            f'Cib:0:{dims.Ci_loop} '
            f'Cob:{Cob_start}:{Cob_stop} '
            f'Subv:0:{dims.wgt_L1_size}'
        )
    raise ValueError(f"co_split_idx {co_split_idx} out of range for Co indices {Co_idxs}")

#
# OFM Memory Formats
#


def ofm_core_memory(dims: ConvDims) -> str:
    '''Define OFM L1 data order and shape'''
    assert dims.Cos == dims.Co_gran
    return f'Yo:{dims.Yos} Xo:{dims.Xos} Co:{dims.Cos}'


def ofm_memtile_memory(dims: ConvDims) -> str:
    '''Define OFM L2 data order and shape'''
    return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'


def ofm_shim_memory(shape: ConvShape) -> str:
    '''Define OFM DDR data order and shape'''
    Yo, Xo, Co = shape.ofm
    return f'Yo:{Yo} Xo:{Xo} Co:{Co}'


#
# OFM Access Patterns
#


def ofm_core_mm2s(dims: ConvDims, col: int, row: int) -> str:
    '''OFM core mm2s access pattern for tile at (col, row)'''
    # NOTE: We require the OFM access pattern to be repeatable
    # so it's safe to use start/stop from iteration 0.
    Yo_start, Yo_stop, _ = Yo_slice(dims, col, row, 0)
    Xo_start, Xo_stop, _ = Xo_slice(dims, col, row, 0)
    Co_start, Co_stop, _ = Co_slice(dims, col, row, 0)
    Yos = Yo_stop - Yo_start
    Xos = Xo_stop - Xo_start
    Cos = Co_stop - Co_start
    return f'Yo:0:{Yos} Xo:0:{Xos} Co:0:{Cos}'


def ofm_memtile_s2mm(dims: ConvDims, col: int, row: int) -> str:
    '''OFM memtile s2mm access pattern for channel routing from core at (col, row)'''
    Yo_start, Yo_stop, Yo_stride = Yo_slice(dims, col, row, 0)
    Xo_start, Xo_stop, Xo_stride = Xo_slice(dims, col, row, 0)
    Co_start, Co_stop, Co_stride = Co_slice(dims, col, row, 0)
    return (
        f'Yo:0:{dims.Y_loop * Yo_stride}:{Yo_stride} '
        f'Xo:0:{dims.X_loop * Xo_stride}:{Xo_stride} '
        f'Co:0:{dims.Co_loop * Co_stride}:{Co_stride} '
        f'Yo:{Yo_start}:{Yo_stop} Xo:{Xo_start}:{Xo_stop} Co:{Co_start}:{Co_stop}'
    )


def ofm_memtile_mm2s(shape: ConvShape) -> str:
    '''OFM memtile mm2s access pattern'''
    # NOTE: this only runs on one column to spill the full OFM buffer
    Yo, Xo, Co = shape.ofm
    return f'Yo:0:{Yo} Xo:0:{Xo} Co:0:{Co}'


def ofm_shim_s2mm(shape: ConvShape) -> str:
    '''OFM shim s2mm access pattern'''
    # NOTE: this only runs on one column to spill the full OFM buffer
    Yo, Xo, Co = shape.ofm
    return f'Yo:0:{Yo} Xo:0:{Xo} Co:0:{Co}'


def generate_conv_core_instrs(
    dims: ConvDims,
    mapping: ConvMapping,
    col: int,
    row: int,
    full_iters: bool = True,
) -> list:
    '''Generate conv core instructions for core (col, row)'''
    ifm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.S2MM, ifm_core_channel(dims)),
        mapping.ifm_L1_ping_addr, mapping.ifm_L1_pong_addr,
        ifm_core_memory(dims),
        ifm_core_s2mm(dims, col, row),
    )
    wgt_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.S2MM, wgt_core_channel(dims)),
        mapping.wgt_L1_ping_addr, mapping.wgt_L1_pong_addr,
        wgt_core_memory(dims),
        wgt_core_s2mm(dims),
    )
    ofm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.MM2S, 0),
        mapping.ofm_L1_ping_addr, mapping.ofm_L1_pong_addr,
        ofm_core_memory(dims),
        ofm_core_mm2s(dims, col, row),
    )
    core_instrs = []
    if full_iters:
        core_instrs += [
            ifm_config,
            wgt_config,
            ofm_config,
        ] + [
            outer_instr for _ in range(dims.Y_loop * dims.X_loop * dims.Co_loop) for outer_instr in [
                inner_instr for _ in range(dims.Ci_loop - 1) for inner_instr in [

                            AcqBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                            AcqBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                            RelBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                            RelBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                        ]
                ] + [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
                        RelBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                        RelBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
                    ]
        ] + [
            CallKernel('run_conv_noqdq_a8w8',
                       generate_conv_noqdq_a8w8_params(dims.Ci,
                                                       dims.Yos, dims.Xos, dims.Cos,
                                                       dims.Yis, dims.Xis, dims.Cis,
                                                       dims.Ky, dims.Kx,
                                                       dims.Sy, dims.Sx,
                                                       dims.Y_loop, dims.X_loop, dims.Co_loop, dims.Ci_loop,
                                                       ifm_core_channel(dims),
                                                       full_iters)),
        ]
    else:
        core_instrs += [
            ifm_config,
            wgt_config,
            ofm_config,
        ] + [
            Loop(dims.Y_loop, [
                Loop(dims.X_loop, [
                    CallKernel('run_conv_noqdq_a8w8',
                               generate_conv_noqdq_a8w8_params(dims.Ci,
                                                               dims.Yos, dims.Xos, dims.Cos,
                                                               dims.Yis, dims.Xis, dims.Cis,
                                                               dims.Ky, dims.Kx,
                                                               dims.Sy, dims.Sx,
                                                               dims.Y_loop, dims.X_loop, dims.Co_loop, dims.Ci_loop,
                                                               ifm_core_channel(dims),
                                                               full_iters)),
                    Loop(dims.Ci_loop - 1, [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                        RelBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                        RelBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                    ]),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
                    RelBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                    RelBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                    RelBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
                ]),
            ]),
        ]

    return core_instrs


def generate_memtile_ifm_mm2s(
    dims: ConvDims,
    col: int,
    row: int,
    channel_id: int,
    name: str = '',
) -> list[TransferParams]:
    '''Generate IFM memtile mm2s access patterns for core (col, row)'''
    packed_transfers = []

    # Calculate total number of phases: 2 setup + Y_loop*X_loop iterations + 2 teardown
    total_phases = 2 + (dims.Y_loop * dims.X_loop) + 2

    for chain_idx in range(len(ifm_memtile_mm2s(dims, 0, 0, 0, 0))):
        # Create the two lists with full phase length
        memory_formats = []
        access_patterns = []
        # Initialize both lists with length (2 + Y_loop * X_loop + 2)
        for phase_idx in range(total_phases):
            # Memory format is the same for all phases
            memory_formats.append(ifm_memtile_memory(dims))
            # Default access pattern for all phases (will be overwritten for valid iterations)
            access_patterns.append("Ci:0:0")
        # Fill in the actual access patterns for Y and X loop iterations only
        # Skip the first 2 setup phases and last 2 teardown phases
        for i in range(dims.Y_loop):
            for j in range(dims.X_loop):
                iteration_idx = (i * dims.X_loop) + j
                phase_idx = 2 + iteration_idx  # Offset by 2 for setup phases
                access_patterns[phase_idx] = ifm_memtile_mm2s(dims, col, row, i, j)[chain_idx]
        # Build transfer configurations using the two lists
        transfer_configs = []
        for phase_idx in range(total_phases):
            config = (
                phase_idx,  # Phase index (0 to total_phases-1)
                memory_formats[phase_idx],
                access_patterns[phase_idx]
            )
            transfer_configs.append(config)
        packed_transfer = pack_reconfig_transfers(
            memtile_dma(col, DmaDir.MM2S, channel_id),
            memory_formats,
            access_patterns,
            bits_per_elem=dims.ifm_bits,
            name=name,
        )
        packed_transfers.append(packed_transfer)

    return packed_transfers


# NOTE: Mypy doesn't correctly infer types with the abstract base construct
# used by the core instruction list, so we disable type checking locally here
@no_type_check
def compile_L2_dataflow(schedule_input: ScheduleInputs) -> tuple:
    '''Compile dataflow with given mapping information'''
    shape: ConvShape = schedule_input.shape
    mapping: ConvMapping = schedule_input.mapping
    L2_alloc: L2Alloc = schedule_input.L2_alloc
    L3_alloc: L3Alloc = schedule_input.L3_alloc
    # NOTE: Check if L3 allocation is provided
    enable_over_compute = shape.enable_over_compute
    conv_shim = L3Alloc_to_Shim(L3_alloc)
    log(f"Shim Allocator: {conv_shim}")

    aie_cols = 3
    aie_rows = 4
    enable_wgt_channel_sharing = True
    log(f"enable_wgt_channel_sharing: {enable_wgt_channel_sharing}")
    if enable_over_compute:
        Yo_pad, Xo_pad, Co_pad = mapping.ofm_pad
        Yi_pad = conv_input(Yo_pad, shape.kernel[0], shape.stride[0])
        Xi_pad = conv_input(Xo_pad, shape.kernel[1], shape.stride[1])
        padded_shape = ConvShape(
            ifm=(Yi_pad, Xi_pad, shape.ifm[2]),
            ofm=(Yo_pad, Xo_pad, Co_pad),
            kernel=shape.kernel,
            stride=shape.stride,
            padding=shape.padding,
            vector_coeff=shape.vector_coeff,
            ifm_bits=shape.ifm_bits,
            ofm_bits=shape.ofm_bits,
            linear_op_type=shape.linear_op_type,
            enable_over_compute=shape.enable_over_compute,
            wgt_bits=shape.wgt_bits,
            bias_bits=shape.bias_bits,
            sign_A=shape.sign_A,
            sign_W=shape.sign_W,
            sign_O=shape.sign_O,
            group=shape.group,
        )
    else:
        padded_shape = ConvShape(
            ifm=shape.ifm,
            ofm=shape.ofm,
            kernel=shape.kernel,
            stride=shape.stride,
            padding=shape.padding,
            vector_coeff=shape.vector_coeff,
            ifm_bits=shape.ifm_bits,
            ofm_bits=shape.ofm_bits,
            linear_op_type=shape.linear_op_type,
            enable_over_compute=shape.enable_over_compute,
            wgt_bits=shape.wgt_bits,
            bias_bits=shape.bias_bits,
            sign_A=shape.sign_A,
            sign_W=shape.sign_W,
            sign_O=shape.sign_O,
            group=shape.group,
        )

    dims = ConvDims(padded_shape, mapping, enable_over_compute=enable_over_compute)
    log(f"ConvMapping: {mapping}")
    log(f"ConvDims: {dims}")
    if dims.Co_split in [2, 4, 6]:
        # NOTE: for these Co_split cases, each column will fetch multiple subvols along Cout dimension.
        # We can parallelize the shim -> memtile weight transfers across multiple streams.
        log("enable Multiple streams for SHIM-WGT")

    wgt_shim_size = compute_buffer_size(wgt_shim_memory(dims))
    prm_shim_size = compute_buffer_size(prm_shim_memory())
    log(f"wgt_shim_size: {wgt_shim_size}")
    log(f"prm_shim_size: {prm_shim_size}")

    ifm_L2_size = compute_buffer_size(ifm_memtile_memory(dims), dims.ifm_bits)
    wgt_L2_size = compute_buffer_size(wgt_memtile_memory(dims))
    ofm_L2_size = compute_buffer_size(ofm_memtile_memory(dims), dims.ofm_bits)
    prm_L2_size = compute_buffer_size(prm_memtile_memory())
    ifm_L2_tile, ifm_L2_addr = L2_alloc.ifm_L2_loc
    ofm_L2_tile, ofm_L2_addr = L2_alloc.ofm_L2_loc
    log(f"wgt_L2_size: {wgt_L2_size}")
    log(f"ifm_L2_tile: {ifm_L2_tile}, ifm_L2_addr: {ifm_L2_addr}, ifm_L2_size: {ifm_L2_size}")
    log(f"ofm_L2_tile: {ofm_L2_tile}, ofm_L2_addr: {ofm_L2_addr}, ofm_L2_size: {ofm_L2_size}")
    wgt_L2_alloc_tiles = [entry[0] for entry in L2_alloc.wgt_l2_loc]
    log(f"wgt_L2_alloc_tiles: {[tile.col for tile in wgt_L2_alloc_tiles]}")
    wgt_L2_ping_addrs = [entry[1] for entry in L2_alloc.wgt_l2_loc]
    log(f"wgt_ping_addrs: {wgt_L2_ping_addrs}")
    wgt_L2_pong_addrs = [entry[2] for entry in L2_alloc.wgt_l2_loc]
    log(f"wgt_L2_pong_addrs: {wgt_L2_pong_addrs}")
    prm_L2_alloc_tiles = [entry[0] for entry in L2_alloc.prm_l2_loc]
    prm_L2_addrs = [entry[1] for entry in L2_alloc.prm_l2_loc]
    log(f"prm_L2_alloc_tiles: {[tile.col for tile in prm_L2_alloc_tiles]}")
    log(f"prm_L2_addr: {prm_L2_addrs}")
    log(f"enable_ifm_fill: {L2_alloc.enable_ifm_fill}")
    log(f"enable_ofm_spill: {L2_alloc.enable_ofm_spill}")

    overlay_shape = OverlayShape(aie_cols, aie_rows)

    dma_connections = overlay_3x4_dma_connections()

    core_instrs = {}
    for col in range(aie_cols):
        for row in range(aie_rows):
            core_instrs[core_tile(col, row)] = generate_conv_core_instrs(
                dims, mapping, col, row, full_iters=True,
            )

    memtile_transfers = [
        DataTransfer(
            [0, 0] + [1 if (i == 0) and (j == 0) else 0
                      for i in range(dims.Y_loop) for j in range(dims.X_loop)] + [0, 0],
            prm_L2_alloc_tiles[col], [prm_L2_addrs[col]], prm_L2_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                prm_memtile_memory(),
                prm_memtile_s2mm(),
                name=f"lp_col_{col}"
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                prm_memtile_memory(),
                prm_memtile_mm2s(row),
                name=f"lp_col_{col}"
            ) for row in range(aie_rows)],
            sync_strategy=SyncStrategy.Serial_M_to_N,
        ) for col in range(aie_cols)
    ] + (([
        generate_memtile_data_transfer(
            [1, 0] + [0 for _ in range(dims.Y_loop) for _ in range(dims.X_loop)] + [0, 0],
            AieDma(memory_tile(1), DmaChannel(DmaDir.S2MM, overlay_3x4_F_ids()[0])),
            ifm_L2_tile, ifm_L2_addr,
            ifm_memtile_memory(dims),
            ifm_memtile_s2mm(shape),
            dims.ifm_bits,
            name="ifm_fill_1"
        )
    ] + [
        # NOTE: This is a dummy transfer to create a sync barrier
        # between IFM s2mm and mm2s for the Shim -> Mem fill transfer
        DataTransfer(
            [0, 1] + [0 for _ in range(dims.Y_loop) for _ in range(dims.X_loop)] + [0, 0],
            ifm_L2_tile, [ifm_L2_addr], 0,
            [TransferParams(
                AieDma(memory_tile(1), DmaChannel(DmaDir.S2MM, overlay_3x4_F_ids()[0])), 0, name=f"ifm_fill_sync_{col}")],
            [],
            sync_strategy=SyncStrategy.Remote_Barrier
        )
    ]) if L2_alloc.enable_ifm_fill else []) + [
        DataTransfer(
            [0, 0] + [dims.Co_loop
                      for _ in range(dims.Y_loop) for _ in range(dims.X_loop)] + [0, 0],
            ifm_L2_tile, [ifm_L2_addr], ifm_L2_size,
            [],
            generate_memtile_ifm_mm2s(dims, col, row, channel_id, name=f"ifm_core_{col}"),
        ) for col in range(dims.aie_cols)
        for row, channel_id in ifm_memtile_channels(dims, col)
    ] + [
        generate_memtile_data_transfer(
            [0, 0] + [1 if (i == 0) and (j == 0) else 0
                      for i in range(dims.Y_loop) for j in range(dims.X_loop)] + [0, 0],
            memtile_dma(col, DmaDir.S2MM, channel_id),
            ofm_L2_tile, ofm_L2_addr,
            ofm_memtile_memory(dims),
            ofm_memtile_s2mm(dims, col, row),
            dims.ofm_bits,
            name=f"ofm_core_{col}"
        ) for col in range(dims.aie_cols)
        for row, channel_id in ofm_memtile_channels()
    ] + (([
        # NOTE: This is a dummy transfer to create a sync barrier
        # between OFM s2mm and mm2s for the Mem -> Shim spill transfer
        DataTransfer(
            [0, 0] + [0 for _ in range(dims.Y_loop) for _ in range(dims.X_loop)] + [1, 0],
            ofm_L2_tile, [ofm_L2_addr], ofm_L2_size,
            [TransferParams(memtile_dma(col, DmaDir.S2MM, channel_id), 0, name=f"ofm_spill_sync_{col}")
             for _, channel_id in ofm_memtile_channels()],
            [],
            sync_strategy=SyncStrategy.Remote_Barrier if col == 0 else SyncStrategy.Default
        ) for col in range(dims.aie_cols)
    ] + [
        generate_memtile_data_transfer(
            [0, 0] + [0 for _ in range(dims.Y_loop) for _ in range(dims.X_loop)] + [0, 1],
            AieDma(memory_tile(1), DmaChannel(DmaDir.MM2S, overlay_3x4_S_ids(1)[0])),
            ofm_L2_tile, ofm_L2_addr,
            ofm_memtile_memory(dims),
            ofm_memtile_mm2s(shape),
            dims.ofm_bits,
            name="ofm_spill_1"
        )
    ]) if L2_alloc.enable_ofm_spill else [])
    wgt_memtile_transfers = []
    for col in range(aie_cols):
        if not enable_wgt_channel_sharing:
            wgt_memtile_transfers.append(
                DataTransfer(
                    [0, 0] + [(dims.Y_loop * dims.X_loop * dims.Co_loop * dims.Ci_loop)
                              if (i == 0) and (j == 0) else 0
                              for i in range(dims.Y_loop) for j in range(dims.X_loop)] + [0, 0],
                    wgt_L2_alloc_tiles[col], [wgt_L2_ping_addrs[col], wgt_L2_pong_addrs[col]], wgt_L2_size,
                    [generate_transfer_params(
                        memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]),
                        wgt_memtile_memory(dims),
                        wgt_memtile_s2mm(dims, enable_wgt_channel_sharing),
                    )],
                    [generate_transfer_params(
                        memtile_dma(col, DmaDir.MM2S, channel_id),
                        wgt_memtile_memory(dims),
                        wgt_memtile_mm2s(dims, col, row, enable_wgt_channel_sharing),
                    ) for row, channel_id in wgt_memtile_channels(dims, col, enable_wgt_channel_sharing)],
                    sync_strategy=SyncStrategy.Parallel_1_to_N,
                )
            )
        else:
            co_split_size = Co_split_size(dims)
            available_f_ids = overlay_3x4_F_ids()  # [0, 1, 2, 3] typically
            log(f"col {col}, available_f_ids: {available_f_ids}, co_split_size: {co_split_size}")
            # For each Co split block in this column
            for co_split_idx in range(co_split_size):
                # Special handling for Co_split = 2
                if dims.Co_split == 2:
                    # Use F_ids 0 and 2 for Co_split = 2
                    f_id = available_f_ids[0] if co_split_idx == 0 else available_f_ids[2]
                else:
                    # Default behavior for other Co_split values
                    if co_split_idx < len(available_f_ids):
                        f_id = available_f_ids[co_split_idx]
                    else:
                        log(f"Warning: co_split_idx {co_split_idx} exceeds available F_ids")
                        f_id = available_f_ids[0]  # Use first F_id as fallback
                wgt_memtile_transfers.append(
                    DataTransfer(
                        [0, 0] + [(dims.Y_loop * dims.X_loop * dims.Co_loop * dims.Ci_loop)
                                  if (i == 0) and (j == 0) else 0
                                  for i in range(dims.Y_loop) for j in range(dims.X_loop)] + [0, 0],
                        wgt_L2_alloc_tiles[col],
                        [wgt_L2_ping_addrs[col], wgt_L2_pong_addrs[col]],
                        wgt_L2_size,
                        [generate_transfer_params(
                            memtile_dma(col, DmaDir.S2MM, f_id),  # Use specific F_id
                            wgt_memtile_memory(dims),
                            wgt_memtile_s2mm(dims, enable_wgt_channel_sharing, co_split_idx),
                            name=f"wgt_transfer_{col}"
                        )],
                        [generate_transfer_params(
                            memtile_dma(col, DmaDir.MM2S, channel_id),
                            wgt_memtile_memory(dims),
                            wgt_memtile_mm2s(dims, col, row, enable_wgt_channel_sharing, co_split_idx),
                            name=f"wgt_transfer_{col}"
                        ) for row, channel_id in wgt_memtile_channels(dims, col, enable_wgt_channel_sharing, co_split_idx)],
                        sync_strategy=SyncStrategy.Parallel_1_to_N,
                    )
                )
                # NOTE: Added to resolve synchronization issues in L2 Fusion
                wgt_memtile_transfers.append(
                    DataTransfer(
                        [0, 0] + [0 for i in range(dims.Y_loop) for j in range(dims.X_loop)] + [1, 0],
                        wgt_L2_alloc_tiles[col],
                        [wgt_L2_ping_addrs[col], wgt_L2_pong_addrs[col]],
                        wgt_L2_size,
                        [],
                        [generate_transfer_params(
                            memtile_dma(col, DmaDir.MM2S, channel_id),
                            'ELEMES:2',
                            'ELEMES:0:0',
                            name=f"wgt_transfer_sync_{col}"
                        ) for row, channel_id in wgt_memtile_channels(dims, col, enable_wgt_channel_sharing, co_split_idx)],
                    )
                )
    memtile_transfers += wgt_memtile_transfers

    shim_transfers = [
        generate_shim_data_transfer(
            [0, 0] + [1 if (i == 0) and (j == 0) else 0
                      for i in range(dims.Y_loop) for j in range(dims.X_loop)] + [0, 0],
            shim_dma(col, DmaDir.MM2S, 0), conv_shim.prm_xrt_idx,
            prm_shim_memory(),
            prm_shim_mm2s(col),
            buffer_offset=conv_shim.prm_xrt_offset,
            name=f"shim_prm_{col}"
        ) for col in range(aie_cols)
    ] + (([
        generate_shim_data_transfer(
            [1, 0] + [0 for _ in range(dims.Y_loop) for _ in range(dims.X_loop)] + [0, 0],
            shim_dma(1, DmaDir.MM2S, 0), conv_shim.ifm_xrt_idx,
            ifm_shim_memory(shape),
            ifm_shim_mm2s(shape),
            dims.ifm_bits,
            buffer_offset=conv_shim.ifm_xrt_offset,
            name=f"shim_ifm_{col}"
        )
    ]) if L2_alloc.enable_ifm_fill else []) + (([
        generate_shim_data_transfer(
            [0, 0] + [0 for _ in range(dims.Y_loop) for _ in range(dims.X_loop)] + [0, 1],
            shim_dma(1, DmaDir.S2MM, 0), conv_shim.ofm_xrt_idx,
            ofm_shim_memory(shape),
            ofm_shim_s2mm(shape),
            dims.ofm_bits,
            buffer_offset=conv_shim.ofm_xrt_offset,
            name=f"shim_ofm_{col}"
        )
    ]) if L2_alloc.enable_ofm_spill else [])
    wgt_shim_transfers = []
    for col in range(aie_cols):
        if not enable_wgt_channel_sharing:
            wgt_shim_transfers.append(
                generate_shim_data_transfer(
                    [0, 0] + [(dims.Y_loop * dims.X_loop) if (i == 0) and (j == 0) else 0
                              for i in range(dims.Y_loop) for j in range(dims.X_loop)] + [0, 0],
                    shim_dma(col, DmaDir.MM2S, 1), conv_shim.wgt_xrt_idx,
                    wgt_shim_memory(dims),
                    wgt_shim_mm2s(dims, col, enable_wgt_channel_sharing),
                    buffer_offset=conv_shim.wgt_xrt_offset,
                )
            )
        else:
            co_split_size = Co_split_size(dims)
            # For each Co split block in this column
            for co_split_idx in range(co_split_size):
                # Special handling for Co_split = 2
                if dims.Co_split == 2:
                    # Use channels 0 and 2 for Co_split = 2
                    channel_id = 0 if co_split_idx == 0 else 2
                else:
                    # Default behavior for other Co_split values
                    channel_id = co_split_idx
                wgt_shim_transfers.append(
                    generate_shim_data_transfer(
                        [0, 0] + [(dims.Y_loop * dims.X_loop) if (i == 0) and (j == 0) else 0
                                  for i in range(dims.Y_loop) for j in range(dims.X_loop)] + [0, 0],
                        shim_dma(col, DmaDir.MM2S, channel_id), conv_shim.wgt_xrt_idx,  # Use specific channel
                        wgt_shim_memory(dims),
                        wgt_shim_mm2s(dims, col, enable_wgt_channel_sharing, co_split_idx),
                        buffer_offset=conv_shim.wgt_xrt_offset,
                        name=f"shim_wgt_{col}"
                    )
                )
    shim_transfers += wgt_shim_transfers

    run_layer_compilation(
        overlay_shape,
        schedule_input.kernel_names,
        schedule_input.kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        dma_connections,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        param_channel_id=overlay_3x4_param_channel_id(),
        back_end=schedule_input.backend,
        layer_file=schedule_input.layer_file_name,
        dma_padding_map=schedule_input.dma_pad,
    )

    shim_prm_offset_next_layer = conv_shim.prm_xrt_offset + prm_shim_size
    shim_wgt_offset_next_layer = conv_shim.wgt_xrt_offset + wgt_shim_size
    log("shim_prm_offset_next_layer", shim_prm_offset_next_layer)
    log("shim_wgt_offset_next_layer", shim_wgt_offset_next_layer)

    return shim_prm_offset_next_layer, shim_wgt_offset_next_layer
