import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import Tuple, List

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, \
    AieTile, TileType, \
    AieDma, DmaDir, memtile_dma, shim_dma, core_dma, DmaChannel, \
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    compute_buffer_size, \
    TransferParams, generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_stack_addr, \
    clean_overlay, \
    build_sim_overlay, \
    ceildiv, \
    shim_alloc, \
    prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s
    
from depthtospace_common import \
    DepthToSpace_dims, \
    depthtospace_preproc_directives
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d

def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param

def Yi_slice(dims: DepthToSpace_dims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yi_split = dims.Yis
    Yi_stride = dims.aie_cols * Yi_split
    Yi_start = (col * Yi_split) + (start_iter * Yi_stride)
    Yi_stop = Yi_start + dims.Yis if Yi_start <= dims.Yi else Yi_start
    Yi_size = max(0, min(Yi_stop, dims.Yi)) - max(0, min(Yi_start, dims.Yi))
    return (Yi_start, Yi_stop, Yi_stride, Yi_size)

def Yi_split_iters(dims: DepthToSpace_dims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        Yi_start, Yi_stop, _, _ = Yi_slice(dims, col, start_iter + num_iters - 1)
        has_no_padding = not ((Yi_start < 0) or (Yi_stop > dims.Yi))
        return has_no_padding 
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def Yi_repeat_counts(dims: DepthToSpace_dims, col: int, scale: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    for start_iter, num_iters in Yi_split_iters(dims, col):
        repeat_counts[start_iter] = num_iters * scale
    return repeat_counts

def Yo_slice(dims: DepthToSpace_dims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yo_stride = dims.aie_cols * dims.Yos
    Yo_start = (col * dims.Yos) + (start_iter * Yo_stride)
    Yo_stop = min(Yo_start + dims.Yos, dims.Yo) if Yo_start <= dims.Yo else Yo_start
    Yo_size = Yo_stop - Yo_start
    return (Yo_start, Yo_stop, Yo_stride, Yo_size)

def Yo_split_iters(dims: DepthToSpace_dims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        _, _, _, Yo_size = Yo_slice(dims, col, start_iter + num_iters - 1)
        is_full_slice = Yo_size == dims.Yos
        return is_full_slice
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def ifm_shim_memory(dims: DepthToSpace_dims) -> str:
    return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'

def ifm_shim_mm2s(dims: DepthToSpace_dims, col: int) -> List[str]:
    def fmt(start_iter: int, num_iters: int) -> str:
        Yi_start, Yi_stop, Yi_stride, _ = Yi_slice(dims, col, start_iter)
        return (
            f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
            f'Xi:0:{dims.Xi}:{dims.Xi} '
            f'Ci:0:{dims.Ci}:{dims.Ci} '
            f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Xi:0:{dims.Xi} Ci:0:{dims.Ci}'
        )
    fs = [fmt(s, n) for s, n in Yi_split_iters(dims, col)]
    return fs

def ifm_memtile_memory(dims: DepthToSpace_dims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        if Yi_size <= 0:
            Yi_size = dims.Yis
        return f'Yi:{Yi_size} Xi:{dims.Xi} Ci:{dims.Ci}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ifm_memtile_s2mm(dims: DepthToSpace_dims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        return f'Yi:0:{Yi_size} Xi:0:{dims.Xi} Ci:0:{dims.Ci}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ifm_memtile_mm2s(dims: DepthToSpace_dims, col: int, row: int) -> List[str]:
    def Xi_slice(row: int, X_split) -> Tuple[int, int, int, int]:
        Xi_size = dims.Xi // X_split
        Xi_start = ((row % X_split) * Xi_size)
        Xi_stop = Xi_start + Xi_size
        return (Xi_start, Xi_stop, Xi_size)     

    def Yi_Xi_slice(dims: DepthToSpace_dims, row: int, Yi_start: int):
        slice_size = dims.Xi * dims.Yis // dims.aie_rows
        if dims.Yis % 4 == 0:
            X_split = 1
            Yi_split = 4
        elif dims.Yis % 2 == 0:
            X_split = 2
            Yi_split = 2
            assert dims.Xi % 2 == 0,"make new split, current split can't evenly splited by four cores in a column"
        else:
            if dims.Yis == 1:
                X_split = 4
                Yi_split = 1
                assert dims.Xi % 4 == 0, "make new split, current split can't evenly splited by four cores in a column"
            else:
                raise ValueError(f'current split Yis{dims.Yis} can not satisfy the cores in one column split from memtile ')
 
        Xi_start, Xi_stop, Xi_size = Xi_slice(row, X_split)
        row_size_start = slice_size * row
        row_size_stop  = slice_size * (row +1)
        Yi_start = (row // (dims.aie_rows // Yi_split) ) * (dims.Yis // Yi_split)
        Yi_stride = slice_size // (dims.Xi // X_split)
        Yi_stride = 1 if Yi_stride == 0 else Yi_stride
        Yi_stop = Yi_start + Yi_stride
       
        return (Yi_start, Yi_stop, Xi_start, Xi_stop)
            
    def fmt(start_iter: int) -> str:
        Yi_start, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        Yi_start, Yi_stop, Xi_start, Xi_stop = Yi_Xi_slice(dims, row, Yi_start)

        Xi_stride = dims.Xi
        Yi_stride = dims.Yis
        
        return (
            f'Yi:0:{Yi_stride}:{Yi_stride} '
            f'Xi:0:{Xi_stride}:{Xi_stride} '
            f'Ci:0:{dims.Ci}:{dims.Ci} '
            f'Yi:{Yi_start}:{Yi_stop} Xi:{Xi_start}:{Xi_stop} Ci:0:{dims.Ci}'
        )
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ofm_shim_memory(dims: DepthToSpace_dims) -> str:
    return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'

def ofm_shim_s2mm(dims: DepthToSpace_dims, col: int) -> List[str]:
    def fmt(start_iter: int, num_iters: int) -> str:
        Yo_start, Yo_stop, Yo_stride, _ = Yo_slice(dims, col, start_iter)
        return (
            f'Yo:0:{Yo_stride * num_iters}:{Yo_stride} '
            f'Yo:{Yo_start}:{Yo_stop} Xo:0:{dims.Xo} Co:0:{dims.Co}'
        )
    fs = [fmt(s, n) for s, n in Yo_split_iters(dims, col)]
    return fs

def ofm_memtile_memory(dims: DepthToSpace_dims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yo_size = Yo_slice(dims, col, start_iter)
        if Yo_size <= 0:
            Yo_size = dims.Yos
        return f'Yo:{Yo_size} Xo:{dims.Xo} Co:{dims.Co}'
    fs = [fmt(s) for s, _ in Yo_split_iters(dims, col)]
    return fs

def ofm_memtile_mm2s(dims: DepthToSpace_dims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yo_size = Yo_slice(dims, col, start_iter)
        return f'Yo:0:{Yo_size} Xo:0:{dims.Xo} Co:0:{dims.Co}'
    fs = [fmt(s) for s, _ in Yo_split_iters(dims, col)]
    return fs

def ofm_memtile_s2mm(dims: DepthToSpace_dims, col: int, row: int) -> List[str]:
    def Xo_slice(row: int, X_split) -> Tuple[int, int, int, int]:
        Xo_size = dims.Xo // X_split
        Xo_start = ((row % X_split) * Xo_size)
        Xo_stop = Xo_start + Xo_size
        return (Xo_start, Xo_stop, Xo_size)     

    def Yo_Xo_slice(dims: DepthToSpace_dims, row: int, Yo_start: int):
        slice_size = dims.Xo * dims.Yos // dims.aie_rows
        if dims.Yis % 4 == 0:
            X_split = 1
            Yo_split = 4
        elif dims.Yis % 2 == 0:
            X_split = 2
            Yo_split = 2
            assert dims.Xi % 2 == 0,"make new split, current split can't evenly splited by four cores in a column"
        else:
            if dims.Yis == 1:
                X_split = 4
                Yo_split = 1
                assert dims.Xi % 4 == 0, "make new split, current split can't evenly splited by four cores in a column"
            else:
                raise ValueError(f'current split Yis{dims.Yis} can not satisfy the cores in one column split from memtile ')
 
        Xo_start, Xo_stop, Xo_size = Xo_slice(row, X_split)
        row_size_start = slice_size * row
        row_size_stop  = slice_size * (row +1)
        Yo_start = (row // (dims.aie_rows // Yo_split) ) * (dims.Yos // Yo_split)
        Yo_stride = slice_size // (dims.Xo // X_split)
        Yo_stride = 1 if Yo_stride == 0 else Yo_stride
        Yo_stop = Yo_start + Yo_stride
       
        return (Yo_start, Yo_stop, Xo_start, Xo_stop)
            
    def fmt(start_iter: int) -> str:
        Yo_start, _, _, Yo_size = Yo_slice(dims, col, start_iter)
        Yo_start, Yo_stop, Xo_start, Xo_stop = Yo_Xo_slice(dims, row, Yo_start)

        Xo_stride = dims.Xo
        Yo_stride = dims.Yos
        
        return (
            f'Yo:0:{Yo_stride}:{Yo_stride} '
            f'Xo:0:{Xo_stride}:{Xo_stride} '
            f'Yo:{Yo_start}:{Yo_stop} Xo:{Xo_start}:{Xo_stop} Co:0:{dims.Co}'
        )
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def Yi_Xi_split(dims: DepthToSpace_dims):
    slice_size = dims.Xi * dims.Yis // dims.aie_rows
    if dims.Yis % 4 == 0:
        X_split = 1
        Yi_split = 4
    elif dims.Yis % 2 == 0:
        X_split = 2
        Yi_split = 2
        assert dims.Xi % 2 == 0,"make new split, current split can't evenly splited by four cores in a column"
    else:
        if dims.Yis == 1:
            X_split = 4
            Yi_split = 1
            assert dims.Xi % 4 == 0, "make new split, current split can't evenly splited by four cores in a column"
        else:
            raise ValueError(f'current split Yis:{dims.Yis} can not satisfy the cores in one column split from memtile ')
    return (Yi_split, X_split)

def gen_permute_params(N: int, C: int, H: int, W:int, B:int, mode_PCR: int):
    return ( N.to_bytes(length=2, byteorder='little', signed=False)
    + C.to_bytes(length=2, byteorder='little', signed=False)
    + H.to_bytes(length=2, byteorder='little', signed=False)
    + W.to_bytes(length=2, byteorder='little', signed=False)
    + B.to_bytes(length=2, byteorder='little', signed=False)
    + mode_PCR.to_bytes(length=2, byteorder='little', signed=False)
    )

def compile_dataflow(
    dims: DepthToSpace_dims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):
    # assert (dims.aie_cols, dims.aie_rows) == (8, 4)

    permute_shim_alloc = shim_alloc()

    param_memtile_size = dims.aie_rows * config.MAX_CORE_LAYER_PARAM_SIZE


    param_memtile_addr = 0

    CoreIfmSize = dims.ifm_subv_elem * dims.ifm_bits // 8
    CoreOfmSize = dims.ifm_subv_elem * dims.ifm_bits // 8
    
    core_bank_size = 16384   
    CoreIfmPingAddr = 0
    CoreOfmPingAddr = iceil(CoreIfmPingAddr + CoreIfmSize, 64)
    if dims.core_ifm_pingpong:
        if CoreOfmPingAddr + CoreOfmSize < 2 * core_bank_size:
            CoreIfmPongAddr = 2 * core_bank_size
        else:
            CoreIfmPongAddr = iceil(CoreOfmPingAddr + CoreOfmSize, 64)
    else:
        CoreIfmPongAddr = None
        
    if dims.core_ofm_pingpong:
        if CoreIfmPongAddr + CoreIfmSize < 3 * core_bank_size:
            CoreOfmPongAddr = 3 * core_bank_size
        else:
            CoreOfmPongAddr = iceil(CoreIfmPongAddr + CoreIfmSize, 64)
    else:
        CoreOfmPongAddr = None     

 
    # ifm_memtile_size = (dims.Yi * dims.Xi * dims.Ci * dims.ifm_bits) // dims.aie_cols // dims.num_splits  // 8
    ifm_memtile_size = (dims.Yis * dims.Xi * dims.Ci * dims.ifm_bits) // 8
    ofm_memtile_size = ifm_memtile_size
    
    
    param_memtile_addr = 0
    param_memtile_size = compute_buffer_size(prm_memtile_memory(dims))


    ifm_memtile_ping_addr = param_memtile_addr + param_memtile_size
    ofm_memtile_ping_addr = ifm_memtile_ping_addr + ifm_memtile_size
    if dims.mt_ifm_pingpong:
        ifm_memtile_pong_addr = ofm_memtile_ping_addr + ofm_memtile_size
        ifm_memtile_addrs = [ifm_memtile_ping_addr, ifm_memtile_pong_addr]
    else:
        ifm_memtile_pong_addr = None
        ifm_memtile_addrs = [ifm_memtile_ping_addr]
    # if ofm pingpong enabled, ifm always be pingpong mode
    if dims.mt_ofm_pingpong:
        ofm_memtile_pong_addr = ifm_memtile_pong_addr + ifm_memtile_size
        ofm_memtile_addrs = [ofm_memtile_ping_addr, ofm_memtile_pong_addr]
    else:
        ofm_memtile_pong_addr = None
        ofm_memtile_addrs = [ofm_memtile_ping_addr]

    
    Tn = (dims.Yi * dims.Xi * dims.Ci // dims.aie_cols // dims.aie_rows) // dims.ifm_subv_elem
    # Tn = 1
    
    run_kernel = 'run_int16_permute'
    batch = dims.batch
    depth = dims.depth
    Y_split, X_split = Yi_Xi_split(dims)
    width = dims.width // X_split
    height = dims.ifm_subv_elem // (depth * width)
    
    
    core_instrs = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize),
        Loop(Tn, [
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
            CallKernel(run_kernel, gen_permute_params(batch, depth, height, width, dims.blockSize, dims.perm_mode)),
            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        ])   
    ]

    memtile_transfers = []
    memtile_param_transfers = [ 
        DataTransfer(
            [1]+ [0] * (dims.Y_loop - 1),
            AieTile(TileType.Memtile, col),
            [param_memtile_addr],
            param_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
        ) for col in range(dims.aie_cols)
    ]
    memtile_transfers += memtile_param_transfers 
    
    memtile_ifm_transfers = [
        DataTransfer(
            Yi_repeat_counts(dims, col, 1),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims,  col),
                ifm_memtile_s2mm(dims, col),
                [n for _, n in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            )],
            [
                pack_transfers(
                memtile_dma(col, DmaDir.MM2S, row),
                ifm_memtile_memory(dims, col),
                ifm_memtile_mm2s(dims, col, row),
                [i for _, i in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            ) for row in range(dims.aie_rows)
            ],
        )  for col in range(dims.aie_cols)
    ]
    memtile_transfers += memtile_ifm_transfers
 
    memtile_ofm_transfers = [
        DataTransfer(
            Yi_repeat_counts(dims, col, 1),
            AieTile(TileType.Memtile, col), ofm_memtile_addrs, ifm_memtile_size,
            [
                pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                ofm_memtile_memory(dims, col),
                ofm_memtile_s2mm(dims, col, row),
                [i for _, i in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            ) for row in range(dims.aie_rows)
            ],
            [
                pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ofm_memtile_memory(dims, col),
                ofm_memtile_mm2s(dims, col),
                [n for _, n in Yo_split_iters(dims, col)],
                dims.ifm_bits,)
            ],
        ) for col in range(dims.aie_cols)
    ]
    memtile_transfers += memtile_ofm_transfers
    
    shim_transfers = []
    shim_param_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.MM2S, 0),
            permute_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        ) for col in range(dims.aie_cols)
    ] 
    shim_transfers += shim_param_transfers 
    shim_ifm_transfers =  [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.MM2S, 0),
            permute_shim_alloc.ifm_buffer_id,
            ifm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
        ) for col in range(dims.aie_cols) for fmt in ifm_shim_mm2s(dims, col)
    ] 
    shim_transfers += shim_ifm_transfers 

    shim_ofm_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.S2MM, 0),
            permute_shim_alloc.ofm_buffer_id,
            ofm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
        ) for col in range(dims.aie_cols) for fmt in ofm_shim_s2mm(dims, col)
    ]
    shim_transfers += shim_ofm_transfers     

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections(),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
    )

def main():
    back_end = BackEnd.Adf
    kernel_names = ["run_int16_permute"]
    kernel_includes = ['super.hh', "permute/wrapper_permute.cc"]
    aie_cols, aie_rows = 8, 4
    in_shape = [1, 128, 128, 256]
    blockSize = 2
    perm_mode = "CRD"
    act_bits = 16
    
    dims = DepthToSpace_dims(
        aie_rows, aie_cols, 
        in_shape, blockSize, perm_mode,
        act_bits   
    )

    clean_overlay()
    compile_dataflow(
        dims,
        back_end,
        kernel_names,
        kernel_includes
    )
    build_sim_overlay(back_end, 'depthtospace_main.cpp', depthtospace_preproc_directives(dims, back_end))

if __name__ == '__main__':
    main()