import os
import sys
from typing import List, Union, Optional, Tuple

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))

from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)

from dataflow_common import ceildiv, iceil, calculate_row_split, overlay_stack_addr, ifloor
from transpose_common import TransposeKernelDims, padding
from dataflow_utils import CommonDims

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def C_iters_loop(dims: TransposeKernelDims, is_Y8_split: bool = True):
    """definition
        1. Nis, Yis, Xis, Cis -- the subv to the core
        2. Nim, Yim, Xim, Cim -- the shim ifm split for each dimension
           Nom, Yom, Xom, Com -- the ofm memtile memory
           1) Nim = Nis, Nom = Nis
           2) if is_Y8_split:
                Yim = ceildiv(dims.Yip, Y_loop * dims.aie_cols)
                Yom = iceil(Yim, dims.aie_rows) if perm[3] != 1 else
                      iceil(iceil(Yim, 32 // dims.ifm_bits), dims.aie_rows)
                Yis = iceil(Yim, dims.aie_rows)
                # --> if X being tranposed to inner-most.
                Xim = ceildiv(dims.Xip, X_loop)
                Xis = iceil(Xim, 32// dims.ifm_bits) if perm[3] == 2 else Xim
                Xom = Xis
                #because Ci already max(64, W8)
                Cis = ceildiv(Cip, C_loop)
                Cim = Cis
                Com = Cis
            3) if is_X8_split:
                Yim = ceildiv(dims.Yip, Y_loop)
                Yis = iceil(Yim, 32// dims.ifm_bits) if perm[3] == 1 else Yim
                Yom = Yis
                Xim = ceildiv(dims.Xip, X_loop * dims.aie_cols)
                Xis = ceildiv(Xim, dims.aie_rows) if perm[3] != 2 else
                      iceil(ceildiv(Xim, 32 // dims.ifm_bits), dims.aie_rows)
                Xom = Xis * dims.aie_rows
                Cis = ceildiv(Cip, C_loop)
                Cim = Cis
                Com = Cis
    """
    """how to do pad and depad concurrently while not have to remember which iteration
        1. for Ci -> Co depad:
            if ceildiv(Cip, Cim) == ceildiv(Cop, Cim):
               the C_loop will gurantee for Ci and Co
               the depadding should within 64
               when last Cop split is 0, handle specially
        2. for Yi/Xi -> Yo/Xo padding
             if ceildiv(Yip, Yim) == ceildiv(Yop, Yim):  # same for X
               the Y_loop will gurantee same for Yi and Yo
               the padding should within 64
               when last Yop split is 0, handle specially

    """
    def verify_padding(total_size: int,
                       split: int,
                       loop_size: int,
                       padding_limit: int = 16):
        reminder = split - (total_size - split * (loop_size -1) * dims.aie_cols) % split
        if reminder >= padding_limit:
            return False
        else:
            return True
    # NOTEs --- exceptions
    # 1. padding, D0:64, D1:32, D2:16
    # 2. memtile_step size = 2**20 words
    #    1) possible three dims
    #    2) possible two dims.
    #    3) but still there is chance the padding will cause invlid wrap issue
    subv_splits = []
    if is_Y8_split and dims.perm[3] == 1 and dims.Yip % 8 != 0:
        return []
    if not is_Y8_split and dims.perm[3] == 2 and dims.Xip % 8 != 0:
        return []
    # N_dim_optimize = 1 if dims.perm[0] == 0 or dims.enable_batch else 0
    Nis_grid = [n for n in range(1, dims.Nip + 1) if dims.Nip % n == 0] if dims.N_dim_optimize else [dims.Ni_gran]
    shim_ofm_stepsize_limit_3dim = ((dims.output[1] * dims.output[2] * dims.output[3]) * dims.ofm_bits // 32) >=  2**20
    shim_ofm_stepsize_limit_2dim = ((dims.output[2] * dims.output[3]) * dims.ofm_bits // 32) >=  2**20
    shim_ifm_stepsize_limit_2dim = ((dims.Xip * dims.Cip) * dims.ifm_bits // 32) >=  2**20
    if dims.perm[3] == 1 and shim_ifm_stepsize_limit_2dim:
        assert False, f"Dim-H being transfered to inner-most," + \
                      f" the gran has to be even, but Dim-W * Dim-C {shim_ifm_stepsize_limit_2dim} exceeds shim step-size {2**20}"

    memtile_usable_size = config.MAX_MEMTILE_ADDR - (dims.aie_rows * config.MAX_CORE_LAYER_PARAM_SIZE)  - (dims.aie_rows * dims.wgt_subv_size)
    core_usable_size = overlay_stack_addr() - dims.wgt_subv_size - dims.param_subv_size
    C_loop_grid = list(range(1, ceildiv(dims.Cip, 32 // dims.ifm_bits) + 1))
    if dims.Cip %(32 // dims.ifm_bits) != 0:
        C_loop_grid = [1]
        assert dims.Xip % (32 // dims.ifm_bits) == 0, "if C is odd, the W has to be even"
    #NOTE: not split or split with 64
    if dims.C_dePad:
        C_loop_grid = [1, ceildiv(dims.Cip, dims.padding_C)]
    if is_Y8_split:
        outer_loop_grid = [1] + list(range(2, ceildiv(dims.Xip, dims.padding_C) + 1, 2))
        if dims.perm[0] == 2 and dims.Xop != 1 and shim_ofm_stepsize_limit_3dim:
            outer_loop_grid = [dims.Xip]
        if dims.perm[1] == 2 and dims.Xop != 1 and shim_ofm_stepsize_limit_2dim:
            outer_loop_grid = [dims.Xip]
        inner_loop_grid = list(range(1, ceildiv(dims.Yip, 32 // dims.ifm_bits) + 1))
        if dims.perm[0] == 1 and dims.Yop != 1 and shim_ofm_stepsize_limit_3dim:
            inner_loop_grid = [ceildiv(dims.Yip, dims.aie_cols)]
        if dims.perm[1] == 1 and dims.Yop != 1 and shim_ifm_stepsize_limit_2dim:
            inner_loop_grid = [ceildiv(dims.Yip, dims.aie_cols)]
        if dims.Yip != 1 and shim_ifm_stepsize_limit_2dim:
            inner_loop_grid = [ceildiv(dims.Yip, dims.aie_cols)]
    else: #X8 split -outer_loop is Y
        outer_loop_grid = list(range(1, ceildiv(dims.Yip, dims.padding_C) + 1))
        if dims.perm[0] == 1 and dims.Yop != 1 and shim_ofm_stepsize_limit_3dim:
            outer_loop_grid = [dims.Yip]
        if dims.perm[1] == 1 and dims.Yop != 1 and shim_ofm_stepsize_limit_2dim:
            outer_loop_grid = [dims.Yip]
        if dims.Yip != 1 and shim_ifm_stepsize_limit_2dim:
            outer_loop_grid = [dims.Yip]
        inner_loop_grid = list(range(1, ceildiv(dims.Xip, 32 // dims.ifm_bits) + 1))
        if dims.perm[0] == 2 and dims.Xop != 1 and shim_ofm_stepsize_limit_3dim:
            inner_loop_grid = [ceildiv(dims.Xip, dims.aie_cols)]

    Yom_p = None
    Xom_p = None
    for Nis in Nis_grid:
        dims.Nim = Nis
        dims.Nom = Nis if dims.perm[3] != 0 else dims.Nop #NOTE: we don't do N split in perm[3] == 0 mode
        dims.Nis = Nis if dims.perm[3] != 0 else dims.Nip  #NOTE: we don't do N split in perm[3] == 0 mode
        N_loop = ceildiv(dims.Nip, Nis)
        for outer_loop in outer_loop_grid:
            for inner_loop in inner_loop_grid:
                if is_Y8_split:
                    dims.Xim = ceildiv(dims.Xip, outer_loop)
                    dims.Xim = iceil(dims.Xim, 32 // dims.ifm_bits) if dims.perm[3] == 2 else dims.Xim
                    # here need to consider both ofm and ifm be 32bits alignment
                    dims.Xim = iceil(dims.Xim, 32 // dims.ofm_bits) if dims.perm[3] == 2 else dims.Xim
                    if (dims.Xip > 1024) and (dims.Xim > dims.Xip) and (dims.Xip % (32 // dims.ifm_bits ) != 0):
                        continue
                    outer_loop = ceildiv(dims.Xip, dims.Xim)
                    dims.Xis = dims.Xim
                    dims.Xom = dims.Xis
                    #NOTE: the padding might happend here, the memtile ofm real transfer length will
                    #      count the padded as well.
                    if dims.perm[3] == 2:
                        Xom_p = padding(dims.Xom)
                    else:
                        Xom_p = dims.Xom
                    if dims.Xim - (dims.Xip % dims.Xim) >= 32 and (dims.Xip % dims.Xim) != 0: # 32 is D1 dimension padding capability
                        continue
                    if dims.perm[3] == 2:
                        if dims.Xop % dims.Xim > 64: # 64 is D1 dimension padding capability
                            continue
                    Yi_size = ceildiv(dims.Yip, inner_loop * dims.aie_cols)
                    dims.Yim = iceil(Yi_size, 32 //dims.ifm_bits) if dims.perm[3] ==1 else Yi_size
                    if dims.perm[3] == 1 or dims.perm[3] == 0:
                        Yom = iceil(Yi_size, 32 // dims.ofm_bits)
                        Yom = iceil(Yom, dims.aie_rows)
                    else:
                        Yom = iceil(Yi_size, dims.aie_rows)
                    dims.Yom = Yom
                    dims.Yis = iceil(Yi_size, dims.aie_rows)
                    Yom_p = Yom
                else:
                    dims.Yim = ceildiv(dims.Yip, outer_loop)
                    dims.Yim = iceil(dims.Yim, 32 // dims.ifm_bits) if dims.perm[3] == 1 else dims.Yim
                    # here need to consider both ofm and ifm be 32bits alignment
                    dims.Yim = iceil(dims.Yim, 32 // dims.ofm_bits) if dims.perm[3] == 1 else dims.Yim
                    outer_loop = ceildiv(dims.Yip, dims.Yim)
                    dims.Yis = dims.Yim
                    dims.Yom = dims.Yis
                    #NOTE: the padding might happend here, the memtile ofm real transfer length will
                    #      count the padded as well.
                    if dims.perm[3] == 1:
                        Yom_p = padding(dims.Yom)
                    else:
                        Yom_p = dims.Yom
                    if dims.Yim - (dims.Yip % dims.Yim) >= 16 and (dims.Yip % dims.Yim) != 0: # 16 is D2 dimension padding capability
                        continue
                    if dims.perm[3] == 1:
                        if dims.Yop % dims.Yim > 64:
                            continue
                    dims.Xim = ceildiv(dims.Xip, inner_loop * dims.aie_cols)
                    if dims.perm[3] == 2:
                        X_split = ceildiv(dims.Xim, dims.aie_rows)
                        dims.Xis = iceil(X_split, 32 // dims.ifm_bits)
                    else:
                        X_split = ceildiv(dims.Xim, dims.aie_rows)
                        dims.Xis = X_split
                    dims.Xom = dims.Xis * dims.aie_rows
                    Xom_p = dims.Xom
                for C_loop in C_loop_grid:
                    if dims.C_dePad:
                        if C_loop == 1:
                            C_size = ceildiv(dims.Cip, C_loop)
                        else:
                            C_size = dims.padding_C
                    else:
                        C_size = ceildiv(dims.Cip, C_loop)
                    C_size = iceil(C_size, dims.Ci_gran)

                    if is_Y8_split:
                        if dims.Yis > C_size: # Y split to core
                            Y_split = ceildiv(dims.Yis, dims.aie_rows)
                            Y_split = iceil(Y_split, 32 //dims.ifm_bits) if dims.perm[3] == 1 else Y_split
                            C_split = C_size
                            # Cis has to consider the A8 Ofm.
                            # C_split = iceil(C_split, 32 //dims.ofm_bits) if dims.perm[3] == 3 else C_split
                            if C_split % (32 //dims.ofm_bits) != 0 and dims.perm[3] ==3:
                                continue
                        else: # C split to core
                            Y_split = dims.Yis
                            C_split = ceildiv(C_size, dims.aie_rows)
                            C_split = iceil(C_split, dims.Ci_gran)
                            if C_split % (32 //dims.ofm_bits) != 0 and dims.perm[3] ==3:
                                continue
                            Yis_temp = ceildiv(dims.Yip, dims.aie_cols * inner_loop)
                            if dims.perm[3] == 1:
                                Yis_temp = iceil(Yis_temp, 32//dims.ifm_bits)
                            Y_size_padding_valid = verify_padding(dims.Yip, Yis_temp, inner_loop, padding_limit = 16)
                            if not Y_size_padding_valid:
                                continue
                    else: #X8 split and #X4->core split
                        if C_size - dims.Cip % C_size >= 64 and dims.Cip % C_size != 0:
                            continue


                    if is_Y8_split or (dims.perm[3] == 0 and dims.N_innermost_Y8 == 1):
                        Cim = 64 if (C_loop != 1 and dims.N_innermost_Y8 == 1) else ceildiv(dims.Cip, C_loop)
                        if C_loop > 1:
                            Cim = iceil(Cim, dims.Ci_gran)
                        dims.Cim = iceil(Cim, dims.Ci_gran) if (Cim * dims.Xim) % (dims.Ci_gran) != 0 \
                            else Cim
                        #recaculate the accurate subv
                        Ci_size = Cim
                        Yi_size = ceildiv(dims.Yip, inner_loop * dims.aie_cols)
                        is_Y_split = (dims.Yim >= dims.Cim) or (Cim % (dims.Ci_gran) !=0)
                        _, _, _ ,_, _, _, dims.Yom, _, Com, dims.Yis, _, dims.Cis = \
                            YXC_slice_mt(dims, Yi_size, Ci_size, 0, is_Y_split = is_Y_split)
                        # dims.subv_shape = [dims.Nis, dims.Yis, dims.Xis, dims.Cis]
                        repeat_scale = N_loop * C_loop * outer_loop * inner_loop
                        dims.Com = Com
                        dims.Xom = dims.Xis
                    else:
                        Cis = ceildiv(dims.Cip, C_loop)
                        dims.Cis = min(dims.Cip, iceil(Cis, dims.Ci_gran))
                        dims.Cim = dims.Cis
                        dims.Com = dims.Cis
                        repeat_scale = N_loop * outer_loop * inner_loop * C_loop
                    #caculate the L2 size
                    ifm_memtile_size = (dims.Nim * dims.Yim * dims.Xim * C_size * dims.ifm_bits) // 8
                    ofm_memtile_size = (dims.Nom * Yom_p * Xom_p * C_size * dims.ofm_bits) // 8
                    total_memtile_size = (
                        ifm_memtile_size * 2 +
                        ofm_memtile_size
                    )
                    # caculate the L1 size
                    ifm_core_size = (dims.Nis * dims.Yis * dims.Xis * dims.Cis * dims.ifm_bits) // 8
                    ofm_core_size = (dims.Nom * dims.Yis * dims.Xis * dims.Cis * dims.ofm_bits) // 8
                    scratch_buf = (dims.Nis * dims.Yis * dims.Xis * dims.Cis * dims.scratch_buf_bits) // 8 if dims.has_scratch_buf else 0
                    total_core_size_ping = (
                        dims.wgt_subv_size +
                        iceil(ifm_core_size, 64 * dims.ifm_bits) * 1 +
                        iceil(ofm_core_size, 64 * dims.ofm_bits) +
                        iceil(scratch_buf, 64*dims.scratch_buf_bits)
                    )
                    total_core_size_pingpong = (
                        dims.wgt_subv_size +
                        iceil(ifm_core_size, 64 * dims.ifm_bits) * 2 +
                        iceil(ofm_core_size, 64 * dims.ofm_bits) +
                        iceil(scratch_buf, 64*dims.scratch_buf_bits)
                    )
                    ping_pong = (total_core_size_pingpong < core_usable_size)
                    """N_dim_optimize valid"""
                    #
                    lower_loop = C_loop * (outer_loop if is_Y8_split else inner_loop)
                    if total_memtile_size <= memtile_usable_size and total_core_size_ping < core_usable_size:

                        over_compute_ratio = (repeat_scale * dims.aie_cols * dims.aie_rows * \
                                            dims.Nis * dims.Yis * dims.Xis * dims.Cis) /  \
                                            (dims.Nip * dims.Yip * dims.Xip * dims.Cip)
                        # return C_loop, inner_loop, outer_loop
                        subv_splits.append((
                            (dims.Nim, dims.Nom, dims.Nis),
                            (dims.Yim, dims.Yom, dims.Yis),
                            (dims.Xim, dims.Xom, dims.Xis),
                            (dims.Cim, dims.Com, dims.Cis),
                            (N_loop, C_loop, inner_loop, outer_loop),
                            is_Y8_split,
                            repeat_scale,
                            lower_loop,
                            over_compute_ratio,
                            ping_pong
                        ))
        # assert False, "can't find valid Y iteration!"
    return subv_splits

def YXC_slice_mt(
    dims:TransposeKernelDims,
    Y_size: int,
    C_size: int,
    row:int,
    s2mm = False,
    is_Y_split = True,
    )  -> Tuple[int, int, int, int, int, int]:
    def Y_slice(dims: TransposeKernelDims,
                 Y_size: int, row: int, s2mm = False, is_Y = True) -> Tuple[int, int]:
        # Yis can be Y or C dim
        if is_Y:
            Yis = ceildiv(dims.Yim, dims.aie_rows)
            Yis = iceil(Yis, 32 //dims.ifm_bits) if dims.perm[3] == 1 else Yis
        else:
            Yis = ceildiv(dims.Cim, dims.aie_rows)
        Y_split = Yis
        if not is_Y:
            Y_split = iceil(Y_split, dims.Ci_gran)
            Yis = Y_split
        Y_start = row * Y_split
        if s2mm:
            Y_stop = Y_start + Y_split
        else:
            if Y_start >= Y_size:
                # Y_start = Y_size - Y_split
                Y_start = 0
                Y_stop = dims.Yis if is_Y else dims.Cis
            else:
                Y_start = Y_start
                Y_stop  = Y_start + Y_split
        Yo_size = Y_split * dims.aie_rows
        return(Y_start, Y_stop, Yo_size, Yis)
    if Y_size <= 0:
        Y_size = dims.Yi_gran
    if C_size <=0:
        C_size = dims.Ci_gran
    if is_Y_split:
        Y_start, Y_stop, Yo_size, Yis = Y_slice(dims, Y_size, row, s2mm)
        C_start = 0
        C_split = dims.Cim
        Co_size = C_split
        C_stop  = C_split
        Cis     = C_split
    else:
        Y_split = dims.Yim
        Y_start = 0
        Y_stop  = Y_split
        Yo_size = Y_split
        C_start, C_stop, Co_size, Cis = Y_slice(dims, C_size, row, s2mm, is_Y=False)
        Yis = Yo_size
    if dims.perm[3] == 2:
        X_start = 0
        X_stop = iceil(dims.Xip, 32 // dims.ifm_bits)
    else:# Y will be transposed to inner most and Y meet max(64, W8)
        X_start = 0
        X_stop  = dims.Xip
    Xo_size = X_stop
    Xis = Xo_size

    return(Y_start, Y_stop, X_start, X_stop, C_start, C_stop,
           Yo_size, Xo_size, Co_size, Yis, Xis, Cis)

def sort_subv_cost(subv_splits: List, perm: List, enable_batch: bool = False) -> List:
        def split_key(split):
            _, _, _, _, loop, is_Y8_split, total_loop, lower_loop, over_compute, ping_pong = split
            if is_Y8_split:
                N_loop, C_loop, Y_loop, X_loop = loop
            else:
                N_loop, C_loop, X_loop, Y_loop = loop
            loop_ordered = [N_loop, Y_loop, X_loop, C_loop]
            if enable_batch:
                shim_ofm_loop = loop_ordered[perm[2]] == 1 and loop_ordered[perm[3]] == 1
            else:
                shim_ofm_loop = True
            return (not shim_ofm_loop, lower_loop, total_loop, over_compute, ping_pong) # Y8 split, 50% N:1024 X:4 C:96 Y:16 no pingong loop: 16
            # return (not ping_pong, over_compute, total_loop) # loop:24; pingpong, Y8 split
        sorted_subv_splits = sorted(subv_splits, key=split_key)
        return sorted_subv_splits

def subv_split(dims:TransposeKernelDims):
    dims.Nis = dims.Ni_gran
 # ===== compute granularity =====
    yo_gran = 1 if dims.perm[3] != 1 else 32 // dims.ofm_bits
    xo_gran = 1 if dims.perm[3] != 2 else 32 // dims.ofm_bits

    # ===== compute split scores =====
    EPS = 1e-2
    y8_split_score = dims.Yip / (dims.aie_cols * yo_gran)
    x8_split_score = dims.Xip / (dims.aie_cols * xo_gran)

    # ===== check exception =====
    y8_exception = dims.perm[3] == 1 and (dims.Yip % 8 != 0)
    x8_exception = dims.perm[3] == 2 and (dims.Xip % 8 != 0)
    use_Y8_split = False
    use_X8_split = False

    # ===== priority logic =====
    if x8_exception:
        use_Y8_split = True
    elif y8_exception:
        use_X8_split = True
    elif abs(y8_split_score - x8_split_score) <= EPS and not y8_exception and not x8_exception:
        use_Y8_split = True
        use_X8_split = True
    elif (y8_split_score - x8_split_score) >= EPS and not y8_exception:
        use_Y8_split = True
    elif (y8_split_score - x8_split_score) < EPS and not x8_exception:
        use_X8_split = True

    # ===== run C_iters_loop =====
    subv_split_Y8 = []
    subv_split_X8 = []
    subv_total_splits = []

    if use_Y8_split:
        subv_split_Y8 = C_iters_loop(dims, True)  # Y8
    if use_X8_split:
        subv_split_X8 = C_iters_loop(dims, False)  # X8

    subv_total_splits.extend(subv_split_Y8 or [])
    subv_total_splits.extend(subv_split_X8 or [])
    sorted_subv = sort_subv_cost(subv_total_splits, dims.perm, dims.enable_batch)

    ((dims.Nim, dims.Nom, dims.Nis),
    (dims.Yim, dims.Yom, dims.Yis),
    (dims.Xim, dims.Xom, dims.Xis),
    (dims.Cim, dims.Com, dims.Cis),
    (N_loop, C_loop, inner_loop, outer_loop),
    dims.is_Y8_split,
    dims.repeat_scale,
    lower_loop,
    over_compute,
    dims.ping_pong,
    ) = sorted_subv[0] #
    if dims.is_Y8_split:
        dims.C_loop, dims.Y_loop, dims.X_loop = C_loop, inner_loop, outer_loop
    else:
        dims.C_loop, dims.X_loop, dims.Y_loop = C_loop, inner_loop, outer_loop
    dims.N_loop = N_loop
    dims.subv_shape = [dims.Nis, dims.Yis, dims.Xis, dims.Cis]

    return dims

def run_tiler(
    aie_cols: int,
    aie_rows: int,
    # ifm_bits: int,
    # ofm_bits: int,
    is_int16: bool,
    is_signed: bool,
    perm: list,
    enable_batch: bool,
    batch_size: int,
    Ni: int,
    Yi: int,
    Xi: int,
    Ci: int,
    No: None,
    Yo: None,
    Xo: None,
    Co: None,
    # is_qdq: bool = True,
    qdq_mode: int = 2,
) -> TransposeKernelDims:

    has_scratch_buf = False
    scratch_buf_bits = 8
    ifm_bits = 16
    ofm_bits = 16
    qdq_mode = int(qdq_mode)
    if is_int16:
        ifm_bits = 16
        ofm_bits = ifm_bits
        has_scratch_buf = False
        transpose_bits = 16
        is_int16_transpose = True
    else: # int8
        is_int16_transpose = False  # for now, we don't sequence the int8 tranpose kernel and qdq order
        if qdq_mode == 0:  #dq only
            #NOTE: sequence:
            # 1. first do transpose (8bits in) -> 8bits output buff (2nd half);
            # 2. then do dq, from 8bits output buf 2nd half to 16bits out buf
            # sctrach buf elem:  0
            ifm_bits = 8
            ofm_bits = 16
            has_scratch_buf = False
            scratch_buf_bits = 8
            transpose_bits = 8
        elif qdq_mode == 1: #q only
            #NOTE: sequence:
            # 1. first do q (16bits input buf) -> 8bits to same buf;
            # 2. then do transpose, from 8bits input buf to 8bits out buf
            # sctrach buf elem:  0
            ifm_bits = 16
            ofm_bits = 8
            has_scratch_buf = False # q output use ifm buffer
            scratch_buf_bits = 8
            transpose_bits = 8
        elif qdq_mode == 2:
            #NOTE: sequence:
            # 1. first do dq (8bits input buf) -> 16bits to scrath buf;
            # 2. second do q (16bits scratch buf) -> 8bits to scratch buf;
            # 3. then do transpose, from 8bits scratch buf to 8bits out buf
            # sctrach buf elem:  same as ifm
            ifm_bits = 8
            ofm_bits = 8
            has_scratch_buf = True
            scratch_buf_bits = 16
            transpose_bits = 8
        elif qdq_mode == 3:
            #NOTE: sequence:
            # 1. do transpose from 8bits input buf to 8bits output buf
            # sctrach buf elem:  0
            ifm_bits = 8
            ofm_bits = 8
            has_scratch_buf = False
            scratch_buf_bits = 8
            transpose_bits = 8
        else:
            assert False, f"qdq_mode:{qdq_mode} is not in range(0..3) !"

    localDims = TransposeKernelDims
    localDims.input = [Ni, Yi, Xi, Ci]
    localDims.aie_cols = aie_cols
    localDims.aie_rows = aie_rows
    localDims.ifm_bits = ifm_bits
    localDims.ofm_bits = ofm_bits
    localDims.is_int16 = is_int16
    localDims.is_signed = is_signed
    localDims.has_scratch_buf = has_scratch_buf
    localDims.scratch_buf_bits = scratch_buf_bits
    localDims.transpose_bits = transpose_bits
    localDims.is_int16_transpose = is_int16_transpose # keep this for performance optimization
    localDims.enable_batch = enable_batch
    localDims.batch_size = batch_size
    localDims.Ni = Ni
    localDims.Yi = Yi
    localDims.Xi = Xi
    localDims.Ci = Ci
    localDims.No = No
    localDims.Yo = Yo
    localDims.Xo = Xo
    localDims.Co = Co
    localDims.perm = perm
    localDims.inner_most_gran = 8
    localDims.inner_most_min = 8
    localDims.padding_C = 8
    localDims.qdq_mode = qdq_mode
    CoreqdqPrmSize = 64
    localDims.wgt_subv_size = CoreqdqPrmSize

    """the definitions
    1. Ni/Yi/Xi/Ci  - the input from graph, there is no padded information
    2. No/Yo/Xo/Co  - the output from graph, there is no padding information
    3. when Ci -> Co:
        1) if Co is transposed to outer dimension
            Cip = max(64, iceil(Ci, 8))
            Cop = Co,  if Cip != Co , depadding
        2) if Co stays in the inner-most dimension
            Cip = max(64, iceil(Ci, 8))
            Cop = Cip
    4. when Yi/Xi -> Yo/Xo: ( "/" is or, not divide)
        1) if Yo/Xo is transposed to inner-most dimension
            Yip/Xip = Yi/Xi
            Yop/Xop = max(64, iceil(Yi/Xi, 8)) --> padding.
        2) if Yo/Xo stays in the outer dimension
            Yip/Xip = Yi/Xi
            Yop/Xop = Yi/Xi
    4. when Ni -> No:
        1) if No is transposed to inner-most dimension
            Y8 or X8 dataflow( Kernel dataflow) does NOT support
            the Kernel Bypass dataflow can support
        2) if No stays in the outer dimension
            Nip = Ni
            Nop = Ni
    """
    """Granurity for each Dimension:
        1. input
            1) because we always assume the input are max(64, W8), there is no limitation for the granurity
        2. output
            1) because the padding will happen in the ifm.mm2s + kernel + ofm.mm2s, there is no limitation for the granurity
            2) to check...
    """


    if localDims.Ci != padding(Ci) and perm[3] != 3:
        localDims.C_dePad = True
        localDims.Cop = Ci
        localDims.Cip = padding(Ci)
    else:
        localDims.C_dePad = False
        localDims.Cip = padding(Ci)
        localDims.Cop = localDims.Cip
    if perm[3] == 2:
        localDims.Xop = padding(Xi)
        localDims.Xip = Xi
    else:
        localDims.Xop = Xi
        localDims.Xip = Xi
    if perm[3] == 1:
        localDims.Yop = padding(Yi)
        localDims.Yip = Yi
    else:
        localDims.Yop = Yi
        localDims.Yip = Yi

    if perm[3] == 0:
        localDims.Nop = padding(Ni)
        localDims.Nip = Ni
    else:
        localDims.Nop = Ni
        localDims.Nip = Ni

    localDims.param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
    localDims.inner_most_dim = perm[3]
    localDims.N_innermost_Y8 = False
    if localDims.inner_most_dim == 0:
        if perm[2] == 1:
            if Ni > 1:
                localDims.is_Y8_split = False
            else:
                localDims.is_Y8_split = True
        else:
            localDims.is_Y8_split = False
        localDims.num_padding = 0
        localDims.N_innermost_Y8 = localDims.is_Y8_split
    else:
        if localDims.inner_most_dim == 1:
            localDims.is_Y8_split = False
            localDims.num_padding = localDims.Yop -localDims.Yip
        else:
            localDims.is_Y8_split = True
            localDims.num_padding = localDims.Xop -localDims.Xip
    localDims.Ni_gran = 1 if perm[3] != 0 else localDims.Nip
    localDims.N_loop = 1 if perm[3] == 0 else localDims.Nip
    localDims.Yi_gran = 1
    localDims.Xi_gran = 1
    localDims.Ci_gran = max(ceildiv(32, ifm_bits), ceildiv(32, ofm_bits))

    output = [localDims.Nop, localDims.Yop, localDims.Xop, localDims.Cop]
    localDims.output = [output[i] for i in perm]
    if localDims.Yo != localDims.Yop or localDims.Xi != localDims.Xop:
        localDims.padding = [0, 0, 0, "zp"]
    else:
        localDims.padding = [0, 0, 0, 0]
    localDims.enable_garbage = False

    # tranposeDims = common + local
    localDims.N_dim_optimize = 1 if perm[0] == 0 or enable_batch else 0


    localDims = subv_split(localDims)

    return localDims