'''
Map conv shapes to the AIE-4 dataflow architecture.
External facing functions are documented below.

    compile_L3_dataflow - given a shape and mapping, compile the data movement
    code to orchestrate the DMA heirarchy
'''
from typing import no_type_check
from typing import List

from dmacompiler import (
    DevGen, set_dev_gen, SyncStrategy,
    OverlayShape, DataTransfer,
    DmaDir,
    memtile_dma, shim_dma, core_tile, memory_tile,
    shim_tile,
    run_layer_compilation,
    generate_transfer_params,
    generate_shim_data_transfer,
    compute_buffer_size,
    pack_reconfig_transfers,
    compute_reuse_chain_length,
)
from utils.utils_common import (
    overlay_3x4_core_stack_addr,
    log,
    split_to_mode,
)

from scheduler.common import (
    overlay_3x4_dma_connections,
    overlay_3x4_param_channel_id,
    overlay_3x4_A_ids,
    overlay_3x4_F_ids,
    overlay_3x4_O_ids,
    overlay_3x4_S_ids,
    overlay_3x4_B_ids,
    prm_memtile_memory,
    prm_shim_memory,
    prm_memtile_mm2s,
    prm_memtile_s2mm,
    prm_shim_mm2s,
    unicast_channels,
    broadcast_channels,
    L3Alloc,
    LinearOpType,
    ShimAllocator,
    L3Alloc_to_Shim
)

from scheduler.conv.conv_config_builders import (
    ConvShape,
    ConvMapping,
    ConvDims,
)

from scheduler.conv.conv_common import (
    ConvDataFlowRepeats,
    ConvL2MemoryAllocator,
    convL2Memory,
    ifm_shim_memory,
    map_shim_ch_memtile_ch,
    wgt_shim_memory,
    ofm_shim_memory,
    Yi_slice,
    Xi_slice,
    Co_split_idxs,
    Co_split_size,
    Co_split_offset,
    Yo_split_size,
    Xo_split_size,
    Yo_split_offset,
    Xo_split_offset,
    generate_conv_core_instrs,
    Yi_slice_per_column,
    Yo_slice_per_column,
    Xi_slice_per_column,
    Xo_slice_per_column,
    Co_slice_per_column,
)

from buildscripts.common import ScheduleInputs

set_dev_gen(DevGen.Aie4)


def generate_conv_repeats(
    dims: ConvDims,
    ifm_L2_strategy: str = 'pin',
    wgt_L2_strategy: str = 'stream',
    ofm_L2_strategy: str = 'stream',
    is_Co_depad: bool = False,
    is_wgt_repeat_high: bool = False,
    gemm_mode: str = 'wgt'
) -> ConvDataFlowRepeats:
    """Generates repeat counts for convolution phases based on data transfer patterns.

    This function creates a choreographed data movement schedule to maximize
    computational efficiency while respecting hardware constraints. It determines
    the number of repetitions for different loops in the convolution schedule
    by selecting appropriate patterns for handling Input Feature Maps (IFM),
    Weights (WGT), and Output Feature Maps (OFM) across L2 and L3 memory hierarchies.

    The patterns address various needs:
    - Synchronization: Coordinates data transfers between L2 and L3 memories.
    - Resource Management: Manages memory constraints, bandwidth, and buffer sizes.
    - Edge Cases: Supports padding, depadding, and irregular tensor dimensions.
    - Performance Optimization: Maximizes data reuse and minimizes memory traffic.

    Key Pattern Categories:
    - OFM Patterns: Control how output data is written (e.g., standard, depadding).
    - WGT Patterns: Manage weight loading (e.g., stream, reload per phase).
    - IFM Patterns: Handle input data loading (e.g., one tile per phase, streaming).
    - L3 Patterns: Control data transfers between DDR and L2 memory.
    """
    # Validate strategies
    valid_strategies = {
        'ifm_L2': ['full', 'pin', 'stream'],
        'wgt_L2': ['stream', 'pin'],
        'ofm_L2': ['stream']
    }
    if ifm_L2_strategy not in valid_strategies['ifm_L2']:
        raise ValueError(f"Unknown IFM L2 strategy: {ifm_L2_strategy}. "
                         f"Supported strategies are {valid_strategies['ifm_L2']}")
    if wgt_L2_strategy not in valid_strategies['wgt_L2']:
        raise ValueError(f"Unknown WGT L2 strategy: {wgt_L2_strategy}. "
                         f"Supported strategies are {valid_strategies['wgt_L2']}")
    if ofm_L2_strategy not in valid_strategies['ofm_L2']:
        raise ValueError(f"Unknown OFM L2 strategy: {ofm_L2_strategy}. "
                         f"Supported strategies are {valid_strategies['ofm_L2']}")
    ifm_L2_s2mm_repeats = {}
    ifm_L2_mm2s_repeats = {}
    ifm_L3_mm2s_repeats = {}
    ofm_L2_s2mm_repeats = {}
    ofm_L2_mm2s_repeats = {}
    ofm_L3_s2mm_repeats = {}
    wgt_L2_s2mm_repeats = {}
    wgt_L2_mm2s_repeats = {}
    wgt_L3_mm2s_repeats = {}
    base_phases = dims.X_loop * dims.Y_loop
    log(f"INFO: is_Co_depad: {is_Co_depad}")
    log(f"INFO: is_wgt_repeat_high: {is_wgt_repeat_high}")

    ci_pad = dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran

    def create_repeat_dict(pattern: list) -> dict:
        """Create a dictionary mapping each column to the given pattern."""
        return {col: pattern for col in range(dims.aie_cols)}

    # OFM repeat pattern generators
    l2_ofm_pattern_generators = {
        'default': lambda: [dims.Co_loop] * base_phases,
        'wgt_repeat_count_overflow': lambda: ([dims.Co_loop] + [0] * (dims.Co_loop-1)) * base_phases,
        'co_depad': lambda: [dims.Co_loop-1, 1] * base_phases,
        'co_depad_wgt_repeat_count_overflow': lambda: ([dims.Co_loop-1, 1] + [0] * (dims.Co_loop-2)) * base_phases,
        'ci_pad': lambda: [0, 1] * dims.Co_loop * base_phases,  # or [dims.Co_loop, 0] * base_phases
    }
    # WGT repeat pattern generators
    l2_wgt_pattern_generators = {
        'default': lambda: [dims.Ci_loop * dims.Co_loop] * base_phases
        if dims.Yi % dims.Y_split != 0 and gemm_mode == 'act' else [dims.Ci_loop * dims.Co_loop * base_phases] + [0] * (base_phases - 1),
        'wgt_repeat_count_overflow': lambda: [dims.Ci_loop * dims.Co_loop] * base_phases,
        # co_depad differs act vs wgt added below
        'co_depad_wgt_repeat_count_overflow': lambda: ([dims.Ci_loop * dims.Co_loop] + [0]) * base_phases,
        'phase_ci_co_xy': lambda: [dims.Ci_loop] * dims.Co_loop * base_phases,
        'ci_pad': lambda: [dims.Ci_loop-1, 1] * dims.Co_loop * base_phases,
        'pin': lambda: [dims.Co_loop] * base_phases if not is_Co_depad else [dims.Co_loop - 1, 1] * base_phases,
    }
    # IFM repeat pattern generators
    l2_ifm_pattern_generators = {
        'default': lambda: [1] * base_phases,
        'wgt_repeat_count_overflow': lambda: ([1] + [0] * dims.Co_loop) * base_phases,
        'co_depad': lambda: ([1] + [0]) * base_phases,
        # co_depad_wgt_repeat_count_overflow is not needed?
        # 'co_depad_wgt_repeat_count_overflow': lambda: ([1, 0] * (dims.Co_loop)) * base_phases,
        'act_ci_pad': lambda: ([1, 0] + [0, 0] * (dims.Co_loop-1)) * base_phases,
        'stream': lambda: ([dims.Ci_loop] if dims.Ci % dims.Ci_loop == 0
                           else [dims.Ci_loop-1, 1]) * base_phases
    }
    # L3 repeat pattern generators
    l3_pattern_generators = {
        'ifm_default': lambda: [1] * base_phases,
        'ifm_co_depad': lambda: [1, 0] * base_phases,
        'ifm_ci_pad': lambda: ([1, 0] + [0, 0] * (dims.Co_loop-1)) * base_phases,
        'wgt_default': lambda: [1]*base_phases if (dims.Yi > 1 and gemm_mode == 'act') else [base_phases] + [0] * (base_phases - 1),
        'wgt_ci_pad': lambda: [1, 0] * dims.Co_loop * base_phases,
        'ofm_default': lambda: [1] * base_phases,
        'ofm_co_depad': lambda: [1, 1] * base_phases,
        'ofm_ci_pad': lambda: [0, 1] * dims.Co_loop * base_phases
    }

    if gemm_mode == 'act':
        l2_wgt_pattern_generators.update({
            'co_depad': lambda: [dims.Ci_loop * (dims.Co_loop - 1), dims.Ci_loop] * base_phases,
        })
        l3_pattern_generators.update({
            'wgt_co_depad': lambda: [1, 1] * base_phases if (dims.Yi > 1 and gemm_mode == 'act') else [1, 1] + [0] * (base_phases - 2),
        })
    elif gemm_mode == 'wgt':
        l2_wgt_pattern_generators.update({
            'co_depad': lambda: [(dims.Ci_loop * dims.Co_loop) * base_phases] + [0] * (2 * base_phases - 1),
        })
        l3_pattern_generators.update({
            'wgt_co_depad': lambda: [1, 0] * base_phases,
        })

    # Generate OFM repeats
    if ofm_L2_strategy == 'stream':
        if ci_pad and gemm_mode == 'act':
            pattern = l2_ofm_pattern_generators['ci_pad']()
        elif is_Co_depad and not is_wgt_repeat_high:
            pattern = l2_ofm_pattern_generators['co_depad']()
        elif not is_Co_depad and is_wgt_repeat_high:
            pattern = l2_ofm_pattern_generators['wgt_repeat_count_overflow']()
        elif is_Co_depad and is_wgt_repeat_high:
            pattern = l2_ofm_pattern_generators['co_depad_wgt_repeat_count_overflow']()
        else:
            pattern = l2_ofm_pattern_generators['default']()
        ofm_L2_mm2s_repeats = create_repeat_dict(pattern)
        ofm_L2_s2mm_repeats = create_repeat_dict(pattern)
    # Generate WGT repeats
    if wgt_L2_strategy == 'stream':
        if ci_pad and gemm_mode == 'act':
            pattern = l2_wgt_pattern_generators['ci_pad']()
        elif (dims.Ci_loop * dims.Co_loop * base_phases) < 32768:
            # No need to split the WGT L2 repeats
            if is_Co_depad:
                pattern = l2_wgt_pattern_generators['co_depad']()
            else:
                pattern = l2_wgt_pattern_generators['default']()
        elif dims.Ci_loop * dims.Co_loop < 32768:
            # Split Y and X into different phases
            if is_Co_depad:
                pattern = l2_wgt_pattern_generators['co_depad_wgt_repeat_count_overflow']()
            else:
                pattern = l2_wgt_pattern_generators['wgt_repeat_count_overflow']()
        else:
            pattern = l2_wgt_pattern_generators['phase_ci_co_xy']()
        wgt_L2_mm2s_repeats = create_repeat_dict(pattern)
        wgt_L2_s2mm_repeats = create_repeat_dict(pattern)
    elif wgt_L2_strategy == 'pin':
        # Pin Ci strategy: Ci is pinned, only Co iterations cause new fetches
        pattern = l2_wgt_pattern_generators['pin']()
        wgt_L2_mm2s_repeats = create_repeat_dict(pattern)
        wgt_L2_s2mm_repeats = create_repeat_dict(pattern)
    # Generate IFM repeats
    if ifm_L2_strategy in ['full', 'pin']:
        if ci_pad and gemm_mode == 'act':
            # NOTE: Only the act x act gemm mode, Ci padding results in an additional phase
            # In case of act x wgt, the Ci padding is handled as a chained BD in the same phase.
            pattern = l2_ifm_pattern_generators['act_ci_pad']()
        elif is_wgt_repeat_high:
            pattern = l2_ifm_pattern_generators['wgt_repeat_count_overflow']()
        elif is_Co_depad:
            pattern = l2_ifm_pattern_generators['co_depad']()
        else:
            pattern = l2_ifm_pattern_generators['default']()
        ifm_L2_mm2s_repeats = create_repeat_dict(pattern)
        ifm_L2_s2mm_repeats = create_repeat_dict(pattern)
    elif ifm_L2_strategy == 'stream':
        pattern = l2_ifm_pattern_generators['stream']()
        ifm_L2_mm2s_repeats = create_repeat_dict(pattern)
        ifm_L2_s2mm_repeats = create_repeat_dict(pattern)
    # Generate L3 repeats
    if ci_pad and gemm_mode == 'act':
        ifm_L3_mm2s_repeats = create_repeat_dict(l3_pattern_generators['ifm_ci_pad']())
        wgt_L3_mm2s_repeats = create_repeat_dict(l3_pattern_generators['wgt_ci_pad']())
        ofm_L3_s2mm_repeats = create_repeat_dict(l3_pattern_generators['ofm_ci_pad']())
    elif wgt_L2_strategy == 'pin' and gemm_mode == 'act':
        # Pin Ci: fetch once per Y/X phase
        ifm_L3_mm2s_repeats = create_repeat_dict(l3_pattern_generators['ifm_default']())
        wgt_L3_mm2s_repeats = create_repeat_dict([1, 0] * base_phases if is_Co_depad else [1] * base_phases)
        ofm_L3_s2mm_repeats = create_repeat_dict(l3_pattern_generators['ofm_default']() if not is_Co_depad else l3_pattern_generators['ofm_co_depad']())
    elif is_Co_depad:
        ifm_L3_mm2s_repeats = create_repeat_dict(l3_pattern_generators['ifm_co_depad']())
        wgt_L3_mm2s_repeats = create_repeat_dict(l3_pattern_generators['wgt_co_depad']())
        ofm_L3_s2mm_repeats = create_repeat_dict(l3_pattern_generators['ofm_co_depad']())
    else:
        ifm_L3_mm2s_repeats = create_repeat_dict(l3_pattern_generators['ifm_default']())
        wgt_L3_mm2s_repeats = create_repeat_dict(l3_pattern_generators['wgt_default']())
        ofm_L3_s2mm_repeats = create_repeat_dict(l3_pattern_generators['ofm_default']())
    # Validate all repeat arrays have the same length

    def validate_repeat_lengths(col: int) -> list[int]:
        """Validate that all repeat arrays for a column have the same length."""
        return [
            len(ifm_L2_mm2s_repeats[col]),
            len(ifm_L2_s2mm_repeats[col]),
            len(ofm_L2_mm2s_repeats[col]),
            len(ofm_L2_s2mm_repeats[col]),
            len(wgt_L2_mm2s_repeats[col]),
            len(wgt_L2_s2mm_repeats[col]),
            len(ifm_L3_mm2s_repeats[col]),
            len(wgt_L3_mm2s_repeats[col]),
            len(ofm_L3_s2mm_repeats[col])
        ]
    for col in range(dims.aie_cols):
        lengths = validate_repeat_lengths(col)
        assert all(length == lengths[0] for length in lengths), \
            f"Column {col} repeat arrays have mismatched lengths: {lengths}"

    return ConvDataFlowRepeats(
        ifm_L2_s2mm_repeats=ifm_L2_s2mm_repeats,
        ifm_L2_mm2s_repeats=ifm_L2_mm2s_repeats,
        ifm_L3_mm2s_repeats=ifm_L3_mm2s_repeats,
        ofm_L2_s2mm_repeats=ofm_L2_s2mm_repeats,
        ofm_L2_mm2s_repeats=ofm_L2_mm2s_repeats,
        ofm_L3_s2mm_repeats=ofm_L3_s2mm_repeats,
        wgt_L2_s2mm_repeats=wgt_L2_s2mm_repeats,
        wgt_L2_mm2s_repeats=wgt_L2_mm2s_repeats,
        wgt_L3_mm2s_repeats=wgt_L3_mm2s_repeats,
    )


def ifm_memtile_channels(dims: ConvDims, col: int) -> list[tuple[int, int]]:
    '''Map spatial split and column to IFM channel allocations (row, id)'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: unicast_channels(),
        1: broadcast_channels(col),
    }
    channels = channel_lookup[mode]
    return channels


def wgt_memtile_channels(dims: ConvDims, col: int, weights_channel_list: List[int] = [1], channel_idx: int = 0) -> list[tuple[int, int]]:
    '''
    Map spatial split and column to WGT channel allocations (row, id)
    Args:
        dims: ConvDims containing the spatial split configuration
        col: Column index
        weights_channel_list: List of shim channels used for weight fetching (e.g., [1, 2])
        channel_idx: Index in weights_channel_list indicating which fetch this is
    Returns:
        List of (row, channel_id) tuples for this fetch
    '''
    mode = split_to_mode(dims)
    # Get all channels based on mode
    if mode == 0:
        all_channels = broadcast_channels(col)
        mode_name = "broadcast"
    elif mode == 1:
        all_channels = unicast_channels()
        mode_name = "unicast"
    else:
        raise ValueError(f"Unknown mode: {mode}. Expected 0 (broadcast) or 1 (unicast).")
    # Get number of channel splits
    num_channel_splits = len(weights_channel_list)
    num_channels = len(all_channels)
    # Validate that channels are evenly divisible by channel splits
    if num_channels % num_channel_splits != 0:
        raise ValueError(
            f"Number of {mode_name} channels ({num_channels}) is not divisible "
            f"by num_channel_splits ({num_channel_splits}). "
            f"Channels must be evenly distributed across splits."
        )
    # Calculate channels per split
    channels_per_split = num_channels // num_channel_splits
    # Validate channel_idx
    if channel_idx < 0 or channel_idx >= num_channel_splits:
        raise ValueError(
            f"Invalid channel_idx {channel_idx}. Must be in range [0, {num_channel_splits - 1}]."
        )
    # Calculate start and end indices for this fetch
    start_idx = channel_idx * channels_per_split
    end_idx = start_idx + channels_per_split
    # Return the slice of channels for this fetch
    channels_for_fetch = all_channels[start_idx:end_idx]
    return channels_for_fetch


def ofm_memtile_channels() -> list[tuple[int, int]]:
    '''Generate OFM channels allocations (row, id)'''
    return list(enumerate(overlay_3x4_O_ids()))


#####################################################
# IFM memory and tiling Formats
#####################################################

def ifm_memtile_memory(dims: ConvDims, col: int, L2_strategy: str = 'pin') -> List[str]:
    '''Define IFM L2 data order and shape'''
    mem_fmt = []
    if L2_strategy == 'full':
        return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'
    if L2_strategy == 'stream':
        return f'Yi:{dims.Yis} Xi:{dims.Xis} Ci:{dims.Cis}'
    if L2_strategy == 'pin':
        for y_iter in range(dims.Y_loop):
            for x_iter in range(dims.X_loop):
                # Here we only need to get the slice for column 0
                # Rest of the columns will consume same or less data
                _, _, Yi_size_col_0 = Yi_slice_per_column(dims, 0, y_iter)
                _, _, Xi_size_col_0 = Xi_slice_per_column(dims, 0, x_iter)
                _, _, Yi_size_col = Yi_slice_per_column(dims, col, y_iter)
                _, _, Xi_size_col = Xi_slice_per_column(dims, col, x_iter)
                Yi_size = Yi_size_col_0 if Yi_size_col == 0 else Yi_size_col
                Xi_size = Xi_size_col_0 if Xi_size_col == 0 else Xi_size_col
                mem_fmt.append(f'Yi:{Yi_size} Xi:{Xi_size} Ci:{dims.Ci}')
        return mem_fmt
    raise ValueError(f"Unknown IFM L2_strategy: {L2_strategy}. "
                     "Supported strategies are 'full', 'pin', 'stream'")


def ifm_memtile_mm2s(
    dims: ConvDims,
    col: int,
    row: int,
    is_Co_depad: bool,
    L2_strategy: str = 'pin',
    gemm_mode: str = 'wgt'
) -> List[List[str]]:
    '''Define IFM L2 MM2S data order and shape'''
    read_fmt_full_iters = []
    read_fmt_partial_iters = []
    if L2_strategy in ['full', 'stream']:
        raise ValueError(f"IFM L2 MM2S is not supported for {L2_strategy} strategy.")
    # Each y_iter, x_iter corresponds to a phase in the transfer
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            # Gather the start index at tensor level for each iteration
            yi_shim_shard_start, _, _ = Yi_slice_per_column(dims, col, y_iter)
            xi_shim_shard_start, _, _ = Xi_slice_per_column(dims, col, x_iter)
            # Get the Yi_start at tensor level for the current column, row and iteration
            Yi_start, _, _ = Yi_slice(dims, col, row, y_iter)
            # Find the relative start and stop indices for the
            # Yi and Xi shard pinned within the column
            # Yis_start is the relative start index of the Yi shard within the column
            Yis_start = 0 if Yi_start >= dims.Yi else Yi_start - yi_shim_shard_start
            Yis_stop = Yis_start + dims.Yis
            # Get the Xi_start at tensor level for the current column, row and iteration
            Xi_start, _, _ = Xi_slice(dims, col, row, x_iter)
            # Find the relative start and stop indices for the
            # Yi and Xi shard pinned within the column
            # Xis_start is the relative start index of the Xi shard within the column
            Xis_start = 0 if Xi_start >= dims.Xi else Xi_start - xi_shim_shard_start
            Xis_stop = Xis_start + dims.Xis
            # Pad the Ci dimension
            # get the neaest multiple of Ci_gran
            Ci_multiple = dims.Ci // dims.Cis
            # if Ci_multiple == 0, then there is no full Ci_gran iteration
            # In that case change Ci_multiple to Cis
            if gemm_mode == 'act':
                if dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran:
                    # If there are ci padding phases,
                    # Co_loop gets peeled out
                    Ci_stop = max((Ci_multiple * dims.Cis), dims.Cis)
                    read_fmt_full_iters.append(
                        f'Ci:{0}:{Ci_stop}:{dims.Ci_gran} '
                        f'Yi:{Yis_start}:{Yis_stop} '
                        f'Xi:{Xis_start}:{Xis_stop} '
                        f'Ci:{0}:{dims.Ci_gran}'
                    )
                    read_fmt_full_iters.append('Ci:0:0')
                    # Append another BD for the residual Ci with padding
                    Ci_start = Ci_multiple * dims.Cis
                    Ci_padded = dims.Cis * dims.Ci_loop
                    # Calculate remaining Ci
                    remain_ci = dims.Ci - Ci_start

                    if remain_ci >= dims.Ci_gran:
                        # Remaining Ci is greater than or equal to Ci_gran
                        read_fmt_partial_iters.append(
                            f'Ci:{Ci_start}:{Ci_padded}:{dims.Ci_gran} '
                            f'Yi:{Yis_start}:{Yis_stop} '
                            f'Xi:{Xis_start}:{Xis_stop} '
                            f'Ci:{0}:{dims.Ci_gran}'
                        )
                    else:
                        # Remaining Ci is less than Ci_gran
                        read_fmt_partial_iters.append(
                            f'Ci:{0}:{dims.Cis}:{dims.Ci_gran} '
                            f'Yi:{Yis_start}:{Yis_stop} '
                            f'Xi:{Xis_start}:{Xis_stop} '
                            f'Ci:{Ci_start}:{Ci_start + dims.Ci_gran}:{remain_ci} '
                            f'Ci:{0}:{remain_ci}'
                        )
                    read_fmt_partial_iters.append('Ci:0:0')
                    # The above formats are for the first Co_iter
                    # For rest if Co_iters, IFM data gets reused.
                    # We need to add dummy accesses for the remaining Co_iters
                    for _ in range(1, dims.Co_loop):
                        read_fmt_full_iters.append('Xi:0:0')
                        read_fmt_full_iters.append('Ci:0:0')
                        read_fmt_partial_iters.append('Xi:0:0')
                        read_fmt_partial_iters.append('Ci:0:0')
                else:   # No Ci padding phases
                    Ci_stop = max((Ci_multiple * dims.Cis), dims.Cis)
                    read_fmt_full_iters.append(
                            f'Ci:{0}:{Ci_stop}:{dims.Ci_gran} '
                            f'Yi:{Yis_start}:{Yis_stop} '
                            f'Xi:{Xis_start}:{Xis_stop} '
                            f'Ci:{0}:{dims.Ci_gran}'
                        )
                    if is_Co_depad:
                        # if there is a Co depad phase added, add a dummy access
                        # pattern alternating with the real fmt
                        read_fmt_full_iters.append('Xi:0:0')
            else:
                Ci_stop = max((Ci_multiple * dims.Cis), dims.Cis)
                read_fmt_full_iters.append(
                    f'Ci:{0}:{Ci_stop}:{dims.Ci_gran} '
                    f'Yi:{Yis_start}:{Yis_stop} '
                    f'Xi:{Xis_start}:{Xis_stop} '
                    f'Ci:{0}:{dims.Ci_gran}'
                )
                if is_Co_depad:
                    # if there is a Co depad phase added, add a dummy access
                    # pattern alternating with the real fmt
                    read_fmt_full_iters.append('Xi:0:0')
                # Get the residual Ci if it exists
                if dims.Ci % dims.Cis != 0 and dims.Ci > Ci_stop:
                    # Append another BD for the residual Ci with padding
                    Ci_start = Ci_multiple * dims.Cis
                    Ci_padded = dims.Cis * dims.Ci_loop
                    # Calculate remaining Ci
                    remain_ci = dims.Ci - Ci_start

                    if remain_ci >= dims.Ci_gran:
                        # Remaining Ci is greater than or equal to Ci_gran
                        read_fmt_partial_iters.append(
                            f'Ci:{Ci_start}:{Ci_padded}:{dims.Ci_gran} '
                            f'Yi:{Yis_start}:{Yis_stop} '
                            f'Xi:{Xis_start}:{Xis_stop} '
                            f'Ci:{0}:{dims.Ci_gran}'
                        )
                    else:
                        # Remaining Ci is less than Ci_gran
                        read_fmt_partial_iters.append(
                            f'Ci:{0}:{dims.Cis}:{dims.Ci_gran} '
                            f'Yi:{Yis_start}:{Yis_stop} '
                            f'Xi:{Xis_start}:{Xis_stop} '
                            f'Ci:{Ci_start}:{Ci_start + dims.Ci_gran}:{remain_ci} '
                            f'Ci:{0}:{remain_ci}'
                        )
                    if is_Co_depad:
                        # if there is a Co depad phase added, add a dummy access
                        # pattern alternating with the real fmt
                        read_fmt_partial_iters.append('Xi:0:0')
    if len(read_fmt_partial_iters) == 0 and len(read_fmt_full_iters) == 0:
        raise ValueError("IFM L2 MM2S read format is empty.")
    log(f"INFO: IFM L2 MM2S read_fmt_full_iters length for column {col}, row {row}: {read_fmt_full_iters}")
    log(f"INFO: IFM L2 MM2S read_fmt_partial_iters length for column {col}, row {row}: {read_fmt_partial_iters}")
    return [read_fmt_full_iters, read_fmt_partial_iters]


def ifm_memtile_s2mm(dims: ConvDims, col: int, is_Co_depad: bool, L2_strategy: str = 'pin', gemm_mode: str = 'wgt') -> List[str]:
    '''Define IFM L2 S2MM data order and shape'''
    if L2_strategy == 'full':
        return [f'Yi:0:{dims.Yi} Xi:0:{dims.Xi} Ci:0:{dims.Ci}']
    if L2_strategy == 'stream':
        return [f'Yi:0:{dims.Yis} Xi:0:{dims.Xis} Ci:0:{dims.Cis}']
    if L2_strategy == 'pin':
        write_fmt = []
        if gemm_mode == 'act' and dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran:
            # The Ci padding phases cause Co_loop to be peeled out
            for y_iter in range(dims.Y_loop):
                for x_iter in range(dims.X_loop):
                    _, _, Yi_shard_size = Yi_slice_per_column(dims, col, y_iter)
                    _, _, Xi_shard_size = Xi_slice_per_column(dims, col, x_iter)
                    write_fmt.append(f'Yi:{0}:{Yi_shard_size} Xi:{0}:{Xi_shard_size} Ci:0:{dims.Ci}')
                    write_fmt.append('Ci:0:0')
                    # For rest of the Co_iters, IFM data gets reused.
                    # We need to add dummy format for the remaining Co_iters
                    for _ in range(1, dims.Co_loop):
                        write_fmt.append('Xi:0:0')
                        write_fmt.append('Ci:0:0')
        else:
            # No Ci padding phases and gemm_mode == 'wgt' have the same phase structure
            for y_iter in range(dims.Y_loop):
                for x_iter in range(dims.X_loop):
                    _, _, Yi_shard_size = Yi_slice_per_column(dims, col, y_iter)
                    _, _, Xi_shard_size = Xi_slice_per_column(dims, col, x_iter)
                    write_fmt.append(f'Yi:{0}:{Yi_shard_size} Xi:{0}:{Xi_shard_size} Ci:0:{dims.Ci}')
                    if is_Co_depad:
                        write_fmt.append('Xi:0:0')
        log(f"INFO: IFM L2 S2MM write_fmt length for column {col}: {write_fmt}")
        return write_fmt
    raise ValueError(f"Unknown L2_strategy: {L2_strategy}. "
                     "Supported strategies are 'full', 'pin', 'stream'")


def ifm_shimtile_mm2s(dims: ConvDims, col: int, gemm_mode: str = 'wgt') -> List[str]:
    '''Define IFM shim MM2S data order and shape'''
    read_fmt = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            Yi_shard_start, Yi_shard_stop, _ = Yi_slice_per_column(dims, col, y_iter)
            Xi_shard_start, Xi_shard_stop, _ = Xi_slice_per_column(dims, col, x_iter)
            read_fmt.append(
                f'Yi:{Yi_shard_start}:{Yi_shard_stop} Xi:{Xi_shard_start}:{Xi_shard_stop} Ci:0:{dims.Ci}'
            )
            # NOTE: if both Co depad and Ci pad ?
            if (dims.Co < dims.Co_loop * dims.Co_split * dims.Cos) and (dims.Co_loop > 1):
                # if there is a Co depad phase added, add a dummy access
                # pattern alternating with the real fmt
                read_fmt.append('Xi:0:0')
            elif dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran and gemm_mode == 'act':
                # if there is a Ci pad phase added, add a dummy access
                # pattern alternating with the real fmt
                read_fmt.append('Ci:0:0')
    return read_fmt


def generate_ifm_memtile_data_transfers(
    dims: ConvDims,
    conv_l2_alloc: convL2Memory,
    conv_repeats: ConvDataFlowRepeats,
    ifm_L2_strategy: str = 'pin',
    is_Co_depad: bool = False,
    gemm_mode: str = 'wgt',
) -> List[DataTransfer]:
    '''Generate IFM memory tile data transfers'''
    data_transfers = []
    log(f"INFO IFM L2 conv repeats.ifm_L2_s2mm_repeats: {conv_repeats.ifm_L2_s2mm_repeats}")
    log(f"INFO IFM L2 conv repeats.ifm_L2_mm2s_repeats: {conv_repeats.ifm_L2_mm2s_repeats}")
    reuse_ratio = dims.Co_loop
    num_consumers = dims.aie_rows if split_to_mode(dims) == 0 else 2
    log(f"INFO: IFM L2 split mode: {split_to_mode(dims)}")
    reuse_bd_chain = 1
    log(f"INFO: IFM MM2S num_consumers: {num_consumers}")
    total_reuse = reuse_ratio * num_consumers
    adjusted_reuse_ratio = reuse_ratio
    if (total_reuse) > 63:
        # If the reuse ratio is too high and the lock values overflow,
        # we need to split the reuse ratio into multiple BD chains
        # 63 is the maximum value for a lock in AIE-4
        reuse_bd_chain = compute_reuse_chain_length(reuse_ratio,
                                                    num_consumers,
                                                    max_chain_length=8)
        adjusted_reuse_ratio = reuse_ratio // reuse_bd_chain
    log(f"INFO: Reuse BD chain length: {reuse_bd_chain}")
    for col in range(dims.aie_cols):
        # Full IFM L2 memory transfer
        mem_fmt = ifm_memtile_memory(dims, col, ifm_L2_strategy)
        formated_mem_fmt = []
        for fmt in mem_fmt:
            # Here we check if there are any additional phases introduced
            # Apart from base phases, in such cases, we need to repeat the mem_fmt
            # The format itself does not change, as the shape of the pinned tile does not change
            if dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran and gemm_mode == 'act':
                # NOTE: If Ci padding is needed in case of act x act gemm mode
                # An additional phase is added for the padded Ci
                # Which means Co_loop becomes a multiplying factor for the Ci phases
                # So we need to repeat the mem_fmt for Co_loop times
                log(f"INFO: Adding Ci padding phases to IFM L2 mem fmt for column {col}")
                for _ in range(dims.Co_loop):
                    # Ci_loop -1 times the real fmt
                    formated_mem_fmt.append(fmt)
                    # 1 time fmt for the padded Ci
                    formated_mem_fmt.append(fmt)
            else:
                formated_mem_fmt.append(fmt)
                if is_Co_depad:
                    # if there is a Co depad phase added, add a dummy access
                    # pattern alternating with the real fmt
                    formated_mem_fmt.append(fmt)
        assert len(formated_mem_fmt) == len(conv_repeats.ifm_L2_mm2s_repeats[col]), \
            f"Column {col} IFM L2 memory fmts and repeats length mismatch: " \
            f"{len(formated_mem_fmt)} != {len(conv_repeats.ifm_L2_mm2s_repeats[col])}"
        write_fmt = ifm_memtile_s2mm(dims, col, is_Co_depad, ifm_L2_strategy, gemm_mode)
        # Check if the number of phases matches the number of fmts
        assert len(write_fmt) == len(conv_repeats.ifm_L2_s2mm_repeats[col]), \
            f"Column {col} fmts and repeats length mismatch: " \
            f"{len(write_fmt)} != {len(conv_repeats.ifm_L2_s2mm_repeats[col])}"

        read_fmt_full_iters = {}
        read_fmt_partial = {}
        for row in range(dims.aie_rows):
            read_fmt_list_results = ifm_memtile_mm2s(
                dims, col, row, is_Co_depad, ifm_L2_strategy, gemm_mode
            )
            read_fmt_list_per_row_full_iters = read_fmt_list_results[0]
            read_fmt_list_per_partial_iters = read_fmt_list_results[1]

            if len(read_fmt_list_per_row_full_iters) > 0:
                assert len(read_fmt_list_per_row_full_iters) == len(conv_repeats.ifm_L2_mm2s_repeats[col]), \
                    f"Column {col}, row {row} fmts and repeats length mismatch: " \
                    f"{len(read_fmt_list_per_row_full_iters)} != {len(conv_repeats.ifm_L2_mm2s_repeats[col])}"
            if len(read_fmt_list_per_partial_iters) > 0:
                assert (len(read_fmt_list_per_partial_iters)) == len(conv_repeats.ifm_L2_mm2s_repeats[col]), \
                    f"Column {col}, row {row} fmts and repeats length mismatch: " \
                    f"{len(read_fmt_list_per_partial_iters)} != {len(conv_repeats.ifm_L2_mm2s_repeats[col])}"

            read_fmt_full_iters[row] = read_fmt_list_per_row_full_iters
            read_fmt_partial[row] = read_fmt_list_per_partial_iters

        # Generate the data transfer for the IFM L2 memory tile
        log(f"INFO: IFM L2 MM2S fmt for column \
              {len(read_fmt_full_iters)} {col}: {read_fmt_full_iters}")
        log(f"INFO: IFM L2 MM2S partial fmt for column \
               {len(read_fmt_partial)}  {col}: {read_fmt_partial}")
        mm2s_transfers = []
        for _ in range(reuse_bd_chain):
            # Build format lists for one complete iteration
            for row, channel_id in ifm_memtile_channels(dims, col):
                mm2s_transfers.append(
                    pack_reconfig_transfers(
                        memtile_dma(col, DmaDir.MM2S, channel_id),
                        formated_mem_fmt,
                        read_fmt_full_iters[row],
                        bits_per_elem=dims.ifm_bits
                    )
                )
            # Partial iteration formats for all rows (if they exist)
            if any(len(read_fmt_partial[row]) > 0 for row in range(dims.aie_rows)):
                for row, channel_id in ifm_memtile_channels(dims, col):
                    mm2s_transfers.append(
                        pack_reconfig_transfers(
                            memtile_dma(col, DmaDir.MM2S, channel_id),
                            formated_mem_fmt,
                            read_fmt_partial[row],
                            bits_per_elem=dims.ifm_bits
                        )
                    )
        data_transfers.append(
            DataTransfer(
                conv_repeats.ifm_L2_s2mm_repeats[col],
                memory_tile(col), [conv_l2_alloc.ifm_ping_addr], conv_l2_alloc.ifm_size,
                [
                    pack_reconfig_transfers(
                        memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                        formated_mem_fmt,
                        write_fmt,
                        bits_per_elem=dims.ifm_bits
                    )
                ],
                mm2s_transfers,
                reuse_ratio=adjusted_reuse_ratio,
                sync_strategy=SyncStrategy.Parallel_1_to_N
            )
        )

    return data_transfers


@no_type_check
def generte_ifm_shimtile_data_transfers(
    dims: ConvDims,
    conv_repeats: ConvDataFlowRepeats,
    conv_shim: ShimAllocator,
    is_Co_depad: bool = False,
    gemm_mode: str = 'wgt',
) -> List[DataTransfer]:
    '''Generate IFM shim tile data transfers'''
    _ = is_Co_depad
    data_transfers = []
    if gemm_mode == 'wgt':
        ifm_xrt_offset = conv_shim.ifm_xrt_offset
        ifm_xrt_idx = conv_shim.ifm_xrt_idx
    else:
        ifm_xrt_offset = conv_shim.ifm_xrt_offset[dims.ifm_to_xrt_idx["ifm0"]]
        ifm_xrt_idx = conv_shim.ifm_xrt_idx[0]
    mem_fmt = ifm_shim_memory(dims)
    ddr_ifm_size = compute_buffer_size(mem_fmt)
    for col in range(dims.aie_cols):
        # Full IFM L2 memory transfer
        read_fmt = ifm_shimtile_mm2s(dims, col, gemm_mode)
        log(f"INFO: IFM shim tile memory pattern: {mem_fmt}")
        log(f"INFO: IFM shim tile MM2S fmt for column {col}: {read_fmt}")
        for _ in range(len(read_fmt), len(conv_repeats.ifm_L3_mm2s_repeats[col])):
            read_fmt.append('Ci:0:0')
        assert len(read_fmt) == len(conv_repeats.ifm_L3_mm2s_repeats[col]), \
            f"Column {col} fmts and repeats length mismatch: " \
            f"{len(read_fmt)} != {len(conv_repeats.ifm_L3_mm2s_repeats[col])}"

        per_col_data_transfer = DataTransfer(
            conv_repeats.ifm_L3_mm2s_repeats[col],
            shim_tile(col), [ifm_xrt_idx], ddr_ifm_size,
            [],
            [
                pack_reconfig_transfers(
                    shim_dma(col, DmaDir.MM2S, 0),
                    [mem_fmt for _ in range(len(conv_repeats.ifm_L3_mm2s_repeats[col]))],
                    read_fmt,
                    bits_per_elem=dims.ifm_bits,
                    buffer_offset=[ifm_xrt_offset]
                )
            ]
        )

        data_transfers.append(per_col_data_transfer)
    return data_transfers


#####################################################
# qdq_param memory and access Formats
#####################################################

def qdq_param_memtile_memory(dims: ConvDims) -> str:
    '''Define qdq_param L2 data order'''
    return f'Param:{dims.qdq_param_size}'


def qdq_param_memtile_s2mm(dims: ConvDims) -> str:
    '''Define qdq_param L2 S2MM data order and shape'''
    return f'Param:0:{dims.qdq_param_size}'


def qdq_param_memtile_mm2s(dims: ConvDims, col: int, row: int) -> str:
    '''Define qdq_param L2 MM2S data order and shape'''
    _ = (col, row)
    return f'Param:0:{dims.qdq_param_size}'


def qdq_param_shim_mm2s(dims: ConvDims, col: int) -> str:
    '''Define qdq_param shim MM2S data order and shape'''
    _ = col
    return f'Param:0:{dims.qdq_param_size}'


def qdq_param_shim_memory(dims: ConvDims) -> str:
    '''Define qdq_param DDR data order'''
    return f'Param:{dims.qdq_param_size}'


#####################################################
# WGT memory and tiling Formats
#####################################################

def wgt_memtile_memory_wgt(
    dims: ConvDims,
    is_Co_depad: bool,
    L2_strategy: str = 'stream',
) -> list[str]:
    '''Define WGT L2 data order'''
    if L2_strategy == 'stream':
        wgt_mem_fmt = []
        subv_fmt = f'Cob:{Co_split_size(dims)} Subv:{dims.wgt_L1_size}'
        for _ in range(dims.Y_loop):
            for _ in range(dims.X_loop):
                wgt_mem_fmt.append(subv_fmt)
                if is_Co_depad:
                    # if there is a Co depad phase added, add another fmt
                    wgt_mem_fmt.append(subv_fmt)
        return wgt_mem_fmt
    raise ValueError(f"Unknown L2_strategy: {L2_strategy}. Supported strategies are 'stream'.")


def wgt_memtile_memory_act(dims: ConvDims, L2_strategy: str = 'stream') -> list[str]:
    '''Define WGT L2 data order'''

    Yib_size = Yo_split_size(dims)  # NOTE: Yi == Yo in BMM
    Cob_size = Co_split_size(dims)
    yi_dim = f'Yi:{Yib_size}'
    co_dim = f'Co:{Cob_size*dims.Cos}'
    ci_gran_dim = f'Ci:{dims.Ci_gran}'
    ci_s_dim = f'Ci:{dims.Cis}'

    if L2_strategy == 'pin':
        # Pin Ci strategy: Keep full Ci dimension in L2, stream Co
        # Memory layout: Yi x Ci x Co (or Yi x Co x Ci if transposed)
        ci_full_dim = f'Ci:{dims.Ci_orig}'
        if dims.transpose_wgts:
            return [f'{yi_dim} {ci_full_dim} {co_dim}']
        return [f'{yi_dim} {ci_full_dim} {co_dim}']

    # Case 1: Ci equals Ci_gran (no slicing needed)
    if dims.Ci == dims.Ci_gran:
        fmt = f'{yi_dim} {co_dim} {ci_gran_dim}' if dims.transpose_wgts else f'{yi_dim} {ci_gran_dim} {co_dim}'
        return [fmt]

    # Case 2: Ci evenly divisible by Cis
    if dims.Ci % dims.Cis == 0:
        fmt = f'{yi_dim} {ci_s_dim} {co_dim} {ci_gran_dim}' if dims.transpose_wgts else f'{yi_dim} {ci_s_dim} {co_dim}'
        return [fmt]

    # Case 3: Ci padding needed (not evenly divisible)
    if dims.transpose_wgts:
        return [
            f'{yi_dim} {ci_s_dim} {co_dim} {ci_gran_dim}',
            f'{yi_dim} {co_dim} {ci_gran_dim}'
        ]
    return [
            f'{yi_dim} {ci_s_dim} {co_dim}',
            f'{yi_dim} {ci_gran_dim} {co_dim}'
        ]


def wgt_memtile_memory(
    dims: ConvDims,
    is_Co_depad: bool,
    L2_strategy: str = 'stream',
    gemm_mode: str = 'wgt',
) -> list[str]:
    '''Define WGT L2 data order and shape'''
    log(f"DEBUG: WGT L2 memory is_Co_depad: {is_Co_depad}, L2_strategy: {L2_strategy}, gemm_mode: {gemm_mode}")
    if gemm_mode == 'act':
        return wgt_memtile_memory_act(dims, L2_strategy)
    if gemm_mode == 'wgt':
        return wgt_memtile_memory_wgt(dims, is_Co_depad, L2_strategy)
    raise ValueError(f"Unknown gemm_mode: {gemm_mode}. Supported modes are 'act' and 'wgt'.")


def wgt_memtile_s2mm_wgt(
    dims: ConvDims,
    is_Co_depad: bool,
    L2_strategy: str = 'stream',
    weights_channel_list: List[int] = [1],
    channel_idx: int = 0,
) -> list[str]:
    '''Define WGT L2 S2MM data order and shape'''
    write_fmt = []
    if L2_strategy == 'stream':
        # Collect Co ranges for all cores in this column (iteration 0)
        # Since weights are preformattes in L3 with necessary Cin and Cout paddings.
        # We only stream the same size chunks to L2 without any unique read/write formats.
        Cob_size = Co_split_size(dims)
        Cob_size_per_channel = Cob_size // len(weights_channel_list)
        log(f"DEBUG: WGT L2 S2MM Cob_size: {Cob_size}, Cob_size_per_channel: {Cob_size_per_channel}, "
            f"weights_channel_list: {weights_channel_list}, channel_idx: {channel_idx}")
        Co_start = channel_idx * Cob_size_per_channel
        Co_end = Co_start + Cob_size_per_channel
        _fmt = f'Cob:{Co_start}:{Co_end} Subv:0:{dims.wgt_L1_size}'
        for _ in range(dims.Y_loop):
            for _ in range(dims.X_loop):
                write_fmt.append(_fmt)
                if is_Co_depad:
                    # if there is a Co depad phase added, add another fmt
                    write_fmt.append(_fmt)
        log(f"DEBUG: WGT L2 S2MM write_fmt: {write_fmt}")
        return write_fmt
    raise ValueError(f"ERROR: Unknown L2_strategy: {L2_strategy}. "
                     "Supported strategies are 'stream'.")


def wgt_memtile_s2mm_act(dims: ConvDims, col: int, is_Co_depad: bool, L2_strategy: str = 'stream') -> list[str]:
    '''Define WGT L2 S2MM data order and shape'''
    write_fmt = []
    # pylint: disable-next=R1702

    if L2_strategy == 'pin':
        # Pin Ci strategy: Write full Ci dimension once per Y/X phase
        for y_iter in range(dims.Y_loop):
            for _ in range(dims.X_loop):
                _, _, Yi_size = Yi_slice_per_column(dims, col, y_iter)
                _, _, Co_size_full = Co_slice_per_column(dims, col, 0)
                # Pin full Ci, stream Co blocks
                write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Ci_orig} Co:0:{Co_size_full}')
                if is_Co_depad:
                    _, _, Co_size_partial = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                    write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Ci_orig} Co:0:{Co_size_partial}')
        return write_fmt
    # pylint: disable-next=R1702
    for y_iter in range(dims.Y_loop):
        for _ in range(dims.X_loop):
            _, _, Yi_size = Yi_slice_per_column(dims, col, y_iter)
            # CASE 1: Handle input channel (Ci) padding scenario Ci=128*X+64
            if dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran:
                # If there are ci padding phases,
                # Co_loop gets peeled out
                for co_iter in range(dims.Co_loop):
                    _, _, Co_size = Co_slice_per_column(dims, col, co_iter)
                    if dims.transpose_wgts:
                        write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Cis}:{dims.Ci_gran} Co:0:{Co_size} Ci:0:{dims.Ci_gran}')
                        write_fmt.append(f'Yi:0:{Yi_size} Co:0:{Co_size} Ci:0:{dims.Ci_gran}')
                    else:
                        write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Cis} Co:0:{Co_size}')
                        write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Ci_gran} Co:0:{Co_size} ')
            # CASE 2: Handle standard scenario (no Ci padding)
            else:
                _, _, Co_size_full = Co_slice_per_column(dims, col, 0)
                _, _, Co_size_partial = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                if dims.Ci == dims.Ci_gran:
                    if dims.transpose_wgts:
                        write_fmt.append(f'Yi:0:{Yi_size} Co:0:{Co_size_full} Ci:0:{dims.Ci_gran}')
                    else:
                        write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Ci_gran} Co:0:{Co_size_full} ')
                    if is_Co_depad:
                        if dims.transpose_wgts:
                            write_fmt.append(f'Yi:0:{Yi_size} Co:0:{Co_size_partial} Ci:0:{dims.Ci_gran}')
                        else:
                            write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Ci_gran} Co:0:{Co_size_partial} ')
                else:
                    if dims.transpose_wgts:
                        write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Cis}:{dims.Ci_gran} Co:0:{Co_size_full} Ci:0:{dims.Ci_gran}')
                    else:
                        write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Cis} Co:0:{Co_size_full} ')
                    if is_Co_depad:
                        if dims.transpose_wgts:
                            write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Cis}:{dims.Ci_gran} Co:0:{Co_size_partial} Ci:0:{dims.Ci_gran}')
                        else:
                            write_fmt.append(f'Yi:0:{Yi_size} Ci:0:{dims.Cis} Co:0:{Co_size_partial} ')
    return write_fmt


def wgt_memtile_s2mm(
    dims: ConvDims,
    col: int,
    is_Co_depad: bool,
    L2_strategy: str = 'stream',
    gemm_mode: str = 'wgt',
    weights_channel_list: List[int] = [1],
    channel_idx: int = 0,
) -> List[str]:
    '''Define WGT L2 S2MM data order and shape'''
    if gemm_mode == 'act':
        return wgt_memtile_s2mm_act(dims, col, is_Co_depad, L2_strategy)
    if gemm_mode == 'wgt':
        return wgt_memtile_s2mm_wgt(dims, is_Co_depad, L2_strategy, weights_channel_list, channel_idx)
    raise ValueError(f"Unknown gemm_mode: {gemm_mode}. Supported modes are 'act' and 'wgt'.")


def wgt_memtile_mm2s_wgt(
    dims: ConvDims,
    col: int,
    row: int,
    is_Co_depad: bool,
    L2_strategy: str = 'stream',
) -> List[str]:
    '''Define WGT L2 MM2S data order and shape'''
    if L2_strategy == 'stream':
        # Collect Co ranges for all cores in this column (iteration 0)
        read_fmt = []
        base_phases = dims.X_loop * dims.Y_loop
        # In case of pre-formatted weights the same subvol is fetch for every iteration
        # Ci_pad, Co_depad phases do not change the weight fetch pattern
        # Geneate a single format based of the Co_split within the column
        Cob_offset = Co_split_offset(dims, col, row)
        _fmt = f'Cob:{Cob_offset}:{Cob_offset + 1} Subv:{0}:{dims.wgt_L1_size}'
        _dummy_fmt = f'Cob:{0}:{0} Subv:{0}:{0}'
        if (dims.Ci_loop * dims.Co_loop * dims.X_loop * dims.Y_loop) < 32768:
            log("INFO: WGT L2 MM2S: No need to split the WGT L2 repeats, all repeats fit in a single phase")
            for phase in range(base_phases):
                if phase == 0:
                    read_fmt.append(_fmt)
                else:
                    read_fmt.append(_dummy_fmt)
                if is_Co_depad:
                    # Depad phase results in a unqiue fmt ONLY on the ofm transfer.
                    # Since weights are preformatted in L3 with necessary paddings,
                    # the depad phase in weights transfer is identical to the normal phase.
                    read_fmt.append(_dummy_fmt)
        elif (dims.Ci_loop * dims.Co_loop) < 32768:
            log("INFO: WGT L2 MM2S: Split X_loop and Y_loop into different phases")
            for _ in range(dims.Y_loop):
                for _ in range(dims.X_loop):
                    read_fmt.append(_fmt)
                    if is_Co_depad:
                        read_fmt.append(_dummy_fmt)
        else:
            log("INFO: WGT L2 MM2S: Split Co_loop, X_loop and Y_loop into different phases")
            for _ in range(dims.Y_loop):
                for _ in range(dims.X_loop):
                    for _ in range(dims.Co_loop):
                        read_fmt.append(_fmt)
        return read_fmt
    raise ValueError(f"Unknown L2_strategy: {L2_strategy}. "
                     "Supported strategies are 'stream'.")


def wgt_memtile_mm2s_act(dims: ConvDims, col: int, row: int, is_Co_depad: bool, L2_strategy: str = 'stream') -> List[str]:
    """WGT L2 MM2S access pattern (act mode). Keeps original behavior with simpler flow."""
    _ = is_Co_depad
    base_phases = dims.X_loop * dims.Y_loop
    total_iters = dims.Ci_loop * dims.Co_loop
    access_pattern: List[str] = []

    def gen_fmt(dims, col, row, y_iter):
        yi_shim_shard_start, _, _ = Yi_slice_per_column(dims, col, y_iter)
        Yi_start, _, _ = Yi_slice(dims, col, row, y_iter)
        Yis_start = 0 if Yi_start >= dims.Yi else Yi_start - yi_shim_shard_start
        Yis_stop = Yis_start + dims.Yis  # NOTE: Yis = 1
        Cob_offset = Co_split_offset(dims, col, row)
        if dims.transpose_wgts:
            access_pattern.append(
                f'Yi:{Yis_start}:{Yis_stop} '
                f'Ci:0:{dims.Cis}:{dims.Ci_gran} '
                f'Co:{Cob_offset*dims.Cos}:{(Cob_offset+1)*dims.Cos} '
                f'Ci:0:{dims.Ci_gran}'
            )
        else:
            access_pattern.append(
                f'Yi:{Yis_start}:{Yis_stop} '
                f'Co:{Cob_offset*dims.Cos}:{(Cob_offset+1)*dims.Cos}:{dims.Co_gran_wgt} '
                f'Ci:0:{dims.Cis} '
                f'Co:0:{dims.Co_gran_wgt} '
            )

    def gen_fmt_pin(dims, col, row, y_iter):
        """Generate MM2S format for pin strategy - read Ci slices from pinned buffer."""
        yi_shim_shard_start, _, _ = Yi_slice_per_column(dims, col, y_iter)
        Yi_start, _, _ = Yi_slice(dims, col, row, y_iter)
        Yis_start = 0 if Yi_start >= dims.Yi else Yi_start - yi_shim_shard_start
        Yis_stop = Yis_start + dims.Yis
        Cob_offset = Co_split_offset(dims, col, row)
        # Read from pinned Ci buffer, iterate over Ci slices
        access_pattern.append(
            f'Yi:{Yis_start}:{Yis_stop} '
            f'Ci:0:{max(dims.Ci_orig//dims.Cis, 1)*dims.Cis}:{dims.Cis} '
            f'Co:{Cob_offset*dims.Cos}:{(Cob_offset+1)*dims.Cos}:{dims.Co_gran_wgt} '
            f'Ci:0:{dims.Cis} '
            f'Co:0:{dims.Co_gran_wgt} '
        )
        if dims.Ci != dims.Ci_orig and dims.Ci_orig > dims.Cis:
            access_pattern.append(
                f'Yi:{Yis_start}:{Yis_stop} '
                f'Co:{Cob_offset*dims.Cos}:{(Cob_offset+1)*dims.Cos}:{dims.Co_gran_wgt} '
                f'Ci:{(dims.Ci_orig//dims.Cis)*dims.Cis}:{(dims.Ci_orig//dims.Cis+1)*dims.Cis} '
                f'Co:0:{dims.Co_gran_wgt} '
            )

    if L2_strategy == 'pin':
        # Pin Ci strategy: Ci is pinned in L2, iterate over Ci slices for each Co
        for y_iter in range(dims.Y_loop):
            for _ in range(dims.X_loop):
                gen_fmt_pin(dims, col, row, y_iter)
                if is_Co_depad:
                    gen_fmt_pin(dims, col, row, y_iter)
        log(f"INFO: WGT L2 MM2S (pin) fmt col {col} row {row}: {access_pattern}")
        return access_pattern

    # pylint: disable-next=R1702
    if dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran:
        # This is a case of Ci padding needed
        # Ci_loop-1, Ci_loop is pealed as phases.
        # Which means Co_loop, X_loop, Y_loop also peeled phases
        # This is a neceesary for functional correctness
        for y_iter in range(dims.Y_loop):
            for _ in range(dims.X_loop):
                for _ in range(dims.Co_loop):
                    gen_fmt(dims, col, row, y_iter)
                    gen_fmt(dims, col, row, y_iter)
    else:
        if total_iters * base_phases < 32768 and dims.Yi < 1:
            log("INFO: WGT L2 MM2S: single phase (no split)")
            gen_fmt(dims, row, col, 0)
            if is_Co_depad:
                gen_fmt(dims, row, col, 0)
            access_pattern.extend(['Co:0:0 Ci:0:0'] * (base_phases - 1))
        elif total_iters < 32768:
            log("INFO: WGT L2 MM2S: split over X/Y loops")
            for y_iter in range(dims.Y_loop):
                for _ in range(dims.X_loop):
                    gen_fmt(dims, col, row, y_iter)
                    if is_Co_depad:
                        gen_fmt(dims, col, row, y_iter)
        else:
            log("INFO: WGT L2 MM2S: split over Co + X/Y loops")
            for y_iter in range(dims.Y_loop):
                for _ in range(dims.X_loop):
                    for _ in range(dims.Co_loop):
                        gen_fmt(dims, col, row, y_iter)
                        if dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran:
                            gen_fmt(dims, col, row, y_iter)
        log(f"INFO: WGT L2 MM2S fmt col {col} row {row}: {access_pattern}")
    return access_pattern


def wgt_memtile_mm2s(
    dims: ConvDims,
    col: int,
    row: int,
    is_Co_depad: bool,
    L2_strategy: str = 'stream',
    gemm_mode: str = 'wgt',
) -> List[str]:
    '''Define WGT L2 MM2S data order and shape'''
    if gemm_mode == 'act':
        return wgt_memtile_mm2s_act(dims, col, row, is_Co_depad, L2_strategy)
    if gemm_mode == 'wgt':
        return wgt_memtile_mm2s_wgt(dims, col, row, is_Co_depad, L2_strategy)
    raise ValueError(f"Unknown gemm_mode: {gemm_mode}. Supported modes are 'act' and 'wgt'.")


def wgt_shimtile_mm2s_wgt(
    dims: ConvDims,
    col: int,
    L2_strategy: str = 'stream',
    weights_channel_list: List[int] = [1],
    channel_idx: int = 0,
) -> List[str]:
    '''Define WGT shim MM2S data order and shape'''
    if L2_strategy == 'stream':
        read_fmt = []
        for _ in range(dims.Y_loop):
            for _ in range(dims.X_loop):
                # NOTE: Traversal along Cob within a column will always have
                # stride one, so we don't need to compute it explicitly.
                Co_idxs = Co_split_idxs(dims, col)
                log(f"DEBUG: WGT shim tile MM2S Co_idxs: {Co_idxs}, weights_channel_list: {weights_channel_list}, channel_idx: {channel_idx}")
                unique_co_blocks = list(set(Co_idxs))
                # Split unique_co_blocks based on channel_idx and length of weights_channel_list
                num_channels = len(weights_channel_list)
                num_blocks = len(unique_co_blocks)
                # Validate that blocks can be evenly distributed
                if num_blocks % num_channels != 0:
                    raise ValueError(
                        f"Number of Co blocks ({num_blocks}) in column {col} is not divisible "
                        f"by number of channels ({num_channels}). "
                        f"Co blocks must be evenly distributed across channels."
                    )
                # Calculate blocks per channel
                blocks_per_channel = num_blocks // num_channels
                # Get the Co blocks for this specific channel
                start_idx = channel_idx * blocks_per_channel
                end_idx = start_idx + blocks_per_channel
                channel_co_blocks = unique_co_blocks[start_idx:end_idx]
                # Get the Co block range for this channel
                Cob_start = min(channel_co_blocks)
                Cob_stop = max(channel_co_blocks) + 1
                log(f"DEBUG: Channel {channel_idx} handles Co blocks: {channel_co_blocks}, range: [{Cob_start}:{Cob_stop})")
                read_fmt.append(
                    f'Cob:0:{dims.Co_loop * dims.Co_split}:{dims.Co_split} '
                    f'Cib:0:{dims.Ci_loop} '
                    f'Cob:{Cob_start}:{Cob_stop} '
                    f'Subv:0:{dims.wgt_L1_size}'
                )
                if (dims.Co < dims.Co_loop * dims.Co_split * dims.Cos) and (dims.Co_loop > 1):
                    read_fmt.append('Cob:0:0 Subv:0:0')
        return read_fmt
    raise ValueError(f"Unknown L2_strategy: {L2_strategy}. "
                     "Supported strategies are 'stream'.")


def wgt_shimtile_mm2s_act(dims: ConvDims, col: int, is_Co_depad: bool, L2_strategy: str = 'stream') -> List[str]:
    '''Define WGT shim MM2S data order and shape'''
    # pylint: disable-next=R1702
    read_fmt = []
    if L2_strategy == 'pin':
        # Pin Ci strategy: Fetch full Ci once per Y/X phase
        for y_iter in range(dims.Y_loop):
            for _ in range(dims.X_loop):
                Yi_shard_start, Yi_shard_stop, _ = Yi_slice_per_column(dims, col, y_iter)
                Co_shard_start, _, _ = Co_slice_per_column(dims, col, 0)
                _, Co_shard_stop, _ = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                # Fetch full Ci dimension, all Co blocks for this column
                read_fmt.append(
                    f'Yi:{Yi_shard_start}:{Yi_shard_stop} '
                    f'Co:{Co_shard_start}:{Co_shard_stop}:{dims.Co_gran} '
                    f'Ci:0:{dims.Ci_orig} '
                    f'Co:0:{dims.Co_gran} '
                )
                if is_Co_depad:
                    # Dummy format for depad phase - no additional fetch needed
                    read_fmt.append('Ci:0:0')
        log(f"INFO: WGT shim tile MM2S (pin) fmt for col {col}: {read_fmt}")
        return read_fmt
    # pylint: disable-next=R1702
    if L2_strategy == 'stream':
        for y_iter in range(dims.Y_loop):
            for _ in range(dims.X_loop):
                Yi_shard_start, Yi_shard_stop, _ = Yi_slice_per_column(dims, col, y_iter)
                if dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran:
                    # This is a case of additional Ci padding phases
                    for Co_iter in range(dims.Co_loop):
                        Co_start, Co_stop, _ = Co_slice_per_column(dims, col, Co_iter)
                        if dims.transpose_wgts:
                            read_fmt.append(
                                f'Yi:{Yi_shard_start}:{Yi_shard_stop} '
                                f'Ci:0:{dims.Ci}:{dims.Ci_gran} '
                                f'Co:{Co_start}:{Co_stop} '
                                f'Ci:0:{dims.Ci_gran} '
                            )
                        else:
                            read_fmt.append(
                                f'Yi:{Yi_shard_start}:{Yi_shard_stop} '
                                f'Ci:0:{dims.Ci} '
                                f'Co:{Co_start}:{Co_stop} '
                            )
                        read_fmt.append('Ci:0:0')
                else:
                    # NO Ci padding phases
                    full_iters = dims.Co_loop - 1
                    co_idxs = Co_split_size(dims)
                    if is_Co_depad:
                        full_iters = dims.Co_loop - 2
                    Co_shard_start, _, _ = Co_slice_per_column(dims, col, 0)
                    _, Co_shard_stop, _ = Co_slice_per_column(dims, col, full_iters)
                    Co_stride = dims.Co_split * dims.Cos
                    Cos_stride = min(Co_shard_stop - Co_shard_start, co_idxs * dims.Cos)
                    if dims.transpose_wgts:
                        read_fmt.append(
                            f'Co:{Co_shard_start}:{Co_shard_stop}:{Co_stride} '
                            f'Yi:{Yi_shard_start}:{Yi_shard_stop} '
                            f'Ci:0:{dims.Ci}:{dims.Ci_gran} '
                            f'Co:0:{Cos_stride} '
                            f'Ci:0:{dims.Ci_gran} '
                        )
                    else:
                        read_fmt.append(
                            f'Co:{Co_shard_start}:{Co_shard_stop}:{Co_stride} '
                            f'Yi:{Yi_shard_start}:{Yi_shard_stop} '
                            f'Ci:0:{dims.Ci} '
                            f'Co:0:{Cos_stride} '
                        )
                    if is_Co_depad:
                        log("INFO: WGT shim tile MM2s Co_pad")
                        # last_iter = dims.Co_loop - 1
                        Co_shard_start, Co_shard_stop, _ = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                        if dims.transpose_wgts:
                            read_fmt.append(
                                f'Yi:{Yi_shard_start}:{Yi_shard_stop} '
                                f'Ci:0:{dims.Ci}:{dims.Ci_gran} '
                                f'Co:{Co_shard_start}:{Co_shard_stop} '
                                f'Ci:0:{dims.Ci_gran} '
                            )
                        else:
                            last_Co = Co_shard_stop - Co_shard_start
                            Co_stride = min(Co_stride, last_Co)
                            read_fmt.append(
                                f'Co:{Co_shard_start}:{Co_shard_stop}:{Co_stride} '
                                f'Yi:{Yi_shard_start}:{Yi_shard_stop} '
                                f'Ci:0:{dims.Ci} '
                                f'Co:0:{Co_stride} '
                            )
        log(f"INFO: WGT shim tile MM2S fmt for col {col}: {read_fmt}")
        return read_fmt
    raise ValueError(f"Unknown L2_strategy: {L2_strategy}. "
                     "Supported strategies are 'stream'.")


def wgt_shimtile_mm2s(dims: ConvDims, col: int, is_Co_depad: bool, L2_strategy: str = 'stream', gemm_mode: str = 'wgt', weights_channel_list: List[int] = [1], channel_idx: int = 0) -> List[str]:
    '''Define WGT shim MM2S data order and shape'''
    if gemm_mode == 'act':
        return wgt_shimtile_mm2s_act(dims, col, is_Co_depad, L2_strategy)
    if gemm_mode == 'wgt':
        return wgt_shimtile_mm2s_wgt(dims, col, L2_strategy, weights_channel_list, channel_idx)
    raise ValueError(f"Unknown gemm_mode: {gemm_mode}. Supported modes are 'act' and 'wgt'.")


@no_type_check
def generate_wgt_memtile_data_transfers_wgt(
    dims: ConvDims,
    conv_l2_alloc: convL2Memory,
    conv_repeats: ConvDataFlowRepeats,
    wgt_L2_strategy: str = 'stream',
    is_Co_depad: bool = False,
    weights_channel_list: List[int] = [1],
) -> List[DataTransfer]:
    '''Generate WGT memory tile data transfers'''
    data_transfers = []
    log(f"INFO WGT L2 conv repeats.wgt_L2_s2mm_repeats: {conv_repeats.wgt_L2_s2mm_repeats}")
    log(f"INFO WGT L2 conv repeats.wgt_L2_mm2s_repeats: {conv_repeats.wgt_L2_mm2s_repeats}")
    for col in range(dims.aie_cols):
        for channel_idx, fill_channel in enumerate(weights_channel_list):
            read_tiling_fmt = {}
            # Full WGT L2 memory transfer
            mem_fmt = wgt_memtile_memory(dims, is_Co_depad, wgt_L2_strategy, gemm_mode='wgt')
            assert len(mem_fmt) == len(conv_repeats.wgt_L2_s2mm_repeats[0]), \
                f"Column {col} WGT L2 memory fmts and repeats length mismatch: " \
                f"{len(mem_fmt)} != {len(conv_repeats.wgt_L2_s2mm_repeats[0])}"
            write_tiling_fmt = wgt_memtile_s2mm(dims, col, is_Co_depad, wgt_L2_strategy, 'wgt', weights_channel_list, channel_idx)
            # Check if the number of phases matches the number of fmts
            assert len(write_tiling_fmt) == len(conv_repeats.wgt_L2_s2mm_repeats[col]), \
                f"Column {col} fmts and repeats length mismatch: " \
                f"{len(write_tiling_fmt)} != {len(conv_repeats.wgt_L2_s2mm_repeats[col])}"
            for row in range(dims.aie_rows):
                read_fmt_list = wgt_memtile_mm2s(dims, col, row, is_Co_depad, wgt_L2_strategy, 'wgt')
                assert len(read_fmt_list) == len(conv_repeats.wgt_L2_mm2s_repeats[col]), \
                    f"Column {col}, row {row} tiling fmts and repeats length mismatch: " \
                    f"{len(read_fmt_list)} != "\
                    f"{len(conv_repeats.wgt_L2_mm2s_repeats[col])}"
                read_tiling_fmt[row] = read_fmt_list
            data_transfers.append(DataTransfer(
                    conv_repeats.wgt_L2_s2mm_repeats[col],
                    memory_tile(col), [conv_l2_alloc.wgt_ping_addr, conv_l2_alloc.wgt_pong_addr], conv_l2_alloc.wgt_size,
                    [
                        pack_reconfig_transfers(
                            memtile_dma(col, DmaDir.S2MM, map_shim_ch_memtile_ch(fill_channel)),
                            mem_fmt,
                            write_tiling_fmt,
                        )
                    ],
                    [
                        pack_reconfig_transfers(
                            memtile_dma(col, DmaDir.MM2S, channel_id),
                            mem_fmt,
                            read_tiling_fmt[row],
                        ) for row, channel_id in wgt_memtile_channels(dims, col, weights_channel_list, channel_idx)
                    ],
                    sync_strategy=SyncStrategy.Parallel_1_to_N
                )
            )
    return data_transfers


@no_type_check
def generate_wgt_memtile_data_transfers_act(
    dims: ConvDims,
    conv_l2_alloc: ConvL2MemoryAllocator,
    conv_repeats: ConvDataFlowRepeats,
    wgt_L2_strategy: str = 'stream',
    is_Co_depad: bool = False,
) -> List[DataTransfer]:
    '''Generate WGT memory tile data transfers'''
    data_transfers = []
    mem_pattern = wgt_memtile_memory(dims, is_Co_depad, wgt_L2_strategy, 'act')
    log(f"INFO: WGT L2 memory pattern: {mem_pattern}")
    if dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran:
        mem_pattern_list = [mem_pattern[0], mem_pattern[1]] * (len(conv_repeats.wgt_L2_s2mm_repeats[0]) // 2)
    else:
        mem_pattern_list = [mem_pattern[0] for _ in range(len(conv_repeats.wgt_L2_s2mm_repeats[0]))]
    if wgt_L2_strategy == 'pin':
        wgt_addr = [conv_l2_alloc.wgt_ping_addr]
    else:
        wgt_addr = [conv_l2_alloc.wgt_ping_addr, conv_l2_alloc.wgt_pong_addr]
    for col in range(dims.aie_cols):
        # Full WGT L2 memory transfer
        write_access_pattern = wgt_memtile_s2mm(dims, col, is_Co_depad, wgt_L2_strategy, 'act')
        assert len(write_access_pattern) == len(conv_repeats.wgt_L2_s2mm_repeats[col]), \
            f"Column {col} access patterns and repeats length mismatch: " \
            f"{len(write_access_pattern)} != {len(conv_repeats.wgt_L2_s2mm_repeats[col])}"

        read_access_pattern = {}
        read_access_pattern_partial = {}
        for row in range(dims.aie_rows):
            read_fmt_list1 = wgt_memtile_mm2s(dims, col, row, is_Co_depad, wgt_L2_strategy, 'act')
            if len(read_fmt_list1) == 2*len(mem_pattern_list):
                read_fmt_list = read_fmt_list1[0::2]  # Even indices: 0, 2, 4, ...
                read_fmt_list_partial = read_fmt_list1[1::2]  # Odd indices: 1, 3, 5, ...
                read_access_pattern[row] = read_fmt_list
                read_access_pattern_partial[row] = read_fmt_list_partial
            else:
                read_fmt_list = read_fmt_list1
                read_access_pattern[row] = read_fmt_list1
                read_access_pattern_partial[row] = []
            assert len(read_fmt_list) == len(conv_repeats.wgt_L2_mm2s_repeats[col]), \
                f"Column {col}, row {row} access patterns and repeats length mismatch: " \
                f"{len(read_fmt_list)} != " \
                f"{len(conv_repeats.wgt_L2_mm2s_repeats[col])}"

        log(f"DEBUG: WGT L2 MM2S col {col} row {row} read_fmt_list {read_access_pattern}")
        log(f"DEBUG: WGT L2 MM2S col {col} row {row} read_fmt_list_partial {read_access_pattern_partial}")
        log(f"DEBUG: WGT L2 S2MM col {col} write_access_pattern {mem_pattern_list}")

        mm2s_transfers = []
        # Full iteration formats for all rows
        for row, channel_id in wgt_memtile_channels(dims, col):
            mm2s_transfers.append(
                pack_reconfig_transfers(
                    memtile_dma(col, DmaDir.MM2S, channel_id),
                    mem_pattern_list,
                    read_access_pattern[row],
                    bits_per_elem=dims.wgt_bits
                )
            )
        # Partial iteration formats for all rows (if they exist)
        if any(len(read_access_pattern_partial[row]) > 0 for row in range(dims.aie_rows)):
            for row, channel_id in wgt_memtile_channels(dims, col):
                mm2s_transfers.append(
                    pack_reconfig_transfers(
                        memtile_dma(col, DmaDir.MM2S, channel_id),
                        mem_pattern_list,
                        read_access_pattern_partial[row],
                        bits_per_elem=dims.wgt_bits
                    )
                )

        per_col_data_transfer = DataTransfer(
            conv_repeats.wgt_L2_s2mm_repeats[col],
            memory_tile(col), wgt_addr, conv_l2_alloc.wgt_size,
            [
                pack_reconfig_transfers(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]),
                    mem_pattern_list,
                    write_access_pattern,
                    bits_per_elem=dims.wgt_bits
                )
            ],
            mm2s_transfers,
            sync_strategy=SyncStrategy.Parallel_1_to_N
        )
        data_transfers.append(per_col_data_transfer)
    return data_transfers


def generate_wgt_memtile_data_transfers(
    dims: ConvDims,
    conv_l2_alloc: ConvL2MemoryAllocator,
    conv_repeats: ConvDataFlowRepeats,
    wgt_L2_strategy: str = 'stream',
    is_Co_depad: bool = False,
    gemm_mode: str = 'wgt',
    weights_channel_list: List[int] = [1],
) -> List[DataTransfer]:
    '''Generate WGT memory tile data transfers'''
    if gemm_mode == 'act':
        return generate_wgt_memtile_data_transfers_act(
            dims, conv_l2_alloc, conv_repeats, wgt_L2_strategy, is_Co_depad
        )
    if gemm_mode == 'wgt':
        return generate_wgt_memtile_data_transfers_wgt(
            dims, conv_l2_alloc, conv_repeats, wgt_L2_strategy, is_Co_depad,
            weights_channel_list
        )
    raise ValueError(f"Unknown gemm_mode: {gemm_mode}. Supported modes are 'act' and 'wgt'.")


@no_type_check
def generate_wgt_shimtile_data_transfers_wgt(
    dims: ConvDims,
    conv_repeats: ConvDataFlowRepeats,
    wgt_xrt_idx: int,
    wgt_xrt_offset: int,
    wgt_L2_strategy: str = 'stream',
    is_Co_depad: bool = False,
    weights_channel_list: List[int] = [1],    # NOTE: Default to use only mm2s1 for weights
) -> List[DataTransfer]:
    '''Generate WGT shim tile data transfers'''
    # NOTE: We intend to use 2 channels to bring in weights per column
    # Specifically mm2s1 and mm2s2 channels as they fall into 2 sepeate external memory controllers
    # There by maximizing parallel fetchs and DDR BW usage
    data_transfers = []
    for col in range(dims.aie_cols):
        mem_fmt = wgt_shim_memory(dims, 'wgt')
        log(f"INFO: WGT shim tile memory pattern: {mem_fmt}")
        ddr_wgt_size = compute_buffer_size(mem_fmt)
        for channel_idx, channel in enumerate(weights_channel_list):
            # Full WGT L2 memory transfer
            read_fmt = wgt_shimtile_mm2s(dims, col, is_Co_depad, wgt_L2_strategy, 'wgt', weights_channel_list, channel_idx)
            assert len(read_fmt) == len(conv_repeats.wgt_L3_mm2s_repeats[col]), \
                f"Column {col} fmts and repeats length mismatch: " \
                f"{len(read_fmt)} != {len(conv_repeats.wgt_L3_mm2s_repeats[col])}"
            data_transfers.append(DataTransfer(
                conv_repeats.wgt_L3_mm2s_repeats[col],
                shim_tile(col), [wgt_xrt_idx], ddr_wgt_size,
                [],
                [
                    pack_reconfig_transfers(
                        shim_dma(col, DmaDir.MM2S, channel),
                        [mem_fmt for _ in range(len(conv_repeats.wgt_L3_mm2s_repeats[col]))],
                        read_fmt,
                        buffer_offset=[wgt_xrt_offset]
                    )
                ]
            ))
    return data_transfers


@no_type_check
def generate_wgt_shimtile_data_transfers_act(
    dims: ConvDims,
    conv_repeats: ConvDataFlowRepeats,
    conv_shim: L3Alloc,
    wgt_L2_strategy: str = 'stream',
    is_Co_depad: bool = False,
) -> List[DataTransfer]:
    '''Generate WGT shim tile data transfers'''
    data_transfers = []
    mem_pattern = wgt_shim_memory(dims, 'act')
    ddr_wgt_size = compute_buffer_size(mem_pattern)
    log(f"INFO: WGT shim tile memory pattern: {mem_pattern}")
    conv_shim_wgt_xrt_offset = conv_shim.ifm_xrt_offset[dims.ifm_to_xrt_idx["ifm1"]]
    for col in range(dims.aie_cols):
        read_access_pattern = {}
        # Full WGT L2 memory transfer
        read_fmt = []
        read_fmt = wgt_shimtile_mm2s(dims, col, is_Co_depad, wgt_L2_strategy, 'act')
        log(f"DEBUG: WGT shim tile MM2S col {col} read_fmt {read_fmt}")
        assert len(read_fmt) == len(conv_repeats.wgt_L3_mm2s_repeats[col]), \
            f"Column {col} access patterns and repeats length mismatch: " \
            f"{len(read_fmt)} != {len(conv_repeats.wgt_L3_mm2s_repeats[col])}"
        read_access_pattern[col] = read_fmt

        try:
            transfer_params_list = pack_reconfig_transfers(
                    shim_dma(col, DmaDir.MM2S, 1),
                    [mem_pattern for _ in range(len(conv_repeats.wgt_L3_mm2s_repeats[col]))],
                    read_access_pattern[col], buffer_offset=[conv_shim_wgt_xrt_offset]*len(conv_repeats.wgt_L3_mm2s_repeats[col]),
                    bits_per_elem=dims.wgt_bits,
                 )
            rc_list = conv_repeats.wgt_L3_mm2s_repeats[col]
        except Exception:  # pylint: disable=broad-exception-caught
            _, transfer_params_list = pack_reconfig_transfers(
                    shim_dma(col, DmaDir.MM2S, 1),
                    [mem_pattern for _ in range(len(conv_repeats.wgt_L3_mm2s_repeats[col]))],
                    read_access_pattern[col], buffer_offset=[conv_shim_wgt_xrt_offset]*len(conv_repeats.wgt_L3_mm2s_repeats[col]),
                    bits_per_elem=dims.wgt_bits,
                    use_iter_step=[True]*len(conv_repeats.wgt_L3_mm2s_repeats[col]),
                 )
            rc_list = [i*dims.Co_loop for i in conv_repeats.wgt_L3_mm2s_repeats[col]]
        per_col_data_transfer = DataTransfer(
            rc_list,
            shim_tile(col), [conv_shim.ifm_xrt_idx[1]], ddr_wgt_size,
            [],
            [
                transfer_params_list
            ]
        )
        data_transfers.append(per_col_data_transfer)
    return data_transfers


def generate_wgt_shimtile_data_transfers(
    dims: ConvDims,
    conv_repeats: ConvDataFlowRepeats,
    conv_shim: ShimAllocator,
    wgt_L2_strategy: str = 'stream',
    is_Co_depad: bool = False,
    gemm_mode: str = 'wgt',
    weights_channel_split: List[int] = [1],     # NOTE: DEfault used mm2s1 only for weights
) -> List[DataTransfer]:
    '''Generate WGT shim tile data transfers'''
    if gemm_mode == 'act':
        return generate_wgt_shimtile_data_transfers_act(
            dims, conv_repeats, conv_shim, wgt_L2_strategy, is_Co_depad
        )
    if gemm_mode == 'wgt':
        return generate_wgt_shimtile_data_transfers_wgt(
            dims, conv_repeats, conv_shim.wgt_xrt_idx, conv_shim.wgt_xrt_offset, wgt_L2_strategy, is_Co_depad, weights_channel_split
        )
    raise ValueError(f"Unknown gemm_mode: {gemm_mode}. Supported modes are 'act' and 'wgt'.")


#####################################################
# OFM memory and tiling Formats
#####################################################


def ofm_memtile_memory(dims: ConvDims, L2_strategy: str = 'stream') -> str:
    '''Define OFM L2 data order and shape'''
    if L2_strategy == 'stream':
        Yo_idxs = Yo_split_size(dims)
        Xo_idxs = Xo_split_size(dims)
        Co_idxs = Co_split_size(dims)
        Yo_shard = Yo_idxs * dims.Yos
        Xo_shard = Xo_idxs * dims.Xos
        Co_shard = Co_idxs * dims.Cos
        return f'Yo:{Yo_shard} Xo:{Xo_shard} Co:{Co_shard}'
    raise ValueError(f"Unknown L2_strategy: {L2_strategy}. "
                     "Supported strategies are 'stream'.")


def ofm_memtile_s2mm(dims: ConvDims, col: int, row: int, is_Co_depad: bool, L2_strategy: str = 'stream', gemm_mode: str = 'wgt') -> list[str]:
    '''Define OFM L2 S2MM data order and shape'''
    res = []
    if L2_strategy == 'stream':
        for _ in range(dims.Y_loop):
            for _ in range(dims.X_loop):
                # Collect Yo and Xo ranges for all cores in this column
                Yos_start = Yo_split_offset(dims, col, row) * dims.Yos
                Yos_stop = Yos_start + dims.Yos
                Xos_start = Xo_split_offset(dims, col, row) * dims.Xos
                Xos_stop = Xos_start + dims.Xos
                Cos_start = Co_split_offset(dims, col, row) * dims.Cos
                Cos_stop = Cos_start + dims.Cos
                if dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran and gemm_mode == 'act':
                    for _ in range(dims.Co_loop):
                        res.append('Co:0:0')
                        res.append(f'Yo:{Yos_start}:{Yos_stop} Xo:{Xos_start}:{Xos_stop} Co:{Cos_start}:{Cos_stop}')
                else:
                    res.append(f'Yo:{Yos_start}:{Yos_stop} Xo:{Xos_start}:{Xos_stop} Co:{Cos_start}:{Cos_stop}')
                    if is_Co_depad:
                        # Depad phase results in a additional phase.
                        # But the eact same tiling fmt is used as we frain full subvol to L2
                        res.append(f'Yo:{Yos_start}:{Yos_stop} Xo:{Xos_start}:{Xos_stop} Co:{Cos_start}:{Cos_stop}')
        log(f"INFO: OFM memtile S2MM fmt col {col} row {row}: {res}")
        return res
    raise ValueError(f"Unknown L2_strategy: {L2_strategy}. "
                     "Supported strategies are 'stream'.")


def ofm_memtile_mm2s(
    dims: ConvDims,
    col: int,
    is_Co_depad: bool,
    L2_strategy: str = 'stream',
    gemm_mode: str = 'wgt'
) -> List[str]:
    '''Define OFM L2 MM2S data order and shape'''
    if L2_strategy != 'stream':
        raise ValueError(f"Unknown L2_strategy: {L2_strategy}. Supported strategies are 'stream'")
    read_fmt = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            _, _, yo_shard_size = Yo_slice_per_column(dims, col, y_iter)
            _, _, xo_shard_size = Xo_slice_per_column(dims, col, x_iter)
            if gemm_mode == 'act' and dims.Ci % dims.Cis != 0 and dims.Ci > dims.Ci_gran:
                # Ci padding case: add phases for each Co iteration
                for co_iter in range(dims.Co_loop):
                    read_fmt.append('Co:0:0')  # Dummy format for padding phase
                    _, _, co_shard_size = Co_slice_per_column(dims, col, co_iter)
                    read_fmt.append(f'Yo:0:{yo_shard_size} Xo:0:{xo_shard_size} Co:0:{co_shard_size}')
            else:
                # Standard case: single format per iteration
                _, _, co_shard_size = Co_slice_per_column(dims, col, 0)
                read_fmt.append(f'Yo:0:{yo_shard_size} Xo:0:{xo_shard_size} Co:0:{co_shard_size}')
                if is_Co_depad:
                    # Add depadding phase with last Co iteration size
                    _, _, co_shard_size_last = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                    read_fmt.append(f'Yo:0:{yo_shard_size} Xo:0:{xo_shard_size} Co:0:{co_shard_size_last}')

    log(f"INFO: OFM memtile MM2S fmt for col {col}: {read_fmt}")
    return read_fmt


def ofm_shimtile_s2mm(
    dims: ConvDims,
    col: int,
    is_Co_depad: bool,
    L2_strategy: str = 'stream',
    gemm_mode: str = 'wgt'
) -> List[str]:
    '''Define OFM shim S2MM data order and shape'''
    if L2_strategy != 'stream':
        raise ValueError(f"Unknown L2_strategy: {L2_strategy}. Supported strategies are 'stream'")
    write_fmt = []
    has_ci_padding_phases = (dims.Ci % dims.Cis != 0 and
                             dims.Ci > dims.Ci_gran and
                             gemm_mode == 'act')
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            yo_start, yo_stop, _ = Yo_slice_per_column(dims, col, y_iter)
            xo_start, xo_stop, _ = Xo_slice_per_column(dims, col, x_iter)
            if has_ci_padding_phases:
                # Ci padding case: add phases for each Co iteration
                for co_iter in range(dims.Co_loop):
                    write_fmt.append('Co:0:0')  # Dummy format for padding phase
                    co_start, co_stop, _ = Co_slice_per_column(dims, col, co_iter)
                    write_fmt.append(
                        f'Yo:{yo_start}:{yo_stop} '
                        f'Xo:{xo_start}:{xo_stop} '
                        f'Co:{co_start}:{co_stop}'
                    )
            else:
                # Standard case: calculate Co range and stride
                co_split_size = Co_split_size(dims)
                full_iters = dims.Co_loop - (2 if is_Co_depad else 1)
                # Main Co transfer with striding
                co_start, _, _ = Co_slice_per_column(dims, col, 0)
                _, co_stop, _ = Co_slice_per_column(dims, col, full_iters)
                co_stride = dims.Co_split * dims.Cos
                cos_stride = min(co_stop - co_start, co_split_size * dims.Cos)
                write_fmt.append(
                    f'Co:{co_start}:{co_stop}:{co_stride} '
                    f'Yo:{yo_start}:{yo_stop} '
                    f'Xo:{xo_start}:{xo_stop} '
                    f'Co:0:{cos_stride}'
                )
                # Depadding phase if needed
                if is_Co_depad:
                    co_start, co_stop, _ = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                    write_fmt.append(
                        f'Yo:{yo_start}:{yo_stop} '
                        f'Xo:{xo_start}:{xo_stop} '
                        f'Co:{co_start}:{co_stop}'
                    )
    return write_fmt


@no_type_check
def generate_ofm_memtile_data_transfers(
    dims: ConvDims,
    conv_l2_alloc: convL2Memory,
    conv_repeats: ConvDataFlowRepeats,
    ofm_L2_strategy: str = 'stream',
    is_Co_depad: bool = False,
    gemm_mode: str = 'wgt'
) -> List[DataTransfer]:
    '''Generate OFM memory tile data transfers'''
    data_transfers = []
    log(f"Co blocks per column: {[Co_split_size(dims) for _ in range(dims.aie_cols)]}")
    log(f"INFO OFM L2 conv repeats.ofm_L2_s2mm_repeats: {conv_repeats.ofm_L2_s2mm_repeats}")
    mem_fmt = ofm_memtile_memory(dims, ofm_L2_strategy)
    for col in range(dims.aie_cols):
        # Full OFM L2 memory transfer
        write_tiling_fmt = {}
        for row in range(dims.aie_rows):
            write_fmt_per_row = ofm_memtile_s2mm(dims, col, row, is_Co_depad, ofm_L2_strategy, gemm_mode)
            log(f"INFO: OFM L2 S2MM fmt for column {col}, row {row}: write_fmt : {write_fmt_per_row}")
            # Check if the number of phases matches the number of fmts
            assert len(write_fmt_per_row) == len(conv_repeats.ofm_L2_s2mm_repeats[col]), \
                f"OFM L2 Column {col} row{row} write fmts and repeats length mismatch: " \
                f"{len(write_fmt_per_row)} != {len(conv_repeats.ofm_L2_s2mm_repeats[col])}"
            write_tiling_fmt[row] = write_fmt_per_row

        read_tiling_fmt = ofm_memtile_mm2s(dims, col, is_Co_depad, ofm_L2_strategy, gemm_mode)
        assert len(read_tiling_fmt) == len(conv_repeats.ofm_L2_mm2s_repeats[col]), \
            f"OFM L2 Column {col}, read fmts and repeats length mismatch: " \
            f"fmt_list length {len(read_tiling_fmt)} != repeat_list length {len(conv_repeats.ofm_L2_mm2s_repeats[col])}"
        log(f"INFO: OFM L2 S2MM fmt for column {col}: write_fmt : {write_tiling_fmt[row]}")
        log(f"INFO: OFM L2 MM2S fmt for column {col}: read_fmt : {read_tiling_fmt}")

        per_col_data_transfer = DataTransfer(
            conv_repeats.ofm_L2_s2mm_repeats[col],
            memory_tile(col), [conv_l2_alloc.ofm_ping_addr, conv_l2_alloc.ofm_pong_addr], conv_l2_alloc.ofm_size,
            [
                pack_reconfig_transfers(
                    memtile_dma(col, DmaDir.S2MM, channel_id),
                    [mem_fmt for _ in range(len(conv_repeats.ofm_L2_s2mm_repeats[col]))],
                    write_tiling_fmt[row],
                    bits_per_elem=dims.ofm_bits
                ) for row, channel_id in ofm_memtile_channels()
            ],
            [
                pack_reconfig_transfers(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[0]),
                    [mem_fmt for _ in range(len(read_tiling_fmt))],
                    read_tiling_fmt,
                    bits_per_elem=dims.ofm_bits
                )
            ],
            sync_strategy=SyncStrategy.Parallel_N_to_1,
        )
        data_transfers.append(per_col_data_transfer)
    return data_transfers


@no_type_check
def generate_ofm_shimtile_data_transfers(
    dims: ConvDims,
    conv_repeats: ConvDataFlowRepeats,
    conv_shim: ShimAllocator,
    is_Co_depad: bool,
    ofm_L2_strategy: str = 'stream',
    gemm_mode: str = 'wgt'
) -> List[DataTransfer]:
    '''Generate OFM shim tile data transfers'''
    data_transfers = []
    mem_fmt = ofm_shim_memory(dims)
    ddr_ofm_size = compute_buffer_size(mem_fmt)
    for col in range(dims.aie_cols):
        # Full OFM L2 memory transfer
        log(f"INFO: OFM shim tile memory fmt: {mem_fmt}")

        write_tiling_fmt = []
        write_tiling_fmt = ofm_shimtile_s2mm(dims, col, is_Co_depad, ofm_L2_strategy, gemm_mode)
        log(f"INFO: OFM shim tile S2MM fmt for column {col}: write_fmt : {write_tiling_fmt}")

        assert len(write_tiling_fmt) == len(conv_repeats.ofm_L3_s2mm_repeats[col]), \
            f"Column {col} fmts and repeats length mismatch: " \
            f"{len(write_tiling_fmt)} != {len(conv_repeats.ofm_L3_s2mm_repeats[col])}"

        data_transfer = DataTransfer(
            conv_repeats.ofm_L3_s2mm_repeats[col],
            shim_tile(col), [conv_shim.ofm_xrt_idx], ddr_ofm_size,
            [
                pack_reconfig_transfers(
                    shim_dma(col, DmaDir.S2MM, 0),
                    [mem_fmt for _ in range(len(conv_repeats.ofm_L3_s2mm_repeats[col]))],
                    write_tiling_fmt,
                    bits_per_elem=dims.ofm_bits,
                    buffer_offset=[conv_shim.ofm_xrt_offset]
                )
            ],
            []
        )
        data_transfers.append(data_transfer)
    return data_transfers


# NOTE: Mypy doesn't correctly infer types with the abstract base construct
# used by the core instruction list, so we disable type checking locally here
@no_type_check
def compile_L3_dataflow(schedule_input: ScheduleInputs) -> tuple:
    '''Compile the L3 dataflow for the given shape and mapping'''
    shape: ConvShape = schedule_input.shape
    mapping: ConvMapping = schedule_input.mapping
    L3_alloc: L3Alloc = schedule_input.L3_alloc
    aie_cols = 3
    aie_rows = 4
    # NOTE: THis is the list default channels to use from shim mm2s for fetch weights
    # We use only 1 channel to stream in weights as default
    # The list MUST contain the list of channel numbers
    weights_channel_list = [1]
    if shape.linear_op_type in [LinearOpType.gemm_A16A16_v2, LinearOpType.gemm_A16A16_v1]:
        gemm_mode = 'act'
    else:
        gemm_mode = 'wgt'

    overlay_shape = OverlayShape(aie_cols, aie_rows)
    dims = ConvDims(shape, mapping)
    ifm_L2_stragety = mapping.ifm_L2_strategy
    wgt_L2_stragety = 'pin' if dims.Ci != dims.Ci_orig and not dims.transpose_wgts else 'stream'
    ofm_L2_stragety = mapping.ofm_L2_strategy

    if gemm_mode == 'wgt':
        # Check if there are more than 1 Co blocks per column
        # In case of convs and gemms, if Co blocks per column is 2 or 4
        # we use shim mm2s channels 1, 2 to fetch weights
        if dims.Co_split in [2, 4, 6, 12]:
            weights_channel_list = [1, 2]
    # NOTE: Check if L3 allocation is provided
    conv_shim = L3Alloc_to_Shim(L3_alloc)
    log(f"Shim Allocator: {conv_shim}")
    # Check if Co_depad is needed as a new phase
    # This is used to re-align the phases and the fmts
    is_Co_depad = (dims.Co < dims.Co_loop * dims.Co_split * dims.Cos) and (dims.Co_loop > 1)
    # WIth pinning of IFM, Ci_pad can be handled with chaining
    # WIll not incur additional phase
    # Not necessary to re-align the phases and the fmts
    is_Ci_padded = dims.Ci < dims.Ci_loop * dims.Cis
    is_Ci_pad_phase = is_Ci_padded and dims.Ci_loop > 1
    log(f"INFO: is_Ci_padded: {is_Ci_padded}")
    log(f"INFO: is_Ci_pad_phase: {is_Ci_pad_phase}")
    # IF wegith repeat is too high, we need to split the WGT L2 repeats into multiple phases
    # This is used to re-align the phases and the fmts
    is_wgt_repeat_high = dims.Ci_loop * dims.Co_loop > 32768
    # NOTE: The ifm memtile memory returns a list of memory format for each phase.
    # As the amount of data held in L2 varies with the phase.
    # Here we find the memory format that requires the largest buffer size
    # and use that format for buffer allocation
    ifm_L2_mem_fmt_for_buff_alloc = ''
    ifm_L2_size = 0
    for col in range(dims.aie_cols):
        ifm_memtile_mem_fmts = ifm_memtile_memory(dims, col, ifm_L2_stragety)
        for fmt in ifm_memtile_mem_fmts:
            ifm_L2_size = max(ifm_L2_size, compute_buffer_size(fmt, dims.ifm_bits))
            ifm_L2_mem_fmt_for_buff_alloc = fmt if ifm_L2_size == compute_buffer_size(fmt, dims.ifm_bits) else ifm_L2_mem_fmt_for_buff_alloc

    l2_allocator = ConvL2MemoryAllocator(
        dims,
        prm_memtile_memory(),
        ifm_L2_mem_fmt_for_buff_alloc,
        wgt_memtile_memory(dims, is_Co_depad, wgt_L2_stragety, gemm_mode)[0],
        ofm_memtile_memory(dims, ofm_L2_stragety),
        ifm_double_buffer=(ifm_L2_stragety == 'stream'),
        wgt_double_buffer=(wgt_L2_stragety == 'stream'),
        ofm_double_buffer=(ofm_L2_stragety == 'stream'),
        gemm_mode=gemm_mode
    )
    conv_l2_alloc = l2_allocator

    # log the L2 memory allocation details
    log(f"INFO conv_l2_alloc.param_addr: {conv_l2_alloc.param_addr} conv_l2_alloc.param_size: {conv_l2_alloc.param_size}")
    log(f"INFO conv_l2_alloc.ifm_ping_addr: {conv_l2_alloc.ifm_ping_addr} conv_l2_alloc.ifm_L2_size: {conv_l2_alloc.ifm_size}")
    log(f"INFO conv_l2_alloc.ifm_pong_addr: {conv_l2_alloc.ifm_pong_addr} conv_l2_alloc.ifm_L2_size: {conv_l2_alloc.ifm_size}")
    log(f"INFO conv_l2_alloc.wgt_ping_addr: {conv_l2_alloc.wgt_ping_addr} conv_l2_alloc.wgt_L2_size: {conv_l2_alloc.wgt_size}")
    log(f"INFO conv_l2_alloc.wgt_pong_addr: {conv_l2_alloc.wgt_pong_addr} conv_l2_alloc.wgt_L2_size: {conv_l2_alloc.wgt_size}")
    log(f"INFO conv_l2_alloc.ofm_ping_addr: {conv_l2_alloc.ofm_ping_addr} conv_l2_alloc.ofm_L2_size: {conv_l2_alloc.ofm_size}")
    log(f"INFO conv_l2_alloc.ofm_pong_addr: {conv_l2_alloc.ofm_pong_addr} conv_l2_alloc.ofm_L2_size: {conv_l2_alloc.ofm_size}")
    dma_connections = overlay_3x4_dma_connections()
    data_stream_mode = split_to_mode(dims)
    if data_stream_mode == 0:
        log("INFO: IFM unicast / WGT broadcast mode")
    elif data_stream_mode == 1:
        log("INFO: IFM broadcast / WGT unicast mode")
    else:
        raise ValueError(f"Unknown data stream mode: {data_stream_mode}")
    log(f"INFO: ConvMapping: {mapping}")
    log(f"Compiling L3 dataflow for dims: {dims}")
    log(f"IFM L2 strategy: {ifm_L2_stragety}")
    log(f"WGT L2 strategy: {wgt_L2_stragety}")
    log(f"OFM L2 strategy: {ofm_L2_stragety}")
    conv_repeats = generate_conv_repeats(dims,
                                         ifm_L2_stragety,
                                         wgt_L2_stragety,
                                         ofm_L2_stragety,
                                         is_Co_depad,
                                         is_wgt_repeat_high,
                                         gemm_mode)
    # log the conv_repeats for debugging
    for col in range(dims.aie_cols):
        log(f"INFO conv_repeats.ifm_L2_s2mm_repeats[{col}]: {conv_repeats.ifm_L2_s2mm_repeats[col]}")
        log(f"INFO conv_repeats.ifm_L2_mm2s_repeats[{col}]: {conv_repeats.ifm_L2_mm2s_repeats[col]}")
        log(f"INFO conv_repeats.wgt_L2_s2mm_repeats[{col}]: {conv_repeats.wgt_L2_s2mm_repeats[col]}")
        log(f"INFO conv_repeats.wgt_L2_mm2s_repeats[{col}]: {conv_repeats.wgt_L2_mm2s_repeats[col]}")
        log(f"INFO conv_repeats.ofm_L2_s2mm_repeats[{col}]: {conv_repeats.ofm_L2_s2mm_repeats[col]}")
        log(f"INFO conv_repeats.ofm_L2_mm2s_repeats[{col}]: {conv_repeats.ofm_L2_mm2s_repeats[col]}")
        log(f"INFO conv_repeats.ifm_L3_mm2s_repeats[{col}]: {conv_repeats.ifm_L3_mm2s_repeats[col]}")
        log(f"INFO conv_repeats.wgt_L3_mm2s_repeats[{col}]: {conv_repeats.wgt_L3_mm2s_repeats[col]}")
        log(f"INFO conv_repeats.ofm_L3_s2mm_repeats[{col}]: {conv_repeats.ofm_L3_s2mm_repeats[col]}")
    core_instrs = {}
    # Determine full_iters based on gemm_mode and is_Co_depad
    full_iters = True
    if gemm_mode == 'act':
        # For gemm_mode == act, full_iters is always True
        full_iters = True
    elif gemm_mode == 'wgt':
        full_iters = (shape.linear_op_type != LinearOpType.conv_A8W8_noqdq) and not is_Co_depad and not is_Ci_pad_phase
        if dims.Xi//dims.Xis >= 36:
            full_iters = False
    log(f"INFO: full_iters set to {full_iters}")
    for col in range(dims.aie_cols):
        for row in range(dims.aie_rows):
            core_instrs[core_tile(col, row)] = generate_conv_core_instrs(
                dims, mapping, shape.linear_op_type, col, row, full_iters, gemm_mode)
    # log the L2 memory allocation details
    log(f"INFO conv_l2_alloc.param_addr: {conv_l2_alloc.param_addr} "
        f"conv_l2_alloc.param_size: {conv_l2_alloc.param_size}")
    log(f"INFO conv_l2_alloc.ifm_ping_addr: {conv_l2_alloc.ifm_ping_addr} "
        f"conv_l2_alloc.ifm_L2_size: {conv_l2_alloc.ifm_size}")
    log(f"INFO conv_l2_alloc.ifm_pong_addr: {conv_l2_alloc.ifm_pong_addr} "
        f"conv_l2_alloc.ifm_L2_size: {conv_l2_alloc.ifm_size}")
    log(f"INFO conv_l2_alloc.wgt_ping_addr: {conv_l2_alloc.wgt_ping_addr} "
        f"conv_l2_alloc.wgt_L2_size: {conv_l2_alloc.wgt_size}")
    log(f"INFO conv_l2_alloc.wgt_pong_addr: {conv_l2_alloc.wgt_pong_addr} "
        f"conv_l2_alloc.wgt_L2_size: {conv_l2_alloc.wgt_size}")
    log(f"INFO conv_l2_alloc.ofm_ping_addr: {conv_l2_alloc.ofm_ping_addr} "
        f"conv_l2_alloc.ofm_L2_size: {conv_l2_alloc.ofm_size}")
    log(f"INFO conv_l2_alloc.ofm_pong_addr: {conv_l2_alloc.ofm_pong_addr} "
        f"conv_l2_alloc.ofm_L2_size: {conv_l2_alloc.ofm_size}")

    memtile_transfers = [
        DataTransfer(
            [1] + [0] * (len(conv_repeats.ifm_L2_s2mm_repeats[col])-1),
            memory_tile(col), [conv_l2_alloc.param_addr], conv_l2_alloc.param_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[2]),
                prm_memtile_memory(),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                prm_memtile_memory(),
                prm_memtile_mm2s(row),
            ) for row in range(aie_rows)],
        ) for col in range(aie_cols)
    ]

    shimtile_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (len(conv_repeats.ifm_L2_s2mm_repeats[col])-1),
            shim_dma(col, DmaDir.MM2S, 2), conv_shim.prm_xrt_idx,
            prm_shim_memory(),
            prm_shim_mm2s(col),
            buffer_offset=conv_shim.prm_xrt_offset
        ) for col in range(aie_cols)
    ]

    if gemm_mode == 'act':
        memtile_qdq_transfers = [
            DataTransfer(
                [1] + [0] * (len(conv_repeats.ifm_L2_s2mm_repeats[col])-1),
                memory_tile(col), [conv_l2_alloc.qdq_addr], conv_l2_alloc.qdq_size,
                [generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]),
                    qdq_param_memtile_memory(dims),
                    qdq_param_memtile_s2mm(dims),
                )],
                [
                    generate_transfer_params(
                        memtile_dma(col, DmaDir.MM2S, overlay_3x4_B_ids(col)[0]),
                        qdq_param_memtile_memory(dims),
                        qdq_param_memtile_mm2s(dims, 0, 0),
                    ),
                    generate_transfer_params(
                        memtile_dma(col, DmaDir.MM2S, overlay_3x4_B_ids(col)[1]),
                        qdq_param_memtile_memory(dims),
                        qdq_param_memtile_mm2s(dims, 0, 0),
                    ),
                ],
                sync_strategy=SyncStrategy.Parallel_1_to_N,
            ) for col in range(aie_cols)
        ]

        shimtile_qdq_transfers = [
            generate_shim_data_transfer(
                [1] + [0] * (len(conv_repeats.ifm_L2_s2mm_repeats[col])-1),
                shim_dma(col, DmaDir.MM2S, 1), conv_shim.wgt_xrt_idx,
                qdq_param_shim_memory(dims),
                qdq_param_shim_mm2s(dims, col),
                buffer_offset=conv_shim.wgt_xrt_offset,
            ) for col in range(aie_cols)
        ]

    ifm_l2_transfers = generate_ifm_memtile_data_transfers(
        dims,
        conv_l2_alloc,
        conv_repeats,
        ifm_L2_stragety,
        is_Co_depad,
        gemm_mode
    )

    wgt_l2_transfers = generate_wgt_memtile_data_transfers(
        dims,
        conv_l2_alloc,
        conv_repeats,
        wgt_L2_stragety,
        is_Co_depad,
        gemm_mode,
        weights_channel_list,
    )

    ifm_l3_transfers = generte_ifm_shimtile_data_transfers(
        dims,
        conv_repeats,
        conv_shim,
        is_Co_depad,
        gemm_mode
    )

    wgt_l3_transfers = generate_wgt_shimtile_data_transfers(
        dims,
        conv_repeats,
        conv_shim,
        wgt_L2_stragety,
        is_Co_depad,
        gemm_mode,
        weights_channel_list,
    )

    ofm_l2_transfers = generate_ofm_memtile_data_transfers(
        dims,
        conv_l2_alloc,
        conv_repeats,
        ofm_L2_stragety,
        is_Co_depad,
        gemm_mode
    )
    ofm_l3_transfers = generate_ofm_shimtile_data_transfers(
        dims,
        conv_repeats,
        conv_shim,
        is_Co_depad,
        ofm_L2_stragety,
        gemm_mode
    )
    if gemm_mode == 'act':
        memtile_transfers += memtile_qdq_transfers
        shimtile_transfers += shimtile_qdq_transfers
    memtile_transfers += ifm_l2_transfers
    memtile_transfers += wgt_l2_transfers
    shimtile_transfers += ifm_l3_transfers
    shimtile_transfers += wgt_l3_transfers
    memtile_transfers += ofm_l2_transfers
    shimtile_transfers += ofm_l3_transfers

    run_layer_compilation(
        overlay_shape,
        schedule_input.kernel_names,
        schedule_input.kernel_includes,
        core_instrs,
        memtile_transfers,
        shimtile_transfers,
        dma_connections,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        param_channel_id=overlay_3x4_param_channel_id(),
        back_end=schedule_input.backend,
        layer_file=schedule_input.layer_file_name,
        dma_padding_map=schedule_input.dma_pad,
    )

    if gemm_mode == 'act':
        wgt_shim_size = dims.qdq_param_size
    else:
        wgt_shim_size = compute_buffer_size(wgt_shim_memory(dims))
    prm_shim_size = compute_buffer_size(prm_shim_memory())
    log(f"wgt_shim_size: {wgt_shim_size}")
    log(f" prm_shim_size: {prm_shim_size}")
    shim_prm_offset_next_layer = conv_shim.prm_xrt_offset + prm_shim_size
    shim_wgt_offset_next_layer = conv_shim.wgt_xrt_offset + wgt_shim_size
    log("shim_prm_offset_next_layer", shim_prm_offset_next_layer)
    log("shim_wgt_offset_next_layer", shim_wgt_offset_next_layer)

    return shim_prm_offset_next_layer, shim_wgt_offset_next_layer
