import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import List, Union, Dict

from dmacompiler import \
    OverlayShape, BackEnd, DataTransfer, AieTile, TileType, \
    DmaDir,  memtile_dma, shim_dma, CascDir, \
    generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config 

from dataflow_common import \
    overlay_4x4_dma_connections, \
    overlay_8x4_dma_connections, \
    overlay_stack_addr, \
    overlay_4x4_core_stream_bdcast, \
    overlay_8x4_core_stream_bdcast

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True


def prm_shim_memory(aie_cols: int, aie_rows: int, prm_subv_size: int) -> str:
    return f'Col:{aie_cols} Row:{aie_rows} Param:{prm_subv_size}'

def prm_shim_mm2s(col: int) -> str:
    return f'Col:{col}:{col + 1} Row Param'

def prm_memtile_memory(aie_rows: int, prm_subv_size: int) -> str:
    return f'Row:{aie_rows} Param:{prm_subv_size}'

def prm_memtile_s2mm() -> str:
    return f'Row Param'

def prm_memtile_mm2s(row: int) -> str:
    return f'Row:{row}:{row + 1} Param'

def compile_dataflow(kernel_names: Union[List[str], Dict[str, int]], kernel_includes: List[str], overlay: str, disable_fastPM: bool = False):
    if disable_fastPM:
        config.ENABLE_FAST_PM = False
    param_channel_id = 0
    prm_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
    prm_buffer_id = 3
    core_instrs = []

    if overlay == '4x4':
        aie_cols = 4
        aie_rows = 4
        overlay_dma_connections = overlay_4x4_dma_connections()
    elif overlay == '8x4':
        aie_cols = 8
        aie_rows = 4
        overlay_dma_connections = overlay_8x4_dma_connections()
    else:
        assert False, f"unsupported overlay. {overlay}"

    memtile_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [0], aie_rows * prm_subv_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, param_channel_id),
                prm_memtile_memory(aie_rows, prm_subv_size),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(aie_rows, prm_subv_size),
                prm_memtile_mm2s(row),
            ) for row in range(aie_rows)],
        ) for col in range(aie_cols)
    ]
    shim_transfers = [
        generate_shim_data_transfer(
            [1],
            shim_dma(col, DmaDir.MM2S, param_channel_id), prm_buffer_id,
            prm_shim_memory(aie_cols, aie_rows, prm_subv_size),
            prm_shim_mm2s(col),
        ) for col in range(aie_cols)
    ]

    run_layer_compilation(
        OverlayShape(aie_cols, aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_dma_connections,
        back_end=BackEnd.TxnHostPatch,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=param_channel_id,
        layer_file='dma.hpp',
        casc_dir = CascDir.Vertical,
        core_connections = overlay_4x4_core_stream_bdcast() if overlay == '4x4' else overlay_8x4_core_stream_bdcast()
    )

def main():
    compile_dataflow([], ['super.hh'], '8x4')

if __name__ == '__main__':
    main()
