import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import Tuple, List

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, \
    AieTile, TileType, \
    AieDma, DmaDir, memtile_dma, shim_dma, \
    TransferParams, generate_transfer_params, \
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    generate_shim_data_transfer, DmaChannel, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_4x4_dma_connections, \
    overlay_stack_addr, \
    clean_overlay, \
    build_sim_overlay, \
    shim_alloc, \
    prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s
    
from resize_common import \
    ResizeDims, \
    resize_preproc_directives
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param

def Yi_slice(dims: ResizeDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yi_split = dims.Yos
    Yi_stride = dims.aie_cols * Yi_split
    Yi_start = (col * Yi_split) + (start_iter * Yi_stride)
    Yi_stop = Yi_start + dims.Yis if Yi_start <= dims.Yi else Yi_start
    Yi_size = max(0, min(Yi_stop, dims.Yi)) - max(0, min(Yi_start, dims.Yi))
    return (Yi_start, Yi_stop, Yi_stride, Yi_size)

def Yi_split_iters(dims: ResizeDims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        Yi_start, Yi_stop, _, _ = Yi_slice(dims, col, start_iter + num_iters - 1)
        has_no_padding = not ((Yi_start < 0) or (Yi_stop > dims.Yi))
        return has_no_padding 
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def Yo_slice(dims: ResizeDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yo_stride = dims.aie_cols * dims.Yos
    Yo_start = (col * dims.Yos) + (start_iter * Yo_stride)
    Yo_stop = min(Yo_start + dims.Yos, dims.Yo) if Yo_start <= dims.Yo else Yo_start
    Yo_size = Yo_stop - Yo_start
    return (Yo_start, Yo_stop, Yo_stride, Yo_size)


def ifm_shim_repeat_counts(y_loop, idx: int) -> List[int]:
    repeat_counts = [0 for _ in range(y_loop)]
    repeat_counts[0] = 1
    return repeat_counts

def Yi_repeat_counts(dims: ResizeDims, col: int, scale: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    for start_iter, num_iters in Yi_split_iters(dims, col):
        repeat_counts[start_iter] = num_iters * scale
    return repeat_counts


def ifm_shim_memory(dims: ResizeDims) -> str:
    return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'

def ifm_shim_mm2s(dims: ResizeDims, col: int) -> List[str]:
    def fmt(start_iter: int, num_iters: int) -> str:
        Yi_start, Yi_stop, Yi_stride, _ = Yi_slice(dims, col, start_iter)
        return (            
            f'Xi:0:{dims.Xi}:{dims.Xi} '
            f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
            f'Ci:0:{dims.Ci}:{dims.Ci} '
            f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Xi:0:{dims.Xi} Ci:0:{dims.Ci}'
        )
    fs = [fmt(s, n) for s, n in Yi_split_iters(dims, col)]
    return fs

def ifm_memtile_memory(dims: ResizeDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        if Yi_size <= 0:
            Yi_size = dims.Yis
        return f'Yi:{Yi_size} Xi:{dims.Xi} Ci:{dims.Ci}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ifm_memtile_s2mm(dims: ResizeDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        return f'Yi:0:{Yi_size} Xi:0:{dims.Xi} Ci:0:{dims.Ci}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

ofm_memtile_memory = ifm_memtile_memory

def ofm_memtile_mm2s(dims: ResizeDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        return f'Yi:0:{Yi_size} Xi:0:{dims.Xi} Ci:0:{dims.Ci}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ofm_shim_memory(dims: ResizeDims) -> str:
    return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'


def ofm_shim_s2mm(dims: ResizeDims, col: int) -> List[str]:
    def fmt(start_iter: int, num_iters: int, i: int) -> str:
        Yo_start, _ , Yo_stride, _ = Yo_slice(dims, col, start_iter)
        Yo_start = Yo_start * dims.num_interpolations
        
        Yo_start_inner, Yo_stop_inner, _, _ = Yo_slice(dims, col, 0)
        Yo_start_inner = Yo_start_inner * dims.num_interpolations
        Yo_stop_inner = Yo_start_inner + dims.num_interpolations
            
        return (
            f'Xo:0:{dims.Xo}:{dims.Xo} '
            f'Yo:{Yo_stride *  dims.num_interpolations * start_iter}:{Yo_stride * num_iters * dims.num_interpolations * (start_iter + 1)}:{Yo_stride *  dims.num_interpolations} '
            f'Co:0:{dims.Co}:{dims.Co} '
            f'Yo:{Yo_start_inner + (i // dims.num_interpolations)}:{Yo_stop_inner}:{dims.num_interpolations} Xo:{i % dims.num_interpolations}:{dims.Xo}:{dims.num_interpolations} Co:0:{dims.Co}'
        )
    split = Yi_split_iters(dims, col)
    num_iters = split[0][1]
    fs = [fmt(iter, 1, i) for iter in range(num_iters) for i in range(dims.num_interpolations ** 2)]
    return fs


def ofm_shim_s2mm_test(dims: ResizeDims, col: int, i: int) -> List[str]:
    def fmt1(start_iter: int, num_iters: int) -> str:
        Yo_start, Yo_stop, Yo_stride, _ = Yo_slice(dims, col, start_iter)
        Yo_start = Yo_start * dims.num_interpolations
        Yo_stop = Yo_start + dims.num_interpolations
        return (
            f'Xo:0:{dims.Xo}:{dims.Xo} '
            f'Yo:0:{Yo_stride}:{Yo_stride *  dims.num_interpolations} '
            f'Co:0:{dims.Co}:{dims.Co} '
            f'Yo:{Yo_start + (i // dims.num_interpolations)}:{Yo_stop}:{dims.num_interpolations} Xo:{i % dims.num_interpolations}:{dims.Xo}:{dims.num_interpolations} Co:0:{dims.Co}'
        )
    def fmt2(start_iter: int, num_iters: int) -> str:
        Yo_start, Yo_stop, Yo_stride, _ = Yo_slice(dims, col, start_iter)
        Yo_start = Yo_start * dims.num_interpolations
        Yo_stop = Yo_start + dims.num_interpolations
        return (
            f'Xo:0:{dims.Xo}:{dims.Xo} '
            f'Yo:{Yo_stride}:{Yo_stride * num_iters}:{Yo_stride *  dims.num_interpolations} '
            f'Co:0:{dims.Co}:{dims.Co} '
            f'Yo:{Yo_start + (i // dims.num_interpolations)}:{Yo_stop}:{dims.num_interpolations} Xo:{i % dims.num_interpolations}:{dims.Xo}:{dims.num_interpolations} Co:0:{dims.Co}'
        )
    fs = [fmt1(0, 1), fmt2(1, 1) ] # ToBedebug
    return fs

def is_2d_list(variable):
    # Check if the variable is a list of lists
    if isinstance(variable, list):
        if not variable:
            return False
        else:
            return all(isinstance(sublist, list) for sublist in variable)
    return False

def transfer_consolidate(
    dims: ResizeDims,
    prm_memtile_transfers: list,
    ifm_memtile_transfers: list,
    prm_shim_transfers: list,
    ifm_shim_transfers: list,
    ofm_shim_transfers: list,
    ):
    """
    """
    def fmt(transfer: DataTransfer, repeat_len: int):
        tmp = transfer._length
        length = ([tmp]* repeat_len if isinstance(tmp, int) else tmp)                    
        tmp = transfer._offset
        offset = ([tmp]* repeat_len if isinstance(tmp, int) else tmp)                  
        a_list = transfer._step    
        is_2d_list = isinstance(a_list, list) and all(isinstance(sublist, list) for sublist in a_list)
        if is_2d_list:
            assert len(a_list) == repeat_len                            
        tmp = transfer._padding
        padding = (tmp if is_2d_list else [tmp[:] for _ in range(repeat_len)])                
        tmp = transfer._step
        step = (tmp if is_2d_list else [tmp[:] for _ in range(repeat_len)])                
        tmp = transfer._wrap
        wrap = (tmp if is_2d_list else [tmp[:] for _ in range(repeat_len)])
        return(length, offset, padding, step, wrap)
   
    def match_lenth(
                prm_memtile_transfers: list,
                ifm_memtile_transfers: list,
                prm_shim_transfers: list,
                ifm_shim_transfers: list,
                ofm_shim_transfers: list,    
    ):
        consolidated_len =0
        for _, transfers in enumerate([prm_memtile_transfers, ifm_memtile_transfers, prm_shim_transfers, ifm_shim_transfers, ofm_shim_transfers]):
            assert all(len(transfers[n].read_params) == len(transfers[0].read_params) for n in range(len(transfers)))
            consolidated_len = max(len(transfers[0].repeat_counts), consolidated_len)
        for _, transfers in enumerate([prm_memtile_transfers, ifm_memtile_transfers, prm_shim_transfers, ifm_shim_transfers, ofm_shim_transfers]):
            if (consolidated_len > len(transfers[0].repeat_counts)) :
                for transfer in transfers:
                    for param in transfer.read_params + transfer.write_params:
                        if isinstance(param._length, list):
                            param._length += [param._length[0]] * (consolidated_len - len(transfer.repeat_counts) )
                        if isinstance(param._offset, list):
                            param._offset += [param._offset[0]] * (consolidated_len - len(transfer.repeat_counts) )
                        if is_2d_list(param._padding):
                            param._padding += [param._padding[0]] * (consolidated_len - len(transfer.repeat_counts) )  
                        if is_2d_list(param._step):
                            param._step += [param._step[0]] * (consolidated_len - len(transfer.repeat_counts) )
                        if is_2d_list(param._wrap):
                            param._wrap += [param._wrap[0]] * (consolidated_len - len(transfer.repeat_counts) )  
                    transfer.repeat_counts += [0] * (consolidated_len - len(transfer.repeat_counts) )         
 
    def consolidate(transfer_params: list):
        assert all(len(transfer_params[n].read_params) == len(transfer_params[0].read_params) for n in range(len(transfer_params)))
        assert all(len(transfer_params[n].write_params) == len(transfer_params[0].write_params) for n in range(len(transfer_params)))
        assert all(isinstance(transfer_params[n].repeat_counts, list) for n in range(len(transfer_params)))
        transfer_consolidated = [None] * dims.aie_cols
        col_track = 0
        for _, transfer in enumerate(transfer_params):
            if col_track == transfer.tile.col and col_track < dims.aie_cols:
                transfer_consolidated[col_track] = transfer
                repeat_len = len(transfer_consolidated[col_track].repeat_counts)
                for parms in (transfer_consolidated[col_track].read_params + transfer_consolidated[col_track].write_params):
                    length, offset, padding, step, wrap = fmt(parms, repeat_len)
                    parms._length = length
                    parms._offset = offset
                    parms._padding = padding
                    parms._step = step
                    parms._wrap = wrap
                col_track += 1
        col_track =[0] * dims.aie_cols
        for _, transfer in enumerate(transfer_params):
            col_index = transfer.tile.col
            if col_track[col_index] == 0:
                col_track[col_index] = 1
            else: 
                transfer_consolidated[col_index].repeat_counts += transfer.repeat_counts
                repeat_len = len(transfer.repeat_counts)
                for parms, consolidated in zip((transfer.read_params + transfer.write_params),
                                               (transfer_consolidated[col_index].read_params + transfer_consolidated[col_index].write_params)):
                    length, offset, padding, step, wrap = fmt(parms, repeat_len)
                    consolidated._length += length
                    consolidated._offset += offset
                    consolidated._padding += padding
                    consolidated._step += step
                    consolidated._wrap += wrap
        return transfer_consolidated
   
    ifm_mt_consolidated = consolidate(ifm_memtile_transfers)
    ifm_shim_consolidated = consolidate(ifm_shim_transfers)
    ofm_shim_consolidated = consolidate(ofm_shim_transfers)
   
    match_lenth(prm_memtile_transfers,
                ifm_mt_consolidated,
                prm_shim_transfers,
                ifm_shim_consolidated,
                ofm_shim_consolidated)
    
    return (ifm_mt_consolidated, ifm_shim_consolidated, ofm_shim_consolidated)

def compile_dataflow(
    dims: ResizeDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):
    if dims.aie_cols == 8:
        overlay = overlay_8x4_dma_connections()
    elif dims.aie_cols == 4:
        overlay = overlay_4x4_dma_connections()
        
    nni_shim_alloc = shim_alloc()

    ifm_bytes = dims.ifm_bits  // 8

    param_memtile_size = dims.aie_rows * config.MAX_CORE_LAYER_PARAM_SIZE
    memtile_size = dims.Nis * dims.Yis * dims.Xis * dims.Cis * ifm_bytes

    param_memtile_addr = 0
    memtile_ping_addr = param_memtile_addr + param_memtile_size
    
    ifm_memtile_repeat_scale = 1
    core_instrs = [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), 0, 0, 0),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ]

    memtile_transfers = []
    memtile_param_transfers = [
        DataTransfer(
            [1] + [0] * (dims.Y_loop - 1),
            AieTile(TileType.Memtile, col),
            [param_memtile_addr],
            param_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
        ) for col in range(dims.aie_cols)
    ]
    
    memtile_ifm_transfers = [
        DataTransfer(
            Yi_repeat_counts(dims, col, ifm_memtile_repeat_scale),
            AieTile(TileType.Memtile, col),
            [memtile_ping_addr],
            memtile_size,
            [
                pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims,  col),
                ifm_memtile_s2mm(dims, col),
                [n for _, n in Yi_split_iters(dims, col)],
                dims.ifm_bits,)
            ],
            [
                pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ifm_memtile_memory(dims, col),
                ofm_memtile_mm2s(dims, col),
                [n for _, n in Yi_split_iters(dims, col)],
                dims.ifm_bits,)
            ],
            reuse_ratio=(dims.num_interpolations ** 2)
        ) for col in range(dims.aie_cols)
    ]

    shim_transfers = []
    shim_param_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.MM2S, 0),
            nni_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        ) for col in range(dims.aie_cols)
    ]

    shim_ifm_transfers = [
        generate_shim_data_transfer(
            ifm_shim_repeat_counts(dims.Y_loop, idx),
            shim_dma(col, DmaDir.MM2S, 0),
            nni_shim_alloc.ifm_buffer_id,
            ifm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
        ) for col in range(dims.aie_cols) for idx, fmt in enumerate(ifm_shim_mm2s(dims, col))
    ]

    shim_ofm_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.S2MM, 0),
            nni_shim_alloc.ofm_buffer_id,
            ofm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
        ) for col in range(dims.aie_cols)  
        for fmt in ofm_shim_s2mm(dims, col) 
    ] 

    (memtile_ifm_consolidated_transfers,
     shim_ifm_consolidated_transfers,
     shim_ofm_consolidated_transfers )= transfer_consolidate(dims, 
                                                             memtile_param_transfers, memtile_ifm_transfers, 
                                                             shim_param_transfers, shim_ifm_transfers, shim_ofm_transfers)
    memtile_transfers += memtile_param_transfers
    memtile_transfers += memtile_ifm_consolidated_transfers
   
    shim_transfers += shim_param_transfers
    shim_transfers += shim_ifm_consolidated_transfers
    shim_transfers += shim_ofm_consolidated_transfers
    # pylint: disable=possibly-used-before-assignment
    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay,
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
    )

def main():
    back_end = BackEnd.Adf
    kernel_names = []
    kernel_includes = ['super.hh']
    aie_cols, aie_rows = 8  , 4
    h_in = 16
    w_in = 16
    c_in = 1280
    num_interpolations = 4
    ifm_bits = 16
    int_16 = 1
    bfloat_16 = 0  
    
    dims = ResizeDims(
        aie_rows, aie_cols, 
        h_in, w_in, c_in,
        num_interpolations,
        ifm_bits,
        int_16, bfloat_16
    )

    clean_overlay()
    compile_dataflow(dims, back_end, kernel_names, kernel_includes)
    build_sim_overlay(back_end, 'resize_main.cpp', resize_preproc_directives(dims, back_end))

if __name__ == '__main__':
    main()