import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels'))
from typing import Tuple, List

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, SyncStrategy, \
    AieTile, TileType, \
    AieDma, DmaDir, memtile_dma, shim_dma, core_dma, DmaChannel, \
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    compute_buffer_size, \
    TransferParams, generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_4x4_dma_connections, \
    overlay_stack_addr, \
    clean_overlay, \
    build_sim_overlay, \
    shim_alloc, \
    prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s, \
    ceildiv
    
from q_dq_common import \
    QDQDims, \
    q_dq_preproc_directives


set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param

def Yi_slice(dims: QDQDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yi_split = dims.Yos
    Yi_stride = dims.aie_cols * Yi_split
    Yi_start = (col * Yi_split) + (start_iter * Yi_stride)
    Yi_stop = Yi_start + dims.Yis if Yi_start <= dims.total_Y else Yi_start
    Yi_size = max(0, min(Yi_stop, dims.total_Y)) - max(0, min(Yi_start, dims.total_Y))
    return (Yi_start, Yi_stop, Yi_stride, Yi_size)

def Yi_split_iters(dims: QDQDims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        Yi_start, Yi_stop, _, _ = Yi_slice(dims, col, start_iter + num_iters - 1)
        has_no_padding = not ((Yi_start < 0) or (Yi_stop > dims.total_Y))
        return has_no_padding 
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def Yo_slice(dims: QDQDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yo_stride = dims.aie_cols * dims.Yos
    Yo_start = (col * dims.Yos) + (start_iter * Yo_stride)
    Yo_stop = min(Yo_start + dims.Yos, dims.total_Y) if Yo_start <= dims.total_Y else Yo_start
    Yo_size = Yo_stop - Yo_start
    return (Yo_start, Yo_stop, Yo_stride, Yo_size)

def Yo_split_iters(dims: QDQDims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        _, _, _, Yo_size = Yo_slice(dims, col, start_iter + num_iters - 1)
        is_full_slice = Yo_size == dims.Yos
        return is_full_slice
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def ifm_shim_repeat_counts(dims: QDQDims, idx: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    repeat_counts[0] = 1
    return repeat_counts

def Yi_repeat_counts(dims: QDQDims, col: int, scale: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    for start_iter, num_iters in Yi_split_iters(dims, col):
        repeat_counts[start_iter] = num_iters
    return repeat_counts

def ifm_shim_memory(dims: QDQDims) -> str:
    return f'Yi:{dims.total_Y} Xi:{dims.total_X} Ci:{dims.subv_elem}'

def ifm_shim_mm2s(dims: QDQDims, col: int) -> List[str]:
    def fmt(start_iter: int, num_iters: int) -> str:
        Yi_start, Yi_stop, Yi_stride, _ = Yi_slice(dims, col, start_iter)
        return (
            f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
            f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.total_Y)} Xi Ci:0:{dims.subv_elem}'
        )
    fs = [fmt(s, n) for s, n in Yi_split_iters(dims, col)]
    return fs

def wgt_shim_memory(dims: QDQDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'

def wgt_shim_mm2s(dims: QDQDims) -> str:
    return f'Subv'

def wgt_memtile_memory(dims: QDQDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'

def wgt_memtile_s2mm(dims: QDQDims) -> str:
    return f'Subv'

def wgt_memtile_mm2s(dims: QDQDims) -> str:
    return f'Subv'

def ifm_memtile_memory(dims: QDQDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        if Yi_size <= 0:
            Yi_size = dims.Yis
        return f'Yi:{Yi_size} Xi:{dims.total_X} Ci:{dims.subv_elem}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs


def ifm_memtile_s2mm(dims: QDQDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        return f'Yi:0:{Yi_size} Xi:0:{dims.total_X} Ci:0:{dims.subv_elem}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs


ofm_memtile_memory = ifm_memtile_memory

def ifm_memtile_mm2s(dims: QDQDims, col: int, row: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        Yi_start, _, _, _ = Yi_slice(dims, col, start_iter)
        if dims.total_X == 1:
            Xi_start, Xi_stop = 0, 1
        else:
            Xi_start = row
            Xi_stop = Xi_start + 1
        return (
            f'Yi:{min(Yi_start, 0)}:{min(Yi_start, 0) + dims.Yis} Xi:{Xi_start}:{Xi_stop} Ci:0:{dims.subv_elem}'
        )
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ofm_memtile_mm2s(dims: QDQDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yo_size = Yo_slice(dims, col, start_iter)
        return f'Yi:0:{Yo_size} Xi:0:{dims.total_X} Ci:0:{dims.subv_elem}'
    fs = [fmt(s) for s, _ in Yo_split_iters(dims, col)]
    return fs

def ofm_shim_memory(dims: QDQDims) -> str:
    return f'Yo:{dims.total_Y} Xo:{dims.total_X} Co:{dims.subv_elem}'

def ofm_shim_s2mm(dims: QDQDims, col: int) -> List[str]:
    def fmt(start_iter: int, num_iters: int) -> str:
        Yo_start, Yo_stop, Yo_stride, _ = Yo_slice(dims, col, start_iter)
        return (
            f'Yo:0:{Yo_stride * num_iters}:{Yo_stride} '
            f'Yo:{Yo_start}:{Yo_stop} Xo:0:{dims.total_X} Co:0:{dims.subv_elem}'
        )
    fs = [fmt(s, n) for s, n in Yo_split_iters(dims, col)]
    return fs
    
def iceil(x: int, d: int) -> int:
    '''Integer ceiling function'''
    return ceildiv(x, d) * d


def gen_qdq_params_a8(subv_elems: int, index: int, quant_offset: int, scratch_buffer_addr: int, fixed_point_bits: int, qdq_mode: int, output_addr: int):
    is_int16 = 0 if fixed_point_bits == 8 else 1
    
    dq_zp_elem_idx = 2 * index
    dq_sc_elem_idx = dq_zp_elem_idx + 1

    q_zp_elem_idx = quant_offset
    q_sc_elem_idx = quant_offset + 1

    dq_enable_idx = quant_offset + 2
    q_enable_idx = quant_offset + 3

    scratch_buffer_addr_core = scratch_buffer_addr if (qdq_mode == 2 and is_int16 == False) else output_addr


    return ( subv_elems.to_bytes(length=2, byteorder='little', signed=False)
    + scratch_buffer_addr_core.to_bytes(length=2, byteorder='little', signed=False)
    + is_int16.to_bytes(length=2, byteorder='little', signed=False)
    + dq_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + dq_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + q_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + q_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + dq_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
    + q_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
    )


def gen_qdq_params(subv_elems: int, qdq_prm_addr: int, index: int, quant_offset: int):

        dq_zp_elem_idx = 2 * index
        dq_sc_elem_idx = dq_zp_elem_idx + 1

        q_zp_elem_idx = quant_offset
        q_sc_elem_idx = quant_offset + 1

        dq_enable_idx = quant_offset + 2
        q_enable_idx = quant_offset + 3
    

        return ( subv_elems.to_bytes(length=2, byteorder='little', signed=False)
        + dq_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + dq_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + q_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + q_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + dq_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
        + q_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
        )

def compile_dataflow(
    dims: QDQDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):

    slice_shim_alloc = shim_alloc()

    ifm_memtile_size = dims.Yis * dims.total_X * dims.subv_elem * (dims.ifm_bits // 8)
    ofm_memtile_size = dims.Yis * dims.total_X * dims.subv_elem * (dims.ofm_bits // 8)

    wgt_memtile_size = dims.wgt_subv_size

    param_memtile_addr = 0
    param_memtile_size = compute_buffer_size(prm_memtile_memory(dims))
    wgt_memtile_addr = param_memtile_addr + param_memtile_size
    wgt_memtile_addrs = [wgt_memtile_addr]

    available_memtile_size = config.MAX_MEMTILE_ADDR - \
        (wgt_memtile_size + param_memtile_size)

    if ifm_memtile_size * 2 + (ofm_memtile_size * 2) < available_memtile_size:
        memtile_ifm_pingpong = True
        memtile_ofm_pingpong = True
    elif ifm_memtile_size * 2  + (ofm_memtile_size) < available_memtile_size:
        memtile_ifm_pingpong = True
        memtile_ofm_pingpong = False
    elif ifm_memtile_size + (ofm_memtile_size * 2) < available_memtile_size:
        memtile_ifm_pingpong = False
        memtile_ofm_pingpong = True
    else:
        memtile_ifm_pingpong = False
        memtile_ofm_pingpong = False

    ifm_memtile_ping_addr = wgt_memtile_addr + wgt_memtile_size
    if memtile_ifm_pingpong:
        ifm_memtile_pong_addr = ifm_memtile_ping_addr + ifm_memtile_size
        ifm_memtile_addr = ifm_memtile_pong_addr
        ifm_memtile_addrs = [ifm_memtile_ping_addr, ifm_memtile_pong_addr]
    else:
        ifm_memtile_pong_addr = None
        ifm_memtile_addr = ifm_memtile_ping_addr
        ifm_memtile_addrs = [ifm_memtile_ping_addr]

    ofm_memtile_ping_addr = ifm_memtile_addr + ifm_memtile_size
    if memtile_ofm_pingpong:
        ofm_memtile_pong_addr = ofm_memtile_ping_addr + ofm_memtile_size
        ofm_memtile_addrs = [ofm_memtile_ping_addr, ofm_memtile_pong_addr]
    else:
        ofm_memtile_pong_addr = None
        ofm_memtile_addrs = [ofm_memtile_ping_addr]

    assert ofm_memtile_addrs[-1] + ofm_memtile_size < config.MAX_MEMTILE_ADDR

    ifm_memtile_repeat_scale = 1

    CoreIfmSize = dims.subv_size_input
    CoreOfmSize = dims.subv_size_output
    ScratchBufferSize = dims.subv_elem * 2
    CoreqdqPrmSize = dims.wgt_subv_size
    CoreQdqPrmPingAddr = 0

    CoreIfmPingAddr = iceil(CoreQdqPrmPingAddr + CoreqdqPrmSize, 64)
    CoreScratchBufferPingAddr = iceil(CoreIfmPingAddr + CoreIfmSize, 64)
    CoreOfmPingAddr = iceil(CoreScratchBufferPingAddr + ScratchBufferSize, 64)

    CoreIfmPongAddr = max(2 * 16384, iceil(CoreOfmPingAddr + CoreOfmSize, 64))
    CoreOfmPongAddr = iceil(CoreIfmPongAddr + CoreIfmSize, 64)

    if CoreOfmPongAddr + CoreOfmSize > overlay_stack_addr():
        CoreIfmPingAddr = iceil(CoreQdqPrmPingAddr + CoreqdqPrmSize, 64)
        CoreScratchBufferPingAddr = iceil(CoreIfmPingAddr + CoreIfmSize, 64)
        CoreOfmPingAddr = iceil(CoreScratchBufferPingAddr + ScratchBufferSize, 64)
        CoreIfmPongAddr = iceil(CoreOfmPingAddr + CoreOfmSize, 64)
        CoreOfmPongAddr = iceil(CoreIfmPongAddr + CoreIfmSize, 64)
        
    assert CoreOfmPongAddr + CoreOfmSize <= overlay_stack_addr()

    Tn = dims.Y_loop

    run_kernel = 'run_combined_qdq' if dims.fixed_point_bits == 16 else 'run_combined_qdq_a8'
    kernel_params = gen_qdq_params(iceil(dims.subv_elem, 64), CoreQdqPrmPingAddr, 0, 2) if dims.fixed_point_bits == 16 else gen_qdq_params_a8(iceil(dims.subv_elem, 64), 0, 2, CoreScratchBufferPingAddr, dims.fixed_point_bits, dims.qdq_mode, CoreOfmPingAddr)
    core_instrs = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreQdqPrmPingAddr, None, dims.wgt_subv_size),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        Loop(Tn, [
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
            CallKernel(run_kernel, kernel_params),
            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        ])
    ]

    memtile_transfers = []
    memtile_param_transfers = [
        DataTransfer(
            [1] + [0] * (dims.Y_loop - 1),
            AieTile(TileType.Memtile, col),
            [param_memtile_addr],
            param_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row), 
            ) for row in range(dims.aie_rows)],
        ) for col in range(dims.aie_cols)
    ]
    memtile_transfers += memtile_param_transfers
    
    memtile_ifm_transfers = [
        DataTransfer(
            Yi_repeat_counts(dims, col, ifm_memtile_repeat_scale),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims,  col),
                ifm_memtile_s2mm(dims, col),
                [n for _, n in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            )],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, row),
                ifm_memtile_memory(dims, col),
                ifm_memtile_mm2s(dims, col, row),
                [i for _, i in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            ) for row in range(dims.aie_rows)
            ],
        )  for col in range(dims.aie_cols)
    ]
    memtile_transfers += memtile_ifm_transfers
    
    memtile_wgt_transfers = [
        DataTransfer(
            [1] + [0] * (dims.Y_loop - 1),
            AieTile(TileType.Memtile, col), wgt_memtile_addrs, wgt_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(dims),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims),
            )],
        ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols==8 else 1))
    ]
    memtile_transfers += memtile_wgt_transfers

    memtile_ofm_transfers = [
        DataTransfer(
            Yi_repeat_counts(dims, col, ifm_memtile_repeat_scale),
            AieTile(TileType.Memtile, col), ofm_memtile_addrs, ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                ifm_memtile_memory(dims, col),
                ifm_memtile_mm2s(dims, col, row),
                [i for _, i in Yi_split_iters(dims, col)],
                dims.ofm_bits,
            ) for row in range(dims.aie_rows)
            ],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ofm_memtile_memory(dims, col),
                ofm_memtile_mm2s(dims, col),
                [n for _, n in Yo_split_iters(dims, col)],
                dims.ofm_bits,)],
            sync_strategy = SyncStrategy.Parallel_N_to_1
        ) for col in range(dims.aie_cols)
    ]
    memtile_transfers += memtile_ofm_transfers
    
    shim_transfers = []

    shim_param_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.MM2S, 0),
            slice_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        ) for col in range(dims.aie_cols)
    ]
    shim_transfers += shim_param_transfers 
    
    shim_ifm_transfers = [
        generate_shim_data_transfer(
            ifm_shim_repeat_counts(dims, idx),
            shim_dma(col, DmaDir.MM2S, 0),
            slice_shim_alloc.ifm_buffer_id,
            ifm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
        ) for col in range(dims.aie_cols) for idx, fmt in enumerate(ifm_shim_mm2s(dims, col))
    ]
    shim_transfers += shim_ifm_transfers  
    
    shim_wgt_transfers = [ generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.MM2S, 1), slice_shim_alloc.wgt_buffer_id,
            wgt_shim_memory(dims),
            wgt_shim_mm2s(col),
        ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols == 8 else 1))
    ]
    shim_transfers += shim_wgt_transfers
    
    shim_ofm_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.S2MM, 0),
            slice_shim_alloc.ofm_buffer_id,
            ofm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ofm_bits,
        ) for col in range(dims.aie_cols) for fmt in ofm_shim_s2mm(dims, col)
    ]
    shim_transfers += shim_ofm_transfers 
    
    if dims.aie_cols == 8:
        overlay = overlay_8x4_dma_connections()
    else:
        overlay = overlay_4x4_dma_connections()

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay,
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
    )