import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import Tuple, List

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, \
    AieTile, TileType, \
    DmaDir, memtile_dma, shim_dma, \
    generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config, \
    ConfigBuffer, DmaChannel, AcqBuffer, RelBuffer

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_stack_addr, \
    clean_overlay, \
    build_sim_overlay, \
    ceildiv, \
    shim_alloc, \
    prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s
    
from transpose_common import \
    TransposeDims, \
    transpose_preproc_directives
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

MAX_SUBV_SIZE = 8192

def axis_name(idx: int) -> str:
    name = chr(ord('A') + idx)
    return name

def axis_slice(dims, col: int, idx: int) -> Tuple[int, int, int, int]:
    axis_size = dims.input[dims.perm[idx]]
    # NOTE: We split the axis at index 0, since this will
    # be the outermost traversal in the write direction, maximizing
    # the linearity of the access pattern.
    if idx == 0:
        split = ceildiv(axis_size, dims.aie_cols)
        start = min(col * split, axis_size)
    else:
        split = axis_size
        start = 0
    stop = min(start + split, axis_size)
    size = stop - start
    return start, stop, split, size

def ifm_shard_total_size(dims, col: int) -> int:
    assert len(dims.input) == len(dims.perm)
    total_size = 1
    for i in range(len(dims.perm)):
        _, _, _, size = axis_slice(dims, col, i)
        total_size *= size
    return total_size

def ifm_shard_subv_size(dims, col: int) -> int:
    def min_divisor(x: int, d: int) -> int:
        for num in range(d, 0, -1):
            if (x % num) == 0:
                return num
        assert False
    assert len(dims.input) == len(dims.perm)
    total_size = ifm_shard_total_size(dims, col)
    subv_size = min_divisor(total_size, MAX_SUBV_SIZE)
    return subv_size

def memtile_repeat_count(dims, col: int, zero_length_bds: bool = False) -> int:
    total_size = ifm_shard_total_size(dims, col)
    subv_size = ifm_shard_subv_size(dims, col)
    assert (total_size % subv_size) == 0
    repeat_count = total_size // subv_size
    if zero_length_bds and total_size == 0:
        repeat_count = 1
    return repeat_count

def memtile_memory(dims, col: int) -> str:
    subv_size = ifm_shard_subv_size(dims, col)
    return f'Elems:{subv_size}'

def mem_transfers(dims, col, repeat_count):
    if sum(repeat_count) == 0:
        return 'Elems:0:0'
    else:
        subv_size = ifm_shard_subv_size(dims, col)
        return f'Elems:0:{subv_size}'

def input_shim_memory(input: List[int]):
    fmt = ' '.join([f'{axis_name(i)}:{input[i]}' for i in range(len(input))])
    return fmt

def input_shim_mm2s(dims, col: int) -> str:
    assert len(dims.input) == len(dims.perm)
    dims_list = []
    for i in range(len(dims.perm)):
        axis = axis_name(dims.perm[i])
        start, stop, _, _ = axis_slice(dims, col, i)
        dims_list.append((axis, start, stop))
    fmt = ' '.join(f'{axis}:{start}:{stop}' for axis, start, stop in dims_list)
    return fmt

def output_shim_memory(dims):
    fmt = ' '.join([f'{axis_name(idx)}:{dims.input[idx]}' for idx in dims.perm])
    return fmt

output_shim_s2mm = input_shim_mm2s

def compile_dataflow(
    dims: TransposeDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):
    assert (dims.aie_cols, dims.aie_rows) == (8, 4)

    transpose_shim_alloc = shim_alloc()

    param_memtile_size = dims.aie_rows * config.MAX_CORE_LAYER_PARAM_SIZE
    tmp_memtile_size = MAX_SUBV_SIZE * dims.act_bits // 8

    param_memtile_addr = 0
    tmp_memtile_ping_addr = param_memtile_addr + param_memtile_size
    tmp_memtile_pong_addr = tmp_memtile_ping_addr + tmp_memtile_size

    core_instrs = [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), 0, 0, 0),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        ]

    memtile_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col),
            [param_memtile_addr],
            param_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [memtile_repeat_count(dims, col, True)],
            AieTile(TileType.Memtile, col),
            [tmp_memtile_ping_addr, tmp_memtile_pong_addr],
            tmp_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                memtile_memory(dims,  col),
                mem_transfers(dims, col, [memtile_repeat_count(dims, col)]),
                bits_per_block=dims.act_bits,
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 5),
                memtile_memory(dims, col),
                mem_transfers(dims, col, [memtile_repeat_count(dims, col)]),
                bits_per_block=dims.act_bits,
            )],
        ) for col in range(dims.aie_cols)
    ]

    shim_transfers = [
        generate_shim_data_transfer(
            [1],
            shim_dma(col, DmaDir.MM2S, 0),
            transpose_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        ) for col in range(dims.aie_cols)
    ] + [
        generate_shim_data_transfer(
            [1],
            shim_dma(col, DmaDir.MM2S, 0),
            transpose_shim_alloc.ifm_buffer_id,
            input_shim_memory(dims.input),
            input_shim_mm2s(dims, col),
            bits_per_block=dims.act_bits,
        ) for col in range(dims.aie_cols)
    ] + [
        generate_shim_data_transfer(
            [1],
            shim_dma(col, DmaDir.S2MM, 0),
            transpose_shim_alloc.ofm_buffer_id,
            output_shim_memory(dims),
            output_shim_s2mm(dims, col),
            bits_per_block=dims.act_bits,
        ) for col in range(dims.aie_cols)
    ]

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections(),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
    )

def main():
    back_end = BackEnd.Adf
    # back_end = BackEnd(int(sys.argv[1]))
    kernel_names = []
    kernel_includes = ['super.hh']
    aie_cols, aie_rows = 8, 4
    inputs = [
        [1, 64, 1, 64],
        [1, 12, 50, 64],
        [1, 50, 12, 64],
        [10, 8, 77, 64],
        [1, 12, 197, 64],
        [1, 197, 12, 64],
        [10, 77, 8, 64],
        ]
    perm = [0, 2, 1, 3]
    act_bits = 16
    for input in inputs:
        print("input:", input)
        dims = TransposeDims(
            aie_rows, aie_cols,
            input, perm,
            act_bits
        )

        clean_overlay()
        compile_dataflow(
            dims,
            back_end,
            kernel_names,
            kernel_includes
        )
        build_sim_overlay(back_end, 'transpose_main.cpp', transpose_preproc_directives(dims, back_end))

if __name__ == '__main__':
    main()