import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
import math
import struct
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', '..'))
from typing import Dict, Tuple, List

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, SyncStrategy, \
    AieTile, TileType, \
    AieDma, DmaDir, TransferParams, core_dma, memtile_dma, shim_dma, DmaChannel,\
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    compute_buffer_size, \
    generate_core_buffer_config, \
    generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_stack_addr, \
    clean_overlay, \
    build_sim_overlay, \
    ceildiv, \
    iceil, \
    shim_alloc, \
    prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s

from concat_common import \
    gen_transfers, \
    concat_preproc_directives
    
from concat_run_tiler import ConcatDims
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True


def wgt_shim_memory(dims: ConcatDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'


def wgt_shim_mm2s() -> str:
    return 'Subv'


def wgt_memtile_memory(dims: ConcatDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'


def wgt_memtile_s2mm() -> str:
    return 'Subv'


def wgt_memtile_mm2s() -> str:
    return 'Subv'




def Yi_repeat_counts(dims: ConcatDims, idx: int, scale: int) -> List[int]:
    repeat_counts = [1 for _ in range(dims.phase * dims.num_inputs)]
    return repeat_counts

def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
    buffer_offset: list = 0,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
            buffer_offset = buffer_offset,
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param

def generate_packed_shim_data_transfer(
    repeat_counts: List[int],
    dma: AieDma,
    shim_buffer_idx: int,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iter_nums: List[int],
    tiling_start_iter: List[int],
    bits_per_elem: int,
    buffer_offset : list,
    buffer_size: int,
    max_chain_length: int = 4,
) -> DataTransfer:
    '''
    Reconfigures a BD with different transfer
    params at the shim for poll and re-enqueue
    '''
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iter_nums)
    assert len(tiling_start_iter) == len(tiling_iter_nums)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iter_nums)
        res = []
        for item, num in zip(items, tiling_iter_nums):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = []
    repeat_coeff_iter = [0] * len(tiling_iter_nums)
    for i in range(num_fmts):
        repeat_coeff, transfer_chain = generate_transfer_params(
                dma,
                memory_fmts[i],
                tiling_fmts[i],
                bits_per_block=bits_per_elem,
                enable_padding=False,
                use_iter_step=True,
                max_chain_length=max_chain_length,
                buffer_offset = buffer_offset[i],
        )
        repeat_coeff_iter[i] = repeat_coeff
        for transfer in transfer_chain:
            params.append(transfer)
    packed_params = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
        iter_step=pack([param.iter_step_i(0) for param in params]),
        iter_wrap=pack([param.iter_wrap_i(0) for param in params]),
    )
    # buffer_size = compute_buffer_size(memory_fmts[0], bits_per_elem)
    buffer_size = buffer_size
    if dma.channel.dir == DmaDir.S2MM:
        write_params = [packed_params]
        read_params = []
    else:
        read_params = [packed_params]
        write_params = []
    for idx, count in enumerate( tiling_start_iter):
        repeat_counts[idx] *= repeat_coeff_iter[idx]

    return DataTransfer(
        repeat_counts,
        dma.tile, [shim_buffer_idx], buffer_size,
        write_params,
        read_params
)

def ifm_shim_memory(dims: ConcatDims) -> str:
    def fmt(phase: int, n_index: int):
        num_rows = dims.input_rows[n_index]
        num_cols = dims.input_cols[n_index]
        num_chs  = dims.input_chs[n_index]
        return f'Yi:{num_rows} Xi:{num_cols} Ci:{dims.input_chs_p[n_index]}'
    fs = [fmt(phase, n_index) for phase in range(dims.phase) for n_index in range(dims.num_inputs)]
    return fs



def ifm_shim_mm2s(dims: ConcatDims, col: int, shim_transfer) -> List[str]:
    def fmt(phase: int, n_index: int) -> str:
        if shim_transfer['shim_ifm'][n_index][col]:
            split = shim_transfer['shim_ifm'][n_index][col][phase]
            Y_start = split[0][0]
            Y_stop = split[0][1]
            X_start = split[1][0]
            X_stop = split[1][1]
            C_start = split[2][0]
            C_stop = split[2][1]
        else:
            Y_start = 0
            Y_stop = 0
            X_start = 0
            X_stop = 0
            C_start = 0
            C_stop = 0

        return (
            f'Yi:{Y_start}:{Y_stop} '
            f'Xi:{X_start}:{X_stop} '
            f'Ci:{C_start}:{C_stop}'
        )
    fs = [fmt(phase, n_index) for phase in range(dims.phase) for n_index in range(dims.num_inputs)]
    return fs

def ifm_shim_repeat_counts(dims: ConcatDims, idx: int) -> List[int]:
    repeat_counts = [1 for _ in range(dims.phase * dims.num_inputs)]
    return repeat_counts

def ofm_shim_repeat_counts(dims: ConcatDims, idx: int) -> List[int]:
    # repeat_counts = [1 for _ in range(dims.phase * dims.num_inputs)]
    repeat_counts = ([0] * (dims.num_inputs -1) + [1]) * dims.phase
    return repeat_counts

def ifm_memtile_memory(dims: ConcatDims, col: int, mt_ifm_transfer) -> List[str]:
    def fmt(phase: int, n_index: int) -> str:
        if dims.is_kernel:
            split = mt_ifm_transfer['mt_ifm_mem'][n_index][col][phase]
        else:
            split = mt_ifm_transfer['mt_ifm_mem'][col][phase]
        if split:
            Yi_size = split[0]
            Xi_size = split[1]
            Ci_size = split[2]
        else:
            Yi_size = 1
            Xi_size = 1
            Ci_size = 1
        return f'Yi:{Yi_size} Xi:{Xi_size} Ci:{Ci_size}'
    fs = [fmt(phase, n_index) for phase in range(dims.phase) for n_index in range(dims.num_inputs)]
    return fs

def ifm_memtile_s2mm(dims: ConcatDims, col: int, mt_ifm_transfer) -> List[str]:
    def fmt(phase: int, n_index: int) -> str:
        split=[]
        if mt_ifm_transfer['mt_ifm_s2mm'][n_index][col]:
            split = mt_ifm_transfer['mt_ifm_s2mm'][n_index][col][phase]
        if split:
            Y_start = split[0][0]
            Y_stop  = split[0][1]
            X_start = split[1][0]
            X_stop  = split[1][1]
            C_start = split[2][0]
            C_stop  = split[2][1]
        else:
            Y_start = 0
            Y_stop  = 0
            X_start = 0
            X_stop  = 0
            C_start = 0
            C_stop  = 0
        return f'Yi:{Y_start}:{Y_stop} Xi:{X_start}:{X_stop} Ci:{C_start}:{C_stop}'
    fs = [fmt(phase, n_index) for phase in range(dims.phase) for n_index in range(dims.num_inputs)]
    return fs

def ifm_memtile_mm2s(dims: ConcatDims, col: int, row: int, mt_ifm_transfer) -> List[str]:
    is_only_qdq = True if (dims.is_qdq and not dims.is_kernel) else False
    is_dma_dataflow = True if (not dims.is_qdq and not dims.is_kernel) else False
    def fmt(phase: int, n_index: int) -> str:
        split = []
        if dims.is_kernel:
            if mt_ifm_transfer['mt_ifm_mm2s'][n_index][col]:
                split = mt_ifm_transfer['mt_ifm_mm2s'][n_index][col][row][phase]
        else:
            if dims.is_qdq:
                if mt_ifm_transfer['mt_ifm_mm2s'][col]:
                    split = mt_ifm_transfer['mt_ifm_mm2s'][col][row][phase]
            else:
                if mt_ifm_transfer['mt_ifm_mm2s'][col]:
                    split = mt_ifm_transfer['mt_ifm_mm2s'][col][phase]
        if n_index != dims.num_inputs - 1:
            Y_start = 0
            Y_stop  = 0
            X_start = 0
            X_stop  = 0
            C_start = 0
            C_stop  = 0
        else:
            if split:
                if is_dma_dataflow:
                    Y_start = split[0][0]
                    Y_stop = split[0][1]
                    X_start = split[1][0]
                    X_stop = split[1][1]
                    C_start = split[2][0]
                    C_stop = split[2][1] if not dims.concat_mode == 0 else dims.output_ch_p
                else:
                    Y_start = split[0][0]
                    Y_stop = split[0][1]
                    X_start = split[1][0]
                    X_stop = split[1][1]
                    C_start = split[2][0]
                    C_stop = split[2][1]
            else:
                Y_start = 0
                Y_stop  = 0
                X_start = 0
                X_stop  = 0
                C_start = 0
                C_stop  = 0
        return f'Yi:{Y_start}:{Y_stop} Xi:{X_start}:{X_stop} Ci:{C_start}:{C_stop}'
    fs = [fmt(phase, n_index) for phase in range(dims.phase) for n_index in range(dims.num_inputs)]
    return fs

def ofm_shim_memory(dims: ConcatDims) -> str:
    return f'Yo:{dims.output_row} Xo:{dims.output_col} Co:{dims.output_ch_p}'


def ofm_shim_s2mm(dims: ConcatDims, col: int, shim_transfer) -> List[str]:
    def fmt(phase: int) -> str:
        split=[]
        if shim_transfer['shim_ofm'][col]:
            split = shim_transfer['shim_ofm'][col][phase]
        Y_start = split[0][0]
        Y_stop  = split[0][1]
        X_start = split[1][0]
        X_stop  = split[1][1]
        C_start = split[2][0]
        C_stop  = split[2][1] if not dims.concat_mode == 0 else dims.output_ch_p
        return f'Yo:{Y_start}:{Y_stop} Xo:{X_start}:{X_stop} Co:{C_start}:{C_stop}'
    fs = [fmt(phase) for phase in range(dims.phase) for n_index in range(dims.num_inputs)]
    return fs

def ofm_memtile_memory(dims: ConcatDims, col: int, mt_ofm_transfer) -> List[str]:
    def fmt(phase: int, n_index: int) -> str:
        split = mt_ofm_transfer['mt_ofm_mem'][col][phase]
        if split:
            Yi_size = split[0]
            Xi_size = split[1]
            Ci_size = split[2]
        else:
            Yi_size = 1
            Xi_size = 1
            Ci_size = 1
        return f'Yi:{Yi_size} Xi:{Xi_size} Ci:{Ci_size}'
    fs = [fmt(phase, n_index) for phase in range(dims.phase) for n_index in range(dims.num_inputs)]
    return fs

def ofm_memtile_s2mm(dims: ConcatDims, col: int, row: int, mt_ofm_transfer) -> List[str]:
    def fmt(phase: int, n_index: int) -> str:
        split = []
        if mt_ofm_transfer['mt_ofm_s2mm'][col]:
            split = mt_ofm_transfer['mt_ofm_s2mm'][col][row][phase]
        if split:
            Y_start = split[0][0]
            Y_stop  = split[0][1]
            X_start = split[1][0]
            X_stop  = split[1][1]
            C_start = split[2][0]
            C_stop  = split[2][1]
        else:
            Y_start = 0
            Y_stop  = 0
            X_start = 0
            X_stop  = 0
            C_start = 0
            C_stop  = 0
        return f'Yi:{Y_start}:{Y_stop} Xi:{X_start}:{X_stop} Ci:{C_start}:{C_stop}'
    fs = [fmt(phase, n_index) for phase in range(dims.phase) for n_index in range(dims.num_inputs)]
    return fs

def ofm_memtile_mm2s(dims: ConcatDims, col: int, mt_ofm_transfer) -> List[str]:
    def fmt(phase: int, n_index: int) -> str:
        split = []
        if mt_ofm_transfer['mt_ofm_mm2s'][col]:
            split = mt_ofm_transfer['mt_ofm_mm2s'][col][phase]
        if split:
            Y_start = split[0][0]
            Y_stop  = split[0][1]
            X_start = split[1][0]
            X_stop  = split[1][1]
            C_start = split[2][0]
            C_stop  = split[2][1] if dims.concat_mode != 0 else dims.output_ch_p
        else:
            Y_start = 0
            Y_stop  = 0
            X_start = 0
            X_stop  = 0
            C_start = 0
            C_stop  = 0
        return f'Yi:{Y_start}:{Y_stop} Xi:{X_start}:{X_stop} Ci:{C_start}:{C_stop}'
    fs = [fmt(phase, n_index) for phase in range(dims.phase) for n_index in range(dims.num_inputs)]
    return fs
def align_core_addr(core_addr: int):
    core_addr = iceil(core_addr, 64)
    return core_addr

"""
real concat Kernel
    1. for concat_kernel + qdq_kernel
        1) layer params:  offset_for_b +
                        total_element +
                        concat_kernel_params
    2. for concat_kernel only
        1) layer params:  offset_for_b +
                        total_element +
                        concat_kernel_params
    3. for qdq_kernel only
        1) layer params:  dummy +
                        total_element
 CPU emulate concat Kernel

    1. for concat_kernel + qdq_kernel
        1) layer params: total_element +
                        num_inputs +
                        concat_mode +
                        Yis, Xis, Cis * num_inputs
    2. for concat_kernel only
        1) layer params: total_element +
                        num_inputs +
                        concat_mode +
                        Yis, Xis, Cis * num_inputs

"""
def set_concat_layer_kernel_params(dims: ConcatDims, CoreWgtPingAddr: int):
    kernel_params = None
    if dims.is_kernel:
        cpu_kernel_on = False if dims.concat_mode == 0 and \
                              dims.num_inputs == 2 else True
        real_kernel_on = not cpu_kernel_on
    else:
        cpu_kernel_on  = False
        real_kernel_on = False

    concat_struct_fields = (
        int(dims.Cis[0] * dims.Xis[0] * dims.Yis[0] * 1),
        int(dims.Cis[1] * dims.Xis[0] * dims.Yis[0] * 1),
        int(dims.Cis[0] * dims.ifm_bits // 8),
        int(dims.Cis[0]),
        int((dims.Cis[1] + 1) * dims.ifm_bits // 8),
        int(dims.Cis[1] - 1),
        int((dims.Cis[0] + 1) * dims.ifm_bits // 8),
    )
    format_string = 'IIIIHHH'
    concat_offset_for_b = dims.Yis[0] * dims.Xis[0] * dims.Cis[0] * dims.ifm_bits // 8

    total_elements = iceil(sum(dims.Yis[n] * dims.Xis[n] * dims.Cis[n] for n in range(dims.num_inputs)), 64)
    dummy = 0
    num_inputs = dims.num_inputs
    concat_mode = dims.concat_mode
    Yis = dims.Yis + [0] * (dims.MAX_INPUTS - num_inputs)
    Xis = dims.Xis + [0] * (dims.MAX_INPUTS - num_inputs)
    Cis = dims.Cis + [0] * (dims.MAX_INPUTS - num_inputs)

    cpu_kernel_params = num_inputs.to_bytes(length=2, byteorder='little', signed=False) +\
        concat_mode.to_bytes(length=2, byteorder='little', signed=False)
    for n in range(dims.MAX_INPUTS):
        cpu_kernel_params += Yis[n].to_bytes(length=2, byteorder='little', signed=False)
        cpu_kernel_params += Xis[n].to_bytes(length=2, byteorder='little', signed=False)
        cpu_kernel_params += Cis[n].to_bytes(length=2, byteorder='little', signed=False)

    input_index = 0
    quant_offset = 2
    if real_kernel_on:
        layer_params = (concat_offset_for_b.to_bytes(length=4, byteorder='little', signed=False)
                        + total_elements.to_bytes(length=4, byteorder='little', signed=False)
                        )
        kernel_params = layer_params + struct.pack(format_string, *concat_struct_fields)
    # no-qdq for cpu_kernel
    if cpu_kernel_on:
        layer_params = (total_elements.to_bytes(length=2, byteorder='little', signed=False))
        layer_params += cpu_kernel_params
        kernel_params = layer_params
    # using run_combined_qdq kernel
    if dims.is_qdq and not dims.is_kernel:
        layer_params = (total_elements.to_bytes(length=2, byteorder='little', signed=False)
                        + CoreWgtPingAddr.to_bytes(length=2, byteorder='little', signed=False)
                        + input_index.to_bytes(length=2, byteorder='little', signed=False)
                        + quant_offset.to_bytes(length=2, byteorder='little', signed=False)
                        )
        kernel_params = layer_params
    elif not dims.is_qdq and not dims.is_kernel:
        assert False, "There is no Kernel requirements, kernel should be disabled"
    return kernel_params

def compile_dataflow(
    dims: ConcatDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):

    mt_ifm_transfer, mt_ofm_transfer, shim_transfer = gen_transfers(dims)
    prm_memtile_size = compute_buffer_size(prm_memtile_memory(dims))
    if dims.is_kernel:
        ifm_mem_tiling_max = [mt_ifm_transfer['mt_ifm_mem'][n][0][0]
                              for n in range(dims.num_inputs)]
        ifm_memtile_size = sum(y * x * c for y, x, c in ifm_mem_tiling_max) * dims.ifm_bits // 8
        ofm_mem_tiling_max = mt_ofm_transfer['mt_ofm_mem'][0][0]
        ofm_memtile_size = math.prod(ofm_mem_tiling_max) * dims.ofm_bits // 8
    else:
        ifm_mem_tiling_max = mt_ifm_transfer['mt_ifm_mem'][0][0]
        ifm_memtile_size = math.prod(ifm_mem_tiling_max) * dims.ifm_bits // 8
        ofm_mem_tiling_max = mt_ofm_transfer['mt_ofm_mem'][0][0]
        ofm_memtile_size = math.prod(ofm_mem_tiling_max) * dims.ofm_bits // 8
    if dims.is_kernel or dims.is_qdq:
        wgt_memtile_size = dims.wgt_subv_size
    else:
        wgt_memtile_size = 0

    max_memtile_size = config.MAX_MEMTILE_ADDR + 1
    prm_memtile_addr = 0
    wgt_memtile_addr = prm_memtile_addr + prm_memtile_size
    wgt_memtile_addrs = [wgt_memtile_addr]
    if dims.is_kernel: # both ifm and wgt
        total_ifm_usable_memtile_size = (
            prm_memtile_size +
            ifm_memtile_size * 2 + ofm_memtile_size + wgt_memtile_size    # ifm pingpong and ofm use same address
        )
        if total_ifm_usable_memtile_size <= max_memtile_size and\
                dims.num_inputs <= 3:   # pingpong
            ifm_memtile_addr_ping = wgt_memtile_addr + wgt_memtile_size
            ifm_memtile_addr_pong = ifm_memtile_addr_ping + ifm_memtile_size
            ofm_memtile_addr_ping = ifm_memtile_addr_pong + ifm_memtile_size
            ofm_memtile_addr_pong = None
            ifm_memtile_addrs = [ifm_memtile_addr_ping, ifm_memtile_addr_pong]
            ofm_memtile_addrs = [ofm_memtile_addr_ping]
        else:
            ifm_memtile_addr_ping = wgt_memtile_addr + wgt_memtile_size
            ofm_memtile_addr_ping = ifm_memtile_addr_ping + ifm_memtile_size
            ifm_memtile_addr_pong = None
            ofm_memtile_addr_pong = None
            ifm_memtile_addrs = [ifm_memtile_addr_ping]
            ofm_memtile_addrs = [ofm_memtile_addr_ping]
        assert ofm_memtile_addr_ping + ofm_memtile_size <= max_memtile_size
    else:
        if dims.is_qdq: #both ifm and wgt
            if dims.num_inputs <= 10:
                ifm_memtile_addr_ping = wgt_memtile_addr + wgt_memtile_size
                ifm_memtile_addr_pong = ifm_memtile_addr_ping + ifm_memtile_size
                ofm_memtile_addr_ping = ifm_memtile_addr_pong + ifm_memtile_size
                ofm_memtile_addr_pong = None
                assert ofm_memtile_addr_ping + ofm_memtile_size <= max_memtile_size
                ifm_memtile_addrs = [ifm_memtile_addr_ping, ifm_memtile_addr_pong]
                ofm_memtile_addrs = [ofm_memtile_addr_ping]
            else:
                ifm_memtile_addr_ping = wgt_memtile_addr + wgt_memtile_size
                ifm_memtile_addrs = [ifm_memtile_addr_ping]
                ofm_memtile_addr_ping = ifm_memtile_addr_ping + ifm_memtile_size
                ofm_memtile_addrs = [ofm_memtile_addr_ping]
                assert ofm_memtile_addr_ping + ofm_memtile_size <= max_memtile_size
        else: # bypass kernel
            if dims.num_inputs <= 10:
                ifm_memtile_addr_ping = wgt_memtile_addr + wgt_memtile_size
                ifm_memtile_addr_pong = ifm_memtile_addr_ping + ifm_memtile_size
                ofm_memtile_addr_ping = None
                ofm_memtile_addr_pong = None
                assert ifm_memtile_addr_pong + ifm_memtile_size <= max_memtile_size
                ifm_memtile_addrs = [ifm_memtile_addr_ping, ifm_memtile_addr_pong]
            else:
                ifm_memtile_addr_ping = wgt_memtile_addr + wgt_memtile_size
                ifm_memtile_addrs = [ifm_memtile_addr_ping]

    ifm_memtile_repeat_scale = 1

    # Shim Buffer Allocation
    concat_shim_alloc = shim_alloc()

    # core kernel buffer
    #current Kernel only supporting 2inputs and C-dim concat
    if dims.is_kernel:
        Core_bank_addr = 16384
        run_kernel = 'run_concat' if dims.concat_mode == 0 and dims.num_inputs == 2 else None
        CoreIfmSize = sum(y * x * c for y, x, c in zip(dims.Yis, dims.Xis, dims.Cis)) * \
            dims.ifm_bits // 8
        CoreOfmSize = CoreIfmSize
        CoreWgtPingAddr  = 0
        CoreWgtSize = dims.wgt_subv_size
        if CoreWgtSize + CoreIfmSize * 3 <= overlay_stack_addr():
            CoreIfm_enable_pingpong = True
        else:
            CoreIfm_enable_pingpong = False
        CoreIfmPingAddr = CoreWgtPingAddr + CoreWgtSize
        if CoreIfm_enable_pingpong:
            CoreIfmPongAddr = max(Core_bank_addr,   align_core_addr(CoreIfmPingAddr + CoreIfmSize))
            CoreOfmPingAddr = max(2*Core_bank_addr,   align_core_addr(CoreIfmPongAddr + CoreIfmSize))
            if CoreOfmPingAddr + CoreOfmSize > overlay_stack_addr():
                CoreIfmPongAddr = align_core_addr(CoreIfmPingAddr + CoreIfmSize)
                CoreOfmPingAddr = align_core_addr(CoreIfmPongAddr + CoreIfmSize)
        else:
            CoreIfmPongAddr = None
            CoreOfmPingAddr = max(2*Core_bank_addr,   align_core_addr(CoreIfmPingAddr + CoreIfmSize))
            if CoreOfmPingAddr + CoreOfmSize > overlay_stack_addr():
                CoreOfmPingAddr = align_core_addr(CoreIfmPingAddr + CoreIfmSize)
        assert CoreOfmPingAddr + CoreOfmSize <= overlay_stack_addr()
        CoreOfmPongAddr = None
        CoreWgtPongAddr = None
        Tn = dims.phase

    else:
        Core_bank_addr = 16384
        if dims.is_qdq: # qdq only
            run_kernel = 'run_combined_qdq'
            # run_kernel = "run_concat"
            CoreIfmSize = sum(y * x * c for y, x, c in zip(dims.Yis, dims.Xis, dims.Cis)) * \
                dims.ifm_bits // 8
            CoreOfmSize = CoreIfmSize
            CoreWgtPingAddr  = 0
            CoreWgtSize = dims.wgt_subv_size
            if CoreWgtSize + CoreIfmSize * 3 <= overlay_stack_addr():
                CoreIfm_enable_pingpong = True
            else:
                CoreIfm_enable_pingpong = False
            CoreIfmPingAddr = CoreWgtPingAddr + CoreWgtSize
            if CoreIfm_enable_pingpong:
                CoreIfmPongAddr = max(Core_bank_addr, align_core_addr(CoreIfmPingAddr + CoreIfmSize))
                CoreOfmPingAddr = max(2*Core_bank_addr, align_core_addr(CoreIfmPongAddr + CoreIfmSize))
                if CoreOfmPingAddr + CoreOfmSize > overlay_stack_addr():
                    CoreIfmPongAddr = align_core_addr(CoreIfmPingAddr + CoreIfmSize)
                    CoreOfmPingAddr = align_core_addr(CoreIfmPongAddr + CoreIfmSize)
            else:
                CoreIfmPongAddr = None
                CoreOfmPingAddr = max(2*Core_bank_addr, align_core_addr(CoreIfmPingAddr + CoreIfmSize))
                if CoreOfmPingAddr + CoreOfmSize > overlay_stack_addr():
                    CoreOfmPingAddr = align_core_addr(CoreIfmPingAddr + CoreIfmSize)
            assert CoreOfmPingAddr + CoreOfmSize <= overlay_stack_addr()
            CoreOfmPongAddr = None
            CoreWgtPongAddr = None
            Tn = dims.phase
        else: # it doesn't matter, but to fix the CI lint error
            Tn = dims.phase
            CoreIfmPingAddr = None
            CoreIfmPongAddr = None
            CoreOfmPingAddr = None
            CoreOfmPongAddr = None
            CoreWgtPingAddr = None
            CoreWgtPongAddr = None
            CoreIfmSize = 0
            CoreOfmSize = 0
            CoreWgtSize = 0
            run_kernel = None
    # for core loop, each configration can support up to 1024, exceeding that, it need to re-config the core
    # even the configuration is same.
    X = Tn
    Tx = 1
    T_remain = 0
    while X > 1024:
        X = X //2
    Tx = Tn // X
    T_remain = Tn % X

    """we need to define three Calkernel:
       1. Calkernel -- kernel-concat                  <--is_kernel = True + is_qdq = False
       2. Calkernel -- kernel-concat + kernel-qdq     <--is_kernel = True + is_qdq = True
       2. Calkernel -- kernel-qdq                     <--is_kernel = False + is_qdq = True
    """

    if dims.is_kernel or dims.is_qdq:
        if T_remain == 0:
            core_instrs = [
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                Loop(Tx, [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize),
                    Loop( X,  [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        CallKernel(run_kernel, set_concat_layer_kernel_params(dims, CoreWgtPingAddr)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]),
                ]),
            ]
        else:
            core_instrs = [
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                Loop(Tx, [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize),
                    Loop( X,  [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        CallKernel(run_kernel, set_concat_layer_kernel_params(dims, CoreWgtPingAddr)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]),
                ]),
                Loop(1, [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize),
                    Loop( T_remain,  [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        CallKernel(run_kernel, set_concat_layer_kernel_params(dims, CoreWgtPingAddr)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]),
                ])
            ]
    else:
        # NOTE: This was a work around for supoer-kernel bug.
        #       Keeping this for future refernce
        core_instrs = [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), 0, 0, 0),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ]
        # core_instrs = []

    memtile_transfers = []
    memtile_prm_transfer = [DataTransfer(
            [1] + [0] * (dims.phase * dims.num_inputs - 1),
            AieTile(TileType.Memtile, col), [prm_memtile_addr], prm_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
            sync_strategy=SyncStrategy.Serial_M_to_N,
        ) for col in range(dims.aie_cols)
    ]
    mt_ifm_buffer_offset = 0
    mt_ifm_buffer_offset = []
    if dims.is_kernel:
        mt_mem = [mt_ifm_transfer['mt_ifm_mem'][n][0][0] for n in range(dims.num_inputs)]
        for n in range(dims.num_inputs):
            if n == 0:
                mt_ifm_buffer_offset.append(0)
            else:
                mt_ifm_buffer_offset.append(
                    sum([mt_mem[k][0] * mt_mem[k][1] *
                         mt_mem[k][2] for k in range(n)]) * dims.ifm_bits // 8)

    if dims.is_kernel:
        memtile_ifm_transfer = [
            DataTransfer(
                Yi_repeat_counts(dims, 0, ifm_memtile_repeat_scale),
                AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
                [pack_transfers(
                    memtile_dma(col, DmaDir.S2MM, 0),
                    ifm_memtile_memory(dims, col, mt_ifm_transfer),
                    ifm_memtile_s2mm(dims, col, mt_ifm_transfer),
                    [1] * dims.phase,
                    dims.ifm_bits,
                    buffer_offset=mt_ifm_buffer_offset[idx],
                ) for idx in range(dims.num_inputs)],
                [pack_transfers(
                    memtile_dma(col, DmaDir.MM2S, row),
                    ifm_memtile_memory(dims, col, mt_ifm_transfer),
                    ifm_memtile_mm2s(dims, col, row, mt_ifm_transfer),
                    [1] * dims.phase,
                    dims.ifm_bits,
                    buffer_offset=mt_ifm_buffer_offset[idx],
                ) for row in range(dims.aie_rows) for idx in range(dims.num_inputs)
                ],
                # sync_strategy=SyncStrategy.Parallel_1_to_N,
            ) for col in range(dims.aie_cols)
            ]
    else:
        if dims.is_qdq:
            memtile_ifm_transfer = [
                DataTransfer(
                    Yi_repeat_counts(dims, 0, ifm_memtile_repeat_scale),
                    AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
                    [pack_transfers(
                        memtile_dma(col, DmaDir.S2MM, 0),
                        ifm_memtile_memory(dims, col, mt_ifm_transfer),
                        ifm_memtile_s2mm(dims, col, mt_ifm_transfer),
                        [1] * dims.phase * dims.num_inputs,
                        dims.ifm_bits,
                    )],
                    [pack_transfers(
                        memtile_dma(col, DmaDir.MM2S, row),
                        ifm_memtile_memory(dims, col, mt_ifm_transfer),
                        ifm_memtile_mm2s(dims, col, row, mt_ifm_transfer),
                        [1] * dims.phase * dims.num_inputs,
                        dims.ifm_bits,
                    ) for row in range(dims.aie_rows)],
                    # sync_strategy=SyncStrategy.Parallel_1_to_N,
                ) for col in range(dims.aie_cols)
                ]
        else:
            memtile_ifm_transfer = [
                DataTransfer(
                    Yi_repeat_counts(dims, 0, ifm_memtile_repeat_scale),
                    AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
                    [pack_transfers(
                        memtile_dma(col, DmaDir.S2MM, 0),
                        ifm_memtile_memory(dims, col, mt_ifm_transfer),
                        ifm_memtile_s2mm(dims, col, mt_ifm_transfer),
                        [1] * dims.phase * dims.num_inputs,
                        dims.ifm_bits,
                    )],
                    [pack_transfers(
                        memtile_dma(col, DmaDir.MM2S, 5),
                        ifm_memtile_memory(dims, col, mt_ifm_transfer),
                        ifm_memtile_mm2s(dims, col, 0, mt_ifm_transfer),
                        [1] * dims.phase * dims.num_inputs,
                        dims.ofm_bits,
                    )]
                ) for col in range(dims.aie_cols)
                ]

    memtile_wgt_transfers = [
        DataTransfer(
            [1] + [0] * (dims.phase*dims.num_inputs - 1),
            AieTile(TileType.Memtile, col), wgt_memtile_addrs, wgt_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(),
            )],
        ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols == 8 else 1))
    ] if dims.is_kernel or dims.is_qdq else []

    if dims.is_kernel or dims.is_qdq:
        memtile_ofm_transfer = [
            DataTransfer(
                # Yi_repeat_counts(dims, 0, ifm_memtile_repeat_scale),
                ofm_shim_repeat_counts(dims, 0),
                AieTile(TileType.Memtile, col), ofm_memtile_addrs, ifm_memtile_size,
                [pack_transfers(
                    memtile_dma(col, DmaDir.S2MM, 2 + row),
                    ofm_memtile_memory(dims, col, mt_ofm_transfer),
                    ofm_memtile_s2mm(dims, col, row, mt_ofm_transfer),
                    [1] * dims.phase * dims.num_inputs,
                    dims.ofm_bits,
                ) for row in range(dims.aie_rows)],
                [pack_transfers(
                    memtile_dma(col, DmaDir.MM2S, 5),
                    ofm_memtile_memory(dims, col, mt_ofm_transfer),
                    ofm_memtile_mm2s(dims, col, mt_ofm_transfer),
                    [1] * dims.phase * dims.num_inputs,
                    dims.ofm_bits,
                )],
                # sync_strategy=SyncStrategy.Parallel_N_to_1,
            ) for col in range(dims.aie_cols)
        ]
    else:
        memtile_ofm_transfer = []

    memtile_transfers += (memtile_prm_transfer +
                          memtile_ifm_transfer +
                          memtile_wgt_transfers +
                          memtile_ofm_transfer)

    shim_transfers = []
    shim_prm_transfer =   [ generate_shim_data_transfer(
            [1] + [0] * (dims.phase * dims.num_inputs - 1),
            shim_dma(col, DmaDir.MM2S, 0), concat_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        ) for col in range(dims.aie_cols)
    ]

    shim_ifm_buffer_offset = []
    for n in range(dims.num_inputs):
        if n == 0:
            shim_ifm_buffer_offset.append(0)
        else:
            shim_ifm_buffer_offset.append(
                sum([dims.input_rows[k] * dims.input_cols[k] *
                     dims.input_chs_p[k] for k in range(n)]) * dims.ifm_bits // 8)

    shim_ifm_size = sum(dims.input_rows[n] * dims.input_cols[n] *
                        dims.input_chs_p[n] * dims.ifm_bits // 8 for n in range(dims.num_inputs))

    # please note, this is a special addtion with both size and offset feed into
    shim_ifm_transfer = [generate_packed_shim_data_transfer(
        ifm_shim_repeat_counts(dims, 0),
        shim_dma(col, DmaDir.MM2S, 0), concat_shim_alloc.ifm_buffer_id,
        ifm_shim_memory(dims),
        ifm_shim_mm2s(dims, col, shim_transfer),
        [1] * dims.phase * dims.num_inputs,
        [0] * dims.phase * dims.num_inputs,
        dims.ifm_bits,
        buffer_offset = shim_ifm_buffer_offset * dims.phase,
        buffer_size = shim_ifm_size,
    ) for col in range(dims.aie_cols)]

    shim_wgt_transfers = [generate_shim_data_transfer(
            [1] + [0] * (dims.phase * dims.num_inputs - 1),
            shim_dma(col, DmaDir.MM2S, 1), concat_shim_alloc.wgt_buffer_id,
            wgt_shim_memory(dims),
            wgt_shim_mm2s(),
        ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols == 8 else 1))
    ] if dims.is_kernel or dims.is_qdq else []

    shim_ofm_size = dims.output_row * dims.output_col * dims.output_ch_p * dims.ifm_bits // 8
    shim_ofm_transfer = [
        DataTransfer(
            ofm_shim_repeat_counts(dims, 0),
            AieTile(TileType.Shim, col), [concat_shim_alloc.ofm_buffer_id], shim_ofm_size,
            [pack_transfers(
                shim_dma(col, DmaDir.S2MM, 0),
                [ofm_shim_memory(dims)] * dims.phase * dims.num_inputs,
                ofm_shim_s2mm(dims, col, shim_transfer),
                [1] * dims.phase * dims.num_inputs,
                dims.ifm_bits,
            )],
            [],
            ) for col in range(dims.aie_cols)
        ]

    shim_transfers += (shim_prm_transfer +
                       shim_ifm_transfer +
                       shim_wgt_transfers +
                       shim_ofm_transfer)

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections(),
        back_end,
        param_channel_id = 0
    )

def main():
    aie_rows = 4
    aie_cols = 8

    ifm_bits = 16
    ofm_bits = 16
    backend = BackEnd.Adf
    concat_mode = 0     # 0: channel concat; 1: Column/W concat
    num_inputs = 5
    input_rows = [16, 16, 16, 16, 16]
    input_cols = [16, 16, 16, 16, 16]
    input_chs = [16, 32, 64, 128, 356]

    # concat_mode = 1     # 0: channel concat; 1: Column/W concat
    # num_inputs = 2
    # input_rows = [16, 16]
    # input_cols = [16, 32]
    # input_chs = [16, 16]

    is_qdq = True

    dims = ConcatDims(
        aie_cols, aie_rows,
        num_inputs, concat_mode,
        input_rows, input_cols, input_chs,
        ifm_bits,  ofm_bits,
        is_qdq
    )

    kernel_names = ['run_combined_qdq', 'run_concat']
    kernel_includes = ['super.hh', 'qdq/wrapper_qdq.cc', 'concat/wrapper_concat.cc']
    # kernel_names = ['run_concat']
    # kernel_includes = ['super.hh', 'concat/wrapper_concat.cc']
    clean_overlay()
    compile_dataflow(dims, backend, kernel_names, kernel_includes)
    build_sim_overlay(backend, 'concat_main.cpp', concat_preproc_directives(dims, backend))

if __name__ == '__main__':
    main()
