import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels'))
from typing import Tuple, List
import math
import struct
import numpy as np

from dmacompiler import (
    OverlayShape, BackEnd,
    DataTransfer, SyncStrategy,
    AieTile, TileType,
    AieDma, DmaDir, memtile_dma, shim_dma, core_dma, DmaChannel,
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop,
    compute_buffer_size,
    TransferParams, generate_transfer_params,
    generate_shim_data_transfer,
    run_layer_compilation,
    set_dev_gen, DevGen, config
    )

from dataflow_common import (
    overlay_8x4_dma_connections,
    overlay_stack_addr,
    clean_overlay,
    build_sim_overlay,
    shim_alloc,
    prm_shim_memory,
    prm_shim_mm2s,
    prm_memtile_memory,
    prm_memtile_s2mm,
    prm_memtile_mm2s,
    ceildiv,
    iceil
    )

from slice_common import (
    SliceDims,
    slice_preproc_directives,
    make_slice_dict,
    split_cost
    )

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

MAX_SUBV_SIZE = 8192


def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
    buffer_offset: list = 0,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
            buffer_offset = buffer_offset,
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param


def Yi_repeat_counts(dims: SliceDims, dims_phase) -> List[int]:
    repeat_counts = [1 for _ in range(dims.No * dims_phase)]
    return repeat_counts


def ifm_memtile_memory(dims: SliceDims, col: int, mt_ifm_transfer, dims_phase) -> List[str]:
    def fmt(phase: int) -> str:
        if dims.is_kernel:
            split = mt_ifm_transfer['mt_ifm_mem'][col][phase]
        else:
            split = mt_ifm_transfer['mt_ifm_mem'][col][phase]
        if split:
            Yi_size = split[0]
            Xi_size = split[1]
            Ci_size = split[2]
        else:
            Yi_size = 1
            Xi_size = 1
            Ci_size = 1
        return f'Ni:{dims.Ni_gran} Yi:{Yi_size} Xi:{Xi_size} Ci:{Ci_size}'
    fs = [fmt(phase)  for _ in range(dims.No) for phase in range(dims_phase)]
    return fs


def ifm_memtile_s2mm(dims: SliceDims, col: int, mt_ifm_transfer, dims_phase) -> List[str]:
    def fmt(phase: int) -> str:
        split=[]
        if mt_ifm_transfer['mt_ifm_s2mm'][col]:
            split = mt_ifm_transfer['mt_ifm_s2mm'][col][phase]
        if split:
            Y_start = split[0][0]
            Y_stop  = split[0][1]
            X_start = split[1][0]
            X_stop  = split[1][1]
            C_start = split[2][0]
            C_stop  = split[2][1]
        else:
            Y_start = 0
            Y_stop  = 0
            X_start = 0
            X_stop  = 0
            C_start = 0
            C_stop  = 0
        return f'Ni:0:{dims.Ni_gran} Yi:{Y_start}:{Y_stop} Xi:{X_start}:{X_stop} Ci:{C_start}:{C_stop}'
    fs = [fmt(phase)  for _ in range(dims.No) for phase in range(dims_phase)]
    return fs


def ifm_memtile_mm2s(dims: SliceDims, col: int, row: int, mt_ifm_transfer, dims_phase) -> List[str]:
    def fmt(phase: int) -> str:
        split = []
        if dims.is_kernel:
            if mt_ifm_transfer['mt_ifm_mm2s'][col]:
                split = mt_ifm_transfer['mt_ifm_mm2s'][col][row][phase]
        else:
            if mt_ifm_transfer['mt_ifm_mm2s'][col]:
                split = mt_ifm_transfer['mt_ifm_mm2s'][col][phase]
        if split:
            Y_start = split[0][0]
            Y_stop  = split[0][1]
            X_start = split[1][0]
            X_stop  = split[1][1]
            C_start = split[2][0]
            C_stop  = dims.Cop if dims.enable_padding and dims.axis == 3 and not dims.is_kernel else split[2][1]
        else:
            Y_start = 0
            Y_stop  = 0
            X_start = 0
            X_stop  = 0
            C_start = 0
            C_stop  = 0
        return f'Ni:0:{dims.Ni_gran} Yi:{Y_start}:{Y_stop} Xi:{X_start}:{X_stop} Ci:{C_start}:{C_stop}'
    fs = [fmt(phase)  for _ in range(dims.No)  for phase in range(dims_phase)]
    return fs


def ofm_memtile_memory(dims: SliceDims, col: int, mt_ofm_transfer, dims_phase) -> List[str]:
    def fmt(phase: int) -> str:
        split = mt_ofm_transfer['mt_ofm_mem'][col][phase]
        if split:
            Yi_size = split[0]
            Xi_size = split[1]
            Ci_size = dims.Com if dims.kernel_padding and dims.axis == 3 else split[2]
        else:
            Yi_size = 1
            Xi_size = 1
            Ci_size = 1
        return f'Ni:{dims.Ni_gran} Yi:{Yi_size} Xi:{Xi_size} Ci:{Ci_size}'
    fs = [fmt(phase)  for _ in range(dims.No) for phase in range(dims_phase)]
    return fs


def ofm_memtile_s2mm(dims: SliceDims, col: int, row: int, mt_ofm_transfer, dims_phase) -> List[str]:
    def fmt(phase: int) -> str:
        split = []
        if mt_ofm_transfer['mt_ofm_s2mm'][col]:
            split = mt_ofm_transfer['mt_ofm_s2mm'][col][row][phase]
        if split:
            Y_start = split[0][0]
            Y_stop  = split[0][1]
            X_start = split[1][0]
            X_stop  = split[1][1]
            C_start = split[2][0]
            C_stop  = dims.Cop if dims.kernel_padding and dims.axis == 3 else split[2][1]
        else:
            Y_start = 0
            Y_stop  = 0
            X_start = 0
            X_stop  = 0
            C_start = 0
            C_stop  = 0
        return f'Ni:0:{dims.Ni_gran} Yi:{Y_start}:{Y_stop} Xi:{X_start}:{X_stop} Ci:{C_start}:{C_stop}'
    fs = [fmt(phase)  for _ in range(dims.No) for phase in range(dims_phase)]
    return fs


def ofm_memtile_mm2s(dims: SliceDims, col: int, mt_ofm_transfer, dims_phase) -> List[str]:
    def fmt(phase: int) -> str:
        split = []
        if mt_ofm_transfer['mt_ofm_mm2s'][col]:
            split = mt_ofm_transfer['mt_ofm_mm2s'][col][phase]
        if split:
            Y_start = split[0][0]
            Y_stop  = split[0][1]
            X_start = split[1][0]
            X_stop  = split[1][1]
            C_start = split[2][0]
            C_stop  = dims.Cop if (dims.enable_padding or dims.kernel_padding) and dims.axis == 3 else split[2][1]
        else:
            Y_start = 0
            Y_stop  = 0
            X_start = 0
            X_stop  = 0
            C_start = 0
            C_stop  = 0
        return f'Ni:0:{dims.Ni_gran} Yi:{Y_start}:{Y_stop} Xi:{X_start}:{X_stop} Ci:{C_start}:{C_stop}'
    fs = [fmt(phase) for _ in range(dims.No) for phase in range(dims_phase)]
    return fs


def Ni_slice_stride(dims: SliceDims, n: int, Ni_slice_start: int = 0, Ni_slice_stop: int = 0) -> Tuple[int, int, int]:
    if Ni_slice_stop == 0:
        Ni_slice_stop = dims.Ni if dims.Ni == dims.No else dims.No
    Ni_split = dims.Ni_gran
    Ni_start = Ni_slice_start + n * Ni_split
    Ni_stop = Ni_start + Ni_split if Ni_start <= Ni_slice_stop else Ni_start
    Ni_size = max(0, min(Ni_stop, Ni_slice_stop)) - max(0, min(Ni_start, Ni_slice_stop))
    return (Ni_start, Ni_stop, Ni_size)


def ifm_shim_repeat_counts(dims: SliceDims, dims_phase) -> List[int]:
    repeat_counts = [1 for _ in range(dims.No * dims_phase)]
    return repeat_counts


def ifm_shim_memory(dims: SliceDims) -> str:
    return f'Ni:{dims.Ni} Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'


def ofm_shim_memory(dims: SliceDims) -> str:
    Co_size = dims.Cop
    return f'No:{dims.No} Yo:{dims.Yo} Xo:{dims.Xo} Co:{Co_size}'


def ifm_shim_mm2s(dims: SliceDims, col: int, shim_transfer, dims_phase) -> List[str]:
    def fmt(phase: int, nn: int) -> str:
        if shim_transfer['shim_ifm'][col]:
            split = shim_transfer['shim_ifm'][col][phase]
            N_start, N_stop, _ = \
                Ni_slice_stride(dims, nn, dims.Ni_slice_start, dims.Ni_slice_stop)
            Y_start = split[0][0]
            Y_stop = split[0][1]
            X_start = split[1][0]
            X_stop = split[1][1]
            C_start = split[2][0]
            C_stop = split[2][1]
        else:
            N_start = 0
            N_stop  = 0
            Y_start = 0
            Y_stop  = 0
            X_start = 0
            X_stop  = 0
            C_start = 0
            C_stop  = 0

        return (
            f'Ni:{N_start}:{N_stop} '
            f'Yi:{Y_start}:{Y_stop} '
            f'Xi:{X_start}:{X_stop} '
            f'Ci:{C_start}:{C_stop}'
        )
    fs = [fmt(phase, nn) for nn in range(dims.No) for phase in range(dims_phase)]
    return fs


def ofm_shim_s2mm(dims: SliceDims, col: int, shim_transfer, dims_phase) -> List[str]:
    def fmt(phase: int, nn: int) -> str:
        split=[]
        if shim_transfer['shim_ofm'][col]:
            split = shim_transfer['shim_ofm'][col][phase]
        N_start, N_stop, _ = Ni_slice_stride(dims, nn, 0, dims.No)
        Y_start = split[0][0]
        Y_stop  = split[0][1]
        X_start = split[1][0]
        X_stop  = split[1][1]
        C_start = split[2][0]
        C_stop  = dims.Cop if (dims.enable_padding or dims.kernel_padding) and dims.axis == 3 else split[2][1]
        return f'No:{N_start}:{N_stop} Yo:{Y_start}:{Y_stop} Xo:{X_start}:{X_stop} Co:{C_start}:{C_stop}'
    fs = [fmt(phase, nn) for nn in range(dims.No) for phase in range(dims_phase)]
    return fs


def wgt_shim_memory(dims: SliceDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'


def wgt_shim_mm2s(dims: SliceDims) -> str:
    return f'Subv'


def wgt_memtile_memory(dims: SliceDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'


def wgt_memtile_s2mm(dims: SliceDims) -> str:
    return f'Subv'


def wgt_memtile_mm2s(dims: SliceDims) -> str:
    return f'Subv'

def align_core_addr(core_addr: int):
    core_addr = ceildiv(core_addr, 64) * 64
    return core_addr


def gen_qdq_params_a8(subv_elems: int, index: int, quant_offset: int, scratch_buffer_addr: int, fixed_point_bits: int, qdq_mode: int, output_addr: int):
    is_int16 = 0 if fixed_point_bits == 8 else 1

    dq_zp_elem_idx = 2 * index
    dq_sc_elem_idx = dq_zp_elem_idx + 1

    q_zp_elem_idx = quant_offset
    q_sc_elem_idx = quant_offset + 1

    dq_enable_idx = quant_offset + 2
    q_enable_idx = quant_offset + 3

    scratch_buffer_addr_core = scratch_buffer_addr if (qdq_mode == 2 and is_int16 == False) else output_addr


    return ( subv_elems.to_bytes(length=2, byteorder='little', signed=False)
    + scratch_buffer_addr_core.to_bytes(length=2, byteorder='little', signed=False)
    + is_int16.to_bytes(length=2, byteorder='little', signed=False)
    + dq_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + dq_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + q_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + q_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + dq_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
    + q_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
    )


def gen_qdq_params(subv_elems: int, qdq_prm_addr: int, index: int, quant_offset: int):

        dq_zp_elem_idx = 2 * index
        dq_sc_elem_idx = dq_zp_elem_idx + 1

        q_zp_elem_idx = quant_offset
        q_sc_elem_idx = quant_offset + 1

        dq_enable_idx = quant_offset + 2
        q_enable_idx = quant_offset + 3


        return ( subv_elems.to_bytes(length=2, byteorder='little', signed=False)
        + dq_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + dq_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + q_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + q_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + dq_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
        + q_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
        )
def kernel_params_range_a8(dims: SliceDims, scratch_buffer_addr: int):
    subv_elems = int(iceil(dims.Cis * dims.Xis * dims.Yis * dims.Ni_gran, 64))

    def ceil( n, d=1 ):
        if d == 1:
            return np.ceil( n )
        else:
            return d * np.ceil( n / d )

    input_dtype = dims.ifm_bits // 8
    output_dtype = dims.ofm_bits // 8


    C_align0 = int(ceil(dims.innerC, max( 1, 4 // output_dtype * ( dims.row_alignment > 1 )))) if dims.innerC else 0
    loop_s1 = int(ceil((C_align0 * dims.Xis * dims.Yis * dims.Ni_gran) / max(1, 4 / output_dtype))) if dims.innerC else 0
    num_s1 = int(C_align0) if dims.innerC else 0
    startC = int(dims.startC * output_dtype) if dims.innerC else 0
    inc_s1 = int((dims.Cis - C_align0 + 1) * output_dtype) if dims.innerC else 0
    inc_O1 = int(( ceil(dims.innerC, dims.row_alignment ) - C_align0 ) * output_dtype + 4) if dims.innerC else 0
    size_1 = int(ceil( ceil( dims.innerC, dims.row_alignment ) * dims.Xis * dims.Yis * dims.Ni_gran * output_dtype / 64 )) if dims.innerC else 0
    mask_1 = int(0xFFFFFFFF >> (( C_align0 - dims.innerC ) * 8 * output_dtype)) if dims.innerC else 0


    struct_fields = (
        int(loop_s1),
        int(startC),
        int(num_s1),
        int(inc_s1),
        int(inc_O1),
        int(size_1),
        int(mask_1)
    )

    format_string = 'IIHhHHI'
    layer_params = (
        scratch_buffer_addr.to_bytes(length=4, byteorder='little', signed=False) +
        subv_elems.to_bytes(length=4, byteorder='little', signed=False) +
        input_dtype.to_bytes(length=4, byteorder='little', signed=False) +
        output_dtype.to_bytes(length=4, byteorder='little', signed=False)
    )
    kernel_params = layer_params + struct.pack(format_string, *struct_fields)
    return kernel_params


def kernel_params_range(dims: SliceDims):
    subv_elems = int(iceil(dims.Cis * dims.Xis * dims.Yis * dims.Ni_gran, 64))

    def ceil( n, d=1 ):
        if d == 1:
            return np.ceil( n )
        else:
            return d * np.ceil( n / d )

    input_dtype = dims.ifm_bits // 8
    output_dtype = dims.ofm_bits // 8


    C_align0 = int(ceil(dims.innerC, max( 1, 4 // output_dtype * ( dims.row_alignment > 1 )))) if dims.innerC else 0
    loop_s1 = int(ceil((C_align0 * dims.Xis * dims.Yis * dims.Ni_gran) / max(1, 4 / output_dtype))) if dims.innerC else 0
    num_s1 = int(C_align0) if dims.innerC else 0
    startC = int(dims.startC * output_dtype) if dims.innerC else 0
    inc_s1 = int((dims.Cis - C_align0 + 1) * output_dtype) if dims.innerC else 0
    inc_O1 = int(( ceil(dims.innerC, dims.row_alignment ) - C_align0 ) * output_dtype + 4) if dims.innerC else 0
    size_1 = int(ceil( ceil( dims.innerC, dims.row_alignment ) * dims.Xis * dims.Yis * dims.Ni_gran * output_dtype / 64 )) if dims.innerC else 0
    mask_1 = int(0xFFFFFFFF >> (( C_align0 - dims.innerC ) * 8 * output_dtype)) if dims.innerC else 0


    struct_fields = (
        int(loop_s1),
        int(startC),
        int(num_s1),
        int(inc_s1),
        int(inc_O1),
        int(size_1),
        int(mask_1)
    )

    format_string = 'IIHhHHI'
    layer_params = (
        subv_elems.to_bytes(length=4, byteorder='little', signed=False)
    )
    kernel_params = layer_params + struct.pack(format_string, *struct_fields)
    return kernel_params

def compile_dataflow(
    dims: SliceDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):
    shim_transfer, mt_ifm_transfer, mt_ofm_transfer = split_cost(dims)
    dims_phase = len(shim_transfer['shim_ifm'][0])

    prm_memtile_size = compute_buffer_size(prm_memtile_memory(dims))
    if dims.is_kernel:
        ifm_mem_tiling_max = mt_ifm_transfer['mt_ifm_mem'][0][0]
        ifm_memtile_size = math.prod(ifm_mem_tiling_max) * \
                            dims.ifm_bits // 8
        ofm_mem_tiling_max = mt_ofm_transfer['mt_ofm_mem'][0][0]
        ofm_mem_tiling_max = (ofm_mem_tiling_max[0], ofm_mem_tiling_max[1], dims.Com )if dims.axis == 3 and \
                            dims.kernel_padding else ofm_mem_tiling_max
        ofm_memtile_size = math.prod(ofm_mem_tiling_max) * \
                            dims.ofm_bits // 8
        wgt_memtile_size = dims.wgt_subv_size
    else:
        ifm_mem_tiling_max = mt_ifm_transfer['mt_ifm_mem'][0][0]
        ifm_memtile_size = math.prod(ifm_mem_tiling_max) * dims.ifm_bits // 8
        ofm_memtile_size = math.prod(ifm_mem_tiling_max) * dims.ofm_bits // 8
        wgt_memtile_size = 0

    max_memtile_size = config.MAX_MEMTILE_ADDR + 1
    prm_memtile_addr = 0
    if dims.is_kernel:
        total_ifm_usable_memtile_size = (
            prm_memtile_size +
            ifm_memtile_size * 2 + # ifm pingpong and ofm use same address
            ofm_memtile_size
        )
        if total_ifm_usable_memtile_size <= max_memtile_size: # pingpong
            wgt_memtile_addr_ping = prm_memtile_addr + prm_memtile_size
            ifm_memtile_addr_ping = wgt_memtile_addr_ping + wgt_memtile_size
            ifm_memtile_addr_pong = ifm_memtile_addr_ping + ifm_memtile_size
            ofm_memtile_addr_ping = ifm_memtile_addr_pong + ifm_memtile_size
            ofm_memtile_addr_pong = None
            ifm_memtile_addrs = [ifm_memtile_addr_ping, ifm_memtile_addr_pong]
            ofm_memtile_addrs = [ofm_memtile_addr_ping]
        else:
            wgt_memtile_addr_ping = prm_memtile_addr + prm_memtile_size
            ifm_memtile_addr_ping = wgt_memtile_addr_ping + wgt_memtile_size
            ofm_memtile_addr_ping = ifm_memtile_addr_ping + ifm_memtile_size
            ifm_memtile_addr_pong = None
            ofm_memtile_addr_pong = None
            ifm_memtile_addrs = [ifm_memtile_addr_ping]
            ofm_memtile_addrs = [ofm_memtile_addr_ping]
        assert ofm_memtile_addr_ping + ofm_memtile_size <= max_memtile_size
    else:
        ifm_memtile_addr_ping = prm_memtile_addr + prm_memtile_size
        ifm_memtile_addr_pong = ifm_memtile_addr_ping + ifm_memtile_size
        ofm_memtile_addr_ping = None
        ofm_memtile_addr_pong = None
        wgt_memtile_addr_ping = None
        assert ifm_memtile_addr_pong + ifm_memtile_size <= max_memtile_size
        ifm_memtile_addrs = [ifm_memtile_addr_ping, ifm_memtile_addr_pong]

    # Shim Buffer Allocation
    slice_shim_alloc = shim_alloc()

    # core kernel buffe
    if dims.is_kernel:
        if dims.fixed_point_bits == 8:
            run_kernel =  'run_slice_a8' if dims.innerC else 'run_combined_qdq_a8'
        elif dims.fixed_point_bits == 16:
            run_kernel =  'run_slice' if dims.innerC else 'run_combined_qdq'
        CoreIfmSize = dims.Yis * dims.Xis * dims.Cis * dims.ifm_bits // 8
        CoreIfmSubvElemsSize = iceil(dims.Yis * dims.Xis * dims.Cis, 64) * dims.ifm_bits // 8
        CoreScratchBufferSize = iceil(dims.Yis * dims.Xis * dims.Cis, 64) * 2 if dims.has_scratch_buf else 0

        """Note
            ofm size redefine because of C-wise slice strategy
            but size has to be assign to ifm_size.
        """
        if dims.kernel_padding:
            CoreOfmSize = dims.Yis * dims.Xis * (dims.Com) * dims.ofm_bits // 8
        else:
            CoreOfmSize = dims.Yis * dims.Xis * (dims.Co) * dims.ofm_bits // 8



        CoreWgtSize = dims.wgt_subv_size

        if CoreIfmSubvElemsSize * 2 + CoreScratchBufferSize +  CoreOfmSize + CoreWgtSize <= overlay_stack_addr():
            CoreIfm_enable_pingpong = True
        else:
            CoreIfm_enable_pingpong = False

        CoreWgtPingAddr = 0
        CoreIfmPingAddr = align_core_addr(CoreWgtPingAddr + CoreWgtSize)
        if CoreIfm_enable_pingpong:
            CoreIfmPongAddr = align_core_addr(CoreIfmPingAddr + CoreIfmSubvElemsSize)
            CoreScratchBufferPingAddr= align_core_addr(CoreIfmPongAddr + CoreIfmSubvElemsSize)
        else:
            CoreIfmPongAddr = None
            CoreScratchBufferPingAddr = align_core_addr(CoreIfmPingAddr + CoreIfmSubvElemsSize)
        CoreOfmPingAddr= align_core_addr(CoreScratchBufferPingAddr + CoreScratchBufferSize)


        assert CoreOfmPingAddr + CoreOfmSize <= overlay_stack_addr()

        CoreOfmPongAddr = None
        CoreWgtPongAddr = None

        Tn = dims.No * dims_phase
    else: # it doesn't matter, but to fix the CI lint error
        Tn = dims.No * dims_phase
        CoreIfmPingAddr = None
        CoreIfmPongAddr = None
        CoreOfmPingAddr = None
        CoreOfmPongAddr = None
        CoreWgtPingAddr = None
        CoreWgtPongAddr = None
        CoreIfmSize = 0
        CoreOfmSize = 0
        CoreWgtSize = 0
        run_kernel = None



    # for core loop, each configration can support up to 1024, exceeding that, it need to re-config the core
    # even the configuration is same.
    X = Tn
    Tx = 1
    T_remain = 0
    while X > 1024:
        X = X //2
    Tx = Tn // X
    T_remain = Tn % X

    if dims.innerC:
        if dims.fixed_point_bits == 8:
            kernel_params = kernel_params_range_a8(dims, CoreScratchBufferPingAddr)
        else:
            kernel_params = kernel_params_range(dims)
    else:
        if dims.is_qdq:
            if dims.fixed_point_bits == 8:
                kernel_params = gen_qdq_params_a8(iceil(dims.subv_elem, 64), 0, 2,  CoreScratchBufferPingAddr, dims.fixed_point_bits, dims.qdq_mode, CoreOfmPingAddr)
            else:
                kernel_params = gen_qdq_params(iceil(dims.subv_elem, 64), CoreWgtPingAddr, 0, 2)
        else:
            kernel_params = []
    if not dims.is_kernel:
        # NOTE: This was a work around for supoer-kernel bug.
        #       Keeping this for future refernce
        # core_instrs = [
        #     ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), 0, 0, 0),
        #     AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
        #     RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        #     ]
        core_instrs = []
    else:
        if T_remain == 0:
            core_instrs = [
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                Loop(Tx, [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize),
                    Loop( X,  [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        CallKernel(run_kernel, kernel_params),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]),
                ]),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1))
            ]
        else:
            core_instrs = [
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),

                Loop(Tx, [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize),
                    Loop( X,  [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        CallKernel(run_kernel, kernel_params),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]),
                ]),
                Loop(1, [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize),
                    Loop( T_remain,  [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        CallKernel(run_kernel, kernel_params),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]),
                ]),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1))
            ]

    memtile_transfers = []
    memtile_prm_transfer = [
        DataTransfer(
            [1] + [0] * (dims.No * dims_phase - 1),
            AieTile(TileType.Memtile, col), [prm_memtile_addr], prm_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm())],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
            sync_strategy=SyncStrategy.Serial_M_to_N,
            ) for col in range(dims.aie_cols)
        ]

    memtile_ifm_transfer = [
        DataTransfer(
            Yi_repeat_counts(dims, dims_phase),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims, col, mt_ifm_transfer, dims_phase),  # pylint: disable=E1121
                ifm_memtile_s2mm(dims, col, mt_ifm_transfer, dims_phase),    # pylint: disable=E1121
                [1] * dims.No * dims_phase,
                dims.ifm_bits)],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ifm_memtile_memory(dims, col, mt_ifm_transfer, dims_phase),   # pylint: disable=E1121
                ifm_memtile_mm2s(dims, col, 0, mt_ifm_transfer, dims_phase),  # pylint: disable=E1121
                [1] * dims.No * dims_phase,
                dims.ofm_bits)]
            ) for col in range(dims.aie_cols)
        ] if not dims.is_kernel else [
            DataTransfer(
                Yi_repeat_counts(dims, dims_phase),
                AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
                [pack_transfers(
                    memtile_dma(col, DmaDir.S2MM, 0),
                    ifm_memtile_memory(dims, col, mt_ifm_transfer, dims_phase),
                    ifm_memtile_s2mm(dims, col, mt_ifm_transfer, dims_phase),
                    [1] * dims.No * dims_phase,
                    dims.ifm_bits)],
                [pack_transfers(
                    memtile_dma(col, DmaDir.MM2S, row),
                    ifm_memtile_memory(dims, col, mt_ifm_transfer, dims_phase),
                    ifm_memtile_mm2s(dims, col, row, mt_ifm_transfer, dims_phase),
                    [1] * dims.No * dims_phase,
                    dims.ifm_bits
                    ) for row in range(dims.aie_rows)],
                sync_strategy=SyncStrategy.Parallel_1_to_N,
                ) for col in range(dims.aie_cols)
            ]

    memtile_wgt_transfer = [] if not dims.is_kernel else [
        DataTransfer(
            [1] + [0] * (dims.No * dims_phase - 1),
            AieTile(TileType.Memtile, col), [wgt_memtile_addr_ping], wgt_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(dims))],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims))],
            ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols==8 else 1))
        ]

    memtile_ofm_transfer = [] if not dims.is_kernel else[
        DataTransfer(
            Yi_repeat_counts(dims, dims_phase),
            AieTile(TileType.Memtile, col), ofm_memtile_addrs, ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                ofm_memtile_memory(dims, col, mt_ofm_transfer, dims_phase),
                ofm_memtile_s2mm(dims, col, row, mt_ofm_transfer, dims_phase),
                [1] * dims.No * dims_phase,
                dims.ofm_bits,
            ) for row in range(dims.aie_rows)],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ofm_memtile_memory(dims, col, mt_ofm_transfer, dims_phase),
                ofm_memtile_mm2s(dims, col, mt_ofm_transfer, dims_phase),
                [1] * dims.No * dims_phase,
                dims.ofm_bits,
            )],
            sync_strategy=SyncStrategy.Parallel_N_to_1,
            ) for col in range(dims.aie_cols)
        ]

    memtile_transfers += (memtile_prm_transfer + memtile_ifm_transfer + memtile_wgt_transfer + memtile_ofm_transfer)


    shim_transfers = []
    shim_prm_transfer = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.No * dims_phase - 1),
            shim_dma(col, DmaDir.MM2S, 0), slice_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col)
            ) for col in range(dims.aie_cols)
        ]

    shim_ifm_size = dims.Ni * dims.Yi * dims.Xi * dims.Ci * (dims.ifm_bits // 8)
    shim_ifm_transfer = [
        DataTransfer(
            ifm_shim_repeat_counts(dims, dims_phase),
            AieTile(TileType.Shim, col), [slice_shim_alloc.ifm_buffer_id], shim_ifm_size,
            [],
            [pack_transfers(
                shim_dma(col, DmaDir.MM2S, 0),
                [ifm_shim_memory(dims)] * dims.No * dims_phase,
                ifm_shim_mm2s(dims, col, shim_transfer, dims_phase),
                [1] * dims.No * dims_phase,
                dims.ifm_bits,
                )] ,
            ) for col in range(dims.aie_cols)
        ]

    shim_wgt_transfer = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.No * dims_phase - 1),
            shim_dma(col, DmaDir.MM2S, 1), slice_shim_alloc.wgt_buffer_id,
            wgt_shim_memory(dims),
            wgt_shim_mm2s(col)
            ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols==8 else 1) )
    ] if dims.is_kernel else []

    shim_ofm_size = dims.No * dims.Yo * dims.Xo * dims.Co * dims.ofm_bits // 8
    shim_ofm_transfer = [
        DataTransfer(
            ifm_shim_repeat_counts(dims, dims_phase),
            AieTile(TileType.Shim, col), [slice_shim_alloc.ofm_buffer_id], shim_ofm_size,
            [pack_transfers(
                shim_dma(col, DmaDir.S2MM, 0),
                [ofm_shim_memory(dims)] * dims.No * dims_phase,
                ofm_shim_s2mm(dims, col, shim_transfer, dims_phase),
                [1] * dims.No * dims_phase,
                dims.ofm_bits
                )],
            [],
            ) for col in range(dims.aie_cols)
        ]

    shim_transfers += (shim_prm_transfer + shim_ifm_transfer + shim_wgt_transfer + shim_ofm_transfer)

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections(),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
    )


def main():
    back_end = BackEnd.Adf
    kernel_names = ['run_combined_qdq', 'run_slice']
    kernel_includes = ['super.hh', 'qdq/wrapper_qdq.cc', 'slice/wrapper_slice.cc']
    aie_cols, aie_rows = 8, 4
    input_shape = [1, 1, 64, 64]
    out_start = 1
    out_stop = 64
    axis = 3
    ifm_bits = 16
    is_qdq = 0
    qdq_mode = 0    # 0: DEQUANT; 1: QUANT; 2: BOTH; 3: NONE

    dims = SliceDims(
        aie_rows, aie_cols,
        input_shape, make_slice_dict(input_shape, axis, out_start, out_stop),
        axis, ifm_bits,
        out_start, out_stop,
        is_qdq, qdq_mode,
        False
    )

    clean_overlay()
    compile_dataflow(
        dims,
        back_end,
        kernel_names,
        kernel_includes
    )
    build_sim_overlay(back_end, 'slice_main.cpp', slice_preproc_directives(dims, back_end))

if __name__ == '__main__':
    main()
