import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels', 'transpose'))
from typing import List, Tuple, Optional, Type
from transpose_kernel_params import TransposeSubvDims, setup_kernel_params
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, SyncStrategy, \
    AieTile, TileType, \
    AieDma, DmaDir, memtile_dma, shim_dma, core_dma, DmaChannel, \
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    compute_buffer_size, \
    TransferParams, generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_stack_addr, \
    ceildiv, \
    clean_overlay, \
    build_sim_overlay, \
    shim_alloc, \
    prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s

from transpose_common import \
    TransposeKernelDims, \
    iceil, \
    transpose_preproc_directives

from transpose_utils import \
    pack_transfers, \
    Yi_slice_stride, \
    Ni_slice_stride, \
    Ci_slice_stride, \
    YiXi_slice, \
    YoXo_slice, \
    Co_slice_stride, \
    YXC_slice_mt, \
    Xi_slice_stride_shim, \
    Xi_slice_stride_mt



set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = None



def wgt_shim_memory(dims: TransposeKernelDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'

def wgt_shim_mm2s() -> str:
    return 'Subv'

def wgt_memtile_memory(dims: TransposeKernelDims) -> str:
    return f'Subv:{dims.wgt_subv_size}'

def wgt_memtile_s2mm() -> str:
    return 'Subv'

def wgt_memtile_mm2s() -> str:
    return 'Subv'

def ifm_shim_repeat_counts(dims: TransposeKernelDims) -> List[int]:
    if dims.N_dim_optimize:
        repeat_counts = [1] * (dims.loop_exclude_N) * dims.B_loop
    else:
        repeat_counts = [1 for _ in range(dims.repeat_scale * dims.B_loop)]
    return repeat_counts

def ofm_shim_repeat_counts(dims: TransposeKernelDims, scale: int) -> List[int]:
    if dims.N_dim_optimize:
        repeat_counts = [1] * (dims.loop_exclude_N) * dims.B_loop
    else:
        repeat_counts = [1 for _ in range(dims.repeat_scale * dims.B_loop)]
    return repeat_counts

def ifm_memtile_repeat_counts(dims: TransposeKernelDims, col: int, scale: int) -> List[int]:
    if dims.N_dim_optimize:
        repeat_counts = [dims.N_loop * scale] * (dims.loop_exclude_N) * dims.B_loop
    else:
        repeat_counts = [1 for _ in range(dims.repeat_scale * dims.B_loop)]
    return repeat_counts

def ifm_shim_memory(dims: TransposeKernelDims):
    return f"B:{dims.batch_size} N:{dims.Nip} Y:{dims.Yip} X:{dims.Xip} C:{dims.Cip}"

def ifm_memtile_memory(dims: TransposeKernelDims, col: int):
    def fmt(yn: int, xn: int, cn: int):
        _, _, Y_size = Yi_slice_stride(dims, col, yn, dims.ifm_bits)
        X_start, X_stop, X_size = Xi_slice_stride_shim(dims, col, xn)
        _, _, Ci_size = Ci_slice_stride(dims, cn)
        _, _, Yi_size = YiXi_slice(dims, yn, not dims.is_Y8_split)
        _, _, Xi_size = YiXi_slice(dims, xn, not dims.is_Y8_split)
        if Y_size <= 0:
            Y_size = dims.Yis
        if X_size <= 0:
            X_size = dims.Xi_gran
        if Ci_size <=0:
            Ci_size = dims.Cis
        Yi_size = Y_size if dims.is_Y8_split else Yi_size
        Xi_size = Xi_size if dims.is_Y8_split else X_size
        return f"N:{dims.Nim} Y:{Yi_size} X:{Xi_size} C:{Ci_size}"
    if dims.perm[3] == 0 and dims.N_innermost_Y8 == 1:
        fs = [fmt(yn, xn, cn)    for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for xn in range(dims.X_loop) for yn in range(dims.Y_loop) for cn in range(dims.C_loop)]
    else:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for yn in range(dims.Y_loop) for xn in range(dims.X_loop) for cn in range(dims.C_loop)]
    return fs


def ofm_memtile_memory(dims: TransposeKernelDims, col: int):
    def fmt(yn: int, xn: int, cn: int):
        _, _, Y_size = Yi_slice_stride(dims, 0, 0, dims.ofm_bits)
        _, _, X_size = Xi_slice_stride_shim(dims, 0, 0)
        Xo_split = ceildiv(X_size, dims.aie_rows)
        _, _, C_size = Ci_slice_stride(dims, 0)
        #bcasue we don't do X split to core here for Y8 transfer
        Xo_size = dims.Xis
        Yi_size = dims.Yis # because we don't do Y split to core
        Xo_split = iceil(Xo_split, 32 // dims.ofm_bits) if dims.perm[3] == 2 else Xo_split
        padded_x_memory = Xo_split * dims.aie_rows
        is_Y_split = (Y_size >= C_size) or (C_size % (32 // dims.ifm_bits) !=0)
        _, _, _, _, _, _, Yo_size, _, Co_size, _, _, _ = \
            YXC_slice_mt(dims, Y_size, C_size, 0, is_Y_split = is_Y_split)
        Yo_size = Yo_size if dims.is_Y8_split else Yi_size
        Xo_size = Xo_size if dims.is_Y8_split else padded_x_memory
        Co_size = Co_size if dims.is_Y8_split else C_size
        dim_list = [f"N:{dims.Nom}",
                    f"Y:{Yo_size}",
                    f"X:{Xo_size}",
                    f"C:{Co_size}"]
        return " ".join(dim_list[i] for i in dims.perm)
    if dims.perm[3] == 0 and dims.N_innermost_Y8 == 1:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for xn in range(dims.X_loop) for yn in range(dims.Y_loop) for cn in range(dims.C_loop)]
    else:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for yn in range(dims.Y_loop) for xn in range(dims.X_loop) for cn in range(dims.C_loop)]
    return fs

def ofm_memtile_s2mm(dims: TransposeKernelDims, col: int, row: int) -> str:
    def fmt(yn: int, xn: int, cn: int):
        _, _, Y_size = Yi_slice_stride(dims, 0, 0, dims.ofm_bits)
        X_start, X_stop, _ = Xi_slice_stride_mt(dims, col, row, xn, dims.ofm_bits, not dims.is_Y8_split)
        _, _, C_size =Ci_slice_stride(dims, 0)
        #because we don't know X split to core for Y8 split
        if dims.is_Y8_split:
            X_start = 0
            X_stop  = dims.Xis
        Yi_size = dims.Yis # because we don't do Y split to core when X8 split
        """YXC combined split
        1. considering the zero padding, don't split the dim which transposed to inner most dim
        2. this is Y8 split cross columns,
           1): Y could be inner most after tranpose, but not padding needed;
           2): X could be inner most after transpose, like X=3, it is very hard to split X to rows
                 while keep zero padding capability
           3)  C after transpose still be inner most will not go through this dataflow routine
        """
        is_Y_split = (Y_size >= C_size) or (C_size % (32 // dims.ifm_bits) !=0)
        Y_start, Y_stop, _, _, C_start, C_stop, _, _, _, _, _, _ = \
                    YXC_slice_mt(dims, Y_size, C_size, row, True, is_Y_split = is_Y_split)
        Yo_start = Y_start if dims.is_Y8_split else 0
        Yo_stop = Y_stop if dims.is_Y8_split else Yi_size
        Xo_start = X_start
        Xo_stop = X_stop
        Co_start = C_start if dims.is_Y8_split else 0
        Co_stop = C_stop if dims.is_Y8_split else C_size
        dim_list = [f"N:0:{dims.Nom}",
                    f"Y:{Yo_start}:{Yo_stop}",
                    f"X:{Xo_start}:{Xo_stop}",
                    f"C:{Co_start}:{Co_stop}"]
        return " ".join(dim_list[i] for i in dims.perm)
    if dims.perm[3] == 0 and dims.N_innermost_Y8 == 1:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for xn in range(dims.X_loop) for yn in range(dims.Y_loop) for cn in range(dims.C_loop)]
    else:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for yn in range(dims.Y_loop) for xn in range(dims.X_loop) for cn in range(dims.C_loop)]
    return fs

def ofm_memtile_mm2s(dims: TransposeKernelDims, col) -> str:
    def fmt(yn: int, xn: int, cn: int):
        _, _, Y_size = Yi_slice_stride(dims, col, yn, dims.ofm_bits)
        _, _, X_size = Xi_slice_stride_shim(dims, col, xn)
        _, _, Ci_size = Co_slice_stride(dims, cn)
        _, _, Yi_size = YoXo_slice(dims, yn, not dims.is_Y8_split)
        # #NOTE: need to write down why here, remove it temporaly
        # if dims.perm[3] !=0 and not dims.is_Y8_split:
        #     if Yi_size != iceil(Yi_size, 32 // dims.ofm_bits):
        #         Yi_size = iceil(Yi_size, 32 // dims.ofm_bits)
        _, _, Xo_size = YoXo_slice(dims, xn, not dims.is_Y8_split)
        if Y_size <= 0:
            Y_stop = 0
        else:
            Y_stop = Y_size
        if X_size <= 0:
            X_stop = 0
        else:
            X_stop = X_size
        Yo_stop = Y_stop if dims.is_Y8_split else Yi_size
        Xo_stop = Xo_size if dims.is_Y8_split else X_stop
        dim_list = [f"N:0:{dims.Nom}",
                    f"Y:0:{Yo_stop}",
                    f"X:0:{Xo_stop}",
                    f"C:0:{Ci_size}"]
        return " ".join(dim_list[i] for i in dims.perm)
    if dims.perm[3] == 0 and dims.N_innermost_Y8 == 1:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for xn in range(dims.X_loop) for yn in range(dims.Y_loop) for cn in range(dims.C_loop)]
    else:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for yn in range(dims.Y_loop) for xn in range(dims.X_loop) for cn in range(dims.C_loop)]
    return fs

def ofm_memtile_mm2s_no_kernel(dims: TransposeKernelDims, col) -> str:
    def fmt(yn: int, xn: int, cn: int):
        _, _, Y_size = Yi_slice_stride(dims, col, yn, dims.ofm_bits)
        _, _, X_size = Xi_slice_stride_shim(dims, col, xn)
        _, _, Ci_size = Co_slice_stride(dims, cn)
        _, _, Yi_size = YoXo_slice(dims, yn, not dims.is_Y8_split)
        # #NOTE: need to write down why here, remove it temporaly
        # if dims.perm[3] !=0 and not dims.is_Y8_split:
        #     if Yi_size != iceil(Yi_size, 32 // dims.ofm_bits):
        #         Yi_size = iceil(Yi_size, 32 // dims.ofm_bits)
        _, _, Xo_size = YoXo_slice(dims, xn, not dims.is_Y8_split)
        if Y_size <= 0:
            Y_stop = 0
        else:
            Y_stop = Y_size
        if X_size <= 0:
            X_stop = 0
        else:
            X_stop = X_size
        Yo_stop = Y_stop if dims.is_Y8_split else Yi_size
        Xo_stop = Xo_size if dims.is_Y8_split else X_stop
        dim_list = [f"N:0:{dims.Nom}",
                    f"Y:0:{Yo_stop}",
                    f"X:0:{Xo_stop}",
                    f"C:0:{Ci_size}"]
        return " ".join(dim_list[i] for i in dims.perm)
    if dims.perm[3] == 0 and dims.N_innermost_Y8 == 1:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for xn in range(dims.X_loop) for yn in range(dims.Y_loop) for cn in range(dims.C_loop)]
    else:
        fs = [fmt(yn, xn, cn)   for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for yn in range(dims.Y_loop) for xn in range(dims.X_loop) for cn in range(dims.C_loop)]
    return fs


def ofm_shim_memory(dims: TransposeKernelDims):
    # Original 4D dims in NHWC order
    dim_list = [
        f"N:{dims.Nop}",
        f"Y:{dims.Yop}",
        f"X:{dims.Xop}",
        f"C:{dims.Cop}",
    ]
    # Insert batch dimension first
    parts = [f"B:{dims.batch_size}"]
    # Add transposed NHWC dims
    parts += [dim_list[i] for i in dims.perm]

    return " ".join(parts)

def ofm_shim_s2mm(dims: TransposeKernelDims, col: int) -> str:
    def fmt(bn: int, nn: int, yn: int, xn: int, cn: int):
        Y_start, Y_stop, _ = Yi_slice_stride(dims, col, yn, dims.ofm_bits)
        X_start, X_stop, _ = Xi_slice_stride_shim(dims, col, xn)
        N_start, N_stop, _ = Ni_slice_stride(dims, nn)
        if dims.perm[3] == 0:
            N_start = 0
            N_stop = dims.Nop
        Ci_start, Ci_stop, _ = Co_slice_stride(dims, cn)
        Yo_start, Yo_stop, _ = YoXo_slice(dims, yn, not dims.is_Y8_split)
        Xi_start, Xi_stop, _ = YoXo_slice(dims, xn, not dims.is_Y8_split)
        #consolidate Y8 and X8 split
        Yo_start = Y_start if dims.is_Y8_split else Yo_start
        Yo_stop = min(dims.Yop, Y_stop) if dims.is_Y8_split else Yo_stop
        Xo_start = Xi_start if dims.is_Y8_split else X_start
        Xo_stop = Xi_stop if dims.is_Y8_split else min(dims.Xop, X_stop)
        B_start = bn if not dims.batch_to_repeat else 0
        B_stop = bn + 1 if not dims.batch_to_repeat else dims.batch_size
        B_step = 1
        axis_map = {
            0: ("N", 0 if dims.N_dim_optimize else N_start, dims.Nop if dims.N_dim_optimize else min(dims.Nop, N_stop)),
            1: ("Y", Yo_start, Yo_stop),
            2: ("X", Xo_start, Xo_stop),
            3: ("C", Ci_start, Ci_stop),
            }
        try:
            parts = []
            # 1) Always insert batch dimension first
            parts.append(f"B:{B_start}:{B_stop}:{B_step}")
            # 2) Follow perm sequence and append mapped NHWC dimensions
            for i in dims.perm:
                axis = axis_map[i]
                name, start, stop = axis
                parts.append(f"{name}:{start}:{stop}")
            return " ".join(parts)
        except KeyError:
            raise AssertionError(
                f"Current data path doesn't support perm: {dims.perm} "
                f"in Kernel mode, try data path only mode"
            )

    if dims.perm[3] == 0 and dims.N_innermost_Y8 == 1:
        fs = [fmt(bn, nn, yn, xn, cn)  for bn in range(dims.B_loop) \
            for nn in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for xn in range(dims.X_loop) for yn in range(dims.Y_loop) for cn in range(dims.C_loop)]
    else:
        fs = [fmt(bn, nn, yn, xn, cn)  for bn in range(dims.B_loop) \
            for nn in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for yn in range(dims.Y_loop) for xn in range(dims.X_loop) for cn in range(dims.C_loop)]
    return fs

def ifm_memtile_s2mm(dims: TransposeKernelDims, col: int) -> str:
    def fmt(yn: int, xn: int,  cn: int):
        _, _, Y_size = Yi_slice_stride(dims, col, yn, dims.ifm_bits)
        _, _, X_size = Xi_slice_stride_shim(dims, col, xn)
        _, _, Ci_size = Ci_slice_stride(dims, cn)
        _, _, Yi_size = YiXi_slice(dims, yn, not dims.is_Y8_split)
        _, _, Xi_size = YiXi_slice(dims, xn, not dims.is_Y8_split)
        if (Y_size <= 0 and dims.is_Y8_split) or (X_size <= 0 and not dims.is_Y8_split):
            N_stop = 0
            Y_stop = 0
            X_stop = 0
            C_stop = 0
        else:
            N_stop = dims.Nim
            Y_stop = Y_size if dims.is_Y8_split else Yi_size
            X_stop = Xi_size if dims.is_Y8_split else X_size
            C_stop = Ci_size
        return f"N:0:{N_stop} Y:0:{Y_stop} X:0:{X_stop} C:0:{C_stop}"
    if dims.perm[3] == 0 and dims.N_innermost_Y8 == 1:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for xn in range(dims.X_loop) for yn in range(dims.Y_loop) for cn in range(dims.C_loop)]
    else:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for yn in range(dims.Y_loop) for xn in range(dims.X_loop) for cn in range(dims.C_loop)]
    return fs


def ifm_shim_mm2s(dims: TransposeKernelDims, col: int) -> str:
    def fmt(bn: int, nn: int, yn: int, xn: int, cn: int):
        Y_start, Y_stop, Y_size = Yi_slice_stride(dims, col, yn, dims.ifm_bits)
        X_start, X_stop, X_size = Xi_slice_stride_shim(dims, col, xn)
        N_start, N_stop, _ = Ni_slice_stride(dims, nn)
        Ci_start, Ci_stop, _ = Ci_slice_stride(dims, cn)
        Yi_start, Yi_stop, _ =YiXi_slice(dims, yn, not dims.is_Y8_split)
        Xi_start, Xi_stop, _ = YiXi_slice(dims, xn, not dims.is_Y8_split)
        #consolidate Y8 and X8
        Yi_start = Y_start if dims.is_Y8_split else Yi_start
        Yi_stop = min(dims.Yip, Y_stop) if dims.is_Y8_split else Yi_stop
        Xi_start = Xi_start if dims.is_Y8_split else X_start
        Xi_stop = Xi_stop if dims.is_Y8_split else min(dims.Xip, X_stop)
        B_start = bn if not dims.batch_to_repeat else 0
        B_stop = bn + 1 if not dims.batch_to_repeat else dims.batch_size
        B_step = 1
        if dims.N_dim_optimize:
             return (
                f"B:{B_start}:{B_stop}:{B_step} "
                f"N:{0}:{dims.Nip} "
                f"Y:{Yi_start}:{Yi_stop} X:{Xi_start}:{Xi_stop} C:{Ci_start}:{Ci_stop}"
            )
        else:
            return (
                f"B:{B_start}:{B_stop}:{B_step} "
                f"N:{N_start}:{min(dims.Nip, N_stop)} "
                f"Y:{Yi_start}:{Yi_stop} X:{Xi_start}:{Xi_stop} C:{Ci_start}:{Ci_stop}"
            )

    if dims.perm[3] == 0 and dims.N_innermost_Y8 == 1:
        fs = [fmt(bn, nn, yn, xn, cn) for bn in range(dims.B_loop) \
            for nn in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for xn in range(dims.X_loop) for yn in range(dims.Y_loop) for cn in range(dims.C_loop)]
    else:
        fs = [fmt(bn, nn, yn, xn, cn) for bn in range(dims.B_loop) \
            for nn in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for yn in range(dims.Y_loop) for xn in range(dims.X_loop) for cn in range(dims.C_loop)]
    return fs

def ifm_memtile_mm2s(dims: TransposeKernelDims, col: int,  row: int) -> str:
    def fmt(yn: int, xn: int,  cn: int):
        _, _, Y_size = Yi_slice_stride(dims, col, yn, dims.ifm_bits)
        X_start, X_stop, _ = Xi_slice_stride_mt(dims, col, row, xn, dims.ifm_bits)
        _, _, C_size =Ci_slice_stride(dims, cn if dims.is_Y8_split else 0)
        if dims.is_Y8_split:
            X_start = 0
            X_stop  = dims.Xis
        """YXC combined split
        1. considering the zero padding, don't split the dim which transposed to inner most dim
        2. this is Y8 split cross columns,
           1): Y could be inner most after tranpose, but not padding needed;
           2): X could be inner most after transpose, like X=3, it is very hard to split X to rows
                 while keep zero padding capability
           3)  C after transpose still be inner most will not go through this dataflow routine
        """
        is_Y_split = (dims.Yim >= dims.Cim) or (dims.Cim % dims.Ci_gran !=0)
        Y_start, Y_stop, _, _, C_start, C_stop, _, _, _, _, _, _ = \
            YXC_slice_mt(dims, Y_size, C_size, row, is_Y_split = is_Y_split)
        Yi_start = Y_start if dims.is_Y8_split else 0
        Yi_stop = Y_stop if dims.is_Y8_split else dims.Yis
        Xi_start = X_start
        Xi_stop = X_stop
        Ci_start = C_start if dims.is_Y8_split else 0
        Ci_stop = C_stop if dims.is_Y8_split else C_size
        return f"N:0:{dims.Nis} Y:{Yi_start}:{Yi_stop} X:{Xi_start}:{Xi_stop} C:{Ci_start}:{Ci_stop}"
    if dims.perm[3] == 0 and dims.N_innermost_Y8 == 1:
        fs = [fmt(yn, xn, cn) for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for xn in range(dims.X_loop) for yn in range(dims.Y_loop) for cn in range(dims.C_loop)]
    else:
        fs = [fmt(yn, xn, cn)  for _ in range(dims.B_loop) \
            for __ in range(1 if dims.N_dim_optimize else dims.N_loop) \
            for yn in range(dims.Y_loop) for xn in range(dims.X_loop) for cn in range(dims.C_loop)]
    return fs

def transpose_core_instrs(
    dims: TransposeKernelDims,
    inner_loop: int,
    ifm_ping_addr: int,
    ifm_pong_addr: int,
    ofm_ping_addr: int,
    ofm_pong_addr: int,
    core_ifm_size: int,
    core_ofm_size: int,
    core_wgt_addr: int,

    ifm_config: Optional[ConfigBuffer] = None,
    wgt_config: Optional[ConfigBuffer] = None,
    ofm_config: Optional[ConfigBuffer] = None,
) -> List[Type[CoreInstr]]:
    def inner_loop_transfer(n: int, subv_shape: list, permute: list):
        N = subv_shape[0]
        Y = subv_shape[1]
        X = subv_shape[2]
        C = subv_shape[3]

        return  Loop(n, [
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                CallKernel(run_kernel, setup_kernel_params(TransposeSubvDims(N, Y, X, C, permute, dims.transpose_bits, dims.scratch_buf, dims.is_int16, dims.perm[3] == 0))),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            ])

    run_kernel = 'run_transpose' if dims.is_int16_transpose else 'run_transpose_a8'
    # for core loop, each configration can support up to 1024, exceeding that, it need to re-config the core
    # even the configuration is same.

    X = inner_loop
    Tx = 1
    T_remain = 0
    while X > 1024:
        X = X //2
    Tx = inner_loop // X
    T_remain = inner_loop % X

    ifm_config = ConfigBuffer(
        DmaChannel(DmaDir.S2MM, 0),
        ifm_ping_addr, ifm_pong_addr, core_ifm_size
    )
    ofm_config = ConfigBuffer(
        DmaChannel(DmaDir.MM2S, 0),
        ofm_ping_addr, ofm_pong_addr, core_ofm_size
    )
    wgt_config = ConfigBuffer(
        DmaChannel(DmaDir.S2MM, 1),
        core_wgt_addr, None, dims.wgt_subv_size
    )
    if T_remain == 0:
        return  [
            wgt_config,
            AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
            Loop(Tx, [
                ifm_config,
                ofm_config,
                inner_loop_transfer(X, dims.subv_shape, dims.perm),
            ]),
        ]
    else:
        return  [
            wgt_config,
            AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
            Loop(Tx, [
                ifm_config,
                ofm_config,
                inner_loop_transfer(X, dims.subv_shape, dims.perm),
            ]),
            Loop(1, [
                ifm_config,
                ofm_config,
                inner_loop_transfer(T_remain, dims.subv_shape, dims.perm),
            ]),
        ]
def garbage_memory(dims) -> str:
    return f'Mem:{dims.ifm_memtile_size * 2}'

def garbage_tiling(dims) -> str:
    return f'Mem'



def compile_dataflow(
    dims: TransposeKernelDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):
    transpose_shim_alloc = shim_alloc()

    param_memtile_size = dims.aie_rows * config.MAX_CORE_LAYER_PARAM_SIZE
    # Nis = dims.Ni_gran
    ifm_memtile_size = (dims.Nis * dims.Yim * dims.Xim * dims.Cim * dims.ifm_bits) // 8
    ofm_memtile_size = (dims.Nom * dims.Yom * dims.Xom * dims.Com * dims.ofm_bits) // 8
    dims.ifm_memtile_size = ifm_memtile_size

    wgt_memtile_addr = 0
    wgt_memtile_addrs = [wgt_memtile_addr]
    param_memtile_addr = wgt_memtile_addr + dims.wgt_subv_size
    usable_memtile_size = config.MAX_MEMTILE_ADDR - param_memtile_size - dims.wgt_subv_size
    total_memtile_size =  (ifm_memtile_size * 2) + ofm_memtile_size

    if total_memtile_size < usable_memtile_size:
        ifm_memtile_ping_addr = param_memtile_addr + param_memtile_size
        ifm_memtile_pong_addr = ifm_memtile_ping_addr + ifm_memtile_size
        ifm_memtile_addrs = [ifm_memtile_ping_addr, ifm_memtile_pong_addr]
    else:
        ifm_memtile_ping_addr = param_memtile_addr + param_memtile_size
        ifm_memtile_addrs = [ifm_memtile_ping_addr]

    ofm_memtile_ping_addr = ifm_memtile_addrs[-1] + ifm_memtile_size
    ofm_memtile_addrs = [ofm_memtile_ping_addr]
    assert ofm_memtile_ping_addr + ofm_memtile_size < config.MAX_MEMTILE_ADDR

    sctrach_buf = dims.Nis * dims.Yis * dims.Xis * dims.Cis * dims.scratch_buf_bits //8 if dims.has_scratch_buf else 0
    CoreIfmSize = dims.Nis * dims.Yis * dims.Xis * dims.Cis * dims.ifm_bits //8
    CoreOfmSize = dims.Nom * dims.Yis * dims.Xis * dims.Cis * dims.ofm_bits //8
    usable_core_size = overlay_stack_addr() - dims.param_subv_size - dims.wgt_subv_size
    total_core_size =  iceil(sctrach_buf, 64 *dims.scratch_buf_bits) + (iceil(CoreIfmSize, 64* dims.ifm_bits) * 2) + iceil(CoreOfmSize, 64 *dims.ofm_bits)
    CoreWgtAddr = 0
    CoreScratchBufAddr = iceil(CoreWgtAddr + dims.wgt_subv_size, 64 * dims.scratch_buf_bits)
    CoreIfmPingAddr = iceil(CoreScratchBufAddr + sctrach_buf, 64 * dims.ifm_bits)

    if total_core_size < usable_core_size:
        CoreOfmPingAddr = iceil(CoreIfmPingAddr + CoreIfmSize, 64 * dims.ifm_bits)
        CoreIfmPongAddr = max(2 * 16384, iceil(CoreOfmPingAddr + CoreOfmSize, 64 * dims.ofm_bits))
        CoreOfmPongAddr = None
        if CoreIfmPingAddr + CoreIfmSize > overlay_stack_addr():
            CoreIfmPongAddr = iceil(CoreIfmPingAddr + CoreIfmSize, 64 * dims.ifm_bits)
            CoreOfmPingAddr = iceil(CoreIfmPongAddr + CoreIfmSize, 64 * dims.ofm_bits)
        assert CoreOfmPingAddr + CoreOfmSize <= overlay_stack_addr()
        assert dims.ping_pong == True
    else:
        CoreOfmPingAddr = max(2 * 16384, iceil(CoreIfmPingAddr + CoreIfmSize, 64 * dims.ofm_bits))
        CoreIfmPongAddr = None
        CoreOfmPongAddr = None
        if CoreOfmPingAddr + CoreOfmSize > overlay_stack_addr():
            CoreOfmPingAddr = iceil(CoreIfmPingAddr + CoreIfmSize, 64 * dims.ofm_bits)
            CoreIfmPongAddr = None
            CoreOfmPongAddr = None
        assert CoreOfmPingAddr + CoreOfmSize <= overlay_stack_addr()

    dims.num_loop =[1 * dims.repeat_scale * dims.batch_size] * dims.aie_cols

    dims.scratch_buf = CoreScratchBufAddr

    dims.loop_exclude_N = dims.Y_loop * dims.X_loop * dims.C_loop
    # dims.N_dim_optimize = 1 if dims.perm[0] == 0 else 0

    batch_to_repeat = True
    dims.batch_to_repeat = batch_to_repeat
    dims.B_loop = 1 if batch_to_repeat else dims.batch_size
    dims.batch_repeat = dims.batch_size if batch_to_repeat else 1

    for attr, value in vars(dims).items():
        print(f"{attr}: {value}")

    dims.is_kernel = False if dims.perm[3] == 3 and dims.qdq_mode ==3 else True

    core_instrs = {}
    if dims.is_kernel:
        for col in range(dims.aie_cols):
            num_loop_col = dims.num_loop[col] # this is the loop count
            for row in range(dims.aie_rows):
                core_instrs[AieTile(TileType.Core, col, row)] = transpose_core_instrs(
                    dims,
                    num_loop_col,
                    CoreIfmPingAddr,
                    CoreIfmPongAddr,
                    CoreOfmPingAddr,
                    CoreOfmPongAddr,
                    CoreIfmSize,
                    CoreOfmSize,
                    CoreWgtAddr,
                )
    else:
        core_instrs = []

    memtile_transfers = []
    memtile_param_transfers = [
        DataTransfer(
            [1] +  [0] * ((dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop -1),
            AieTile(TileType.Memtile, col),
            [param_memtile_addr],
            param_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
        ) for col in range(dims.aie_cols)
    ]
    if dims.enable_garbage:
        memtile_garbage_transfers = [
            DataTransfer(
                [1] +  [0] * ((dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop -1),
                AieTile(TileType.Memtile, col),
                [ofm_memtile_ping_addr],
                dims.ifm_memtile_size * 2,
                [generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, 0),
                    garbage_memory(dims),
                    garbage_tiling(dims),
                )],
                [],
            ) for col in range(dims.aie_cols)
        ]
        memtile_transfers += memtile_garbage_transfers
    memtile_transfers += memtile_param_transfers
    memtile_ifm_transfer = [
        DataTransfer(
            ifm_memtile_repeat_counts(dims, col, dims.batch_repeat),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims, col),
                ifm_memtile_s2mm(dims, col),
                [1] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
                dims.ifm_bits,
            )],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, row),
                ifm_memtile_memory(dims, col),
                ifm_memtile_mm2s(dims, col, row),
                [1] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
                dims.ifm_bits,
            ) for row in range(dims.aie_rows)
            ],
            sync_strategy=SyncStrategy.Parallel_1_to_N,
        ) for col in range(dims.aie_cols)
    ] if dims.is_kernel else []
    memtile_transfers += memtile_ifm_transfer

    memtile_wgt_transfers = [
        DataTransfer(
            [1] + [0] * ((dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop -1),
            AieTile(TileType.Memtile, col), wgt_memtile_addrs, dims.wgt_subv_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(),
            )],
        ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols == 8 else 1))
    ] if dims.is_kernel else []
    memtile_transfers += memtile_wgt_transfers

    memtile_ofm_transfer =[
        DataTransfer(
            ifm_memtile_repeat_counts(dims, col, dims.batch_repeat),
            AieTile(TileType.Memtile, col), ofm_memtile_addrs, ofm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                ofm_memtile_memory(dims, col),
                ofm_memtile_s2mm(dims, col, row),
                [1] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
                dims.ofm_bits,
            ) for row in range(dims.aie_rows)],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ofm_memtile_memory(dims, col),
                ofm_memtile_mm2s(dims, col),
                [1] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
                dims.ofm_bits,
            )],
            sync_strategy=SyncStrategy.Parallel_N_to_1,

        ) for col in range(dims.aie_cols)
    ] if dims.is_kernel else []
    memtile_transfers += memtile_ofm_transfer

    memtile_no_kernel_transfer = [
         DataTransfer(
            ifm_memtile_repeat_counts(dims, col, dims.batch_repeat),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims, col),
                ifm_memtile_s2mm(dims, col),
                [1] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
                dims.ifm_bits,
            )],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ifm_memtile_memory(dims, col),
                ofm_memtile_mm2s_no_kernel(dims, col),
                [1] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
                dims.ofm_bits,
            )],
            # sync_strategy=SyncStrategy.Parallel_1_to_N,
        ) for col in range(dims.aie_cols)
    ] if not dims.is_kernel else []

    memtile_transfers += memtile_no_kernel_transfer

    shim_transfers = []
    shim_param_transfers = [
        generate_shim_data_transfer(
            [1] +  [0] * ((dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop -1),
            shim_dma(col, DmaDir.MM2S, 0),
            transpose_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        ) for col in range(dims.aie_cols)
    ]
    if dims.enable_garbage:
        shim_garbage_transfers = [
            generate_shim_data_transfer(
                [1] +  [0] * ((dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop -1),
                shim_dma(col, DmaDir.MM2S, 0),
                transpose_shim_alloc.ifm_buffer_id,
                garbage_memory(dims),
                garbage_tiling(dims),
            ) for col in range(dims.aie_cols)
        ]
        shim_transfers += shim_garbage_transfers

    shim_transfers += shim_param_transfers
    shim_ifm_size =  dims.batch_size * dims.Nip * dims.Yip * dims.Xip * dims.Cip * dims.ifm_bits // 8
    shim_ifm_transfer = [
    DataTransfer(
        ifm_shim_repeat_counts(dims),
        AieTile(TileType.Shim, col), [transpose_shim_alloc.ifm_buffer_id], shim_ifm_size,
        [],
        [pack_transfers(
            shim_dma(col, DmaDir.MM2S, 0),
            [ifm_shim_memory(dims)] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
            ifm_shim_mm2s(dims, col),
            [1] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
            dims.ifm_bits,
        ) ] ,
        ) for col in range(dims.aie_cols)
    ]

    shim_transfers += shim_ifm_transfer

    shim_wgt_transfers = [generate_shim_data_transfer(
            [1] + [0] * ((dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop -1),
            shim_dma(col, DmaDir.MM2S, 1), shim_alloc().wgt_buffer_id,
            wgt_shim_memory(dims),
            wgt_shim_mm2s(),
        ) for col in range(0, dims.aie_cols, (2 if dims.aie_cols == 8 else 1))
    ] if dims.is_kernel else []
    shim_transfers += shim_wgt_transfers

    shim_ofm_size =  dims.batch_size * dims.Nop * dims.Yop * dims.Xop * dims.Cop * dims.ofm_bits // 8
    shim_ofm_transfer = [
    DataTransfer(
        ofm_shim_repeat_counts(dims, 1),
        AieTile(TileType.Shim, col), [transpose_shim_alloc.ofm_buffer_id], shim_ofm_size,
        [pack_transfers(
            shim_dma(col, DmaDir.S2MM, 0),
            [ofm_shim_memory(dims)] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
            ofm_shim_s2mm(dims, col),
            [1] * (dims.repeat_scale if not dims.N_dim_optimize else dims.loop_exclude_N) * dims.B_loop,
            dims.ofm_bits,
        ) ] ,
        [],
        ) for col in range(dims.aie_cols)
    ]
    shim_transfers += shim_ofm_transfer

    """NOTE keep below for local PM_size test"""
    # kernel_names = {'run_transpose':kernel_func_list.index('run_transpose')} if dims.is_int16_transpose else \
    #                {'run_transpose_a8':kernel_func_list.index('run_transpose')}

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections(),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
        enable_debug_print=True
    )


def main():
    back_end = BackEnd(int(sys.argv[1]))
    kernel_names = ['run_transpose', 'run_transpose_a8']
    kernel_includes = ['super.hh', 'transpose/wrapper_transpose.cc']
    aie_cols, aie_rows = 8, 4
    inputs = [
        [[1, 14, 14, 768], [0, 3, 1, 2]], # Nis, Yis, Xis, Cis = 1, 1, 4, 768
        [[1, 197, 12, 64], [0, 3, 1, 2]], # Nis, Yis, Xis, Cis = 1, 1, 4, 64
        [[1, 1, 768, 196], [0, 1, 3, 2]], # Nis, Yis, Xis, Cis = 1, 2, 1, 64
        [[1, 1, 768, 196//4], [0, 1, 3, 2]], # Nis, Yis, Xis, Cis = 1, 1, 192, 49
        [[1, 1, 768//4, 196], [0, 1, 3, 2]], # Nis, Yis, Xis, Cis = 1, 1, 48, 196
        [[3, 64, 1, 64], [0, 2, 3, 1]], # Nis, Yis, Xis, Cis = 1, 2, 1, 64
        [[16, 49, 3, 256], [2, 0, 1, 3]], # Nis, Yis, Xis, Cis = 1, 1, 1, 256
        [[1, 4096, 1, 64], [0, 2, 3, 1]], # Nis, Yis, Xis, Cis = 1, 8, 1, 64
        ]
    ifm_bits = 16
    for input, perm in inputs:
        print("input:", input)
        dims = TransposeKernelDims(
            aie_rows, aie_cols,
            input, perm,
            ifm_bits
        )
        clean_overlay()
        compile_dataflow(
            dims,
            back_end,
            kernel_names,
            kernel_includes
        )
        build_sim_overlay(back_end, 'transpose_main.cpp', transpose_preproc_directives(dims, back_end))

if __name__ == '__main__':
    main()
