import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels'))
from typing import Tuple, List
import math
import struct
import numpy as np

from dmacompiler import (
    OverlayShape, BackEnd,
    DataTransfer, SyncStrategy,
    AieTile, TileType,
    AieDma, DmaDir, memtile_dma, shim_dma, core_dma, DmaChannel,
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop,
    compute_buffer_size,
    TransferParams, generate_transfer_params,
    pack_reconfig_transfers,
    generate_shim_data_transfer,
    run_layer_compilation,
    set_dev_gen, DevGen, config
    )

from dataflow_common import (
    overlay_8x4_dma_connections,
    overlay_stack_addr,
    clean_overlay,
    build_sim_overlay,
    shim_alloc,
    prm_shim_memory,
    prm_shim_mm2s,
    prm_memtile_memory,
    prm_memtile_s2mm,
    prm_memtile_mm2s,
    ceildiv,
    iceil
    )

from pad_common import (
    PadDims,
    pad_preproc_directives,
    # make_slice_dict,
    # split_cost
    )

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

MAX_SUBV_SIZE = 8192

def should_do_padding(dim_size: int, size_per_iter: int, spatial_factor: int,
                      col: int, loop: int, iter: int):
    # Case 1: spatial_factor == 1 (simple 1D split)
    if spatial_factor == 1:
        return iter == (loop - 1)
    # spatial_factor = 8 case:
    elements_per_iter = size_per_iter * spatial_factor
    full_iters = dim_size // elements_per_iter
    remaining = dim_size % elements_per_iter
    # Always identify the last piece even if remainder == 0
    if remaining == 0:
        # Perfectly divisible → last piece falls into:
        # iter = full_iters - 1, col = spatial_factor - 1 (e.g., col=7)
        return (iter == (full_iters - 1)) and (col == (spatial_factor - 1))
    # Partial last iteration exists:
    last_iter = full_iters
    used_cols = remaining // size_per_iter
    # If remainder fits exactly into k full columns:
    if remaining % size_per_iter == 0:
        last_col = used_cols - 1
    else:
        last_col = used_cols  # the partially filled column
    return (iter == last_iter) and (col == last_col)

def spatial_split(
    dim_size: int, dim_split: int, dim_gran: int,
    loop:int,
    split_factor, n_iter: int,
    col: int,
    should_do_pad: bool = False,
    pad_elem: int = 0,
    ) -> Tuple[int, int, int]:
    """_summary_

    Args:
        dim (int): the dim of the tensor
        dim_gran (int): granunity
        loop (int): the total loop split for that dim
        split_factor (_type_): =aie_cols for spatial split, otherwise = 1
        n_iter (int): iteration : [iter for iter in range(loop)]
        col (int): the clomn
        should_do_pad (bool, optional): Defaults to False.  when True, pad the last iteration
        pad_elem (int, optional): Defaults to 0, when should_do_pad true, add the stop and size for that pad.
        aie_cols (int, optional): Defaults to 1 (=aie_cols).

    Returns:
        Tuple[int, int, int]: _description_
    """
    #the split for each iteration
    col_valid = 0 if split_factor == 1 else col
    # dim_split = ceildiv(dim_size, loop * split_factor)
    assert dim_split % dim_gran == 0, f"the split should meet granunity, re=check tiler"
    dim_stride = split_factor * dim_split
    dim_start = col_valid * dim_split + n_iter * dim_stride
    dim_stop = dim_start + dim_split if dim_start + dim_split <= dim_size else dim_size
    dim_size = max(0, min(dim_stop, dim_size)) - max(0, min(dim_start, dim_size))
    if should_do_pad:
        dim_stop = dim_stop + pad_elem
        dim_size = dim_size + pad_elem
    return (dim_start, dim_stop, dim_size)


def row_split(
    dim_size: int,
    dim_split: int,
    dim_gran: int,
    split_factor,
    row: int,
    s2mm: bool = True,
    ) -> Tuple[int, int, int]:
    """_summary_

    Args:
        dim (int): the dim of the tensor, this is the Nim/Yim/Xim/Cim
        dim_split: the Nis/Yis/Xis/Cis
        dim_gran (int): granunity
        split_factor (_type_): =aie_rows for row split, otherwise = 1
        row (int): the row
        s2mm: or mm2s, the implementation will be diffirent
              s2mm receives all rows and stack them
              mm2s split the current pile, padding need when no sufficent for row split
    Returns:
        Tuple[int, int, int]: _description_
    """
    row_valid = 0 if split_factor == 1 else row
    assert dim_split % dim_gran == 0, f"the split should meet granunity, re=check tiler"
    dim_start = row_valid * dim_split
    if s2mm:
        dim_stop = dim_start + dim_split
    else:
        if dim_start >= dim_size:
            dim_start = dim_size - dim_split
            dim_stop = dim_size
        else:
            dim_start = dim_start
            dim_stop = dim_start + dim_split

    dim_size = max(0, min(dim_stop, dim_size)) - max(0, min(dim_start, dim_size))

    return (dim_start, dim_stop, dim_size)



def Yi_repeat_counts(dims: PadDims, dims_phase) -> List[int]:
    repeat_counts = [1 for _ in range(dims_phase)]
    return repeat_counts


def ifm_memtile_memory(dims: PadDims, col: int) -> List[str]:
    def fmt(i_n: int, i_y: int, i_x: int, i_c: int) -> str:
        _, _, N_size = spatial_split(dims.Ni, dims.Nim, dims.Nis_gran, dims.N_loop,
                                           dims.spatial_split_mode[0],
                                           i_n, col)
        _, _, Y_size = spatial_split(dims.Yi, dims.Yim, dims.Yis_gran, dims.Y_loop,
                                           dims.spatial_split_mode[1],
                                           i_y, col)
        _, _, X_size = spatial_split(dims.Xi, dims.Xim, dims.Xis_gran, dims.X_loop,
                                           dims.spatial_split_mode[2],
                                           i_x, col)
        _, _, C_size = spatial_split(dims.Ci, dims.Cim, dims.Cis_gran, dims.C_loop,
                                           dims.spatial_split_mode[3],
                                           i_c, col)
        return (
            f'Ni:{N_size} '
            f'Yi:{Y_size} '
            f'Xi:{X_size} '
            f'Ci:{C_size}'
        )
    fs = [fmt(i_n, i_y, i_x, i_c) for i_n in range(dims.N_loop) for i_y in range(dims.Y_loop) \
                                  for i_x in range(dims.X_loop) for i_c in range(dims.C_loop)]
    return fs

def ofm_memtile_memory(dims: PadDims, col: int) -> List[str]:
    def fmt(i_n: int, i_y: int, i_x: int, i_c: int) -> str:

        return (
            f'Ni:{dims.Nis * dims.row_split_mode[0]} '
            f'Yi:{dims.Yis * dims.row_split_mode[1]} '
            f'Xi:{dims.Xis * dims.row_split_mode[2]} '
            f'Ci:{dims.Cis * dims.row_split_mode[3]}'
        )
    fs = [fmt(i_n, i_y, i_x, i_c) for i_n in range(dims.N_loop) for i_y in range(dims.Y_loop) \
                                  for i_x in range(dims.X_loop) for i_c in range(dims.C_loop)]
    return fs


def ifm_memtile_s2mm(dims: PadDims, col: int) -> List[str]:
    def fmt(i_n: int, i_y: int, i_x: int, i_c: int) -> str:
        #get the size for mt split
        _, _, N_size = spatial_split(dims.Ni, dims.Nim, dims.Nis_gran, dims.N_loop,
                                           dims.spatial_split_mode[0],
                                           i_n, col)
        _, _, Y_size = spatial_split(dims.Yi, dims.Yim, dims.Yis_gran, dims.Y_loop,
                                           dims.spatial_split_mode[1],
                                           i_y, col)
        _, _, X_size = spatial_split(dims.Xi, dims.Xim, dims.Xis_gran, dims.X_loop,
                                           dims.spatial_split_mode[2],
                                           i_x, col)
        _, _, C_size = spatial_split(dims.Ci, dims.Cim, dims.Cis_gran, dims.C_loop,
                                           dims.spatial_split_mode[3],
                                           i_c, col)
        return (
            f'Ni:{0}:{N_size} '
            f'Yi:{0}:{Y_size} '
            f'Xi:{0}:{X_size} '
            f'Ci:{0}:{C_size}'
        )
    fs = [fmt(i_n, i_y, i_x, i_c) for i_n in range(dims.N_loop) for i_y in range(dims.Y_loop) \
                                  for i_x in range(dims.X_loop) for i_c in range(dims.C_loop)]
    return fs

def ifm_memtile_mm2s(dims: PadDims, col: int, row: int) -> List[str]:
    def fmt(i_n: int, i_y: int, i_x: int, i_c: int) -> str:
        _, _, N_size = spatial_split(dims.Ni, dims.Nim, dims.Nis_gran, dims.N_loop,
                                           dims.spatial_split_mode[0],
                                           i_n, col)
        _, _, Y_size = spatial_split(dims.Yi, dims.Yim, dims.Yis_gran, dims.Y_loop,
                                           dims.spatial_split_mode[1],
                                           i_y, col)
        _, _, X_size = spatial_split(dims.Xi, dims.Xim, dims.Xis_gran, dims.X_loop,
                                           dims.spatial_split_mode[2],
                                           i_x, col)
        _, _, C_size = spatial_split(dims.Ci, dims.Cim, dims.Cis_gran, dims.C_loop,
                                           dims.spatial_split_mode[3],
                                           i_c, col)
        N_start, N_stop, _ = row_split(N_size, dims.Nis, dims.Nis_gran,
                                        dims.row_split_mode[0], row,
                                        s2mm = False)
        Y_start, Y_stop, _ = row_split(Y_size, dims.Yis, dims.Yis_gran,
                                        dims.row_split_mode[1], row,
                                        s2mm = False)
        X_start, X_stop, _ = row_split(X_size, dims.Xis, dims.Xis_gran,
                                        dims.row_split_mode[2], row,
                                        s2mm = False)
        C_start, C_stop, _ = row_split(C_size, dims.Cis, dims.Cis_gran,
                                        dims.row_split_mode[3], row,
                                        s2mm = False)
        return (
            f'Ni:{N_start}:{N_stop} '
            f'Yi:{Y_start}:{Y_stop} '
            f'Xi:{X_start}:{X_stop} '
            f'Ci:{C_start}:{C_stop}'
        )
    fs = [fmt(i_n, i_y, i_x, i_c) for i_n in range(dims.N_loop) for i_y in range(dims.Y_loop) \
                                  for i_x in range(dims.X_loop) for i_c in range(dims.C_loop)]
    return fs




def ofm_memtile_s2mm(dims: PadDims, col: int, row: int) -> List[str]:
    def fmt(i_n: int, i_y: int, i_x: int, i_c: int) -> str:
        _, _, N_size = spatial_split(dims.Ni, dims.Nim, dims.Nis_gran, dims.N_loop,
                                           dims.spatial_split_mode[0],
                                           i_n, col)
        _, _, Y_size = spatial_split(dims.Yi, dims.Yim, dims.Yis_gran, dims.Y_loop,
                                           dims.spatial_split_mode[1],
                                           i_y, col)
        _, _, X_size = spatial_split(dims.Xi, dims.Xim, dims.Xis_gran, dims.X_loop,
                                           dims.spatial_split_mode[2],
                                           i_x, col)
        _, _, C_size = spatial_split(dims.Ci, dims.Cim, dims.Cis_gran, dims.C_loop,
                                           dims.spatial_split_mode[3],
                                           i_c, col)
        N_start, N_stop, _ = row_split(N_size, dims.Nis, dims.Nis_gran,
                                        dims.row_split_mode[0], row,
                                        s2mm = True)
        Y_start, Y_stop, _ = row_split(Y_size, dims.Yis, dims.Yis_gran,
                                        dims.row_split_mode[1], row,
                                        s2mm = True)
        X_start, X_stop, _ = row_split(X_size, dims.Xis, dims.Xis_gran,
                                        dims.row_split_mode[2], row,
                                        s2mm = True)
        C_start, C_stop, _ = row_split(C_size, dims.Cis, dims.Cis_gran,
                                        dims.row_split_mode[3], row,
                                        s2mm = True)
        return (
            f'Ni:{N_start}:{N_stop} '
            f'Yi:{Y_start}:{Y_stop} '
            f'Xi:{X_start}:{X_stop} '
            f'Ci:{C_start}:{C_stop}'
        )
    fs = [fmt(i_n, i_y, i_x, i_c) for i_n in range(dims.N_loop) for i_y in range(dims.Y_loop) \
                                  for i_x in range(dims.X_loop) for i_c in range(dims.C_loop)]
    return fs

def ofm_memtile_mm2s(dims: PadDims, col: int) -> List[str]:
    def fmt(i_n: int, i_y: int, i_x: int, i_c: int) -> str:
        _, _, N_size = spatial_split(dims.Ni, dims.Nim, dims.Nis_gran, dims.N_loop,
                                           dims.spatial_split_mode[0],
                                           i_n, col,
                                           should_do_padding(dims.Ni, dims.Nim,
                                                             dims.spatial_split_mode[0],
                                                             col, dims.N_loop, i_n), #should_do_padding
                                           dims.pad_N)
        _, _, Y_size = spatial_split(dims.Yi, dims.Yim, dims.Yis_gran, dims.Y_loop,
                                           dims.spatial_split_mode[1],
                                           i_y, col,
                                           should_do_padding(dims.Yi, dims.Yim,
                                                             dims.spatial_split_mode[1],
                                                             col, dims.Y_loop, i_y), #should_do_padding
                                           dims.pad_Y)
        _, _, X_size = spatial_split(dims.Xi, dims.Xim, dims.Xis_gran, dims.X_loop,
                                           dims.spatial_split_mode[2],
                                           i_x, col,
                                           should_do_padding(dims.Xi, dims.Xim,
                                                             dims.spatial_split_mode[2],
                                                             col, dims.X_loop, i_x), #should_do_padding
                                           dims.pad_X)
        _, _, C_size = spatial_split(dims.Ci, dims.Cim, dims.Cis_gran, dims.C_loop,
                                           dims.spatial_split_mode[3],
                                           i_c, col,
                                           should_do_padding(dims.Ci, dims.Cim,
                                                             dims.spatial_split_mode[3],
                                                             col, dims.C_loop, i_c), #should_do_padding
                                           dims.pad_C)
        return (
            f'Ni:{0}:{N_size} '
            f'Yi:{0}:{Y_size} '
            f'Xi:{0}:{X_size} '
            f'Ci:{0}:{C_size}'
        )
    fs = [fmt(i_n, i_y, i_x, i_c) for i_n in range(dims.N_loop) for i_y in range(dims.Y_loop) \
                                  for i_x in range(dims.X_loop) for i_c in range(dims.C_loop)]
    return fs





def ifm_shim_repeat_counts(dims: PadDims) -> List[int]:
    repeat_counts = [1 for _ in range(math.prod(dims.loop))]
    return repeat_counts

def wgt_shim_repeat_counts(dims: PadDims) -> List[int]:
    repeat_counts = [1] + [0] * (math.prod(dims.loop) -1 )
    return repeat_counts


def ifm_shim_memory(dims: PadDims) -> str:
    return f'Ni:{dims.input[0]} Yi:{dims.input[1]} Xi:{dims.input[2]} Ci:{dims.Cip}'


def ofm_shim_memory(dims: PadDims) -> str:
    return f'Ni:{dims.output[0]} Yi:{dims.output[1]} Xi:{dims.output[2]} Ci:{dims.Cop}'


def ifm_shim_mm2s(dims: PadDims, col: int) -> List[str]:
    def fmt(i_n: int, i_y: int, i_x: int, i_c: int) -> str:
        N_start, N_stop, _ = spatial_split(dims.Ni, dims.Nim, dims.Nis_gran, dims.N_loop,
                                           dims.spatial_split_mode[0],
                                           i_n, col)
        Y_start, Y_stop, _ = spatial_split(dims.Yi, dims.Yim, dims.Yis_gran, dims.Y_loop,
                                           dims.spatial_split_mode[1],
                                           i_y, col)
        X_start, X_stop, _ = spatial_split(dims.Xi, dims.Xim, dims.Xis_gran, dims.X_loop,
                                           dims.spatial_split_mode[2],
                                           i_x, col)
        C_start, C_stop, _ = spatial_split(dims.Ci, dims.Cim, dims.Cis_gran, dims.C_loop,
                                           dims.spatial_split_mode[3],
                                           i_c, col)
        return (
            f'Ni:{N_start}:{N_stop} '
            f'Yi:{Y_start}:{Y_stop} '
            f'Xi:{X_start}:{X_stop} '
            f'Ci:{C_start}:{C_stop}'
        )
    fs = [fmt(i_n, i_y, i_x, i_c) for i_n in range(dims.N_loop) for i_y in range(dims.Y_loop) \
                                  for i_x in range(dims.X_loop) for i_c in range(dims.C_loop)]
    return fs


def ofm_shim_s2mm(dims: PadDims, col: int) -> List[str]:
    def fmt(i_n: int, i_y: int, i_x: int, i_c: int) -> str:
        N_start, N_stop, _  = spatial_split(dims.Ni, dims.Nim, dims.Nis_gran, dims.N_loop,
                                           dims.spatial_split_mode[0],
                                           i_n, col,
                                           should_do_padding(dims.Ni, dims.Nim,
                                                             dims.spatial_split_mode[0],
                                                             col, dims.N_loop, i_n), #should_do_padding
                                           dims.pad_N)
        Y_start, Y_stop, _  = spatial_split(dims.Yi, dims.Yim, dims.Yis_gran, dims.Y_loop,
                                           dims.spatial_split_mode[1],
                                           i_y, col,
                                           should_do_padding(dims.Yi, dims.Yim,
                                                             dims.spatial_split_mode[1],
                                                             col, dims.Y_loop, i_y), #should_do_padding
                                           dims.pad_Y)
        X_start, X_stop, _ = spatial_split(dims.Xi, dims.Xim, dims.Xis_gran, dims.X_loop,
                                           dims.spatial_split_mode[2],
                                           i_x, col,
                                           should_do_padding(dims.Xi, dims.Xim,
                                                             dims.spatial_split_mode[2],
                                                             col, dims.X_loop, i_x), #should_do_padding
                                           dims.pad_X)
        C_start, C_stop, _ = spatial_split(dims.Ci, dims.Cim, dims.Cis_gran, dims.C_loop,
                                           dims.spatial_split_mode[3],
                                           i_c, col,
                                           should_do_padding(dims.Ci, dims.Cim,
                                                             dims.spatial_split_mode[3],
                                                             col, dims.C_loop, i_c), #should_do_padding
                                           dims.pad_C)
        return (
            f'Ni:{N_start}:{N_stop} '
            f'Yi:{Y_start}:{Y_stop} '
            f'Xi:{X_start}:{X_stop} '
            f'Ci:{C_start}:{C_stop}'
        )
    fs = [fmt(i_n, i_y, i_x, i_c) for i_n in range(dims.N_loop) for i_y in range(dims.Y_loop) \
                                  for i_x in range(dims.X_loop) for i_c in range(dims.C_loop)]
    return fs

def wgt_shim_memory(dims: PadDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'

def wgt_shim_mm2s(dims: PadDims) -> str:
    return f'Subv'

def wgt_memtile_memory(dims: PadDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'

def wgt_memtile_s2mm(dims: PadDims) -> str:
    return f'Subv'

def wgt_memtile_mm2s(dims: PadDims) -> str:
    return f'Subv'

def align_core_addr(core_addr: int):
    core_addr = ceildiv(core_addr, 64) * 64
    return core_addr

def gen_qdq_params_a8(subv_elems: int, index: int, quant_offset: int,
                      scratch_buffer_addr: int, fix_point_bits: int,
                      qdq_mode: int, output_addr: int):
    is_int16 = 0 if fix_point_bits == 8 else 1

    dq_zp_elem_idx = 2 * index
    dq_sc_elem_idx = dq_zp_elem_idx + 1
    q_zp_elem_idx = quant_offset
    q_sc_elem_idx = quant_offset + 1
    dq_enable_idx = quant_offset + 2
    q_enable_idx = quant_offset + 3
    scratch_buffer_addr_core = scratch_buffer_addr if (qdq_mode == 2 and is_int16 == False) else output_addr

    return ( subv_elems.to_bytes(length=2, byteorder='little', signed=False)
    + scratch_buffer_addr_core.to_bytes(length=2, byteorder='little', signed=False)
    + is_int16.to_bytes(length=2, byteorder='little', signed=False)
    + dq_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + dq_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + q_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + q_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
    + dq_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
    + q_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
    )


def gen_qdq_params(subv_elems: int, qdq_prm_addr: int, index: int, quant_offset: int):

        dq_zp_elem_idx = 2 * index
        dq_sc_elem_idx = dq_zp_elem_idx + 1
        q_zp_elem_idx = quant_offset
        q_sc_elem_idx = quant_offset + 1
        dq_enable_idx = quant_offset + 2
        q_enable_idx = quant_offset + 3

        return ( subv_elems.to_bytes(length=2, byteorder='little', signed=False)
        + dq_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + dq_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + q_zp_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + q_sc_elem_idx.to_bytes(length=2, byteorder='little', signed=False)
        + dq_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
        + q_enable_idx.to_bytes(length=2, byteorder='little', signed=False)
        )

def compile_dataflow(
    dims: PadDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):

    # Shim Buffer Allocation
    pad_shim_alloc = shim_alloc()

    # shim_transfer, mt_ifm_transfer, mt_ofm_transfer = split_cost(dims)
    dims_phase = math.prod(dims.loop)

    prm_memtile_size = compute_buffer_size(prm_memtile_memory(dims))
    wgt_subv_size = dims.wgt_subv_size

    ifm_memtile_size = dims.ifm_memtile_size
    ofm_memtile_size = dims.ofm_memtile_size
    ifm_core_size = dims.ifm_core_size
    ofm_core_size = dims.ofm_core_size

    prm_memtile_addr = 0
    wgt_memtile_addr_ping = prm_memtile_addr + prm_memtile_size
    if dims.ping_pong:
        ifm_memtile_addr_ping = wgt_memtile_addr_ping + wgt_subv_size
        ifm_memtile_addr_pong = ifm_memtile_addr_ping + ifm_memtile_size
        ofm_memtile_addr_ping = ifm_memtile_addr_pong + ifm_memtile_size
        ofm_memtile_addr_pong = ofm_memtile_addr_ping + ofm_memtile_size
        ifm_memtile_addrs = [ifm_memtile_addr_ping, ifm_memtile_addr_pong]
        ofm_memtile_addrs = [ofm_memtile_addr_ping, ofm_memtile_addr_pong]
        assert ofm_memtile_addr_pong + ofm_memtile_size <= config.MAX_MEMTILE_ADDR
    else:
        ifm_memtile_addr_ping = wgt_memtile_addr_ping + wgt_subv_size
        ofm_memtile_addr_ping = ifm_memtile_addr_ping + ifm_memtile_size
        ifm_memtile_addrs = [ifm_memtile_addr_ping]
        ofm_memtile_addrs = [ofm_memtile_addr_ping]
        assert ofm_memtile_addr_ping + ofm_memtile_size <= config.MAX_MEMTILE_ADDR
    if dims.qdq_mode != 3:
        CoreWgtPingAddr = 0
        scratch_buf_addr = align_core_addr(CoreWgtPingAddr + wgt_subv_size)
        if dims.ping_pong:
            CoreIfmPingAddr = align_core_addr(scratch_buf_addr + dims.scratch_buf_size)
            CoreIfmPongAddr = align_core_addr(CoreIfmPingAddr + ifm_core_size)
            CoreOfmPingAddr = align_core_addr(CoreIfmPongAddr + ifm_core_size)
            CoreOfmPongAddr = align_core_addr(CoreOfmPingAddr + ofm_core_size)
            assert CoreOfmPongAddr + ofm_core_size <= overlay_stack_addr()
        else:
            CoreIfmPingAddr = align_core_addr(scratch_buf_addr + dims.scratch_buf_size)
            CoreOfmPingAddr = align_core_addr(CoreIfmPingAddr + ifm_core_size)
            CoreIfmPongAddr = None
            CoreOfmPongAddr = None
            assert CoreOfmPingAddr + ofm_core_size <= overlay_stack_addr()
    else:
        CoreWgtPingAddr = None
        scratch_buf_addr= None
        CoreIfmPingAddr = None
        CoreIfmPongAddr = None
        CoreOfmPingAddr = None
        CoreOfmPongAddr = None


    Tn = dims_phase

    # for core loop, each configration can support up to 1024, exceeding that, it need to re-config the core
    # even the configuration is same.
    X = Tn
    Tx = 1
    T_remain = 0
    while X > 1024:
        X = X //2
    Tx = Tn // X
    T_remain = Tn % X

    # QDQ to be added.
    run_kernel = 'run_combined_qdq' if dims.fix_point_bits == 16 else 'run_combined_qdq_a8'
    if dims.qdq_mode != 3:
        kernel_params = gen_qdq_params(iceil(math.prod(dims.core_subv), 64),
                                    CoreWgtPingAddr, 0, 2) if dims.fix_point_bits == 16 else \
                        gen_qdq_params_a8(iceil(math.prod(dims.core_subv), 64), 0, 2,
                                        scratch_buf_addr, dims.fix_point_bits,
                                        dims.qdq_mode, CoreOfmPingAddr)
    core_instrs = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, CoreIfmPongAddr, ifm_core_size),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreWgtPingAddr, None, wgt_subv_size),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, CoreOfmPongAddr, ofm_core_size),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        Loop(Tn, [
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
            CallKernel(run_kernel, kernel_params),
            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        ])
    ] if dims.qdq_mode != 3 else []

    memtile_transfers = []
    memtile_prm_transfer = [
        DataTransfer(
            [1] + [0] * (dims_phase - 1),
            AieTile(TileType.Memtile, col), [prm_memtile_addr], prm_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm())],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
            sync_strategy=SyncStrategy.Serial_M_to_N,
            ) for col in range(dims.aie_cols)
        ]

    memtile_ifm_transfer = [
        DataTransfer(
            Yi_repeat_counts(dims, dims_phase),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims, col),  # pylint: disable=E1121
                ifm_memtile_s2mm(dims, col),    # pylint: disable=E1121
                [1] * dims_phase,
                dims.ifm_bits)],
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ifm_memtile_memory(dims, col),   # pylint: disable=E1121
                ofm_memtile_mm2s(dims, col),  # pylint: disable=E1121
                [1] * dims_phase,
                dims.ofm_bits)]
            ) for col in range(dims.aie_cols)
        ] if not dims.is_qdq else [
            DataTransfer(
                Yi_repeat_counts(dims, dims_phase),
                AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
                [pack_reconfig_transfers(
                    memtile_dma(col, DmaDir.S2MM, 0),
                    ifm_memtile_memory(dims, col),
                    ifm_memtile_s2mm(dims, col),
                    [1] * dims_phase,
                    dims.ifm_bits)],
                [pack_reconfig_transfers(
                    memtile_dma(col, DmaDir.MM2S, row),
                    ifm_memtile_memory(dims, col),
                    ifm_memtile_mm2s(dims, col, row),
                    [1] * dims_phase,
                    dims.ifm_bits
                    ) for row in range(dims.aie_rows)],
                sync_strategy=SyncStrategy.Parallel_1_to_N,
                ) for col in range(dims.aie_cols)
            ]

    memtile_wgt_transfer = [] if not dims.is_qdq else [
        DataTransfer(
            [1] + [0] * (dims_phase - 1),
            AieTile(TileType.Memtile, col), [wgt_memtile_addr_ping], dims.wgt_subv_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(dims))],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims))],
            ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols==8 else 1))
        ]

    memtile_ofm_transfer = [] if not dims.is_qdq else[
        DataTransfer(
            Yi_repeat_counts(dims, dims_phase),
            AieTile(TileType.Memtile, col), ofm_memtile_addrs, ifm_memtile_size,
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                ofm_memtile_memory(dims, col),
                ofm_memtile_s2mm(dims, col, row),
                [1] * dims_phase,
                dims.ofm_bits,
            ) for row in range(dims.aie_rows)],
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ofm_memtile_memory(dims, col),
                ofm_memtile_mm2s(dims, col),
                [1] * dims_phase,
                dims.ofm_bits,
            )],
            sync_strategy=SyncStrategy.Parallel_N_to_1,
            ) for col in range(dims.aie_cols)
        ]

    memtile_transfers += (memtile_prm_transfer + memtile_ifm_transfer + memtile_wgt_transfer + memtile_ofm_transfer)

    shim_transfers = []
    shim_prm_transfer = [
        generate_shim_data_transfer(
            [1] + [0] * (dims_phase - 1),
            shim_dma(col, DmaDir.MM2S, 0), pad_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col)
            ) for col in range(dims.aie_cols)
        ]

    shim_ifm_size = dims.input[0] * dims.input[1] * dims.input[2] * dims.Cip * (dims.ifm_bits // 8)
    shim_ifm_transfer = [
        DataTransfer(
            ifm_shim_repeat_counts(dims),
            AieTile(TileType.Shim, col), [pad_shim_alloc.ifm_buffer_id], shim_ifm_size,
            [],
            [pack_reconfig_transfers(
                shim_dma(col, DmaDir.MM2S, 0),
                [ifm_shim_memory(dims)] * dims_phase,
                ifm_shim_mm2s(dims, col),
                [1] * dims_phase,
                dims.ifm_bits,
                )] ,
            ) for col in range(dims.aie_cols)
        ]

    shim_wgt_transfer = [
        generate_shim_data_transfer(
            [1] + [0] * (dims_phase - 1),
            shim_dma(col, DmaDir.MM2S, 1), pad_shim_alloc.wgt_buffer_id,
            wgt_shim_memory(dims),
            wgt_shim_mm2s(col)
            ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols==8 else 1) )
    ] if dims.is_qdq else []

    shim_ofm_size = dims.output[0] * dims.output[1] * dims.output[2] * dims.Cop * dims.ofm_bits // 8
    shim_ofm_transfer = [
        DataTransfer(
            ifm_shim_repeat_counts(dims),
            AieTile(TileType.Shim, col), [pad_shim_alloc.ofm_buffer_id], shim_ofm_size,
            [pack_reconfig_transfers(
                shim_dma(col, DmaDir.S2MM, 0),
                [ofm_shim_memory(dims)] * dims_phase,
                ofm_shim_s2mm(dims, col),
                [1] * dims_phase,
                dims.ofm_bits
                )],
            [],
            ) for col in range(dims.aie_cols)
        ]

    shim_transfers += (shim_prm_transfer + shim_ifm_transfer + shim_wgt_transfer + shim_ofm_transfer)

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections(),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
        enable_debug_print=True
    )


def main():
    back_end = BackEnd.Adf
    kernel_names = ['run_combined_qdq']
    kernel_includes = ['super.hh', 'qdq/wrapper_qdq.cc']
    aie_cols, aie_rows = 8, 4
    input_shape = [1, 32, 32, 768]
    output_shape = [1, 35, 35, 768]
    pad_dims = [output_shape[i] - input_shape[i] for i in range(len(input_shape))]
    Cip = iceil(input_shape[3], 8)
    Cop = iceil(output_shape[3], 8)
    ifm_bits = 16
    ofm_bits = 16
    fix_point_bits = 16
    qdq_mode = 2    # 0: DEQUANT; 1: QUANT; 2: BOTH; 3: NONE
    in_gran  = [1, 1, 1, 32//ifm_bits]
    out_gran = [1, 1, 1, 32//ofm_bits]
    pad_limit = [0, 16, 32, 64] # for aie_2p
    wgt_subv_size = 64
    param_subv_size=1024
    ifm_memtile_size = 49152
    ofm_memtile_size = 49152
    ifm_core_size = 12288
    ofm_core_size = 12288
    scratch_buf_size = 0
    core_subv = [1, 1, 8, 768]
    mt_subv = [1, 4, 8, 768]
    loop = [1, 1, 4, 1]
    ping_pong = True
    spatial_split_mode =[1, 8, 1, 1]
    row_split_mode = [1, 4, 1, 1]

    dims = PadDims(
        aie_cols=aie_cols,
        aie_rows=aie_rows,
        input = input_shape, Cip = Cip,
        output = output_shape, Cop = Cop,
        pad_dims = pad_dims,
        in_gran = in_gran,
        out_gran = out_gran,
        pad_limit = pad_limit,
        ifm_bits= ifm_bits, ofm_bits= ofm_bits,
        wgt_subv_size= wgt_subv_size,
        qdq_mode = qdq_mode,
        fix_point_bits = fix_point_bits,
        param_subv_size= param_subv_size,
        ifm_memtile_size= ifm_memtile_size,
        ofm_memtile_size= ofm_memtile_size,
        ifm_core_size= ifm_core_size,
        ofm_core_size= ofm_core_size,
        scratch_buf_size=scratch_buf_size,

        core_subv = core_subv,
        mt_subv = mt_subv,
        loop = loop,
        ping_pong = ping_pong,
        spatial_split_mode= spatial_split_mode,
        row_split_mode = row_split_mode,

        )

    clean_overlay()
    compile_dataflow(
        dims,
        back_end,
        kernel_names,
        kernel_includes
    )
    build_sim_overlay(back_end, 'slice_main.cpp', pad_preproc_directives(dims, back_end))

if __name__ == '__main__':
    main()
