import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels'))
from typing import Tuple, List

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, SyncStrategy, \
    AieTile, TileType, \
    AieDma, DmaDir, memtile_dma, shim_dma, core_dma, DmaChannel, \
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    compute_buffer_size, \
    TransferParams, generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_stack_addr, \
    clean_overlay, \
    build_sim_overlay, \
    shim_alloc, \
    prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s
    
from slice_neg_common import \
    SliceDims, \
    slice_neg_preproc_directives
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

MAX_SUBV_SIZE = 8192

def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param

def Yi_slice(dims: SliceDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yi_split = dims.Yos
    Yi_stride = dims.aie_cols * Yi_split
    Yi_start = (col * Yi_split) + (start_iter * Yi_stride)
    Yi_stop = Yi_start + dims.Yis if Yi_start <= dims.Yi else Yi_start
    Yi_size = max(0, min(Yi_stop, dims.Yi)) - max(0, min(Yi_start, dims.Yi))
    return (Yi_start, Yi_stop, Yi_stride, Yi_size)

def Yi_split_iters(dims: SliceDims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        Yi_start, Yi_stop, _, _ = Yi_slice(dims, col, start_iter + num_iters - 1)
        has_no_padding = not ((Yi_start < 0) or (Yi_stop > dims.Yi))
        return has_no_padding 
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def Yo_slice(dims: SliceDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yo_stride = dims.aie_cols * dims.Yos
    Yo_start = (col * dims.Yos) + (start_iter * Yo_stride)
    Yo_stop = min(Yo_start + dims.Yos, dims.Yo) if Yo_start <= dims.Yo else Yo_start
    Yo_size = Yo_stop - Yo_start
    return (Yo_start, Yo_stop, Yo_stride, Yo_size)

def Yo_split_iters(dims: SliceDims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        _, _, _, Yo_size = Yo_slice(dims, col, start_iter + num_iters - 1)
        is_full_slice = Yo_size == dims.Yos
        return is_full_slice
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def ifm_shim_repeat_counts(dims: SliceDims, idx: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    repeat_counts[0] = 1
    return repeat_counts

def Yi_repeat_counts(dims: SliceDims, col: int, scale: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    for start_iter, num_iters in Yi_split_iters(dims, col):
        repeat_counts[start_iter] = num_iters * scale
    return repeat_counts

def ifm_shim_memory(dims: SliceDims) -> str:
    return f'Yi:{dims.h_in} Xi:{dims.w_in}'

def ifm_shim_mm2s(dims: SliceDims, col: int) -> List[str]:
    def fmt(start_iter: int, num_iters: int) -> str:
        Yi_start, Yi_stop, Yi_stride, _ = Yi_slice(dims, col, start_iter)
        return (
            f'Xi:0:{dims.w_in}:{dims.w_in} '
            f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
            f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Xi:{dims.w_out_start}:{dims.w_out_stop}'
        )
    fs = [fmt(s, n) for s, n in Yi_split_iters(dims, col)]
    return fs

def wgt_shim_memory(dims: SliceDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'

def wgt_shim_mm2s(dims: SliceDims) -> str:
    return f'Subv'

def wgt_memtile_memory(dims: SliceDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'

def wgt_memtile_s2mm(dims: SliceDims) -> str:
    return f'Subv'

def wgt_memtile_mm2s(dims: SliceDims) -> str:
    return f'Subv'

def ifm_memtile_memory(dims: SliceDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        if Yi_size <= 0:
            Yi_size = dims.Yis
        return f'Yi:{Yi_size} Xi:{dims.w_out}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ifm_memtile_s2mm(dims: SliceDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        return f'Yi:0:{Yi_size} Xi:0:{dims.w_out}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

ofm_memtile_memory = ifm_memtile_memory

def ifm_memtile_mm2s(dims: SliceDims, col: int, row: int) -> List[str]:
    def Xi_slice(row: int, X_split) -> Tuple[int, int, int, int]:
        Xi_size = dims.Xi // X_split
        Xi_start = ((row % X_split) * Xi_size)
        Xi_stop = Xi_start + Xi_size
        return (Xi_start, Xi_stop, Xi_size)     

    def Yi_Xi_slice(dims: SliceDims, row: int, Yi_start: int):
        slice_size = dims.Xi * dims.Yis // dims.aie_rows
        if dims.Yis % 4 == 0:
            X_split = 1
            Yi_split = 4
        elif dims.Yis % 2 == 0:
            X_split = 2
            Yi_split = 2
            assert dims.Xi % 2 == 0,"make new split, current split can't evenly splited by four cores in a column"
        else:
            if dims.Yis == 1:
                X_split = 4
                Yi_split = 1
                assert dims.Xi % 4 == 0, "make new split, current split can't evenly splited by four cores in a column"
            else:
                raise ValueError(f'current split Yis{dims.Yis} can not satisfy the cores in one column split from memtile ')
 
        Xi_start, Xi_stop, Xi_size = Xi_slice(row, X_split)
        row_size_start = slice_size * row
        row_size_stop  = slice_size * (row +1)
        Yi_start = (row // (dims.aie_rows // Yi_split) ) * (dims.Yis // Yi_split)
        Yi_stride = slice_size // (dims.Xi // X_split)
        Yi_stride = 1 if Yi_stride == 0 else Yi_stride
        Yi_stop = Yi_start + Yi_stride
       
        return (Yi_start, Yi_stop, Xi_start, Xi_stop)
            
    def fmt(start_iter: int) -> str:
        Yi_start, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        Yi_start, Yi_stop, Xi_start, Xi_stop = Yi_Xi_slice(dims, row, Yi_start)

        Xi_stride = dims.Xi
        Yi_stride = dims.Yis
        
        return (
            f'Yi:0:{Yi_stride}:{Yi_stride} '
            f'Xi:0:{Xi_stride}:{Xi_stride} '
            f'Yi:{Yi_start}:{Yi_stop} Xi:{Xi_start}:{Xi_stop}'
        )
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ofm_memtile_mm2s(dims: SliceDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yo_size = Yo_slice(dims, col, start_iter)
        return f'Yi:0:{Yo_size} Xi:0:{dims.w_out}'
    fs = [fmt(s) for s, _ in Yo_split_iters(dims, col)]
    return fs

def ofm_shim_memory(dims: SliceDims) -> str:
    return f'Yo:{dims.h_out} Xo:{dims.w_out}'

def ofm_shim_s2mm(dims: SliceDims, col: int) -> List[str]:
    def fmt(start_iter: int, num_iters: int) -> str:
        Yo_start, Yo_stop, Yo_stride, _ = Yo_slice(dims, col, start_iter)
        return (
            f'Yo:0:{Yo_stride * num_iters}:{Yo_stride} '
            f'Yo:{Yo_start}:{Yo_stop} Xo:0:{dims.w_out}'
        )
    fs = [fmt(s, n) for s, n in Yo_split_iters(dims, col)]
    return fs

def gen_params(subv_elems: int, qdq_prm_addr: int):
    return ( subv_elems.to_bytes(length=2, byteorder='little', signed=False)
    + qdq_prm_addr.to_bytes(length=2, byteorder='little', signed=False)
    )

def compile_dataflow(
    dims: SliceDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):
    # assert (dims.aie_cols, dims.aie_rows) == (8, 4)

    slice_shim_alloc = shim_alloc()
    
    ifm_memtile_size = (dims.h_in * (dims.w_out_stop - dims.w_out_start) * dims.ifm_bits) // dims.aie_cols // dims.num_splits  // 8
    ofm_memtile_size = ifm_memtile_size
    wgt_memtile_size = dims.wgt_subv_size   
    
    
    param_memtile_addr = 0
    param_memtile_size = compute_buffer_size(prm_memtile_memory(dims))
    wgt_memtile_addr = param_memtile_addr + param_memtile_size

    ifm_memtile_ping_addr = param_memtile_addr + param_memtile_size
    ifm_memtile_pong_addr = ifm_memtile_ping_addr + ifm_memtile_size
    ofm_memtile_addr_ping = ifm_memtile_pong_addr + ifm_memtile_size
    ofm_memtile_addr_pong = ofm_memtile_addr_ping + ifm_memtile_size

    ifm_memtile_addrs = [ifm_memtile_ping_addr, ifm_memtile_pong_addr]
    ofm_memtile_addrs = [ofm_memtile_addr_ping, ofm_memtile_addr_pong]
    wgt_memtile_addrs = [wgt_memtile_addr]
    
    ifm_memtile_repeat_scale = 1
    
    CoreIfmSize = dims.ifm_subv_elem * dims.ifm_bits // 8
    CoreOfmSize = dims.ifm_subv_elem * dims.ifm_bits // 8
    CoreqdqPrmSize = dims.wgt_subv_size
    
    CoreQdqPrmPingAddr  = 0
    CoreQdqPrmPongAddr  = CoreQdqPrmPingAddr + CoreqdqPrmSize
       
    CoreIfmPingAddr = CoreQdqPrmPongAddr + CoreqdqPrmSize
    CoreIfmPongAddr = CoreIfmPingAddr + CoreIfmSize
    CoreOfmPingAddr = CoreIfmPongAddr + CoreIfmSize
    CoreOfmPongAddr = CoreOfmPingAddr + CoreOfmSize

    Tn = (dims.h_in * dims.w_out // dims.aie_cols // dims.aie_rows) // dims.ifm_subv_elem
    
    run_kernel = 'run_int16_negative'
    
    core_instrs = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize),
        Loop(Tn, [
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
            CallKernel(run_kernel, gen_params(dims.ifm_subv_elem, CoreQdqPrmPingAddr)),
            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        ])         
    ]

    memtile_transfers = []
    
    memtile_param_transfers = [
        DataTransfer(
            [1] + [0] * (dims.Y_loop - 1),
            AieTile(TileType.Memtile, col),
            [param_memtile_addr],
            param_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
        ) for col in range(dims.aie_cols)
    ]
    memtile_transfers += memtile_param_transfers
    
    memtile_ifm_transfers = [
        DataTransfer(
            Yi_repeat_counts(dims, col, ifm_memtile_repeat_scale),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims,  col),
                ifm_memtile_s2mm(dims, col),
                [n for _, n in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            )],
            [
                pack_transfers(
                memtile_dma(col, DmaDir.MM2S, row),
                ifm_memtile_memory(dims, col),
                ifm_memtile_mm2s(dims, col, row),
                [i for _, i in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            ) for row in range(dims.aie_rows)
            ],
        )  for col in range(dims.aie_cols)
    ]
    memtile_transfers += memtile_ifm_transfers
    
    memtile_wgt_transfers = []
    memtile_transfers += memtile_wgt_transfers

    memtile_ofm_transfers = [
        DataTransfer(
            Yi_repeat_counts(dims, col, ifm_memtile_repeat_scale),
            AieTile(TileType.Memtile, col), ofm_memtile_addrs, ifm_memtile_size,
            [
                pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                ifm_memtile_memory(dims, col),
                ifm_memtile_mm2s(dims, col, row),
                [i for _, i in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            ) for row in range(dims.aie_rows)
            ],
            [
                pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ofm_memtile_memory(dims, col),
                ofm_memtile_mm2s(dims, col),
                [n for _, n in Yo_split_iters(dims, col)],
                dims.ifm_bits,)
            ],
        ) for col in range(dims.aie_cols)
    ]
    memtile_transfers += memtile_ofm_transfers
    
    shim_transfers = []

    shim_param_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.MM2S, 0),
            slice_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        ) for col in range(dims.aie_cols)
    ]
    shim_transfers += shim_param_transfers 
    
    shim_ifm_transfers = [
        generate_shim_data_transfer(
            ifm_shim_repeat_counts(dims, idx),
            shim_dma(col, DmaDir.MM2S, 0),
            slice_shim_alloc.ifm_buffer_id,
            ifm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
        ) for col in range(dims.aie_cols) for idx, fmt in enumerate(ifm_shim_mm2s(dims, col))
    ]
    shim_transfers += shim_ifm_transfers  
    
    shim_wgt_transfers = []
    shim_transfers += shim_wgt_transfers
    
    shim_ofm_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.S2MM, 0),
            slice_shim_alloc.ofm_buffer_id,
            ofm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
        ) for col in range(dims.aie_cols) for fmt in ofm_shim_s2mm(dims, col)
    ]
    shim_transfers += shim_ofm_transfers 

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections(),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
    )

def main():
    back_end = BackEnd.Adf
    kernel_names = ['run_int16_negative']
    kernel_includes = ['super.hh', 'qdq/wrapper_qdq.cc']
    aie_cols, aie_rows = 8, 4
    h_in = 4000
    w_in = 5000
    w_out_start = 1300
    w_out_stop = 2708
    ifm_bits = 16
    
    dims = SliceDims(
        aie_rows, aie_cols, 
        h_in, w_in,
        w_out_start, w_out_stop,
        ifm_bits,
    )

    clean_overlay()
    compile_dataflow(
        dims,
        back_end,
        kernel_names,
        kernel_includes
    )
    build_sim_overlay(back_end, 'slice_neg_main.cpp', slice_neg_preproc_directives(dims, back_end))

if __name__ == '__main__':
    main()