'''
Map conv shapes to the AIE-4 dataflow architecture.
External facing functions are documented below.

    compile_DWC_L3_dataflow - given a shape and mapping, compile the data movement
    code to orchestrate the DMA heirarchy
'''
from typing import no_type_check
from typing import List

from dmacompiler import (
    DevGen, set_dev_gen, SyncStrategy,
    OverlayShape, DataTransfer,
    DmaDir,
    memtile_dma, shim_dma, core_tile, memory_tile,
    shim_tile,
    run_layer_compilation,
    generate_transfer_params,
    generate_shim_data_transfer,
    compute_buffer_size,
    pack_reconfig_transfers,
)

from utils.utils_common import (
    overlay_3x4_core_stack_addr,
    log,
    split_to_mode,
)

from scheduler.common import (
    overlay_3x4_dma_connections,
    overlay_3x4_param_channel_id,
    overlay_3x4_A_ids,
    overlay_3x4_F_ids,
    overlay_3x4_O_ids,
    overlay_3x4_S_ids,
    prm_memtile_memory,
    prm_shim_memory,
    prm_memtile_mm2s,
    prm_memtile_s2mm,
    prm_shim_mm2s,
    unicast_channels,
    broadcast_channels,
    L3Alloc,
    ShimAllocator,
    L3Alloc_to_Shim,
    TensorSlicer
)

from scheduler.conv.conv_config_builders import (
    ConvShape,
    ConvMapping,
    ConvDims,
)

from scheduler.conv.conv_common import (
    ConvDataFlowRepeats,
    convL2Memory,
    ifm_shim_memory,
    ofm_shim_memory,
    Yi_slice,
    Xi_slice,
    Co_split_idxs,
    Co_split_size,
    Co_split_offset,
    Yo_split_size,
    Xo_split_size,
    Yo_split_offset,
    Xo_split_offset,
    Yi_slice_per_column,
    Yo_slice_per_column,
    Xi_slice_per_column,
    Xo_slice_per_column,
    Co_slice_per_column,
)

from scheduler.dwc.dwc_common import (
    DwcL2MemoryAllocator,
    dwc_wgt_shim_memory,
    generate_dwc_core_instrs,
)

from buildscripts.common import ScheduleInputs

set_dev_gen(DevGen.Aie4)


def generate_dwc_repeats(
    dims: ConvDims,
    slicer: TensorSlicer,
    is_repeat_too_high: bool,
) -> ConvDataFlowRepeats:
    '''
    DWC repeats derivation.
    '''
    ifm_L2_mm2s_repeats = {}
    ofm_L2_mm2s_repeats = {}
    wgt_L2_mm2s_repeats = {}
    ifm_L3_mm2s_repeats = {}
    wgt_L3_mm2s_repeats = {}
    ofm_L3_s2mm_repeats = {}
    base_phases = dims.X_loop * dims.Y_loop
    if is_repeat_too_high:
        raise ValueError("repeat too high is not supported yet !!!")
    else:
        for col in range(dims.aie_cols):
            ifm_L2_mm2s_repeats[col] = []
            ofm_L2_mm2s_repeats[col] = []
            wgt_L2_mm2s_repeats[col] = []
            ifm_L3_mm2s_repeats[col] = []
            ofm_L3_s2mm_repeats[col] = []
            wgt_L3_mm2s_repeats[col] = []
            for phase in range(base_phases):   # Each Y and X iters is a unque phase
                # Create the phases and repeat lists for mem tile transfers
                Co_phases = slicer.Co_split_iters()
                log(f"Co_phases: {Co_phases}")
                # check if the 2nd entry is the Co_phases list is non zero
                # if it is non zero, add both the entries to ifm_L2_mm2s_repeats
                # else add only the first entry to ifm_L2_mm2s_repeats
                ifm_L2_mm2s_repeats[col].append(Co_phases[0])
                wgt_L2_mm2s_repeats[col].append(Co_phases[0])
                ofm_L2_mm2s_repeats[col].append(Co_phases[0])
                if Co_phases[-1] != 0:
                    ifm_L2_mm2s_repeats[col].append(Co_phases[1])
                    wgt_L2_mm2s_repeats[col].append(Co_phases[1])
                    ofm_L2_mm2s_repeats[col].append(Co_phases[1])
                # Create the shim repeats
                ifm_L3_mm2s_repeats[col].append(1)
                wgt_L3_mm2s_repeats[col].append(1)
                ofm_L3_s2mm_repeats[col].append(1)
                # IF there are depad phases, add inactive shim phases
                if Co_phases[-1] != 0:
                    ifm_L3_mm2s_repeats[col].append(1)  # C0 pad phase
                    wgt_L3_mm2s_repeats[col].append(0)  # WGT are preformatted and the whole tensor in scheduled in a single phase
                    ofm_L3_s2mm_repeats[col].append(1)  # C0 depad phase

    def validate_repeat_lengths(col: int) -> list[int]:
        """Validate that all repeat arrays for a column have the same length."""
        return [
            len(ifm_L2_mm2s_repeats[col]),
            len(ofm_L2_mm2s_repeats[col]),
            len(wgt_L2_mm2s_repeats[col]),
            len(ifm_L3_mm2s_repeats[col]),
            len(wgt_L3_mm2s_repeats[col]),
            len(ofm_L3_s2mm_repeats[col])
        ]
    for col in range(dims.aie_cols):
        lengths = validate_repeat_lengths(col)
        assert all(length == lengths[0] for length in lengths), \
            f"Column {col} repeat arrays have mismatched lengths: {lengths}"

    return ConvDataFlowRepeats(
        ifm_L2_s2mm_repeats={},
        ifm_L2_mm2s_repeats=ifm_L2_mm2s_repeats,
        ifm_L3_mm2s_repeats=ifm_L3_mm2s_repeats,
        ofm_L2_s2mm_repeats={},
        ofm_L2_mm2s_repeats=ofm_L2_mm2s_repeats,
        ofm_L3_s2mm_repeats=ofm_L3_s2mm_repeats,
        wgt_L2_s2mm_repeats={},
        wgt_L2_mm2s_repeats=wgt_L2_mm2s_repeats,
        wgt_L3_mm2s_repeats=wgt_L3_mm2s_repeats,
    )


def ifm_memtile_channels(dims: ConvDims, col: int) -> list[tuple[int, int]]:
    '''Map spatial split and column to IFM channel allocations (row, id)'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: unicast_channels(),
        1: broadcast_channels(col),
    }
    channels = channel_lookup[mode]
    return channels


def wgt_memtile_channels(dims: ConvDims, col: int) -> list[tuple[int, int]]:
    '''Map spatial split and column to WGT channel allocations (row, id)'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: broadcast_channels(col),
        1: unicast_channels(),
    }
    channels = channel_lookup[mode]
    return channels


def ofm_memtile_channels() -> list[tuple[int, int]]:
    '''Generate OFM channels allocations (row, id)'''
    return list(enumerate(overlay_3x4_O_ids()))


#####################################################
# IFM memory and tiling Formats
#####################################################

def ifm_memtile_memory(dims: ConvDims, col: int, slicer: TensorSlicer) -> List[str]:
    '''Define IFM L2 data order and shape'''
    mem_fmt = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            _, _, Yi_size_col_0 = Yi_slice_per_column(dims, 0, y_iter)
            _, _, Xi_size_col_0 = Xi_slice_per_column(dims, 0, x_iter)
            _, _, Yi_size_col = Yi_slice_per_column(dims, col, y_iter)
            _, _, Xi_size_col = Xi_slice_per_column(dims, col, x_iter)
            # If the column is not working on any input indices
            # Then the size for that column will be returned 0
            # But mem format cannot be a zero value.
            # Which is why the following check defaults to the size of col0
            # Col0 will always have some work as its the first tiled segment.
            Yi_size = Yi_size_col_0 if Yi_size_col == 0 else Yi_size_col
            Xi_size = Xi_size_col_0 if Xi_size_col == 0 else Xi_size_col
            Co_idxs = Co_split_size(dims)
            Co_size = dims.Cos * Co_idxs
            mem_fmt.append(f'Yi:{Yi_size} Xi:{Xi_size} Ci:{Co_size}')
            Co_phases = slicer.Co_split_iters()
            if Co_phases[-1] != 0:
                mem_fmt.append(f'Yi:{Yi_size} Xi:{Xi_size} Ci:{Co_size}')
        return mem_fmt


def ifm_memtile_mm2s(dims: ConvDims, col: int, row: int, slicer: TensorSlicer) -> List[str]:
    '''Define IFM MM2S data order and shape'''
    read_tiling_fmt = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            # Gather the start index at tensor level for each iteration
            yi_shim_shard_start, _, _ = Yi_slice_per_column(dims, col, y_iter)
            xi_shim_shard_start, _, _ = Xi_slice_per_column(dims, col, x_iter)
            # Get the Yi_start at tensor level for the current column, row and iteration
            Yi_start, _, _ = Yi_slice(dims, col, row, y_iter)
            # Find the relative start and stop indices for the
            # Yi and Xi shard pinned within the column
            # Yis_start is the relative start index of the Yi shard within the column
            Yis_start = 0 if Yi_start >= dims.Yi else Yi_start - yi_shim_shard_start
            Yis_stop = Yis_start + dims.Yis
            # Get the Xi_start at tensor level for the current column, row and iteration
            Xi_start, _, _ = Xi_slice(dims, col, row, x_iter)
            # Find the relative start and stop indices for the
            # Yi and Xi shard pinned within the column
            # Xis_start is the relative start index of the Xi shard within the column
            Xis_start = 0 if Xi_start >= dims.Xi else Xi_start - xi_shim_shard_start
            Xis_stop = Xis_start + dims.Xis
            Cob_offset = Co_split_offset(dims, col, row)
            Co_start = Cob_offset * dims.Cos
            Co_stop = Co_start + dims.Cos
            read_tiling_fmt.append(
                f'Ci:{Co_start}:{Co_stop}:{dims.Co_gran} '
                f'Yi:{Yis_start}:{Yis_stop} '
                f'Xi:{Xis_start}:{Xis_stop} '
                f'Ci:0:{dims.Co_gran}'
            )
            Co_phases = slicer.Co_split_iters()
            if Co_phases[-1] != 0:
                # If there is depad / pad phase add another format
                # That accounts for padding in Ci/Co dimension
                read_tiling_fmt.append(
                    f'Ci:{Co_start}:{Co_stop}:{dims.Co_gran} '
                    f'Yi:{Yis_start}:{Yis_stop} '
                    f'Xi:{Xis_start}:{Xis_stop} '
                    f'Ci:0:{dims.Co_gran}'
                )
    if len(read_tiling_fmt) == 0:
        raise ValueError("IFM memtile MM2S tiling format list is empty !!!")
    return read_tiling_fmt


def ifm_memtile_s2mm(dims: ConvDims, col: int, slicer: TensorSlicer) -> List[str]:
    write_tiling_fmt = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            _, _, Yi_shard_size = Yi_slice_per_column(dims, col, y_iter)
            _, _, Xi_shard_size = Xi_slice_per_column(dims, col, x_iter)
            _, _, co_shard_size = Co_slice_per_column(dims, col, 0)
            write_tiling_fmt.append(f'Yi:0:{Yi_shard_size} Xi:0:{Xi_shard_size} Ci:0:{co_shard_size}')
            Co_phases = slicer.Co_split_iters()
            if Co_phases[-1] != 0:
                # If there is CO pad/depad phase
                _, _, co_shard_size_last = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                write_tiling_fmt.append(f'Yi:0:{Yi_shard_size} Xi:0:{Xi_shard_size} Ci:0:{co_shard_size_last}')
    return write_tiling_fmt


def generate_ifm_memtile_data_transfers(
    dims: ConvDims,
    dwc_l2_alloc: convL2Memory,
    dwc_repeats: ConvDataFlowRepeats,
    slicer: TensorSlicer,
) -> List[DataTransfer]:
    '''GEnerate IFM memtile data transfers'''
    data_transfers = []
    log(f"INFO IFM L2 conv repeats.ifm_L2_mm2s_repeats: {dwc_repeats.ifm_L2_mm2s_repeats}")
    for col in range(dims.aie_cols):
        mem_fmt = ifm_memtile_memory(dims, col, slicer)
        assert len(mem_fmt) == len(dwc_repeats.ifm_L2_mm2s_repeats[col]), \
            f"Column {col} IFM L2 memory fmts and repeats length mismatch: " \
            f"{len(mem_fmt)} != {len(dwc_repeats.ifm_L2_mm2s_repeats[col])}"
        write_fmt = ifm_memtile_s2mm(dims, col, slicer)
        # Check if the number of phases matches the number of fmts
        assert len(write_fmt) == len(dwc_repeats.ifm_L2_mm2s_repeats[col]), \
            f"Column {col} fmts and repeats length mismatch: " \
            f"{len(write_fmt)} != {len(dwc_repeats.ifm_L2_mm2s_repeats[col])}"
        # The follwing dictonary captures the tiling format for each row
        read_tiling_fmts = {}
        for row in range(dims.aie_rows):
            per_row_read_fmt = ifm_memtile_mm2s(dims, col, row, slicer)
            assert len(per_row_read_fmt) == len(dwc_repeats.ifm_L2_mm2s_repeats[col]), \
                f"Column: {col}, row: {row}, ifm memtile mm2s phases of length: {len(dwc_repeats.ifm_L2_mm2s_repeats[col])} " \
                f"mismatch with fmts length: {len(per_row_read_fmt)}"
            read_tiling_fmts[row] = per_row_read_fmt
        data_transfers.append(
            DataTransfer(
                dwc_repeats.ifm_L2_mm2s_repeats[col],
                memory_tile(col), [dwc_l2_alloc.ifm_ping_addr, dwc_l2_alloc.ifm_pong_addr], dwc_l2_alloc.ifm_size,
                [
                    pack_reconfig_transfers(
                        memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                        mem_fmt,
                        write_fmt,
                        bits_per_elem=dims.ifm_bits,
                        name=f"IFM_memtile_S2MM_col{col}",
                    )
                ],
                [
                    pack_reconfig_transfers(
                        memtile_dma(col, DmaDir.MM2S, channel_id),
                        mem_fmt,
                        read_tiling_fmts[row],
                        bits_per_elem=dims.ifm_bits,
                        name=f"IFM_memtile_MM2S_col{col}_row{row}",
                    ) for row, channel_id in ifm_memtile_channels(dims, col)
                ],
                sync_strategy=SyncStrategy.Parallel_1_to_N,
            )
        )
    return data_transfers


def ifm_shimtile_mm2s(dims: ConvDims, col: int, slicer: TensorSlicer) -> List[str]:
    read_tiling_fmt = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            Yi_shard_start, Yi_shard_stop, _ = Yi_slice_per_column(dims, col, y_iter)
            Xi_shard_start, Xi_shard_stop, _ = Xi_slice_per_column(dims, col, x_iter)
            co_split_size = Co_split_size(dims)
            Co_phases = slicer.Co_split_iters()
            full_iters = dims.Co_loop - 1 if Co_phases[-1] == 0 else dims.Co_loop - 2
            #  Main Co transfer with striding
            co_start, _, _ = Co_slice_per_column(dims, col, 0)
            _, co_stop, _ = Co_slice_per_column(dims, col, full_iters)
            cos_stride = min(co_stop - co_start, co_split_size * dims.Cos)
            read_tiling_fmt.append(
                f'Ci:{co_start}:{co_stop}:{cos_stride} '
                f'Yi:{Yi_shard_start}:{Yi_shard_stop} '
                f'Xi:{Xi_shard_start}:{Xi_shard_stop} '
                f'Ci:0:{cos_stride}'
            )
            if Co_phases[-1] != 0:
                co_start, co_stop, _ = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                read_tiling_fmt.append(
                    f'Yi:{Yi_shard_start}:{Yi_shard_stop} '
                    f'Xi:{Xi_shard_start}:{Xi_shard_stop} '
                    f'Ci:{co_start}:{co_stop}'
                )
    return read_tiling_fmt


def generate_ifm_shimtile_data_transfers(
    dims: ConvDims,
    dwc_repeats: ConvDataFlowRepeats,
    dwc_shim: ShimAllocator,
    slicer: TensorSlicer,
) -> List[DataTransfer]:
    '''Generate IFM shim tile data transfers'''
    data_transfers = []
    mem_fmt = ifm_shim_memory(dims)
    ddr_ifm_size = compute_buffer_size(mem_fmt)
    for col in range(dims.aie_cols):
        mem_fmt_list = [mem_fmt for _ in range(len(dwc_repeats.ifm_L3_mm2s_repeats[col]))]
        read_fmt = ifm_shimtile_mm2s(dims, col, slicer)
        assert len(read_fmt) == len(dwc_repeats.ifm_L3_mm2s_repeats[col]), \
            f"Column: {col} IFM shim tile MM2S phases of length {len(dwc_repeats.ifm_L3_mm2s_repeats[col])} " \
            f"Does not match read formats of length {len(read_fmt)}"
        data_transfers.append(
            DataTransfer(
                dwc_repeats.ifm_L3_mm2s_repeats[col],
                shim_tile(col), [dwc_shim.ifm_xrt_idx], ddr_ifm_size,
                [
                    # No IFM write
                ],
                [
                    pack_reconfig_transfers(
                        shim_dma(col, DmaDir.MM2S, 0),
                        mem_fmt_list,
                        read_fmt,
                        bits_per_elem=dims.ifm_bits,
                        buffer_offset=[dwc_shim.ifm_xrt_offset],
                        name=f"ifm_shimtile_mm2s_col_{col}"
                    )
                ]
            )
        )
    return data_transfers

#####################################################
# WGT memory and tiling Formats
#####################################################


def dwc_wgt_memtile_memory(dims: ConvDims, col: int, slicer: TensorSlicer) -> list[str]:
    '''Define WGT L2 data order'''
    _ = col
    wgt_mem_fmt = []
    subv_fmt = f'Cob:{Co_split_size(dims)} Subv:{dims.wgt_L1_size}'
    for _ in range(dims.Y_loop):
        for _ in range(dims.X_loop):
            Co_phases = slicer.Co_split_iters()
            wgt_mem_fmt.append(subv_fmt)
            if Co_phases[-1] != 0:
                # partial subvol on Co dimension
                wgt_mem_fmt.append(subv_fmt)
    return wgt_mem_fmt


def dwc_wgt_memtile_s2mm(dims: ConvDims, col: int, slicer: TensorSlicer) -> list[str]:
    '''Define WGT S2MM data order'''
    _ = col
    write_tiling_fmt = []
    Cob_size = Co_split_size(dims)
    subv_fmt = f'Cob:0:{Cob_size} Subv:0:{dims.wgt_L1_size}'
    for _ in range(dims.Y_loop):
        for _ in range(dims.X_loop):
            Co_phases = slicer.Co_split_iters()
            write_tiling_fmt.append(subv_fmt)
            if Co_phases[-1] != 0:
                # partial subvol on Co dimension
                write_tiling_fmt.append(subv_fmt)
    return write_tiling_fmt


def dwc_wgt_memtile_mm2s(dims: ConvDims, col: int, row: int, slicer: TensorSlicer) -> list[str]:
    '''Define WGT MM2S data order'''
    _ = col
    read_tiling_fmt = []
    base_phases = dims.X_loop * dims.Y_loop
    Cob_offset = Co_split_offset(dims, col, row)
    subv_fmt = f'Cob:{Cob_offset}:{Cob_offset + 1} Subv:{0}:{dims.wgt_L1_size}'
    for _ in range(base_phases):
        Co_phases = slicer.Co_split_iters()
        read_tiling_fmt.append(subv_fmt)
        if Co_phases[-1] != 0:
            # partial subvol on Co dimension
            read_tiling_fmt.append(subv_fmt)
    return read_tiling_fmt


def generate_dwc_wgt_memtile_data_transfers(
    dims: ConvDims,
    slicer: TensorSlicer,
    dwc_l2_alloc: convL2Memory,
    dwc_repeats: ConvDataFlowRepeats,
) -> List[DataTransfer]:
    '''Define WGT data transfers'''
    data_transfers = []
    for col in range(dims.aie_cols):
        read_tiling_fmt = {}
        mem_fmt = dwc_wgt_memtile_memory(dims, col, slicer)
        assert len(mem_fmt) == len(dwc_repeats.wgt_L2_mm2s_repeats[0]), \
            f"Column {col} WGT L2 memory fmts and repeats length mismatch: " \
            f"{len(mem_fmt)} != {len(dwc_repeats.wgt_L2_mm2s_repeats[0])}"
        write_tiling_fmt = dwc_wgt_memtile_s2mm(dims, col, slicer)
        assert len(write_tiling_fmt) == len(dwc_repeats.wgt_L2_mm2s_repeats[col]), \
            f"Column {col} fmts and repeats length mismatch: " \
            f"{len(write_tiling_fmt)} != {len(dwc_repeats.wgt_L2_mm2s_repeats[col])}"
        for row in range(dims.aie_rows):
            per_row_fmt = dwc_wgt_memtile_mm2s(dims, col, row, slicer)
            assert len(per_row_fmt) == len(dwc_repeats.wgt_L2_mm2s_repeats[col]), \
                f"Column {col}, row {row} tiling fmts and repeats length mismatch: " \
                f"{len(per_row_fmt)} != {len(dwc_repeats.wgt_L2_mm2s_repeats[col])}"
            read_tiling_fmt[row] = per_row_fmt
        data_transfers.append(
            DataTransfer(
                dwc_repeats.wgt_L2_mm2s_repeats[col],
                memory_tile(col), [dwc_l2_alloc.wgt_ping_addr, dwc_l2_alloc.wgt_pong_addr], dwc_l2_alloc.wgt_size,
                [
                    pack_reconfig_transfers(
                        memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]),
                        mem_fmt,
                        write_tiling_fmt,
                        name=f"DWC_WGT_memtile_s2mm_col_{col}",
                    )
                ],
                [
                    pack_reconfig_transfers(
                        memtile_dma(col, DmaDir.MM2S, channel_id),
                        mem_fmt,
                        read_tiling_fmt[row],
                        name=f"DWC_WGT_memtile_mm2s_col_{col}"
                    ) for row, channel_id in wgt_memtile_channels(dims, col)
                ],
                sync_strategy=SyncStrategy.Parallel_1_to_N,
            )
        )
    return data_transfers


def dwc_wgt_shimtile_mm2s(dims: ConvDims, col: int, slicer: TensorSlicer) -> list[str]:
    '''Define WGT SHIM MM2S data order'''
    _ = slicer
    read_tiling_fmt = []
    for _ in range(dims.Y_loop):
        for _ in range(dims.X_loop):
            Co_idxs = Co_split_idxs(dims, col)
            Cob_start = min(Co_idxs)
            Cob_stop = max(Co_idxs) + 1
            read_tiling_fmt.append(
                f'Cob:0:{dims.Co_loop * dims.Co_split}:{dims.Co_split} '
                f'Cob:{Cob_start}:{Cob_stop} '
                f'Subv:0:{dims.wgt_L1_size}'
            )
            Co_phases = slicer.Co_split_iters()
            if Co_phases[-1] != 0:
                # Dummy phase for depad / pad
                read_tiling_fmt.append(
                    f'Subv:{0}:{0}'
                )
    return read_tiling_fmt


def generate_dwc_wgt_shimtile_data_transfers(
    dims: ConvDims,
    dwc_repeats: ConvDataFlowRepeats,
    dwc_shim: ShimAllocator,
    slicer: TensorSlicer,
) -> List[DataTransfer]:
    '''Generate DWC WGT shim tile data transfers'''
    data_transfers = []
    mem_fmt = dwc_wgt_shim_memory(dims)
    ddr_wgt_size = compute_buffer_size(mem_fmt)
    for col in range(dims.aie_cols):
        mem_fmt_list = [mem_fmt for _ in range(len(dwc_repeats.wgt_L3_mm2s_repeats[col]))]
        read_fmt = dwc_wgt_shimtile_mm2s(dims, col, slicer)
        assert len(read_fmt) == len(dwc_repeats.wgt_L3_mm2s_repeats[col]), \
            f"Column: {col} WGT shim tile MM2S phases of length {len(dwc_repeats.wgt_L3_mm2s_repeats[col])} " \
            f"Does not match read formats of length {len(read_fmt)}"
        data_transfers.append(
            DataTransfer(
                dwc_repeats.wgt_L3_mm2s_repeats[col],
                shim_tile(col), [dwc_shim.wgt_xrt_idx], ddr_wgt_size,
                [
                    # NO IFM WRITES
                ],
                [
                    pack_reconfig_transfers(
                        shim_dma(col, DmaDir.MM2S, 1),
                        mem_fmt_list,
                        read_fmt,
                        buffer_offset=[dwc_shim.wgt_xrt_offset],
                        name=f"wgt_shimtile_mm2s_col_{col}"
                    )
                ]
            )
        )
    return data_transfers


#####################################################
# OFM memory and tiling Formats
#####################################################


def ofm_memtile_memory(dims: ConvDims, col: int, slicer: TensorSlicer) -> List[str]:
    '''Define OFM L2 data order and shape'''
    mem_fmt = []
    Yo_idxs = Yo_split_size(dims)
    Xo_idxs = Xo_split_size(dims)
    Co_idxs = Co_split_size(dims)
    Yo_shard = Yo_idxs * dims.Yos
    Xo_shard = Xo_idxs * dims.Xos
    Co_shard = Co_idxs * dims.Cos
    for _ in range(dims.Y_loop):
        for _ in range(dims.X_loop):
            mem_fmt.append(f'Yo:{Yo_shard} Xo:{Xo_shard} Co:{Co_shard}')
            Co_phases = slicer.Co_split_iters()
            if Co_phases[-1] != 0:
                # partial subvol on Co dimension
                mem_fmt.append(f'Yo:{Yo_shard} Xo:{Xo_shard} Co:{Co_shard}')
    return mem_fmt


def ofm_memtile_s2mm(dims: ConvDims, col: int, row: int, slicer: TensorSlicer) -> List[str]:
    '''Define OFM S2MM data order and shape'''
    write_tiling_fmt = []
    for _ in range(dims.Y_loop):
        for _ in range(dims.X_loop):
            Yos_start = Yo_split_offset(dims, col, row) * dims.Yos
            Yos_stop = Yos_start + dims.Yos
            Xos_start = Xo_split_offset(dims, col, row) * dims.Xos
            Xos_stop = Xos_start + dims.Xos
            Cob_offset = Co_split_offset(dims, col, row)
            Co_start = Cob_offset * dims.Cos
            Co_stop = Co_start + dims.Cos
            subv_fmt = f'Co:{Co_start}:{Co_stop}:{dims.Co_gran} Yo:{Yos_start}:{Yos_stop} Xo:{Xos_start}:{Xos_stop} Co:0:{dims.Co_gran}'
            write_tiling_fmt.append(subv_fmt)
            Co_phases = slicer.Co_split_iters()
            if Co_phases[-1] != 0:
                # partial subvol on Co dimension
                write_tiling_fmt.append(subv_fmt)
    return write_tiling_fmt


def ofm_memtile_mm2s(dims: ConvDims, col: int, slicer: TensorSlicer) -> List[str]:
    '''Define OFM MM2S data order and shape'''
    read_tiling_fmt = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            _, _, Yo_shard_size = Yo_slice_per_column(dims, col, y_iter)
            _, _, Xo_shard_size = Xo_slice_per_column(dims, col, x_iter)
            _, _, co_shard_size = Co_slice_per_column(dims, col, 0)
            read_tiling_fmt.append(
                f'Yo:0:{Yo_shard_size} '
                f'Xo:0:{Xo_shard_size} '
                f'Co:0:{co_shard_size}'
            )
            Co_phases = slicer.Co_split_iters()
            if Co_phases[-1] != 0:
                # If there is CO pad/depad phase
                _, _, co_shard_size_last = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                read_tiling_fmt.append(
                    f'Yo:0:{Yo_shard_size} '
                    f'Xo:0:{Xo_shard_size} '
                    f'Co:0:{co_shard_size_last}'
                )
    return read_tiling_fmt


def generate_ofm_memtile_data_transfers(
    dims: ConvDims,
    dwc_l2_alloc: convL2Memory,
    dwc_repeats: ConvDataFlowRepeats,
    slicer: TensorSlicer,
) -> List[DataTransfer]:
    '''Generate OFM memtile data transfers'''
    data_transfers = []
    log(f"INFO OFM L2 conv repeats.ofm_L2_mm2s_repeats: {dwc_repeats.ofm_L2_mm2s_repeats}")
    for col in range(dims.aie_cols):
        mem_fmt = ofm_memtile_memory(dims, col, slicer)
        assert len(mem_fmt) == len(dwc_repeats.ofm_L2_mm2s_repeats[col]), \
            f"Column {col} OFM L2 memory fmts and repeats length mismatch: " \
            f"{len(mem_fmt)} != {len(dwc_repeats.ofm_L2_mm2s_repeats[col])}"
        write_tiling_fmt = {}
        for row in range(dims.aie_rows):
            per_row_write_fmt = ofm_memtile_s2mm(dims, col, row, slicer)
            assert len(per_row_write_fmt) == len(dwc_repeats.ofm_L2_mm2s_repeats[col]), \
                f"Column: {col}, row: {row}, ofm memtile s2mm phases of length: {len(dwc_repeats.ofm_L2_mm2s_repeats[col])} " \
                f"mismatch with fmts length: {len(per_row_write_fmt)}"
            write_tiling_fmt[row] = per_row_write_fmt
        read_fmt = ofm_memtile_mm2s(dims, col, slicer)
        # Check if the number of phases matches the number of fmts
        assert len(read_fmt) == len(dwc_repeats.ofm_L2_mm2s_repeats[col]), \
            f"Column {col} fmts and repeats length mismatch: " \
            f"{len(read_fmt)} != {len(dwc_repeats.ofm_L2_mm2s_repeats[col])}"
        data_transfers.append(
            DataTransfer(
                dwc_repeats.ofm_L2_mm2s_repeats[col],
                memory_tile(col), [dwc_l2_alloc.ofm_ping_addr, dwc_l2_alloc.ofm_pong_addr], dwc_l2_alloc.ofm_size,
                [
                    pack_reconfig_transfers(
                        memtile_dma(col, DmaDir.S2MM, channel_id),
                        mem_fmt,
                        write_tiling_fmt[row],
                        bits_per_elem=dims.ofm_bits,
                        name=f"OFM_memtile_S2MM_col{col}_row{row}",
                    ) for row, channel_id in ofm_memtile_channels()
                ],
                [
                    pack_reconfig_transfers(
                        memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[0]),
                        mem_fmt,
                        read_fmt,
                        bits_per_elem=dims.ofm_bits,
                        name=f"OFM_memtile_MM2S_col{col}",
                    )
                ],
                sync_strategy=SyncStrategy.Parallel_N_to_1,
            )
        )
    return data_transfers


def ofm_shimtile_s2mm(dims: ConvDims, col: int, slicer: TensorSlicer) -> List[str]:
    '''Define OFM SHIM S2MM data order and shape'''
    write_tiling_fmt = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            yo_start, yo_stop, _ = Yo_slice_per_column(dims, col, y_iter)
            xo_start, xo_stop, _ = Xo_slice_per_column(dims, col, x_iter)
            co_split_size = Co_split_size(dims)
            Co_phases = slicer.Co_split_iters()
            full_iters = dims.Co_loop - 1 if Co_phases[-1] == 0 else dims.Co_loop - 2
            # Main Co transfer with striding
            co_start, _, _ = Co_slice_per_column(dims, col, 0)
            _, co_stop, _ = Co_slice_per_column(dims, col, full_iters)
            cos_stride = min(co_stop - co_start, co_split_size * dims.Cos)
            write_tiling_fmt.append(
                f'Co:{co_start}:{co_stop}:{cos_stride} '
                f'Yo:{yo_start}:{yo_stop} '
                f'Xo:{xo_start}:{xo_stop} '
                f'Co:0:{cos_stride}'
            )
            if Co_phases[-1] != 0:
                # If there is Co pad/depad phase
                co_start, co_stop, _ = Co_slice_per_column(dims, col, dims.Co_loop - 1)
                write_tiling_fmt.append(
                    f'Yo:{yo_start}:{yo_stop} '
                    f'Xo:{xo_start}:{xo_stop} '
                    f'Co:{co_start}:{co_stop}'
                )
    return write_tiling_fmt


def generate_ofm_shimtile_data_transfers(
    dims: ConvDims,
    dwc_repeats: ConvDataFlowRepeats,
    dwc_shim: ShimAllocator,
    slicer: TensorSlicer,
) -> List[DataTransfer]:
    '''Generate OFM shim tile data transfers'''
    data_transfers = []
    mem_fmt = ofm_shim_memory(dims)
    ddr_ofm_size = compute_buffer_size(mem_fmt)
    for col in range(dims.aie_cols):
        mem_fmt_list = [mem_fmt for _ in range(len(dwc_repeats.ofm_L3_s2mm_repeats[col]))]
        write_fmt = ofm_shimtile_s2mm(dims, col, slicer)
        assert len(write_fmt) == len(dwc_repeats.ofm_L3_s2mm_repeats[col]), \
            f"Column: {col} OFM shim tile S2MM phases of length {len(dwc_repeats.ofm_L3_s2mm_repeats[col])} "\
            f"Does not match write formats of length {len(write_fmt)}"
        data_transfers.append(
            DataTransfer(
                dwc_repeats.ofm_L3_s2mm_repeats[col],
                shim_tile(col), [dwc_shim.ofm_xrt_idx], ddr_ofm_size,
                [
                    pack_reconfig_transfers(
                        shim_dma(col, DmaDir.S2MM, 0),
                        mem_fmt_list,
                        write_fmt,
                        bits_per_elem=dims.ofm_bits,
                        buffer_offset=[dwc_shim.ofm_xrt_offset],
                        name=f"ofm_shimtile_s2mm_col_{col}"
                    )
                ],
                [
                    # No OFM reads
                ]
            )
        )
    return data_transfers


#####################################################
# NOTE: Mypy doesn't correctly infer types with the abstract base construct
# used by the core instruction list, so we disable type checking locally here
@no_type_check
def compile_DWC_L3_dataflow(schedule_input: ScheduleInputs) -> tuple:
    '''
    DWC dataflow
    Note we ONLY do streaming for all input output tensors
    as there is no reuse or accumulation unlike conv
    '''
    shape: ConvShape = schedule_input.shape
    mapping: ConvMapping = schedule_input.mapping
    log(f"Scheduler shape: {shape}")
    log(f"Scheduler mapping: {mapping}")
    L3_alloc: L3Alloc = schedule_input.L3_alloc
    aie_cols = 3
    aie_rows = 4
    overlay_shape = OverlayShape(aie_cols, aie_rows)
    dims = ConvDims(shape, mapping)
    slicer = TensorSlicer(dims)
    dwc_shim = L3Alloc_to_Shim(L3_alloc)
    dwc_l2_alloc = DwcL2MemoryAllocator(
        dims,
        prm_memtile_memory(),
        ifm_memtile_memory(dims, 0, slicer)[0],
        dwc_wgt_memtile_memory(dims, 0, slicer)[0],
        ofm_memtile_memory(dims, 0, slicer)[0],
        ifm_double_buffer=True,
        wgt_double_buffer=True,
        ofm_double_buffer=True,
    )
    # log the L2 memory allocation details
    log(f"INFO dwc_l2_alloc.param_addr: {dwc_l2_alloc.param_addr} dwc_l2_alloc.param_size: {dwc_l2_alloc.param_size}")
    log(f"INFO dwc_l2_alloc.ifm_ping_addr: {dwc_l2_alloc.ifm_ping_addr} dwc_l2_alloc.ifm_L2_size: {dwc_l2_alloc.ifm_size}")
    log(f"INFO dwc_l2_alloc.ifm_pong_addr: {dwc_l2_alloc.ifm_pong_addr} dwc_l2_alloc.ifm_L2_size: {dwc_l2_alloc.ifm_size}")
    log(f"INFO dwc_l2_alloc.wgt_ping_addr: {dwc_l2_alloc.wgt_ping_addr} dwc_l2_alloc.wgt_L2_size: {dwc_l2_alloc.wgt_size}")
    log(f"INFO dwc_l2_alloc.wgt_pong_addr: {dwc_l2_alloc.wgt_pong_addr} dwc_l2_alloc.wgt_L2_size: {dwc_l2_alloc.wgt_size}")
    log(f"INFO dwc_l2_alloc.ofm_ping_addr: {dwc_l2_alloc.ofm_ping_addr} dwc_l2_alloc.ofm_L2_size: {dwc_l2_alloc.ofm_size}")
    log(f"INFO dwc_l2_alloc.ofm_pong_addr: {dwc_l2_alloc.ofm_pong_addr} dwc_l2_alloc.ofm_L2_size: {dwc_l2_alloc.ofm_size}")
    dma_connections = overlay_3x4_dma_connections()
    data_stream_mode = split_to_mode(dims)
    if data_stream_mode == 0:
        log("INFO: IFM unicast / WGT broadcast mode")
    elif data_stream_mode == 1:
        log("INFO: IFM broadcast / WGT unicast mode")
    else:
        raise ValueError(f"Unknown data stream mode: {data_stream_mode}")
    is_wgt_repeat_high = False
    dwc_dataflow_repeats = generate_dwc_repeats(dims, slicer, is_wgt_repeat_high)
    log(f"dwc_dataflow_repeats: {dwc_dataflow_repeats}")
    core_instrs_dict = {}
    for col in range(aie_cols):
        for row in range(aie_rows):
            core_instrs_dict[core_tile(col, row)] = generate_dwc_core_instrs(
                dims,
                mapping,
                shape.linear_op_type,
                col, row,
            )

    memtile_transfers = [
        DataTransfer(
            [1] + [0] * (len(dwc_dataflow_repeats.ifm_L2_mm2s_repeats[col])-1),
            memory_tile(col), [dwc_l2_alloc.param_addr], dwc_l2_alloc.param_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[2]),
                prm_memtile_memory(),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                prm_memtile_memory(),
                prm_memtile_mm2s(row),
            ) for row in range(aie_rows)],
        ) for col in range(aie_cols)
    ]

    shimtile_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (len(dwc_dataflow_repeats.ifm_L2_mm2s_repeats[col])-1),
            shim_dma(col, DmaDir.MM2S, 2), dwc_shim.prm_xrt_idx,
            prm_shim_memory(),
            prm_shim_mm2s(col),
            buffer_offset=dwc_shim.prm_xrt_offset
        ) for col in range(aie_cols)
    ]

    ifm_l2_transfers = generate_ifm_memtile_data_transfers(dims, dwc_l2_alloc, dwc_dataflow_repeats, slicer)
    memtile_transfers += ifm_l2_transfers
    wgt_l2_transfers = generate_dwc_wgt_memtile_data_transfers(dims, slicer, dwc_l2_alloc, dwc_dataflow_repeats)
    memtile_transfers += wgt_l2_transfers
    ofm_l2_transfers = generate_ofm_memtile_data_transfers(dims, dwc_l2_alloc, dwc_dataflow_repeats, slicer)
    memtile_transfers += ofm_l2_transfers
    ifm_l3_transfers = generate_ifm_shimtile_data_transfers(dims, dwc_dataflow_repeats, dwc_shim, slicer)
    shimtile_transfers += ifm_l3_transfers
    wgt_l3_transfers = generate_dwc_wgt_shimtile_data_transfers(dims, dwc_dataflow_repeats, dwc_shim, slicer)
    shimtile_transfers += wgt_l3_transfers
    ofm_l3_transfers = generate_ofm_shimtile_data_transfers(dims, dwc_dataflow_repeats, dwc_shim, slicer)
    shimtile_transfers += ofm_l3_transfers

    run_layer_compilation(
        overlay_shape,
        schedule_input.kernel_names,
        schedule_input.kernel_includes,
        core_instrs_dict,
        memtile_transfers,
        shimtile_transfers,
        dma_connections,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        param_channel_id=overlay_3x4_param_channel_id(),
        back_end=schedule_input.backend,
        layer_file=schedule_input.layer_file_name,
    )

    wgt_shim_size = compute_buffer_size(dwc_wgt_shim_memory(dims))
    prm_shim_size = compute_buffer_size(prm_shim_memory())
    shim_prm_offset_next_layer = dwc_shim.prm_xrt_offset + prm_shim_size
    shim_wgt_offset_next_layer = dwc_shim.wgt_xrt_offset + wgt_shim_size
    log("shim_prm_offset_next_layer", shim_prm_offset_next_layer)
    log("shim_wgt_offset_next_layer", shim_wgt_offset_next_layer)
    return shim_prm_offset_next_layer, shim_wgt_offset_next_layer
