'''
Dummy dataflow used to do a dry run for processing combined pdi
'''
from typing import no_type_check, Union

from dmacompiler import (
    DevGen, set_dev_gen, BackEnd, SyncStrategy,
    OverlayShape, DataTransfer, DmaDir,
    memtile_dma, shim_dma, memory_tile,
    run_layer_compilation,
    generate_transfer_params,
    generate_shim_data_transfer,
)
from utils.utils_common import (
    overlay_3x4_core_stack_addr,
)
from scheduler.common import (
    overlay_3x4_dma_connections,
    overlay_3x4_param_channel_id,
    overlay_3x4_A_ids, overlay_3x4_F_ids,
    prm_memtile_memory, prm_shim_memory,
    prm_memtile_mm2s, prm_memtile_s2mm,
    prm_shim_mm2s, shim_alloc,
    overlay_3x4_col_core_stream_bdcast
)

set_dev_gen(DevGen.Aie4)


@no_type_check
def compile_pdi(
    kernel_names: Union[dict, list[str]],
    kernel_includes: list[str],
    backend: BackEnd,
    layer_file_name: str,
) -> None:
    '''Compile the L3 dataflow for the given shape and mapping'''
    aie_cols = 3
    aie_rows = 4
    print(backend)

    overlay_shape = OverlayShape(aie_cols, aie_rows)

    core_instrs = []

    dummy_shim_alloc = shim_alloc()
    dma_connections = overlay_3x4_dma_connections()
    memtile_transfers = [
        DataTransfer(
            [1],
            memory_tile(col), [0], 4096,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                prm_memtile_memory(),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                prm_memtile_memory(),
                prm_memtile_mm2s(row),
            ) for row in range(aie_rows)],
            sync_strategy=SyncStrategy.Parallel_1_to_N,
        ) for col in range(aie_cols)
    ]

    shimtile_transfers = [
        generate_shim_data_transfer(
            [1],
            shim_dma(col, DmaDir.MM2S, 0), dummy_shim_alloc.prm_buffer_id,
            prm_shim_memory(),
            prm_shim_mm2s(col)
        ) for col in range(aie_cols)
    ]
    core_connections = overlay_3x4_col_core_stream_bdcast()  # if "run_group_norm_qdq" in kernel_includes else None
    run_layer_compilation(
        overlay_shape,
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shimtile_transfers,
        dma_connections,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        param_channel_id=overlay_3x4_param_channel_id(),
        back_end=backend,
        layer_file=layer_file_name,
        core_connections=core_connections
    )
