import sys
import os
import logging
import pdb
from OGOAT.src.Scheduling_Engine.infra import const
import OGOAT.src.Scheduling_Engine.infra.scheduler_utils as utils
from OGOAT.src.Scheduling_Engine.code_gen.template_base import BaseTemplate, BaseDims
import copy
import dataclasses
import math
import re
import numpy as np
import itertools

"""
TODO:
1. code readability
2. Transpose on ifm and ofm
3. row wise and column wise batch matmul
4. subarray support
"""

SHIM_REPEAT_MAX = 64
MEM_REPEAT_MAX_UNICAST = 1024
MEM_REPEAT_MAX_BROADCAST = 1024
MEM_REPEAT_MAX = 1024

"""
    Physical AIE Core Grid (col, row) -> core_id:      
    col:    0   1   2   3   4   5   6   7              
    row ┌───┬───┬───┬───┬───┬───┬───┬───┐              
     0  │ 0 │ 4 │ 8 │12 │16 │20 │24 │28 │              
        ├───┼───┼───┼───┼───┼───┼───┼───┤              
     1  │ 1 │ 5 │ 9 │13 │17 │21 │25 │29 │              
        ├───┼───┼───┼───┼───┼───┼───┼───┤              
     2  │ 2 │ 6 │10 │14 │18 │22 │26 │30 │              
        ├───┼───┼───┼───┼───┼───┼───┼───┤              
     3  │ 3 │ 7 │11 │15 │19 │23 │27 │31 │              
        └───┴───┴───┴───┴───┴───┴───┴───┘              
=======================================================
              M_split=1, N_split=32                     
    Logical Matrix Split (M_idx, N_idx):               
    N_idx:  0   1   2   3   4   5   6   7----::----31  
    M_idx ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐
       0  │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │:: │:: │31 │
          └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘
----------------------------------------------------   
              M_split=4, N_split=8:                    
    Logical Matrix Split (M_idx, N_idx):               
    N_idx:  0   1   2   3   4   5   6   7              
    M_idx ┌───┬───┬───┬───┬───┬───┬───┬───┐            
       0  │ 0 │ 4 │ 8 │12 │16 │20 │24 │28 │            
          ├───┼───┼───┼───┼───┼───┼───┼───┤            
       1  │ 1 │ 5 │ 9 │13 │17 │21 │25 │29 │            
          ├───┼───┼───┼───┼───┼───┼───┼───┤            
       2  │ 2 │ 6 │10 │14 │18 │22 │26 │30 │            
          ├───┼───┼───┼───┼───┼───┼───┼───┤            
       3  │ 3 │ 7 │11 │15 │19 │23 │27 │31 │            
          └───┴───┴───┴───┴───┴───┴───┴───┘            
----------------------------------------------------   
              M_split=8, N_split=4:                    
    Logical Matrix Split (M_idx, N_idx):               
    N_idx:  0   1   2   3                              
    M_idx ┌───┬───┬───┬───┐                            
       0  │ 0 │ 1 │ 2 │ 3 │                            
          ├───┼───┼───┼───┤                            
       1  │ 4 │ 5 │ 6 │ 7 │                            
          ├───┼───┼───┼───┤                            
       2  │ 8 │ 9 │10 │11 │                            
          ├───┼───┼───┼───┤                            
       3  │12 │13 │14 │15 │                            
          ├───┼───┼───┼───┤                            
       4  │16 │17 │18 │19 │                            
          ├───┼───┼───┼───┤                            
       5  │20 │21 │22 │23 │                            
          ├───┼───┼───┼───┤                            
       6  │24 │25 │26 │27 │                            
          ├───┼───┼───┼───┤                            
       7  │28 │29 │30 │31 │                            
          └───┴───┴───┴───┘                            
----------------------------------------------------   
                  M_split=32, N_split=1:
    Logical Matrix Split (M_idx, N_idx):
    N_idx:  0  
    M_idx ┌───┐
       0  │ 0 │
          ├───┤
       1  │ 1 │
          ├───┤
       2  │ 2 │
          ├───┤
       3  │ 3 │
          ├───┤
       4  │ 4 │
          ├───┤
       5  │ 5 │
          ├───┤
       6  │ 6 │
          ├───┤
       7  │ 7 │
          ├───┤
       :: │:: │
          ├───┤
       30 │30 │
          ├───┤
       31 │31 │
          └───┘

debug help:
M1N32 col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==4
M4N8  col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==4
M8N4  col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==8
M32N1 col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==8          
"""

def column_to_split(dims, col):
    #return sets of B, M, N index for the selected column
    if dims.split == [1, 1, 1, 32]:
        return [[0, 0, x+col*dims.aie_rows] for x in range(dims.aie_rows)]
    elif dims.split == [1, 4, 1, 8]:
        return [[0, x, col] for x in range(dims.aie_rows)]
    elif dims.split == [1, 8, 1, 4]:
        return [[0, col, x] for x in range(dims.aie_rows)]
    elif dims.split == [1, 32, 1, 1]:
        return [[0, x+col*dims.aie_rows, 0] for x in range(dims.aie_rows)]
    elif dims.split == [4, 8, 1, 1]:
        return [[x, col, 0] for x in range(dims.aie_rows)]
    elif dims.split == [32, 1, 1, 1]:
        return [[x+col*dims.aie_rows, 0, 0] for x in range(dims.aie_rows)]
    else:
        raise ValueError(f"Unsupported split mode: M{dims.M_split}N{dims.N_split}")


def get_seq_dimensions(Tb, Tm, Tk, Tn, seq, itr_idx):
    def compute_width(seq, idx, key, k_width=0, n_width=0):
        dim_list = ['b', 'm', 'n', 'k']
        dim_idx = dim_list.index(key)
        current = seq[idx][dim_idx]
        next = seq[idx + 1][dim_idx] if idx < len(seq) - 1 else None
        dim  = [Tb, Tm, Tn, Tk][dim_idx]
        if next is None:          #last iteration
            return dim - current
        elif next < current:
            return 1
        elif next == current: #single iteration per dimension
            if key == 'k':
                return dim
            elif key == 'n':
                if k_width==Tk:                 #m, n if k_width is equal to whole k dimension 
                    return dim
                else:                           #m, n if k pading needing additioanl phase
                    return 1
            else: # 'm'
                if k_width==Tk and n_width==Tn:                 #m, n if k_width is equal to whole k dimension 
                    return dim
                else:                           #m, n if k pading needing additioanl phase
                    return 1
        else:                     #next iteration is in the same dimension
            return next - current
    k_width = compute_width(seq, itr_idx, 'k')
    n_width = compute_width(seq, itr_idx, 'n', k_width)
    m_width = compute_width(seq, itr_idx, 'm', k_width, n_width)
    b_width = compute_width(seq, itr_idx, 'b')
    return b_width, m_width, n_width, k_width

def tiling_xpr_seq_generator(Tb, Tm, Tn, Tk, Bpad_A, Bpad_B, Bpad_Y, Mpad_ifm, Kpad_ifm, Kpad_wgt, Npad_wgt, Munpad, Nunpad, sched):
    """
    In the context of super-kenel's MNK loop-nest, tiles fed to iteration X
    could be different from X's immediate previous iteration. This problem
    is arising from DMA-padding to M, K, or N dimension.
    This function is to write down all such iterations, and sort them in
    the loopnest's iteration order (i.e. the dictionary order of MNK).
    NOTE:
      - the nesting level of the loop-nest in super-kenel is MNK, NOT MKN
      - the itearation space of the loop-nest is naturally depicted by
        (m, n, k) where 0 <= m/n/k <= dims.Tm/Tn/Tk
    """
    seq = []

    # step 1: Consider 3 consectuive iterations:
    #  - iter1=(x, y, dims.Tk-2)
    #  - iter2=(x, y, dims.Tk-1)
    #  - iter3=(x, y+1, dims.0)
    # shape of tile fed to iter2 and iter3 and different from their
    # immedidately previous iteation. and hence we record down
    # (*, *, dims.Tk-1) and (*, *, 0)
    if Tk > 1 and (Kpad_wgt != 0 or Kpad_ifm !=0):
        for m in range(Tm):
            for n in range(Tn):
                seq.append([m, n, 0])
                seq.append([m, n, Tk - 1])

    # step 2: Do something similar for the N-loop. No need to consider
    #  K-loop.
    if Tn > 1 and (Npad_wgt != 0 or Nunpad != 0):
        for m in range(Tm):
          seq.append([m, Tn-1, 0])
          seq.append([m, 0, 0])


    # step 3: Do something similar for M-loop
    seq.append([0, 0, 0])
    if sched == 5:
        for m in range(Tm):
            seq.append([m, 0, 0])
    if Mpad_ifm != 0 or Munpad != 0:               #mpadding
        seq.append([Tm-1, 0, 0])

    # step 4: deduplicate and sort these iteartions in the dictionary
    #  order of MNK -- i.e. the order where iterations are executed
    #  in the loop-nest
    unique_seq  = set(tuple(sublist) for sublist in seq)
    sorted_seq_tuples = sorted(unique_seq, key=lambda x: (x[0],x[1],x[2]))
    seq = [list(tup) for tup in sorted_seq_tuples]

    b_seq = []
    tb_itr_list = [0]
    if Tb > 1:
        if Tm > 1 or Tn > 1 or Tk > 1:
            tb_itr_list = list(range(Tb))
        elif Bpad_A > 0 or Bpad_B > 0 or Bpad_Y > 0:
            tb_itr_list = [0, Tb-1]
        else:
            tb_itr_list = [0]
    b_itr_seq = copy.deepcopy(seq)
    for b in tb_itr_list:
        for s in b_itr_seq:
            b_seq.append([b] + s)

    seq_dim = []
    for itr_idx, _ in enumerate(b_seq):
        seq_dim.append(get_seq_dimensions(Tb, Tm, Tk, Tn, b_seq, itr_idx))

    return b_seq, seq_dim


@dataclasses.dataclass(slots=True)
class GemmDims(BaseDims):
    Mpad_ifm: int
    Kpad_ifm: int
    Kpad_wgt: int
    Npad_wgt: int
    Munpad: int
    Nunpad: int
    K_ifmB: int
    B_split: int
    B_itr  : int
    Bpad_A  : int
    Bpad_B  : int
    Bpad_Y  : int

class tiling_funcs():
    def check_dimension(self, dims, col, d):
        #check for common index across a single dimension d
        mn_idx = column_to_split(dims, 0)
        dim_list = np.array(mn_idx)[:,d]
        return len(np.unique(dim_list))

    """
    TODO: This function is better named to_consume_partial_tile().

    Consider matmul C = A x B. At core-tie level, these matrices are tiled with
    with Msubv/Ksubv/Nsubv on M/K/N direction, respectively. Let us call
    these tiled matrices "grid", and use A', B' and C' to denote them for A B
    and C, respectively. In concept, the grid and tile here is akin to CUDA's
    grid and block/tile (note: blocking and tiling is interchangeable in the
    domain of loop-nest-optimization)

    The super-kernel, which is the ultimate code, run on aie-core, contains a
    3-deep loop-nest. Let us call it super-kernel-loop-nest. It normally
    has 3-deep, and with the outer-most loop, inner and innser-most going over
    M, N and K dimmenion grid, respectively, as illustrated below

    ----------------------------------------------------
    # the {M|K|N}_h denote host-padded value at {M|K|N} dimension
    for m = 0 to ceil(M_h/Msubv)      # M' loop over the grid
      for n = 0 to ceil(N_h/Nsubv)    # N' loop over the the grid
        for k = 0 to ceil(K_h/Ksubv)  # K' loop could be fully unrolled
          do_computation
    -----------------------------------------------------

    The index of the loop-iteration can be denoted by a 3-tuple: (m, n, k).
    Howver, a index to particular nesting can be simply denoted by a scalar,
    in case the nesting level is revealed by the context.

    This function will return true iff given loop-index of particular loop
    (in the super-kernel-loop-nest) will consume partial-tile.

    e.g. M_h is not multiple of Nsubv, then the right-most tile is partial-tile
    so to figure out this situation, call this function with
      self.is_last(T, iter, pad), where
        - T = trip-count of the N' loop, i.e. ceil(M_h/Msubv)
        - iter = loop index to N' loop, in [0, T)
        - pad: 0: no padding, non-zero: need padding

    Note that the "last" in the function name refers to the last-iteration
    of particular loop in the aforementioned loop-nest.
    """
    def is_last(self, T, itr, pad):
        return 1 if itr == (T - 1) and (pad > 0) else 0

    def split_val(self, repeat_val, limit):
        repeat = []
        while True:
            if repeat_val > limit:
                repeat.append(limit)
                repeat_val = repeat_val - limit
            else:
                repeat.append(repeat_val)
                break
        return repeat

    def generate_padding_val(self, subv, pad, col_list, idx, enable=True):
        # return size base on subv, pad value, number of columns and col idx
        # total process size is subv * len(col_list) - pad
        # total process size is then dividied by subv and stored into pad list
        # function returns pad list [column index]
        if enable:
            valid = subv * len(col_list) - pad
            pad_list = []
            for x in col_list:
                if valid > subv:
                    pad_list.append(subv)
                    valid -= subv
                else:
                    pad_list.append(valid)
                    valid -= valid
            return pad_list[idx]
        else:
            return min(subv, subv - pad)
        
    def create_padding_list(self, subv, pad, col_list, idx):
        # return size base on subv, pad value, number of columns and col idx
        # total process size is subv * len(col_list) - pad
        # total process size is then dividied by subv and stored into pad list
        # function returns pad list [column index]
        valid = subv * len(col_list) - pad
        pad_list = []
        for x in col_list:
            if valid > subv:
                pad_list.append(subv)
                valid -= subv
            else:
                pad_list.append(valid)
                valid -= valid
        return pad_list[idx]

    def col_index(self, dims: GemmDims, array: int, col: int) -> int:
        assert 0 <= col < dims.aie_cols
        assert 0 <= array < dims.aie_arrays
        return (array * dims.aie_cols) + col
    
    def reuse_chain_length(
        self,
        reuse_ratio: int,
        num_consumers: int,
        max_chain_length: int = 4
    ) -> int:
        max_lock_value = 63
        for i in range(1, max_chain_length + 1):
            is_valid = (
                ((reuse_ratio % i) == 0) and
                (((reuse_ratio // i) * num_consumers) <= max_lock_value)
            )
            if is_valid: return i
        raise RuntimeError(f'Failed to allocate reuse chain! reuse_ratio = {reuse_ratio}, num_consumers = {num_consumers}')
    
    def get_subvol_order_str(self, perm, h_start: int, h_stop: int, w_start: int, w_stop: int) -> str:
        """
        Helper function to get subvolume order string based on permutation.
        
        Args:
            perm: List containing permute order.
            h_start: Start position for height dimension
            h_stop: Stop position for height dimension
            w_start: Start position for width dimension
            w_stop: Stop position for width dimension
            
        Returns:
            String with permuted subvolume ordering for data transfers
        """
        # Extract M,K or K,N order from rev_perm
        perm_vec = np.array(perm)
        subvol_perm = perm_vec[perm_vec != 0] - 1

        subvol_default_order = np.array([f'H:{h_start}:{h_stop}', f'W:{w_start}:{w_stop}'])
        return ' '.join(subvol_default_order[subvol_perm].tolist())


    def extract_CN_from_perm(self, params, tensor: str) -> tuple[int, int]:
        if tensor == 'A':
            perm = params['info_4d']['permA']
            perm_shape = params['info_4d']['in_act_shape_orig']
        elif tensor == 'B':
            perm = params['info_4d']['permB']
            perm_shape = params['info_4d']['in_wgt_shape_orig']
        elif tensor == 'Y':
            perm = params['info_4d']['permY']
            perm_shape = params['info_4d']['out_act_shape_orig']
        else:
            raise ValueError(f"Invalid tensor type: {tensor}")

        C = perm_shape[perm[0]]
        N = perm_shape[perm[1]]
        return perm, C, N

    def act_shim_memory(self, params) -> str:
        dims = params['dims']
        if params['ActFormat'] == "default":
            H_val = dims.M-dims.Mpad_ifm if dims.aie_cols == len(params['active_col']) else dims.M
            if params['transpose_4d'][0]:
                perm, C, N = self.extract_CN_from_perm(params, 'A')
                mem_default_order = np.array([f'C:{C}', f'N:{N}', f'H:{H_val}', f'W:{dims.K}'])
                rev_permA = np.empty(len(perm), dtype=int)
                rev_permA[list(perm)] = np.arange(len(perm))
            else:
                mem_default_order = np.array([f'N:{dims.Bpad_A}', f'H:{H_val}', f'W:{dims.K}'])
                rev_permA = params['rev_permA']
            return ' '.join(mem_default_order[rev_permA].tolist())

        if params['ActFormat'] == "BFP_A8":
            return f"W:{dims.K} H:{dims.M+dims.Mpad_ifm} W:{params['min_cols']}"

    def act_shim_mm2s_batch(self, params, B_start: int, B_len: int, B_stop: int, B_stride_start: int, B_stride_stop: int, B_stride: int) -> tuple[str, str]:
        dims = params['dims']
        if params['transpose_4d'][0]:
            _, C, N = self.extract_CN_from_perm(params, 'A')
            if dims.split[0] == 32:
                if dims.split[0]//8 >= N:
                    bmm_str1                = f'C:{B_stride_start//N}:{B_stride_stop//N}:{B_stride//N}'
                    bmm_str0                = f'C:{B_start//N}:{B_stop//N} N:0:{N}'
                else: #dims.split[0]//8 < N:
                    bmm_str1                = f'C:{B_stride_start//N+B_start//N}:{B_stride_stop//N}:{B_stride//N} N:{0}:{N}:{N}'
                    bmm_str0                = f'N:{B_start%N}:{B_start%N+B_len}'
            elif dims.split[0] == 4:
                if dims.split[0] > N:
                    bmm_str1                = f'C:{B_stride_start//N}:{B_stride_stop//N}:{B_stride//N}'
                    bmm_str0                = f'C:{B_start//N}:{B_stop//N} N:0:{N}'
                else: #dims.split[0] <= N:
                    bmm_str1                = f'C:{B_stride_start//N}:{B_stride_stop//N}:{max(B_stride//N, 1)} N:{0}:{N}:{dims.split[0]}'
                    bmm_str0                = f'N:{B_start%dims.split[0]}:{B_start%N+B_len}'
            else:
                bmm_str1                    = f'C:{B_stride_start}:{B_stride_stop//N}:{B_stride} N:0:{N}:1'
                bmm_str0                    = f'N:{B_start}:{B_stop}'
        else:
            bmm_str1                    = f'N:{B_stride_start}:{B_stride_stop}:{B_stride}'
            bmm_str0                    = f'N:{B_start}:{B_stop}'
        return bmm_str1, bmm_str0

    def act_shim_mm2s(self, params, col, itr_idx) -> str:
        #M1N32 col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==4
        #M4N8  col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==4
        #M8N4  col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==8
        #M32N1 col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==8   
        dims                        = params['dims']
        batch_ceil                  = math.ceil(dims.Bpad_A / dims.split[0])
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        B_itr, M_itr, N_itr, K_itr           = params['tiling_seq'][itr_idx]
        last_b                      = self.is_last(batch_ceil, B_itr, batch_ceil * dims.split[0] - dims.Bpad_A)
        last_m                      = self.is_last(dims.Tm, M_itr, dims.Mpad_ifm)
        last_k                      = self.is_last(dims.Tk, K_itr, dims.Kpad_ifm)
        col_aie_b                   = self.check_dimension(dims, 0, 0)
        col_aie_m                   = self.check_dimension(dims, 0, 1)
        row_aie_b                   = dims.aie_rows * dims.aie_cols // dims.split[0]
        rev_permA                   = params['rev_permA']
        Bfactor                     = int(np.ceil(dims.split[0]/len(params['wgt_idx'][1])))        #B32, B8, B4, B1 = [4, ?, 1, 1]
        Bpadding                    = False if dims.split[0] in [1, 4] else True                     #B32, B8, B4, B1 = [True, True, False, True]

        B_stride_start              = min(B_itr * dims.split[0], dims.Bpad_A)
        B_stride_stop               = B_stride_start + b_width * dims.split[0]
        B_stride                    = dims.split[0]
        B_start                     = col * Bfactor * Bpadding
        B_len                       = self.generate_padding_val(col_aie_b, last_b*(batch_ceil *dims.split[0] - dims.Bpad_A), params['ifm_idx'][1], col, Bpadding)
        B_stop                      = B_start + B_len

        bmm_str1, bmm_str0          = self.act_shim_mm2s_batch(params, B_start, B_len, B_stop, B_stride_start, B_stride_stop, B_stride)

        Hpadding = dims.split[1]//len(params['ifm_idx'][0])                     #M1N32, M4N8, M8N4, M32N1 = [0, 1, 1, 4]
        Hfactor  = int(np.ceil(dims.split[1]/len(params['ifm_idx'][0])))        #M1N32, M4N8, M8N4, M32N1 = [1, 1, 1, 4]

        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "default"]:
            H_stride        = dims.M_subv * dims.split[1]
            H_start         = (Hpadding > 0) * col * dims.M_subv * Hfactor + M_itr * H_stride
            H_stop          = H_start + self.generate_padding_val(dims.M_subv*Hfactor, dims.Mpad_ifm*last_m, 
                                                                  params['ifm_idx'][1], col, Hpadding)

            W_stride        = dims.K_subv
            W_stride_start  = K_itr * W_stride
            W_stride_stop   = W_stride_start + k_width * W_stride
            W_start         = K_itr * W_stride * last_k
            W_stop          = W_start + min(dims.K_subv-dims.Kpad_ifm*last_k, dims.K)

            subvol_order_str = self.get_subvol_order_str(rev_permA, H_start, H_stop, W_start, W_stop)

            if params['transpose_4d'][0]:
                return f'{bmm_str1} ' + \
                    f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' * (not last_k) +  \
                    f"{subvol_order_str.split(' ')[0]} " + \
                    f'{bmm_str0} ' + \
                    f"{subvol_order_str.split(' ')[1]}"
            else:
                return f'{bmm_str1} ' + \
                    f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' * (not last_k) +  \
                    f'{bmm_str0} ' + \
                    f'{subvol_order_str}'

        if [params['ActDataFlow'], params['ActFormat']] == ["pin",    "default"]:
            H_stride        = dims.M_subv * dims.split[1]
            H_start         = (Hpadding > 0) * col * dims.M_subv*Hfactor
            H_stop          = H_start + self.generate_padding_val(dims.M_subv*Hfactor, dims.Mpad_ifm*last_m, 
                                                                  params['ifm_idx'][1], col, Hpadding)
            H_stride_start  = M_itr * H_stride
            H_stride_stop   = H_stride_start + H_stride * m_width

            W_start        = 0
            W_stop         = dims.K
            subvol_order_str = self.get_subvol_order_str(rev_permA, H_start, H_stop, W_start, W_stop)
            if params['transpose_4d'][0] and dims.split[0] > 1:
                return f'{bmm_str1} ' + \
                    f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                    f"{subvol_order_str.split(' ')[0]} " + \
                    f'{bmm_str0} ' + \
                    f"{subvol_order_str.split(' ')[1]}"
            else:
                return f'{bmm_str1} ' + \
                    f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                    f'{bmm_str0} ' + \
                    f'{subvol_order_str}'
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "BFP_A8"]:
            raise RuntimeError('BFP with Activation stream mode is not supported!')
            #H_stride = dims.aie_cols * dims.M_subv
            #H_start = col * dims.M_subv + itr * H_stride
            #H_stop = H_start + dims.M_subv - Mpad
            #W_start = 0 if active_col else dims.K_subv
            #return f'W:0:{dims.K}:{dims.K_subv} H:{H_start}:{H_stop} W:{W_start}:{dims.K_subv}'
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",     "BFP_A8"]:
            raise RuntimeError('BFP with Activation stream mode is not supported!')
            #H_stride = dims.aie_cols * dims.M_subv
            #H_start = col * dims.M_subv
            #H_stop = H_start + dims.M_subv
            #W_start = 0 if active_col else params['min_cols']
            #return f"H:0:{dims.M}:{H_stride} W:0:{dims.K}:{params['min_cols']} " + \
            #       f"H:{H_start}:{H_stop} W:{W_start}:{params['min_cols']}"

    def act_memtile_memory(self, params, col, itr_idx) -> str:
        #M1N32 col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==4
        #M4N8  col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==4
        #M8N4  col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==8
        #M32N1 col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==8    
        B_itr, M_itr, N_itr, K_itr         = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        dims                        = params['dims']
        last_m                      = self.is_last(dims.Tm, M_itr, dims.Mpad_ifm)
        last_k                      = self.is_last(dims.Tk, K_itr, dims.Kpad_ifm)
        col_aie_b                   = self.check_dimension(dims, 0, 0)
        col_aie_m                   = self.check_dimension(dims, 0, 1)
        Hpadding = dims.split[1]//len(params['ifm_idx'][0])                     #M1N32, M4N8, M8N4, M32N1 = [0, 1, 1, 4]
        Hfactor  = int(np.ceil(dims.split[1]/len(params['ifm_idx'][0])))        #M1N32, M4N8, M8N4, M32N1 = [1, 1, 1, 4]
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "default"]:
            #stream mode: 
            #M1N32 H is same cross all columns
            #M4N8 and M8N4 M pad across columns with M_subv
            #M32N1 H pad across columns with M_subv * 4
            #K pad only in last k iteration
            N = col_aie_b
            H = self.generate_padding_val(dims.M_subv*Hfactor, dims.Mpad_ifm*last_m, 
                                          params['ifm_idx'][1], col, Hpadding)
            H = H if H > 0 else dims.M_subv*Hfactor                                 #set memory to full resolution if channel has 0 data transfer
            W = min(dims.K_subv-dims.Kpad_ifm*last_k, dims.K)
            if params['transposeA']:
                return f"N:{N} H:{H} W:{W} H:{params['min_cols']}"
            return f"N:{N} W:{W} H:{H} W:{params['min_cols']}"
        elif [params['ActDataFlow'], params['ActFormat']] == ["pin", "default"]:
            Hfactor  = int(np.ceil(dims.split[1]/len(params['ifm_idx'][0])))
            H_stride = dims.M_subv*Hfactor

            N = col_aie_b
            H = min(dims.M-dims.Mpad_ifm , H_stride)
            W = dims.K
            if params['transposeA']:
                return f"N:{N} H:{H} W:{W} H:{params['min_cols']}"
            return f"N:{N} W:{W} H:{H} W:{params['min_cols']}"
        else:
            if params['transposeA']:
                return f'H:{dims.M_subv-dims.Mpad_ifm} W:{dims.K-dims.Kpad_ifm} H:8'
            return f'W:{dims.K-dims.Kpad_ifm} H:{dims.M_subv-dims.Mpad_ifm} W:8'
        
    def act_memtile_s2mm(self, params, col, itr_idx) -> str:
        #M1N32 col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==4
        #M4N8  col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==4
        #M8N4  col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==8
        #M32N1 col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==8    
        B_itr, M_itr, N_itr, K_itr           = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        dims                        = params['dims']
        batch_ceil                  = math.ceil(dims.Bpad_B / dims.split[0])
        last_b                      = self.is_last(batch_ceil, B_itr, batch_ceil * dims.split[0] - dims.Bpad_A)
        last_m                      = self.is_last(dims.Tm, M_itr, dims.Mpad_ifm)
        last_k                      = self.is_last(dims.Tk, K_itr, dims.Kpad_ifm)
        col_aie_b                   = self.check_dimension(dims, 0, 0)
        col_aie_m                   = self.check_dimension(dims, 0, 1)
        Bpadding                    = False if dims.split[0]==4 else True                     #B32, B8, B4, B1 = [True, True, False, True]
        Hpadding                    = dims.split[1]//len(params['ifm_idx'][0])                     #M1N32, M4N8, M8N4, M32N1 = [0, 1, 1, 4]
        Hfactor                     = int(np.ceil(dims.split[1]/len(params['ifm_idx'][0])))        #M1N32, M4N8, M8N4, M32N1 = [1, 1, 1, 4]
        rev_permA                   = params['rev_permA']
        active_col                  = True if params['ifm_idx'] is params['broadcast_idx'] else col in params['active_col']
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "default"]:
            N_start = 0
            N_stop  = N_start + self.generate_padding_val(col_aie_b, last_b*(batch_ceil *dims.split[0] - dims.Bpad_A), params['ifm_idx'][1], col, Bpadding)

            H_start = 0
            H_stop  = H_start + self.generate_padding_val(dims.M_subv*Hfactor, dims.Mpad_ifm*last_m, 
                                                          params['ifm_idx'][1], col, Hpadding)
            W_start = 0
            W_stop = min(dims.K_subv-dims.Kpad_ifm*last_k, dims.K)
            subvol_order_str = self.get_subvol_order_str(rev_permA, H_start, H_stop, W_start, W_stop)

            if params['transpose_4d'][0] and dims.split[0] > 1:
                return f"{subvol_order_str.split(' ')[0]} N:{N_start}:{N_stop} {subvol_order_str.split(' ')[1]}"
            else:
                return f'N:{N_start}:{N_stop} {subvol_order_str}'
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",    "default"]:
            N_start = 0
            N_stop  = N_start + self.generate_padding_val(col_aie_b, last_b*(batch_ceil *dims.split[0] - dims.Bpad_A), params['ifm_idx'][1], col, Bpadding)

            H_start = 0
            H_stop  = H_start + self.generate_padding_val(dims.M_subv*Hfactor, dims.Mpad_ifm*last_m, 
                                                          params['ifm_idx'][1], col, Hpadding)
            W_start = 0 if active_col else dims.K
            W_stop  = dims.K
            subvol_order_str = self.get_subvol_order_str(rev_permA, H_start, H_stop, W_start, W_stop)
            if params['transpose_4d'][0] and dims.split[0] > 1:
                return f"{subvol_order_str.split(' ')[0]} N:{N_start}:{N_stop} {subvol_order_str.split(' ')[1]}"
            else:
                return f'N:{N_start}:{N_stop} {subvol_order_str}'
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "BFP_A8"]:
            raise RuntimeError('BFP with Activation stream mode is not supported!')
            #W_start = 0 if active_col else dims.K_subv
            #return f'H:0:{dims.M_subv-Mpad} W:{W_start}:{dims.K_subv}'                      #TBD
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",    "BFP_A8"]:
            raise RuntimeError('BFP with Activation stream mode is not supported!')
            #W_start = 0 if active_col else params['min_cols']
            #return f"W:0:{dims.K}:{params['min_cols']} H:0:{dims.M_subv} W:{W_start}:{params['min_cols']}"

    def act_memtile_mm2s(self, params, col, row, itr_idx) -> str:
        B_itr, M_itr, N_itr, K_itr         = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        dims                        = params['dims']
        last_m                      = self.is_last(dims.Tm, M_itr, dims.Mpad_ifm)
        last_k                      = self.is_last(dims.Tk, K_itr, dims.Kpad_ifm)
        col_aie_b                   = self.check_dimension(dims, 0, 0)
        col_aie_m                   = self.check_dimension(dims, 0, 1)
        Hpadding   = dims.split[1]//len(params['ifm_idx'][0])                     #M1N32, M4N8, M8N4, M32N1 = [0, 1, 1, 4]
        Hfactor    = int(np.ceil(dims.split[1]/len(params['ifm_idx'][0])))        #M1N32, M4N8, M8N4, M32N1 = [1, 1, 1, 4]
        active_col = True if params['ifm_idx'] is params['broadcast_idx'] else col in params['active_col']
        if params['ActDataFlow'] == "stream":
            #M1N32 col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==4
            #M4N8  col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==4
            #M8N4  col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==8
            #M32N1 col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==8   
            #M32N1 memtile handles 4 Msubvs per iteration
            N_start        = (col_aie_b > 1) * (dims.B_subv * row)
            N_stop         = N_start + dims.B_subv
  
            H_stride        = dims.M_subv * Hfactor
            H_col           = self.generate_padding_val(H_stride, dims.Mpad_ifm*last_m, 
                                                        params['ifm_idx'][1], col, Hfactor>1)
            H_col           = H_col if H_col > 0 else H_stride                             #streaming always transfer data size larger than 0
            H_start         = row * dims.M_subv if row * dims.M_subv < H_col else 0
            H_stop          = H_start + dims.M_subv

            W_start         = 0
            W_stop          = params['min_cols']
            W_stride        = params['min_cols']
            W_stride_start  = 0
            W_stride_stop   = dims.K_subv
            if params['transposeA']:
                return f'N:{N_start}:{N_stop} H:{H_start}:{H_stop}:{W_stride} W:{W_stride_start}:{W_stride_stop} H:{W_start}:{W_stop}'
            else:
                return f'N:{N_start}:{N_stop} W:{W_stride_start}:{W_stride_stop}:{W_stride} H:{H_start}:{H_stop} W:{W_start}:{W_stop}'
        if params['ActDataFlow'] == "pin":
            N_start        = (col_aie_b > 1) * (dims.B_subv * row)
            N_stop         = N_start + dims.B_subv

            H_stride        = dims.M_subv * Hfactor
            H_col           = self.generate_padding_val(H_stride, dims.Mpad_ifm*last_m, 
                                                        params['ifm_idx'][1], col, Hfactor>1) #check for Height per col
            #H_col           = H_col if H_col > 0 else H_stride                             #pin will can have column transfer 0 data
            H_start         = row * dims.M_subv if row * dims.M_subv < H_col else 0
            H_stop          = H_start + dims.M_subv

            W_start         = 0
            W_stop          = params['min_cols']
            W_stride        = params['min_cols']
            W_stride_start  = 0
            W_stride_stop   = dims.K+dims.Kpad_ifm
            if params['transposeA']:
                return f'N:{N_start}:{N_stop} H:{H_start}:{H_stop}:{W_stride} W:{W_stride_start}:{W_stride_stop} H:{W_start}:{W_stop}'
            else:
                return f'N:{N_start}:{N_stop} W:{W_stride_start}:{W_stride_stop}:{W_stride} H:{H_start}:{H_stop} W:{W_start}:{W_stop}'
    
    def wgt_shim_memory(self, params, itr_idx) -> str:
        B_itr, M_itr, N_itr, K_itr         = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        dims = params['dims']
        if params['actxact']:
            if params['transpose_4d'][1]:
                perm, C, N = self.extract_CN_from_perm(params, 'B')
                mem_default_order = np.array([f'C:{C}', f'N:{N}', f'H:{dims.K_ifmB}', f'W:{dims.N - dims.Npad_wgt}'])
                rev_permB = np.empty(len(perm), dtype=int)
                rev_permB[list(perm)] = np.arange(len(perm))
            else:            
                mem_default_order = np.array([f'N:{dims.Bpad_B}', f'H:{dims.K_ifmB}', f'W:{dims.N - dims.Npad_wgt}'])
                rev_permB = params['rev_permB']
            return ' '.join(mem_default_order[rev_permB].tolist())
        else: 
            return f'N:{dims.B} Cols:{dims.wgt_subv_cols} Rows:{dims.wgt_subv_rows} Bytes:{dims.wgt_subv_bytes}'
    
    def wgt_shim_mm2s_batch(self, params, B_start: int, B_len: int, B_stop: int, B_stride_start: int, B_stride_stop: int, B_stride: int) -> tuple[str, str]:
        dims = params['dims']
        _, C, N = self.extract_CN_from_perm(params, 'B')
        if params['transpose_4d'][1]:
            if dims.split[0] == 32:
                if dims.split[0]//4 >= N:
                    bmm_str1                = f'C:{B_stride_start//N}:{B_stride_stop//N}:{B_stride//N}'
                    B_offset                = 2 if B_start%N+B_len>N else 1
                    bmm_str0                = f'C:{B_start//N}:{B_start//N+B_offset} N:{B_start%N}:{min(B_start%N+B_len,N)}'
                else: #dims.split[0] <= N:
                    bmm_str1                = f'C:{B_stride_start}:{B_stride_stop//N}:{B_stride//N} N:{0}:{N}:{dims.split[0]}'
                    bmm_str0                = f'N:{B_start%dims.split[0]}:{((B_stop-1)%dims.split[0])+1}'
            elif dims.split[0] == 4:
                if dims.split[0] > N:
                    bmm_str1                = f'C:{B_stride_start}:{B_stride_stop//N}:{B_stride//N}'
                    bmm_str0                = f'C:{B_start//N}:{B_stop//N} N:0:{N}'
                else: #dims.split[0] <= N:
                    bmm_str1                = f'C:{B_stride_start//N}:{B_stride_stop//N}:{max(B_stride//N, 1)} N:{0}:{N}:{dims.split[0]}'
                    bmm_str0                = f'N:{B_start%dims.split[0]}:{B_start%N+B_len}'
            else:
                bmm_str1                = f'C:{B_stride_start}:{B_stride_stop//N}:{B_stride} N:0:{N}:1'
                bmm_str0                = f"N:{B_start}:{B_stop}"
        else:
            bmm_str0                = f"N:{B_start}:{B_stop}"
            bmm_str1                = f'N:{B_stride_start}:{B_stride_stop}:{B_stride}'
        return bmm_str1, bmm_str0

    def wgt_shim_mm2s(self, params, col: int, itr_idx, K_repeat_itr) -> str:
        #M1N32 col_aie_m=1, col_aie_n=4 len(params['wgt_idx'][1])==8
        #M4N8  col_aie_m=4, col_aie_n=1 len(params['wgt_idx'][1])==8
        #M8N4  col_aie_m=1, col_aie_n=4 len(params['wgt_idx'][1])==4
        #M32N1 col_aie_m=4, col_aie_n=1 len(params['wgt_idx'][1])==4 
        dims                        = params['dims']
        batch_ceil                  = math.ceil(dims.Bpad_B / dims.split[0])
        B_itr, M_itr, N_itr, K_itr           = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        last_b                      = self.is_last(batch_ceil, B_itr, batch_ceil * dims.split[0] - dims.Bpad_B)
        last_k                      = self.is_last(dims.Tk, K_itr, dims.Kpad_wgt)
        last_n                      = self.is_last(dims.Tn, N_itr, dims.Npad_wgt)
        col_aie_b                   = self.check_dimension(dims, 0, 0)
        col_aie_n                   = self.check_dimension(dims, 0, 2)
        row_aie_b                   = dims.aie_rows * dims.aie_cols // dims.split[0]
        Wpadding                    = dims.split[3]//len(params['wgt_idx'][0])                     #M1N32, M4N8, M8N4, M32N1 = [4, 1, 1, 0]
        Wfactor                     = int(np.ceil(dims.split[3]/len(params['wgt_idx'][0])))        #M1N32, M4N8, M8N4, M32N1 = [4, 1, 1, 1]
        Bfactor                     = int(np.ceil(dims.split[0]/len(params['wgt_idx'][1])))        #B32, B8, B4, B1 = [4, ?, 1, 1]

        K_itr += K_repeat_itr
        rev_permB = params['rev_permB']

        Bpadding                = False if dims.split[0]==1 else True                     #B32, B8, B4, B1 = [True, True, True, False]
        B_start                 = col * Bfactor * Bpadding
        B_len                   = self.generate_padding_val(Bfactor, last_b*(batch_ceil *dims.split[0] - dims.Bpad_B), params['wgt_idx'][1], col, Bpadding)
        B_stop                  = B_start + B_len
        B_stride_start          = min(B_itr * dims.split[0], dims.Bpad_B)
        B_stride_stop           = B_stride_start + b_width * dims.split[0]
        B_stride                = dims.split[0]
        
        bmm_str1, bmm_str0      = self.wgt_shim_mm2s_batch(params, B_start, B_len, B_stop, B_stride_start, B_stride_stop, B_stride)

        if params['actxact']:
            if [params['WgtDataFlow'], params['WgtFormat']] == ["full",    "default"]:
                H_start         = 0
                H_stop          = dims.K_ifmB
                
                W_stride        = dims.N_subv * dims.split[3]
                W_stride_start  = N_itr * W_stride
                W_stride_stop   = W_stride_start + n_width * W_stride
                W_start         = min(col * Wpadding * dims.N_subv, dims.N-dims.Npad_wgt)
                W_stop          = W_start + self.generate_padding_val(dims.N_subv*Wfactor, dims.Npad_wgt*last_n, 
                                                        params['wgt_idx'][1], col, Wpadding)
                subvol_order_str = self.get_subvol_order_str(rev_permB, H_start, H_stop, W_start, W_stop)
                if params['transpose_4d'][1]:
                    return f'{bmm_str1} W:{W_stride_start}:{W_stride_stop}:{W_stride} ' + \
                           f"{subvol_order_str.split(' ')[0]} {bmm_str0} " + \
                           f"{subvol_order_str.split(' ')[1]}" 
                else:
                    return f'{bmm_str1} {bmm_str0} W:{W_stride_start}:{W_stride_stop}:{W_stride} {subvol_order_str}'

            if [params['WgtDataFlow'], params['WgtFormat']] == ["stream",    "default"]:
                H_stride        = dims.K_subv
                H_stride_start  = K_itr * H_stride
                H_stride_stop   = H_stride_start + k_width * H_stride
                H_start         = 0
                H_stop          = H_start + min(dims.K_subv-dims.Kpad_wgt*last_k, dims.K_ifmB)

                W_stride        = dims.N_subv * dims.split[3]
                W_stride_start  = N_itr * W_stride
                W_stride_stop   = W_stride_start + n_width * W_stride
                W_start         = col * dims.N_subv * Wpadding
                W_stop          = W_start + self.generate_padding_val(dims.N_subv*Wfactor, dims.Npad_wgt*last_n, 
                                                        params['wgt_idx'][1], col, Wpadding)

                subvol_order_str = self.get_subvol_order_str(rev_permB, H_start, H_stop, W_start, W_stop)

                if dims.split[0]==4:
                    if params['transpose_4d'][1]:
                        return f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' + \
                            f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                            f'{bmm_str1} ' + \
                            f"{subvol_order_str.split(' ')[0]} {bmm_str0} " + \
                            f"{subvol_order_str.split(' ')[1]}" 
                    else:
                        return f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' + \
                            f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                            f'{bmm_str1} {bmm_str0} ' + \
                            f'{subvol_order_str}'
                else:
                    if params['transpose_4d'][1]:
                        Warning('Transpose 4D with Wgt stream and Bfactor!=4 may not work as expected!')
                        return f'{bmm_str1}' + \
                            f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' + \
                            f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                            f'{bmm_str1} ' + \
                            f"{subvol_order_str.split(' ')[0]} {bmm_str0} " + \
                            f"{subvol_order_str.split(' ')[1]}" 
                    else:
                        return f'{bmm_str1} {bmm_str0} ' + \
                            f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' + \
                            f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                            f'{subvol_order_str}'
        else: #actxwgt
            if params['WgtDataFlow'] == 'stream' and params.get('shim_wgt_reengueue', False):
                    col_start   = col * Wfactor + K_repeat_itr * dims.split[3]
                    col_stop    = col_start + Wfactor
                    row_start   = 0
                    row_stop    = dims.wgt_subv_rows
                    Byte_start  = 0
                    Byte_stop   = dims.wgt_subv_bytes
                    return f'{bmm_str1} {bmm_str0} ' + \
                        f'Rows:{row_start}:{row_stop} ' + \
                        f'Cols:{col_start}:{col_stop} ' + \
                        f'Bytes:{Byte_start}:{Byte_stop}'
            elif params['WgtDataFlow'] == 'stream' and Wfactor > 1: #why only M1N32 stream need row and col swap?
                col_start        = col * Wfactor * (Wpadding > 0)
                col_stop         = col_start + Wfactor
                col_stride       = dims.split[3]
                col_stride_start = col_stride * N_itr
                col_stride_stop  = col_stride_start + col_stride * n_width#dims.wgt_subv_cols
                row_start        = 0
                row_stop         = dims.wgt_subv_rows
                Byte_start       = 0
                Byte_stop        = dims.wgt_subv_bytes
                return (
                    f'{bmm_str1} {bmm_str0} ' +
                    f'Cols:{col_stride_start}:{col_stride_stop}:{col_stride} ' +
                    f'Rows:{row_start}:{row_stop} ' +
                    f'Cols:{col_start}:{col_stop} ' +
                    f'Bytes:{Byte_start}:{Byte_stop}'
                )
            elif params['WgtDataFlow'] == 'stream':
                col_start        = col * Wfactor * (Wpadding > 0)
                col_stop         = col_start + Wfactor
                col_stride       = dims.split[3]
                col_stride_start = col_stride * N_itr
                col_stride_stop  = col_stride_start + col_stride * n_width#dims.wgt_subv_cols
                row_start        = 0
                row_stop         = dims.wgt_subv_rows
                Byte_start       = 0
                Byte_stop        = dims.wgt_subv_bytes
                return (
                    f'{bmm_str1} {bmm_str0} ' +
                    f'Cols:{col_stride_start}:{col_stride_stop}:{col_stride} ' +
                    f'Cols:{col_start}:{col_stop} ' +
                    f'Rows:{row_start}:{row_stop} ' +
                    f'Bytes:{Byte_start}:{Byte_stop}'
                )
            else: #full
                col_start        = col * Wfactor * (Wpadding > 0)
                col_stop         = col_start + Wfactor
                col_stride       = dims.split[3]
                col_stride_start = col_stride * N_itr
                col_stride_stop  = dims.wgt_subv_cols
                row_start        = 0
                row_stop         = dims.wgt_subv_rows
                Byte_start       = 0
                Byte_stop        = dims.wgt_subv_bytes
                return (
                    f'{bmm_str1} {bmm_str0} ' +
                    f'Cols:{col_stride_start}:{col_stride_stop}:{col_stride} ' +
                    f'Cols:{col_start}:{col_stop} ' +
                    f'Rows:{row_start}:{row_stop} ' +
                    f'Bytes:{Byte_start}:{Byte_stop}'
                )


    def wgt_memtile_memory(self, params, itr_idx) -> str:
        #M1N32 col_aie_m=1, col_aie_n=4 len(params['wgt_idx'][1])==8
        #M4N8  col_aie_m=4, col_aie_n=1 len(params['wgt_idx'][1])==8
        #M8N4  col_aie_m=1, col_aie_n=4 len(params['wgt_idx'][1])==4
        #M32N1 col_aie_m=4, col_aie_n=1 len(params['wgt_idx'][1])==4 
        B_itr, M_itr, N_itr, K_itr         = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        dims                        = params['dims']
        last_k                      = self.is_last(dims.Tk, K_itr, dims.Kpad_wgt)
        last_n                      = self.is_last(dims.Tn, N_itr, dims.Npad_wgt)
        col_aie_b                   = self.check_dimension(dims, 0, 0)
        col_aie_n                   = self.check_dimension(dims, 0, 2)
        Wpadding                    = dims.split[3]//len(params['wgt_idx'][0])                     #M1N32, M4N8, M8N4, M32N1 = [4, 1, 1, 0]
        Wfactor                     = int(np.ceil(dims.split[3]/len(params['wgt_idx'][0])))        #M1N32, M4N8, M8N4, M32N1 = [4, 1, 1, 1]
        if params['actxact']:
            if params['WgtDataFlow'] == "stream":
                N = max(dims.split[0] // dims.aie_cols, 1)
                H = min(dims.K_subv - dims.Kpad_wgt * last_k, dims.K) 
                W = dims.N_subv * Wfactor
                if params['transposeB']:
                    return f"N:{N} W:{H} H:{W} W:{params['min_cols']}"
                else:
                    return f"N:{N} W:{W} H:{H} W:{params['min_cols']}"
            if params['WgtDataFlow'] == "full":
                N = max(dims.split[0] // dims.aie_cols, 1)
                H = dims.K_ifmB
                W = dims.N // (dims.split[3] // Wfactor)
                if params['transposeB']:
                    return f"N:{N} W:{H} H:{W} W:{params['min_cols']}"
                else:
                    return f"N:{N} W:{W} H:{H} W:{params['min_cols']}"
        else:
            if params['WgtDataFlow'] == "stream":
                Cols = Wfactor
                return f'Cols:{Cols} Bytes:{dims.wgt_subv_bytes}'
            if params['WgtDataFlow'] == "full":
                num_cols = dims.wgt_subv_cols // (dims.split[3] // Wfactor)
                return f'Cols:{num_cols} Rows:{dims.wgt_subv_rows} Bytes:{dims.wgt_subv_bytes}'

    def wgt_memtile_s2mm(self, params, col, itr_idx,) -> str:
        #M1N32 col_aie_m=1, col_aie_n=4 len(params['wgt_idx'][1])==8
        #M4N8  col_aie_m=4, col_aie_n=1 len(params['wgt_idx'][1])==8
        #M8N4  col_aie_m=1, col_aie_n=4 len(params['wgt_idx'][1])==4
        #M32N1 col_aie_m=4, col_aie_n=1 len(params['wgt_idx'][1])==4 
        dims                        = params['dims']
        batch_ceil                  = math.ceil(dims.Bpad_B / dims.split[0])
        B_itr, M_itr, N_itr, K_itr           = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        last_b                      = self.is_last(batch_ceil, B_itr, batch_ceil * dims.split[0] - dims.Bpad_B)
        last_k                      = self.is_last(dims.Tk, K_itr, dims.Kpad_wgt)
        last_n                      = self.is_last(dims.Tn, N_itr, dims.Npad_wgt)
        col_aie_b                   = self.check_dimension(dims, 0, 0)
        col_aie_n                   = self.check_dimension(dims, 0, 2)
        row_aie_b                   = dims.aie_rows * dims.aie_cols // dims.split[0]
        Wpadding                    = dims.split[3]//len(params['wgt_idx'][0])                     #M1N32, M4N8, M8N4, M32N1 = [4, 1, 1, 0]
        Wfactor                     = int(np.ceil(dims.split[3]/len(params['wgt_idx'][0])))        #M1N32, M4N8, M8N4, M32N1 = [4, 1, 1, 1]
        active_col = True if params['wgt_idx'] is params['broadcast_idx'] else col in params['active_col']
        if params['actxact']:
            if [params['WgtDataFlow'], params['WgtFormat']] == ["stream",    "default"]:
                N_start        = 0
                local_stop     = col_aie_b if dims.split[0] == 32 else 1
                N_stop         = self.generate_padding_val(local_stop, last_b*(batch_ceil *dims.split[0] - dims.Bpad_B), params['wgt_idx'][1], col)
                H_start        = 0
                H_stop         = H_start + min(dims.K_ifmB, dims.K_subv - dims.Kpad_wgt) if last_k else dims.K_subv
                W_start        = 0
                W_stop         = W_start + self.generate_padding_val(dims.N_subv*Wfactor, dims.Npad_wgt*last_n, params['wgt_idx'][1], col, Wpadding)
                if params['transposeB']:
                    if params['transpose_4d'][1]:
                        return f'H:{W_start}:{W_stop} N:{N_start}:{N_stop} W:{H_start}:{H_stop}'
                    else:
                        return f'N:{N_start}:{N_stop} H:{W_start}:{W_stop} W:{H_start}:{H_stop}'
                else:
                    if params['transpose_4d'][1]:
                        return f'H:{H_start}:{H_stop} N:{N_start}:{N_stop} W:{W_start}:{W_stop}'
                    else:
                        return f'N:{N_start}:{N_stop} H:{H_start}:{H_stop} W:{W_start}:{W_stop}'

            if [params['WgtDataFlow'], params['WgtFormat']] == ["full",    "default"]:
                N_start        = 0
                local_stop     = col_aie_b if dims.split[0] == 32 else 1
                N_stop         = self.generate_padding_val(local_stop, last_b*(batch_ceil *dims.split[0] - dims.Bpad_B), params['wgt_idx'][1], col)
                H_start        = 0
                H_stop         = dims.K_ifmB
                W_stride       = dims.N_subv * Wfactor
                W_stride_start = W_stride * N_itr * last_n
                W_stride_stop  = W_stride_start + n_width * W_stride
                W_start        = 0
                W_stop         = W_start + self.generate_padding_val(dims.N_subv*Wfactor, dims.Npad_wgt*last_n, params['wgt_idx'][1], col, Wpadding)
                W_stop         = W_stop if M_itr==0 and K_itr==0 else 0 #in full mode, only enable M_itr==0 and K_itr==0 iteration to transfer data
                W_stop         = W_stop if Wpadding!=0 or K_itr==0 else 0 #only enable first K_itr to transfer data in N1 mode
                if params['transposeB']:
                    if params['transpose_4d'][1]:
                        return f'H:{W_stride_start}:{W_stride_stop}:{W_stride} H:{W_start}:{W_stop} N:{N_start}:{N_stop} W:{H_start}:{H_stop}'
                    else:
                        return f'N:{N_start}:{N_stop} H:{W_stride_start}:{W_stride_stop}:{W_stride} H:{W_start}:{W_stop} W:{H_start}:{H_stop}'
                else:
                    if params['transpose_4d'][1]: # it should not happen since tranposeB is 0. just is for debug
                        return f'W:{W_stride_start}:{W_stride_stop}:{W_stride} H:{H_start}:{H_stop} N:{N_start}:{N_stop} W:{W_start}:{W_stop}'
                    else:
                        return f'N:{N_start}:{N_stop} W:{W_stride_start}:{W_stride_stop}:{W_stride} H:{H_start}:{H_stop} W:{W_start}:{W_stop}'
        else:
            Byte_start = 0
            Byte_stop  = dims.wgt_subv_bytes
            if params['WgtDataFlow']=="stream":
                Cols_start = 0
                Cols_stop  = Wfactor
                return f'Cols:{Cols_start}:{Cols_stop} Bytes:{Byte_start}:{Byte_stop}'
            else: #full mode
                Cols_start = 0
                Cols_stop  = dims.wgt_subv_cols // (dims.split[3] // Wfactor)
                Rows_start = 0
                Rows_stop  = dims.wgt_subv_rows
                return f'Cols:{Cols_start}:{Cols_stop} Rows:{Rows_start}:{Rows_stop} Bytes:{Byte_start}:{Byte_stop}'

    def padding_solver(self, H_start, H_stop, pad_val) -> str:
        #temperary solution for uneven k padding handling
        #H_factor must be multpiple of 4
        #H_factor must be divisible by H_stop
        #H_factor must be larger than H_stop - pad_val
        H_factor = min(item for item in range(H_stop, H_start, -4) 
                       if item>0 and H_stop%item==0 and item >= H_stop - pad_val)

        if H_factor != H_stop:
            H_str = f"H:{H_start}:{H_stop}:{H_factor} H:{H_start}:{H_factor}"
        else:
            H_str = f"H:{H_start}:{H_stop}"

        return H_str

    def wgt_memtile_mm2s(self, params, row, col, itr_idx) -> str:
        #M1N32 col_aie_m=1, col_aie_n=4 len(params['wgt_idx'][1])==8
        #M4N8  col_aie_m=4, col_aie_n=1 len(params['wgt_idx'][1])==8
        #M8N4  col_aie_m=1, col_aie_n=4 len(params['wgt_idx'][1])==4
        #M32N1 col_aie_m=4, col_aie_n=1 len(params['wgt_idx'][1])==4 
        B_itr, M_itr, N_itr, K_itr         = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        dims                        = params['dims']
        last_k                      = self.is_last(dims.Tk, K_itr, dims.Kpad_wgt)
        last_n                      = self.is_last(dims.Tn, N_itr, dims.Npad_wgt)
        col_aie_b                    = self.check_dimension(dims, 0, 0)
        col_aie_n                   = self.check_dimension(dims, 0, 2)
        Wpadding                    = dims.split[3]//len(params['wgt_idx'][0])                     #M1N32, M4N8, M8N4, M32N1 = [4, 1, 1, 0]
        Wfactor                     = int(np.ceil(dims.split[3]/len(params['wgt_idx'][0])))        #M1N32, M4N8, M8N4, M32N1 = [4, 1, 1, 1]
        if params['actxact']:
            if [params['WgtDataFlow'], params['WgtFormat']] == ["stream",    "default"]:
                N_start             = row * (col_aie_b > 1)
                N_stop              = N_start + 1

                H_start             = 0
                H_stop              = dims.K_subv
                H_str               = self.padding_solver(H_start, H_stop, dims.Kpad_wgt)

                W_start             = 0
                W_stride_start      = (row) * (Wfactor > 1) * dims.N_subv
                W_stride_stop       = W_stride_start + dims.N_subv
                if params['transposeB']:
                    return f"N:{N_start}:{N_stop} W:{H_start}:{H_stop}:{params['min_cols']} H:{W_stride_start}:{W_stride_stop} W:{W_start}:{params['min_cols']}"
                else:
                    return f"N:{N_start}:{N_stop} W:{W_stride_start}:{W_stride_stop}:{params['min_cols']} {H_str} W:{W_start}:{params['min_cols']}"
                
            if [params['WgtDataFlow'], params['WgtFormat']] == ["full",    "default"]:
                N_start             = row * (col_aie_b > 1)
                N_stop              = N_start + 1

                H_stride            = dims.K_subv
                H_stride_start      = H_stride * K_itr * last_k
                H_stride_stop       = H_stride_start + k_width * H_stride
                H_start             = (dims.K_ifmB + dims.Kpad_wgt) - H_stride if last_k else 0
                H_stop              = (dims.K_ifmB + dims.Kpad_wgt) if last_k else dims.K_subv

                W_stride            = dims.N_subv * Wfactor
                W_stride_start      = W_stride * N_itr * last_n
                W_stride_stop       = W_stride_start + n_width * W_stride
                W_start             = (row) * dims.N_subv * (Wfactor > 1)
                W_stop              = W_start + dims.N_subv

                if params['transposeB']:
                    return f"W:{H_stride_start}:{H_stride_stop}:{H_stride} " + \
                           f"N:{N_start}:{N_stop} " + \
                           f"H:{W_stride_start}:{W_stride_stop}:{W_stride} " * (not last_k) + \
                           f"W:{H_start}:{H_stop}:{params['min_cols']} " + \
                           f"H:{W_start}:{W_stop} " + \
                           f"W:0:{params['min_cols']}"
                else:
                    return f"W:{W_stride_start}:{W_stride_stop}:{W_stride} " + \
                           f"N:{N_start}:{N_stop} " + \
                           f"H:{H_stride_start}:{H_stride_stop}:{H_stride} " * (not last_k) + \
                           f"W:{W_start}:{W_stop}:{params['min_cols']} " + \
                           f"H:{H_start}:{H_stop} " + \
                           f"W:0:{params['min_cols']}"
        else: #actxwgt
            Byte_start = 0
            Byte_stop  = dims.wgt_subv_bytes
            if params['WgtDataFlow']=="stream":
                Cols_start          = (row) * (Wfactor > 1)
                Cols_stop           = (row) * (Wfactor > 1) + 1
                return f'Cols:{Cols_start}:{Cols_stop} Bytes:{Byte_start}:{Byte_stop}'
            else:
                Cols_stride         = Wfactor
                Cols_stride_start   = (row) * (Wfactor > 1)
                Cols_stride_stop    = dims.wgt_subv_cols // (dims.split[3] // Wfactor)
                Rows_start          = 0
                Rows_stop           = dims.wgt_subv_rows
                return f'Cols:{Cols_stride_start}:{Cols_stride_stop}:{Cols_stride} ' + \
                       f'Rows:{Rows_start}:{Rows_stop} ' + \
                       f'Bytes:{Byte_start}:{Byte_stop}'

    def out_shim_memory(self, params) -> str:
        dims = params['dims']
        if params['OfmFormat'] == "default":
            if params['transpose_4d'][2]:
                perm, C, N = self.extract_CN_from_perm(params, 'Y')
                mem_default_order = np.array([f'C:{C}', f'N:{N}', f'H:{dims.M-dims.Munpad}', f'W:{dims.N - dims.Nunpad}'])
                rev_permY = np.empty(len(perm), dtype=int)
                rev_permY[list(perm)] = np.arange(len(perm))
            else:
                mem_default_order = np.array([f"N:{dims.Bpad_Y}", f'H:{dims.M-dims.Munpad}', f'W:{dims.N - dims.Nunpad}'])
                rev_permY = params['permY']
            return ' '.join(mem_default_order[rev_permY].tolist())
        if params['OfmFormat'] == "BFP_A8":
            return f"W:{dims.N} H:{dims.M+dims.Munpad} W:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8": #same as default
            return f'H:{dims.M+dims.Munpad} W:{dims.N}'

    def out_shim_mm2s_batch(self, params, B_start: int, B_len: int, B_stop: int, B_stride_start: int, B_stride_stop: int, B_stride: int) -> tuple[str, str]:
        dims = params['dims']
        if params['transpose_4d'][2]:
            _, C, N = self.extract_CN_from_perm(params, 'A')
            if dims.split[0] == 32:
                if dims.split[0]//4 >= N:
                    bmm_str1                = f'C:{B_stride_start//N}:{B_stride_stop//N}:{B_stride//N}'
                    B_offset                = 2 if B_start%N+B_len>N else 1
                    bmm_str0                = f'C:{B_start//N}:{B_start//N+B_offset} N:{B_start%N}:{min(B_start%N+B_len,N)}'   #B_start 0,4,8,12,16
                else: #dims.split[0]//8 <= N:
                    bmm_str1                = f'C:{B_stride_start//N}:{B_stride_stop//N}:{32//8} N:{0}:{N}:{N}'
                    bmm_str0                = f'N:{B_start%4}:{B_stop%4}'
            elif dims.split[0] == 4:
                if dims.split[0] > N:
                    bmm_str1                = f'C:{B_stride_start//N}:{B_stride_stop//N}:{B_stride//N}'
                    bmm_str0                = f'C:{B_start//N}:{B_stop//N} N:0:{N}'
                else: #dims.split[0] <= N:
                    bmm_str1                = f'C:{B_stride_start//N}:{B_stride_stop//N}:{max(B_stride//N, 1)} N:{0}:{N}:{dims.split[0]}'
                    bmm_str0                = f'N:{B_start%dims.split[0]}:{B_start%N+B_len}'
            else:
                bmm_str0                    = f'N:{B_start}:{B_stop}'
                bmm_str1                    = f'N:{B_stride_start}:{B_stride_stop//N}:{B_stride} N:0:{N}:1'
        else:
            bmm_str0                    = f'N:{B_start}:{B_stop}'
            bmm_str1                    = f'N:{B_stride_start}:{B_stride_stop}:{B_stride}'
        return bmm_str1, bmm_str0
    
    def out_shim_s2mm(self, params, col, itr_idx) -> str:
        #M1N32 col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==4
        #M4N8  col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==4
        #M8N4  col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==8
        #M32N1 col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==8
        dims                        = params['dims']
        batch_ceil                  = math.ceil(dims.Bpad_Y / dims.split[0])
        B_itr, M_itr, N_itr, K_itr           = params['tiling_seq'][itr_idx] 
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        last_b                      = self.is_last(batch_ceil, B_itr, batch_ceil * dims.split[0] - dims.Bpad_Y)
        last_m                      = self.is_last(dims.Tm, M_itr, dims.Munpad)
        last_n                      = self.is_last(dims.Tn, N_itr, dims.Nunpad)
        col_aie_b                   = self.check_dimension(dims, 0, 0)
        col_aie_m                   = self.check_dimension(dims, 0, 1)
        col_aie_n                   = self.check_dimension(dims, 0, 2)
        row_aie_b                   = dims.aie_rows * dims.aie_cols // dims.split[0]
        active_col                  = True if params['ofm_idx'] is params['broadcast_idx'] else col in params['active_col']
        Bpadding                    = False if dims.split[0] in [1, 4] else True                   #B32, B8, B4, B1 = [True, True, False, False]
        Bfactor                     = int(np.ceil(dims.split[0]/len(params['wgt_idx'][1])))        #B32, B8, B4, B1 = [4, ?, 1, 1]

        B_stride_start              = min(B_itr * dims.split[0], dims.Bpad_Y)
        B_stride_stop               = B_stride_start + b_width * dims.split[0]
        B_stride                    = dims.split[0]
        
        B_start                     = col * Bfactor * Bpadding
        B_len                       = self.generate_padding_val(col_aie_b, last_b*(batch_ceil *dims.split[0] - dims.Bpad_Y), params['ofm_idx'][1], col, Bpadding)
        B_stop                      = B_start + B_len

        bmm_str1, bmm_str0          = self.out_shim_mm2s_batch(params, B_start, B_len, B_stop, B_stride_start, B_stride_stop, B_stride)

        if params['OfmFormat'] == "default":
            H_stride                = dims.M_subv * dims.split[1]
            H_stride_start          = M_itr * H_stride
            H_stride_stop           = H_stride_start + m_width * H_stride
            H_start                 = dims.M_subv * col * col_aie_m * (params['ofm_channel'] == 'unicast')
            H_stop                  = H_start + self.generate_padding_val(dims.M_subv*col_aie_m, dims.Munpad*last_m, 
                                            params['ofm_idx'][1], col, params['ofm_channel'] == 'unicast')

            W_stride                = dims.N_subv * dims.split[3]
            W_stride_start          = N_itr * W_stride
            W_stride_stop           = W_stride_start + n_width * W_stride
            W_start                 = dims.N_subv * col * col_aie_n * (params['ofm_channel'] == 'broadcast')
            W_stop                  = W_start + self.generate_padding_val(dims.N_subv*col_aie_n, dims.Nunpad*last_n, 
                                            params['ofm_idx'][1], col, params['ofm_channel'] == 'broadcast')

            if dims.split[0]>1:
                if params['transpose_4d'][2]:
                    return                  f'{bmm_str1} ' + \
                                            f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                                            f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' + \
                                            f'H:{H_start}:{H_stop} '+ \
                                            f'{bmm_str0} ' +\
                                            f'W:{W_start}:{W_stop}'
                else:
                    return                  f'{bmm_str1} ' + \
                                            f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                                            f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' + \
                                            f'{bmm_str0} ' +\
                                            f'H:{H_start}:{H_stop} '+ \
                                            f'W:{W_start}:{W_stop}'
            else:
                if params['transpose_4d'][2]:
                    Warning('Transpose 4D with Ofm stream and bsplit==1 may not work as expected!')
                    return                  f'{bmm_str1} ' + \
                                            f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                                            f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' + \
                                            f'H:{H_start}:{H_stop} '+ \
                                            f'{bmm_str0} ' + \
                                            f'W:{W_start}:{W_stop}'
                else:
                    return                  f'{bmm_str1} ' + \
                                            f'{bmm_str0} ' + \
                                            f'H:{H_stride_start}:{H_stride_stop}:{H_stride} ' + \
                                            f'W:{W_stride_start}:{W_stride_stop}:{W_stride} ' + \
                                            f'H:{H_start}:{H_stop} '+ \
                                            f'W:{W_start}:{W_stop}'
        if params['OfmFormat'] == "BFP_A8":
            H_stride                = dims.M_subv* dims.aie_rows
            H_start                 = col * H_stride
            H_stop                  = H_start + H_stride
            W_stride                = dims.N_subv
            W_start                 = col * W_stride
            W_stop                  = W_start + W_stride
            W_start                 = W_start if active_col else W_stop
            return                  f"H:0:{dims.M}:{H_stride} " + \
                                    f"W:0:{dims.N}:{dims.N_subv*dims.aie_cols} " + \
                                    f"W:{W_start}:{W_stop}:{params['min_cols']} " + \
                                    f"H:0:{H_stride} W:0:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8": #same as default
            H_start                 = 0
            H_stop                  = H_start + (dims.M_subv * dims.aie_rows)
            H_stride                = dims.M_subv * dims.aie_rows
            W_start                 = self.col_index(dims, 0, col) * dims.N_subv
            W_stop                  = W_start + dims.N_subv
            W_start                 = W_start if active_col else W_stop
            W_stride                = dims.N_subv * min(dims.aie_cols, len(params['active_col'])) * dims.aie_arrays
            return                  f"H:0:{dims.M+dims.Munpad}:{H_stride} " + \
                                    f"W:0:{dims.N}:{W_stride} " + \
                                    f"H:{H_start}:{H_stop} " + \
                                    f"W:{W_start}:{W_stop}"
        
    def out_memtile_memory(self, params) -> str:
        #M1N32 col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==4
        #M4N8  col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==4
        #M8N4  col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==8
        #M32N1 col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==8
        dims = params['dims']
        col_aie_b = self.check_dimension(dims, 0, 0)
        col_aie_m = self.check_dimension(dims, 0, 1)
        col_aie_n = self.check_dimension(dims, 0, 2)
        if params['OfmFormat'] == "default":
            return f'N:{dims.B_subv*col_aie_b} H:{dims.M_subv * col_aie_m} W:{dims.N_subv * col_aie_n}'
        if params['OfmFormat'] == "BFP_A8":
            return f"H:{dims.M_subv * dims.aie_rows} W:{dims.N_subv} H:{dims.M_subv} W:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8": #same as default
            return f"H:{dims.M_subv * dims.aie_rows} W:{dims.N_subv}"

    def out_memtile_s2mm(self, params, row: int, col: int, itr_idx: int) -> str:
        B_itr, M_itr, N_itr, K_itr         = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        dims = params['dims']
        col_aie_b = self.check_dimension(dims, col, 0)
        col_aie_m = self.check_dimension(dims, col, 1)
        col_aie_n = self.check_dimension(dims, col, 2)
        if params['OfmFormat'] == "default":
            #M1N32 col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==4
            #M4N8  col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==4
            #M8N4  col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==8
            #M32N1 col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==8
            N_start        = (col_aie_b > 1) * (dims.B_subv * row)
            N_stop         = N_start + dims.B_subv

            H_start        = (col_aie_m > 1) * (dims.M_subv * row)
            H_stop         = H_start + dims.M_subv

            W_start        = 0
            W_stop         = params['min_cols']
            W_stride       = params['min_cols']
            W_stride_start = (col_aie_n > 1) * (dims.N_subv * row)
            W_stride_stop  = W_stride_start + dims.N_subv
            return f"N:{N_start}:{N_stop} " + \
                   f"W:{W_stride_start}:{W_stride_stop}:{W_stride} " + \
                   f"H:{H_start}:{H_stop} " + \
                   f"W:{W_start}:{W_stop}"
        if params['OfmFormat'] == "BFP_A8":
            H_start = row * dims.M_subv
            H_stop = H_start + dims.M_subv
            return f"W:0:{dims.N_subv}:{params['min_cols']} H:{H_start}:{H_stop} W:{W_start}:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8": 
            H_start = row * dims.M_subv
            H_stop = H_start + dims.M_subv
            return f"W:0:{dims.N_subv}:{params['min_cols']} H:{H_start}:{H_stop} W:{W_start}:{params['min_cols']}"
        
    def out_memtile_mm2s(self, params, col, itr_idx) -> str:
        B_itr, M_itr, N_itr, K_itr         = params['tiling_seq'][itr_idx]
        b_width, m_width, n_width, k_width   = params['tiling_seq_dim'][itr_idx]
        dims        = params['dims']
        batch_ceil  = math.ceil(dims.Bpad_Y / dims.split[0])
        last_b      = self.is_last(batch_ceil, B_itr, batch_ceil * dims.split[0] - dims.Bpad_Y)
        col_aie_b   = self.check_dimension(dims, col, 0)
        col_aie_m   = self.check_dimension(dims, col, 1)
        col_aie_n   = self.check_dimension(dims, col, 2)
        last_m      = self.is_last(dims.Tm, M_itr, dims.Munpad)
        last_n      = self.is_last(dims.Tn, N_itr, dims.Nunpad)
        Bpadding    = False if dims.split[0]==4 else True                     #B32, B8, B4, B1 = [True, True, False, True]
        if params['OfmFormat'] == "default":
            #M1N32 col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==4
            #M4N8  col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==4
            #M8N4  col_aie_m=1, col_aie_n=4 len(params['ifm_idx'][1])==8
            #M32N1 col_aie_m=4, col_aie_n=1 len(params['ifm_idx'][1])==8
            N_start  = 0
            N_stop   = self.generate_padding_val(col_aie_b, last_b*(batch_ceil *dims.split[0] - dims.Bpad_Y), params['ofm_idx'][1], col, Bpadding)

            H_stride = dims.M_subv * col_aie_m
            H_start  = 0
            H_stop = H_start + self.generate_padding_val(H_stride, dims.Munpad*last_m, 
                                            params['ofm_idx'][1], col, params['ofm_channel'] == 'unicast')

            W_stride = dims.N_subv * col_aie_n
            W_start  = 0
            W_stop = W_start + self.generate_padding_val(W_stride, dims.Nunpad*last_n, 
                                            params['ofm_idx'][1], col, params['ofm_channel'] == 'broadcast')
            
            if params['transpose_4d'][2]:
                return f'H:{H_start}:{H_stop} N:{N_start}:{N_stop} W:{W_start}:{W_stop}'
            else:
                return f'N:{N_start}:{N_stop} H:{H_start}:{H_stop} W:{W_start}:{W_stop}'
        if params['OfmFormat'] == "BFP_A8":
            W_start = 0
            return f"W:0:{dims.N_subv}:{params['min_cols']} H:0:{dims.M_subv*dims.aie_rows} W:{W_start}:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8":
            W_start = 0
            return f'H:0:{dims.M_subv * dims.aie_rows} W:{W_start}:{dims.N_subv}'

class params_funcs(tiling_funcs):
    def set_spatialsplit_params(self, params):
        self.set_kernelmode_prm(params)
        self.set_in_ch_mode(params)

    def set_kernelmode_prm(self, params):
        params['kernelmode'] = 1

    def set_in_ch_mode(self, params):
        params['in_ch_mode'] = 0

    #repeat params
    def decode_data_format(self, dtype):
        if dtype == "bfp16":
            return "BFP_A8"
        elif dtype == "bf16":
            return "BF16_A8"
        else:
            return "default"

    def extend_rval(self, n, dnom):
        return [int(round(x * n / sum(dnom))) for x in dnom]

    def update_buffer_alloc_to_params(self, params, buff):
        buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
        split = [int(x) for x in re.findall(r'\d+', str(buff_prm['mode']))]
        split = [1] + split if len(split)==3 else split
        data = (split,
                buff_prm['B'],
                buff_prm['M']+sum(buff_prm['padding'][0]['pad_ifm_y']), 
                buff_prm['K'],
                buff_prm['N']+sum(buff_prm['padding'][1]['pad_wgt_z']),
                1,  # B_subV
                buff_prm['M_subV'], buff_prm['K_subV'], buff_prm['N_subV'],
                buff_prm['aie_rows'], buff_prm['aie_cols'], buff_prm['aie_arrays'],
                buff_prm['ifm_bits'], buff_prm['ofm_bits'], buff_prm['qdq_bytes'],
                buff_prm['outer_loop'], buff_prm['inner_loop'], buff_prm['acc_loop'],
                buff_prm['wgt_subv_rows'], buff_prm['wgt_subv_cols'],
                buff_prm['ifm_core_subv_bytes'], buff_prm['wgt_core_subv_bytes'], 
                buff_prm['ofm_core_subv_bytes'],
                buff_prm['sum_core_subv_bytes'], buff_prm['tdm_core_subv_bytes'],
                buff_prm['wgt_bits'] if buff_prm['actxact'] else const.BITS_PER_BYTE,
                const.BITS_PER_BYTE, buff_prm['bias_bits'],
                buff_prm['tdm_bits'], 
                sum(buff_prm['padding'][0]['pad_ifm_y']), sum(buff_prm['padding'][0]['pad_ifm_z']),
                sum(buff_prm['padding'][1]['pad_wgt_y']), sum(buff_prm['padding'][1]['pad_wgt_z']),
                sum(buff_prm['padding'][2]['pad_ofm_y']), sum(buff_prm['padding'][2]['pad_ofm_z']),
                buff_prm['K_ifmB'], split[0],
                math.ceil(buff_prm['B'] / split[0]),
                buff_prm['Bpad_A'], buff_prm['Bpad_B'], buff_prm['Bpad_Y'])
        params['dims']                 = GemmDims(*data)  # pylint: disable=too-many-function-args
        params['ParamSize']            = const.PARAM_SIZE
        params['CoreBankSize']         = const.CORE_BANK_SIZE
        params['ShimOutBufferIdx']     = 0
        params['ShimActBufferIdx']     = 1
        params['ShimWgtBufferIdx']     = 2
        params['ShimPrmBufferIdx']     = 3
        params.update(buff[const.BufAllocator_Idx.CORE_TILE_ADDR_IDX.value])
        params.update(buff[const.BufAllocator_Idx.CORE_TILE_SIZE_IDX.value])
        params.update(buff[const.BufAllocator_Idx.MEM_TILE_ADDR_IDX.value])
        params['sync_strategy']        = 'SyncStrategy.Parallel_1_to_N'
        params['ping_pong_enable']     = buff_prm['sch_attr'].ping_pong_enable
        params['ActDataFlow']          = buff_prm['sch_attr'].dataflow_mode['ifm']
        params['WgtDataFlow']          = buff_prm['sch_attr'].dataflow_mode['wgt']
        params['actxact']              = buff_prm['actxact']
        params['SHARE_CH_MODE']        = (params['actxact']==False and      #actxwgt only
                                          params['dims'].Tm==1 and          #Tm = 1 only for gemmv application 
                                          params['dims'].split[1] == 1 and  #only M1 split, M4 each column shares the same wgt, can't split the wgts
                                          params['ActDataFlow']=="pin" and  #ifm streaming mode will hold the dma channel active and prevent ch sharing
                                          not (params['dims'].wgt_subv_bytes >= 4096 and params['dims'].Tn > 1)) #reengueue will run out of bds        
        #only enable pining pingong when there is no padding and additional buffer fits in memtile
        #if ((params['ActDataFlow'] == 'pin') and #pin mode
        #   (params['MemtileOfmPingAddr'] + buff_prm['ofm_mem_tile_size'] + buff_prm['ifm_mem_tile_size'] < 512*1024) and #ping pong must fit in memtile
        #   (params['dims'].Tm==1 or params['dims'].Mpad_ifm==0) and #no padding
        #   (params['dims'].Npad_wgt==0 and params['dims'].Nunpad==0) and #no padding
        #   (params['dims'].Tk==1 or (params['dims'].Kpad_wgt==0 and params['dims'].Kpad_ifm==0)) and #no padding 
        #   (params['dims'].Tn>1) and #pin reuse ratio must be larger than 1
        #   (not params['SHARE_CH_MODE'])):
        #    params['ActPingPong']          = True
        #    params['MemtileIfmPongAddr']   = params['MemtileOfmPingAddr'] + buff_prm['ofm_mem_tile_size']
        #    params['task_queue_optimization'] = params['dims'].split[1] > params['dims'].split[3]
        #else:
        #    params['ActPingPong']          = buff_prm['sch_attr'].ping_pong_enable['ifm']
        #    params['task_queue_optimization'] = True
        params['task_queue_optimization'] = True
        params['ActPingPong']          = buff_prm['sch_attr'].ping_pong_enable['ifm']
        if ((params['ActDataFlow'] == 'pin') and (params['ActPingPong'])):
            params['task_queue_optimization'] = params['dims'].split[1] > params['dims'].split[3]
            if ((params['dims'].Tm==1 or params['dims'].Mpad_ifm==0) and #no padding
                (params['dims'].Npad_wgt==0 and params['dims'].Nunpad==0) and #no padding
                (params['dims'].Tk==1 or (params['dims'].Kpad_wgt==0 and params['dims'].Kpad_ifm==0)) and #no padding 
                (not params['SHARE_CH_MODE'])):
                raise("Cannot enable activation pin pingpong, mode unsupported in scheduler")
        params['WgtPingPong']          = buff_prm['sch_attr'].ping_pong_enable['wgt']
        params['OfmPingPong']          = buff_prm['sch_attr'].ping_pong_enable['ofm']
        params['ShimIfmSize']          = buff_prm['ifm_shim_tile_size']
        params['ShimWgtSize']          = buff_prm['wgt_shim_tile_size']
        params['ShimOfmSize']          = buff_prm['ofm_shim_tile_size']
        params['MemtileActSize']       = buff_prm['ifm_mem_tile_size']
        params['MemtileWgtSize']       = buff_prm['wgt_mem_tile_size']
        params['MemtileOutSize']       = buff_prm['ofm_mem_tile_size']
        params['ActDtype']             = buff_prm['ifm'].dtype
        params['WgtDtype']             = buff_prm['wgt'].dtype
        params['OfmDtype']             = buff_prm['ofm'].dtype
        params['Actbits']              = int(buff_prm['ifm'].bytes * const.BITS_PER_BYTE)
        params['Wgtbits']              = int(buff_prm['wgt'].bytes * const.BITS_PER_BYTE)
        params['ActFormat']            = self.decode_data_format(buff_prm['ifm'].dtype)
        params['WgtFormat']            = self.decode_data_format(buff_prm['wgt'].dtype)
        params['OfmFormat']            = self.decode_data_format(buff_prm['ofm'].dtype)
        params['min_cols']             = 8
        params['min_rows']             = 8
        params['op_name']              = buff_prm['orig_op_type']
        params['op_ver']               = buff_prm['op_type']
        params['op_mode']              = buff_prm['op_mode']
        params['sign_act']             = 0 if 'uint' in buff_prm['op_type'].split('_')[-1].split('x')[0] else 1
        params['sign_wgt']             = 0 if 'uint' in buff_prm['op_type'].split('_')[-1].split('x')[1] else 1
        params['sign_out']             = 0 if 'uint' in buff_prm['op_type'].split('_')[-1].split('x')[2] else 1
        params['is_pwla_fused']        = buff_prm['is_pwla_fused']
        params['is_rope_fused']        = buff_prm['is_rope_fused']
        params['is_fused_rope_actxact']= buff_prm['is_fused_rope_actxact']
        params['rope_total_size']      = buff_prm['M'] * buff_prm['N'] * buff_prm['B'] * (buff_prm['ofm_bits'] // 8) if  params['is_rope_fused'] else 0
        params['rope_shim_offset']     = buff_prm['M'] * buff_prm['K'] * buff_prm['B'] * int(buff_prm['B_split']) * (buff_prm['ifm_bits'] // 8) \
                                            if params['is_fused_rope_actxact'] else buff_prm['dram_sizes']['wgt'] if  params['is_rope_fused'] else 0 #TODO: Check with kyle if this is ok
        params['rope_qdq_size']        = params['CoreQdqSize'] // 2 if  params['is_rope_fused'] else 0
        params['ShimRoPESize']         = params['ShimOfmSize']
        params['is_elew_fused']        = buff_prm['is_elew_fused']
        params['elew_total_size']      = buff_prm['M'] * buff_prm['N'] * buff_prm['B'] * (buff_prm['ofm_bits'] // 8) if  params['is_elew_fused'] else 0
        params['elew_ifmB_shim_offset']= buff_prm['M'] * buff_prm['K'] * buff_prm['B'] * int(buff_prm['B_split']) * (buff_prm['ifm_bits'] // 8) \
                                            if  params['is_elew_fused'] else 0 #TODO: Check with karthik/xiaohan if this is ok
        params['elew_qdq_size']        = params['CoreQdqSize'] // 2 if  params['is_elew_fused'] else 0
        params['ShimElewSize']         = params['ShimOfmSize']
        params['param_ptr']            = {'orig_op_type': buff_prm['orig_op_type'],
                                          'op_type': buff_prm['op_type']}
        params['template_meta_data']   = {'op_type': buff_prm['orig_op_type'],
                                          'mode': buff_prm['mode'],
                                          'overlay': buff_prm['overlay']}
        params['qdq_offset']           = 0 if params['actxact'] else buff_prm['wgt_core_subv_bytes']*buff_prm['wgt_subv_rows']*buff_prm['wgt_subv_cols']*params['dims'].Bpad_A
        params['active_col']           = buff_prm.get('active_col', buff_prm['aie_cols'])
        params['transposeA']           = int(buff_prm['transposeA']) 
        params['transposeB']           = int(buff_prm['transposeB']) 
        params['rev_permA']            = list(map(int, buff_prm['rev_permA']))
        params['rev_permB']            = list(map(int, buff_prm['rev_permB']))
        params['permY']                = list(map(int, buff_prm['permY']))
        params['Bpadding']             = 'B' in buff_prm['mode'] and (params['dims'].B % split[0]) != 0 
        params['Bpad_A']               = params['dims'].Bpad_A
        params['Bpad_B']               = params['dims'].Bpad_B
        params['Bpad_Y']               = params['dims'].Bpad_Y
        params['wgt_shim_offset']      = buff_prm['M'] * buff_prm['K'] *  params['Bpad_A'] * (buff_prm['ifm_bits'] // 8) if params['actxact'] else 0
        params['CoreTdmAddr'][1]       = params['CoreTdmAddr'][0] if params['CoreTdmAddr'][1] == None else params['CoreTdmAddr'][1]
        params['transpose_4d']         = buff_prm['transpose_4d']
        params['info_4d']              = buff_prm['info_4d']
        if params['ActDataFlow']=="stream" and params['WgtDataFlow']=="stream":
            params['sched'] = 5
        elif params['ActDataFlow']=="pin" and params['WgtDataFlow']=="stream":
            params['sched'] = 2
        else:
            params['sched'] = 1
        assert params['ActDataFlow']=="pin" or "stream" or "full", "Activation mode: pin, Stream"
        assert params['WgtDataFlow']=="pin" or "stream" or "full", "Weight mode: pin, Stream"

    def calc_rep_params(self, params):
        def tilling_itr(params, tiling_exp_rp):
            params['ifm_tiling_iters']     = tiling_exp_rp
            params['wgt_tiling_iters']     = tiling_exp_rp
            params['out_tiling_iters']     = tiling_exp_rp

        def shim_reenqueue(x1, x2, s=1):
            """repeat x1 for x2 times with s as repeat factor"""
            assert s>0, "ShimWgtRepeat should be called only when s > 0"
            x1_len = len(x1) if isinstance(x1, list) else 1
            #NOTE: Added to handle corner cases, need review on this logic.
            s = 1 if s < 1 else s
            ShimWgtRepeat_list = []
            for x in range(x2):
                ShimWgtRepeat_list.append([0]*x2*s*x1_len)
                if x1_len > 1:
                    for y in range(x1_len):
                        ShimWgtRepeat_list[-1][x*x1_len*s+y] = x1[y]
                else:
                    if isinstance(x1, list):
                        ShimWgtRepeat_list[-1][x*s] = x1[0]
                    else:
                        ShimWgtRepeat_list[-1][x*s] = x1
            return ShimWgtRepeat_list

        class rr_val():
            mem_ifm: int; shim_ifm: int; shim_ifm_ol: int; 
            shim_wgt: int; wgt_rcl: int; e_idx: int = None

            def assign_params(self):
                params['MemtileActReuseRatio'] = self.mem_ifm
                params['ShimActOuterLoop']     = self.shim_ifm_ol
                params['ShimActReuseRatio']    = self.shim_ifm

                params['MemtileWgtReuseRatio'] = self.mem_wgt
                params['ShimWgtReuseRatio']    = self.shim_wgt

                params['act_reuse_chain_length'] = self.act_rcl
                params['wgt_reuse_chain_length'] = self.wgt_rcl

        class r_val():
            mem_ifm_pi: int; mem_ifm_po: int; mem_wgt: int; mem_ofm: int; shim_ifm:int; shim_wgt:int; shim_ofm:int
            def __init__(self):  
                self.mem_ifm_pi = []
                self.shim_ifm   = []
                self.shim_wgt   = []
                self.mem_wgt    = []
                self.mem_ofm    = []
                self.shim_ofm   = []

            def maxlen(self):
                try:
                    max_len = max([np.array(x).shape[-1] for x in vars(self).values() if x!=None])
                except ValueError:
                    for x in vars(self).values():
                        if x!=None:
                            print(f"{x}")
                    max_len = 0
                return max_len

            def pad(self): #pad 0s
                maxlen = self.maxlen()
                for key, val in vars(self).items():
                    if val != None:
                        np_val = np.array(val)
                        if np_val.ndim==1 and np_val.size<maxlen:
                            vars(self)[key] = val+[0]*(maxlen-len(val))
                        elif np_val.ndim==2 and np_val[0].size<maxlen:
                            vars(self)[key] = [x+[0]*(maxlen-len(x)) for x in val]

            def split_val(self, repeat_val, limit, zeros=0):
                repeat = []
                if repeat_val == None:
                    return [None]
                while True:
                    if repeat_val > limit:
                        repeat.append(limit)
                        repeat+=[0]*zeros
                        repeat_val = repeat_val - limit

                    else:
                        repeat.append(repeat_val)
                        repeat+=[0]*zeros
                        break
                return repeat

            def extend_rval(self, n, dnom, zeros=0):
                out = []
                for x in dnom:
                    out += [int(round(x * n / sum(dnom)))] + [0]*zeros
                return out

            def assign_params(self, params):
                dims = params['dims']
                rope_tensor_fetch = 2 if params['is_rope_fused'] else 0 #for sin and cos tensor 
                elew_tensor_fetch = 1 if params['is_elew_fused'] else 0 #for ifmB
                additional_tensor_fetch = rope_tensor_fetch if params['is_rope_fused'] else elew_tensor_fetch if params['is_elew_fused'] else 0
                maxlen = self.maxlen() + additional_tensor_fetch  
                params['mem_ifm_pi'] = self.mem_ifm_pi

                params['MemtileParamRepeat']   = [1] + [0]*(maxlen-1)
                params['MemtileQdqPrmRepeat']  = [1] + [0]*(maxlen-1)
                params['ShimParamRepeat']      = [1] + [0]*(maxlen-1)
                params['ShimQdqPrmRepeat']     = [1] + [0]*(maxlen-1)
                params['MemtileActPingRepeat'] = memtile_repeat_optimizer(params, 'Act', self.mem_ifm_pi) + [0]*additional_tensor_fetch
                params['MemtileActPongRepeat'] = None if self.mem_ifm_po is None else memtile_repeat_optimizer(params, 'Act', self.mem_ifm_po) + [0]*additional_tensor_fetch
                params['ShimActRepeat']        = [inner_list     + [0]*additional_tensor_fetch for inner_list in self.shim_ifm]

                params['MemtileWgtRepeat']     = memtile_repeat_optimizer(params, 'Wgt', self.mem_wgt)    + [0]*additional_tensor_fetch
                params['ShimWgtRepeat']        = [inner_list     + [0]*additional_tensor_fetch for inner_list in self.shim_wgt]

                params['MemtileOutRepeat']      = memtile_ofm_repeat_optimizer(params,self.mem_ofm)   + [0]*additional_tensor_fetch
                params['ShimOutRepeat']         = self.shim_ofm  + [0]*additional_tensor_fetch

        def memtile_repeat_optimizer(params, pos, val):
            dims = params['dims']
            uniq_seq = set(tuple(sublist) for sublist in np.array(params['tiling_seq'])[:,1:].tolist())
            reuse_ratio_is_one = params.get('Memtile'+pos+'ReuseRatio',1) == 1
            batch_padding_is_zero = dims.Bpad_A%dims.split[0] !=0 or dims.Bpad_B%dims.split[0] !=0 or dims.Bpad_Y%dims.split[0] !=0
            act_pin_pingpong = pos == 'Act' and params['ActDataFlow'] == 'pin' and params['ActPingPong']
            if sum(val) <= MEM_REPEAT_MAX and \
                    len(uniq_seq)==1 and \
                    reuse_ratio_is_one and \
                    not batch_padding_is_zero and \
                    not act_pin_pingpong and \
                    not params['SHARE_CH_MODE']:
                return [sum(val)] + [0]*(len(val)-1)
            return val

        def memtile_ofm_repeat_optimizer(params, val):
            dims = params['dims']
            uniq_seq = set(tuple(sublist) for sublist in np.array(params['tiling_seq'])[:,1:].tolist())
            batch_padding_is_zero = dims.Bpad_A%dims.split[0] !=0 or dims.Bpad_B%dims.split[0] !=0 or dims.Bpad_Y%dims.split[0] !=0
            if sum(val) <= MEM_REPEAT_MAX and len(uniq_seq)==1 and not batch_padding_is_zero:
                return [sum(val)] + [0]*(len(val)-1)
            return val

        def set_ifm_mem_rp(params, b, m, n, k):
            dims = params['dims']
            if params['ActDataFlow'] == 'stream':
                if [b, m, n, k] in params['tiling_seq']:
                    return ".append(1)"
                if [m, n, k] not in params['tiling_seq']: #TODO. check condition. this will never be true
                    return "[-1] += 1"
            else:
                if [b, m, n, k] in params['tiling_seq']:
                    if n==0 and k==0:       #append first tm iteration
                        return ".append(1)"
                    else:                   #append 0 for n and k iteration
                        return ".append(0)"
                elif n==0 and k ==0:        #accumulate repeat for each tm iteration
                    if params['ActPingPong']:
                        return ".req_append(1)"
                    else:
                        return "[-1] += 1"
            return ""
        
        def set_ifm_shim_rp(params, b, m, n, k):
            dims = params['dims']
            if params['ActDataFlow'] == 'stream':
                if [b, m, n, k] in params['tiling_seq']:
                    return ".append(1)"
                if k == 0 and [m, n, k] not in np.array(params['tiling_seq'])[:,1:].tolist():
                    if params['repeat_mode'] == 'shim_reenqueue':
                        return ".req_append(1)"
                    else:
                        return "[-1] += 1"
            else:
                if [b, m, n, k] in params['tiling_seq']:
                    if n==0 and k==0:
                        return ".append(1)"
                    else:
                        return ".append(0)"
            return ""

        def set_wgt_mem_rp(params, b, m, n, k):
            dims = params['dims']
            if params['WgtDataFlow'] == 'stream':
                if [b, m, n, k] in params['tiling_seq']:
                    return ".append(1)"
                else:
                    if params['repeat_mode'] == 'mem_reenqueue':
                        return ".req_acc(1)"
                    else:
                        return "[-1] += 1"
            else:
                if [b, m, n, k] in params['tiling_seq']:
                    if dims.Kpad_wgt > 0:
                        return ".append(1)"
                    elif dims.Npad_wgt > 0 and dims.Tn > 1:
                        return ".append(1)"
                    elif [m, n, k] == [0, 0, 0]: #first iteration
                        return ".append(1)"
                    else:
                        return ".append(0)"
                elif [m,n,k] in np.array(params['tiling_seq'])[:,1:].tolist():
                    return "[-1] += 1"
            return ""
        
        def set_wgt_shim_rp(params, b, m, n, k):
            dims = params['dims']
            if params['WgtDataFlow'] == 'stream':                           #stream
                if [b, m, n, k] in params['tiling_seq']:
                    if k==dims.Tk-1 and dims.Tk > 1 and dims.Kpad_wgt == 0:         #append 0 if kpad == 0 for last k iteration
                        return ".append(0)"
                    else:                                                           #append 1 if kpad > 0
                        return ".append(1)"
                if k==0 and n==0 and [m, n, k] not in np.array(params['tiling_seq'])[:,1:].tolist():
                    return "[-1] += 1"
            else:                                                           #full
                Kpad = dims.Kpad_wgt > 0 or dims.Kpad_ifm > 0
                Npad = dims.Npad_wgt > 0 and dims.Nunpad > 0
                if [b, m, n, k] in params['tiling_seq']:
                    if [m, n, k] == [0, 0, 0]: #first iteration
                        return ".append(1)"
                    elif Npad and not Kpad and m == 0: #Npad only
                        return ".append(1)"
                    elif Npad and dims.Tn > 1:
                        return ".append(1)"
                    else:
                        return ".append(0)"
            return ""
        
        def set_ofm_mem_rp(params, b, m, n, k):
            dims = params['dims']
            if [b, m, n, k] in params['tiling_seq']:
                if k == dims.Tk-1:
                    return ".append(1)"
                else:
                    return ".append(0)"
            if k == dims.Tk-1:
                return "[-1]+= 1"
            return ""
        
        def set_ofm_shim_rp(params, b, m, n, k):
            dims = params['dims']
            if [b, m, n, k] in params['tiling_seq']:
                if k == dims.Tk-1:
                    return ".append(1)"
                else:
                    return ".append('0')"
            if k == dims.Tk-1:
                return "set one"#"[-1] = 1"
            return ""
        
        def repeat_generator(params, r):
            dims = params['dims']
            tiling_exp_rp = []

            #check and set max repeat special flag
            params['repeat_mode'] = 'default'
            if params['sched']==5 and dims.Tn > SHIM_REPEAT_MAX:
                params['repeat_mode'] = 'shim_reenqueue'
            elif params['sched'] in [2, 5] and dims.Tm*dims.Tn*dims.Tk > MEM_REPEAT_MAX:
                params['repeat_mode'] = 'mem_reenqueue'

            #iterate through m n k and calculate repeat count values
            for b, m, n, k in itertools.product(range(dims.B_itr), range(dims.Tm), range(dims.Tn), range(dims.Tk)):
                if [b, m, n, k] in params['tiling_seq']:
                    tiling_exp_rp.append(1)

                #shim ifm max repeat handling
                ifm_shim_str = set_ifm_shim_rp(params, b, m, n, k)
                if ifm_shim_str == ".req_append(1)":
                    for x in vars(r).keys():
                        exec('r.' + x + '.append(0)')
                    exec('r.shim_ifm[-1] += 1')
                    tiling_exp_rp[-1] += 1
                else:
                    exec('r.shim_ifm' + ifm_shim_str)

                #wgt memtile max repeat handling
                wgt_mem_dfactor_bool = False
                wgt_mem_str = set_wgt_mem_rp(params, b, m, n, k)
                if wgt_mem_str == ".req_acc(1)":
                    if dims.Tk*dims.Tn <= MEM_REPEAT_MAX:
                        dfactor        = dims.Tk*dims.Tn
                    else:
                        dfactor        = MEM_REPEAT_MAX//dims.Tk*dims.Tk
                    if r.mem_wgt[-1] >= dfactor:
                        wgt_mem_dfactor_bool = True
                        #append 0 to all ports
                        for x in vars(r).keys():
                            #check if shim ifm op string is '[-1] += 1'
                            if x == 'shim_ifm' and ifm_shim_str == '[-1] += 1':
                                r.shim_ifm[-1] -= 1
                                r.shim_ifm.append(1)
                            else:
                                exec('r.' + x + '.append(0)')
                        exec('r.mem_wgt[-1] += 1')
                        tiling_exp_rp[-1] += 1
                    else:
                        exec('r.mem_wgt[-1] += 1')
                else:
                    exec('r.mem_wgt'  + wgt_mem_str)

                ifm_mem_str = set_ifm_mem_rp(params, b, m, n, k)
                if ifm_mem_str == ".req_append(1)":
                    if not wgt_mem_dfactor_bool: #already appended 0 in memtile wgt
                        for x in vars(r).keys():
                            exec('r.' + x + '.append(0)')
                            
                        tiling_exp_rp[-1] += 1
                    exec('r.mem_ifm_pi[-1] += 1')
                else:
                    exec('r.mem_ifm_pi' + ifm_mem_str)
                exec('r.shim_wgt'   + set_wgt_shim_rp(params, b, m, n, k))
                exec('r.mem_ofm'    + set_ofm_mem_rp(params, b, m, n, k))

                #ofm shim: only allow 0 appened from set_ofm_shim_rp to increment
                ofm_str = set_ofm_shim_rp(params, b, m, n, k)
                if ofm_str == "set one":
                    if r.shim_ofm[-1] == '0':
                        r.shim_ofm[-1] = 1
                else:
                    exec('r.shim_ofm' + ofm_str)
                pass

            r.shim_ofm = [int(x) for x in r.shim_ofm]  # convert to int
            if params['ActPingPong'] and params['ActDataFlow'] == 'pin':
                r.mem_ifm_po = [i%2 for i in range(0,len(r.mem_ifm_pi))]
                r.mem_ifm_pi = [i%2 for i in range(1,len(r.mem_ifm_pi)+1)]
            else:
                r.mem_ifm_po = None #no ifm pin pingpong support

            
            #update shim ifm and wgt to include reenqueue streaming support
            if params['sched'] == 1:
                r.shim_ifm     = [r.shim_ifm]
                r.shim_wgt     = [r.shim_wgt]
            elif params['sched'] == 2:
                r.shim_ifm     = [r.shim_ifm]
                r.shim_wgt     = shim_reenqueue(dims.Tm, dims.Tn) if params.get('shim_wgt_reengueue', False) else [r.shim_wgt]
            elif params['sched'] == 5:
                r.shim_ifm     = shim_reenqueue(r.shim_ifm[:(len(r.shim_ifm)//dims.Tm)], dims.Tm)
                r.shim_wgt     = [r.shim_wgt]

            #pad after adding reenqueue streaming
            r.pad()
            return tiling_exp_rp

        def reuse_generator(params, r, rr):
            dims = params['dims']
            if params['sched'] == 1:
                rr.mem_ifm     = dims.Tn
                rr.shim_ifm    = 1
                rr.shim_ifm_ol = 1
                rr.mem_wgt     = 1 if dims.Kpad_ifm>0 and dims.Kpad_wgt>0 and dims.Tk > 1 else dims.Tm
                rr.shim_wgt    = 1
                rr.act_rcl     = self.reuse_chain_length(dims.Tn,  params['dims'].aie_rows, 4 
                                                         if params['ifm_channel']== 'broadcast' else 2)
                rr.wgt_rcl     = self.reuse_chain_length(dims.Tm,  params['dims'].aie_rows, 4 
                                                         if params['wgt_channel']== 'broadcast' else 2)
            elif params['sched'] == 2:
                rr.mem_ifm     = dims.Tn
                rr.shim_ifm    = 1
                rr.shim_ifm_ol = 1
                rr.mem_wgt     = 1
                rr.shim_wgt    = 1
                rr.act_rcl     = self.reuse_chain_length(dims.Tn,  params['dims'].aie_rows, 4 
                                                         if params['ifm_channel']== 'broadcast' else 2)
                rr.wgt_rcl     = dims.Tn
            elif params['sched'] == 5:
                rr.mem_ifm     = 1
                rr.shim_ifm    = 1
                rr.shim_ifm_ol = dims.Tm
                rr.mem_wgt     = 1
                rr.shim_wgt    = 1
                rr.act_rcl     = 1
                rr.wgt_rcl     = 1    

        dims = params['dims']
        params['tiling_seq'], params['tiling_seq_dim'] = \
            tiling_xpr_seq_generator(dims.B_itr, dims.Tm, dims.Tn, dims.Tk, 
                                     dims.Bpad_A%dims.split[0], dims.Bpad_B%dims.split[0], dims.Bpad_Y%dims.split[0],
                                     dims.Mpad_ifm, dims.Kpad_ifm, 
                                     dims.Kpad_wgt, dims.Npad_wgt, 
                                     dims.Munpad  , dims.Nunpad,
                                     params['sched'])
        r  = r_val()
        rr = rr_val()
        tiling_exp_rp = repeat_generator(params, r)
        reuse_generator(params, r, rr)
        rr.assign_params()
        r.assign_params(params)
        tilling_itr(params, tiling_exp_rp)

    def gen_shim_act_params(self, params):
        dims = params['dims']
        if params['ShimActOuterLoop']==1: #either reenqueue
            params['act_shim_memory'] =  [self.act_shim_memory(params)] * len(params['tiling_seq'])
            params['act_shim_mm2s'] = []

            for itr_idx, _ in enumerate(params['tiling_seq']):
                itr_fmt = []
                for col in params['ifm_idx'][1]:
                    col_fmt = []
                    for itr in range(params['ShimActOuterLoop']):
                        col_fmt.append(self.act_shim_mm2s(params, col, itr_idx))
                    itr_fmt.append(col_fmt)
                params['act_shim_mm2s'].append(itr_fmt)
        else: #or bd chaining
            params['act_shim_memory'] =  [self.act_shim_memory(params)] * len(params['tiling_seq'])
            params['act_shim_mm2s'] = []
            #[:,col_idx,itr]

            for itr_idx, _ in enumerate(params['tiling_seq']):
                col_fmts = []
                for col in params['ifm_idx'][1]:
                    itr_fmts = []
                    for itr in range(len(params['tiling_seq'])):
                        itr_fmts.append(self.act_shim_mm2s(params, col, itr_idx))
                    col_fmts.append(itr_fmts)
                params['act_shim_mm2s'].append(col_fmts)

    def gen_mem_act_params(self, params):
        dims = params['dims']
        keys = ['act_memtile_memory', 'act_memtile_s2mm', 'act_memtile_mm2s']
        for key in keys:
            params[key] = []

        for itr_idx, _ in enumerate(params['tiling_seq']):
            params['act_memtile_memory'].append([self.act_memtile_memory(params, col, itr_idx) 
                        for col in params['ifm_idx'][1]])
            params['act_memtile_s2mm'  ].append([self.act_memtile_s2mm(params, col, itr_idx)
                        for col in params['ifm_idx'][1]])
            params['act_memtile_mm2s'  ].append([[self.act_memtile_mm2s(params, col, row, itr_idx)
                        for row in params['row_list']] 
                            for col in params['ifm_idx'][1]])

    def gen_wgt_params(self, params):
        dims = params['dims']
        keys = ['wgt_shim_memory', 'wgt_shim_mm2s', 'wgt_memtile_s2mm', 'wgt_memtile_mm2s', 'wgt_memtile_memory']
        for key in keys:
            params[key] = []
        for itr_idx, _ in enumerate(params['tiling_seq']):
            params['wgt_shim_memory'].append(self.wgt_shim_memory(params, itr_idx))
            params['wgt_shim_mm2s'].append([[self.wgt_shim_mm2s(params, col, itr_idx, K_repeat_itr) 
                        for K_repeat_itr in range(len(params['ShimWgtRepeat']))] 
                            for col in params['wgt_idx'][1]])

            params['wgt_memtile_memory'].append([self.wgt_memtile_memory(params, itr_idx)
                        for col in params['wgt_idx'][1]])
            params['wgt_memtile_s2mm'].append([self.wgt_memtile_s2mm(params, col, itr_idx) 
                        for col in params['wgt_idx'][1]])
            params['wgt_memtile_mm2s'].append([[self.wgt_memtile_mm2s(params, row, col, itr_idx) 
                        for row in params['row_list']] 
                            for col in params['wgt_idx'][1]])

    def gen_out_params(self, params):
        dims = params['dims']
        keys = ['out_memtile_memory', 'out_memtile_s2mm', 'out_memtile_mm2s', 'out_shim_memory', 'out_shim_s2mm']
        for key in keys:
            params[key] = []
        for itr_idx, _ in enumerate(params['tiling_seq']):
            params['out_shim_memory'   ].append(self.out_shim_memory(params))
            params['out_shim_s2mm'     ].append([[self.out_shim_s2mm(params, col, itr_idx)] 
                        for col in params['ofm_idx'][1]])
            params['out_memtile_memory'].append(self.out_memtile_memory(params))
            params['out_memtile_s2mm'  ].append([[self.out_memtile_s2mm(params, row, col, itr_idx) 
                        for row in params['row_list']] for col in params['ofm_idx'][1]])
            params['out_memtile_mm2s'  ].append([self.out_memtile_mm2s(params, col, itr_idx) 
                        for col in params['ofm_idx'][1]])

    def gen_dma_tile_params(self, params):
        dims = params['dims']

        if params.get('CoreQdqAddr') is not None: #qdq
            if params['actxact'] and params['dims'].act_bits == 16:
                mode = params.get('kernelmode', 1)
                # TransposeMode
                #    - 0 = !transposeA && !transposeB
                #    - 1 = transposeA && !transposeB
                #    - 2 = !transposeA && transposeB
                #    - 3 = transposeA && transposeB
                transpose_mode = int(params.get('transposeA', 0)) + (int(params.get('transposeB', 0)) << 1)
                unpack_V = int(params['actxact'] and params['Wgtbits'] == 8)
                #buff_addrs = f"{params['CoreQdqAddr'][0]}, {params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, {params['CoreTdmAddr'][0]}, {params['CoreIfm_sumAddr'][0]}, {params['CoreIfm_sumAddr'][1]}, {params['CoreBufc0Addr'][0]}, {params['CoreScratchAddr'][0]}"
                buff_addrs = f"{params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, {params['CoreQdqAddr'][0]}, {params['CoreIfm_sumAddr'][0]}, {params['CoreIfm_sumAddr'][1]}, {params['CoreBufc0Addr'][0]}, {params['CoreScratchAddr'][0]}"
                params['gemm_params'] = [f"smxv_qdq_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), {mode}, {transpose_mode}, 0, 0, 1, 0, {unpack_V}, {buff_addrs}, {params['sign_act']}, {params['sign_wgt']}, {params['sign_out']})",
                                         f"smxv_qdq_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), {mode}, {transpose_mode}, 0, 0, 0, 0, {unpack_V}, {buff_addrs}, {params['sign_act']}, {params['sign_wgt']}, {params['sign_out']})",
                                         f"smxv_qdq_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), {mode}, {transpose_mode}, 0, 0, 0, 1, {unpack_V}, {buff_addrs}, {params['sign_act']}, {params['sign_wgt']}, {params['sign_out']})",
                                         f"smxv_qdq_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), {mode}, {transpose_mode}, 0, 0, 1, 1, {unpack_V}, {buff_addrs}, {params['sign_act']}, {params['sign_wgt']}, {params['sign_out']})"]
            elif params['actxact'] and params['dims'].act_bits == 8:
                mode = params.get('kernelmode', 0)
                op_mode = 27
                buff_addrs = f"{params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, {params['CoreIfm_sumAddr'][0]}, {params['CoreWgtunpackAddr'][0]}, {params['CoreQdqAddr'][0]}, {params['CoreBufc0Addr'][0]}"
                params['gemm_params'] = [f"gen_a8w8_gemm_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 1, 0, {buff_addrs}, {op_mode}, {mode}, {params['transposeA']}, {params['transposeB']})",
                                         f"gen_a8w8_gemm_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 0, 0, {buff_addrs}, {op_mode}, {mode}, {params['transposeA']}, {params['transposeB']})",
                                         f"gen_a8w8_gemm_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 0, 1, {buff_addrs}, {op_mode}, {mode}, {params['transposeA']}, {params['transposeB']})",
                                         f"gen_a8w8_gemm_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 1, 1, {buff_addrs}, {op_mode}, {mode}, {params['transposeA']}, {params['transposeB']})"]
            elif params['dims'].act_bits == 8:
                mode = params.get('kernelmode', 1)
                op_mode = 11 if params['op_mode'] < 2 else 9
                buff_addrs = f"{params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, {params['CoreIfm_sumAddr'][0]}, {params['CoreWgtunpackAddr'][0]}, {params['CoreQdqAddr'][0]}, {params['CoreBufc0Addr'][0]}"
                params['gemm_params'] = [f"gen_a8w8_gemm_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 1, 0, {buff_addrs}, {op_mode}, {mode}, {params['transposeA']}, 0)",
                                         f"gen_a8w8_gemm_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 0, 0, {buff_addrs}, {op_mode}, {mode}, {params['transposeA']}, 0)",
                                         f"gen_a8w8_gemm_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 0, 1, {buff_addrs}, {op_mode}, {mode}, {params['transposeA']}, 0)",
                                         f"gen_a8w8_gemm_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 1, 1, {buff_addrs}, {op_mode}, {mode}, {params['transposeA']}, 0)"]
            else:
                buff_addrs = f"{params['CoreQdqAddr'][0]}, {params['CoreIfm_sumAddr'][0]}, {params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}"
                # NOTE: Mode0: args_params[1] == 0 => MatA broadcast, MatB unicast
                #       Mode1: args_params[1] == 1 => MatA unicast, MatB broadcast
                mode = params.get('kernelmode', 1)
                op_mode_str = 'OPMode.OP_CONV_ASYM' if params['op_mode'] < 2 else 'OPMode.OP_CONV_SYM'
                if params['Wgtbits']==4:
                    int4_wgt = 1
                    wgt_unpack_addr = params['CoreWgtunpackAddr'][0]
                else:
                    int4_wgt = 0
                    wgt_unpack_addr = 0
                params['gemm_params'] = [f"gemm_qdq_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 1, {mode}, {buff_addrs}, {int4_wgt}, {wgt_unpack_addr}, {op_mode_str}, {params['sign_act']}, {params['sign_wgt']}, {params['sign_out']}, {params['Wgtbits']})",
                                         f"gemm_qdq_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 0, {mode}, {buff_addrs}, {int4_wgt}, {wgt_unpack_addr}, {op_mode_str}, {params['sign_act']}, {params['sign_wgt']}, {params['sign_out']}, {params['Wgtbits']})",
                                         f"gemm_qdq_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 0, {mode}, {buff_addrs}, {int4_wgt}, {wgt_unpack_addr}, {op_mode_str}, {params['sign_act']}, {params['sign_wgt']}, {params['sign_out']}, {params['Wgtbits']})",
                                         f"gemm_qdq_params(({dims.M_subv}, {dims.K_subv}), ({dims.K_subv}, {dims.N_subv}), 1, {mode}, {buff_addrs}, {int4_wgt}, {wgt_unpack_addr}, {op_mode_str}, {params['sign_act']}, {params['sign_wgt']}, {params['sign_out']}, {params['Wgtbits']})",]
        else: #bfp16
            supported_subv_list = [
                    (16, 80, 80, 0),(16, 64, 80, 0),(128, 64, 16, 0),(32, 128, 32, 0),(64, 80, 64, 0),
                    (64, 80, 80, 0),(32, 128, 64, 0),(128, 64, 32, 0),(64, 128, 32, 0),(32, 128, 80, 0),
                    (32, 128, 64, 1),(64, 128, 32, 1),(80, 128, 32, 1),(64, 80, 64, 1),]
        
            transpose_out = 0  #TODO: hardcoded for now
            act_in_bfp16 = params['ActDtype']=='bfp16'
            act_out_bfp  = params['OfmDtype'] =='bfp16'
            gelu_fuse    = 0
            sub_vol_id   = supported_subv_list.index((params['dims'].M_subv, params['dims'].K_subv, params['dims'].N_subv, transpose_out))
            sub_vol_id   = sub_vol_id + 14 if act_out_bfp else sub_vol_id  #TODO: suv_vol_id is offseted by 14 if bfp16 output is enabled
        
            params['gemm_params']          = [f"gemm_params_bfp16(1, 0, {sub_vol_id}, {params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, {transpose_out}, {act_in_bfp16}, {act_out_bfp}, {gelu_fuse}, 1)",
                                              f"gemm_params_bfp16(0, 0, {sub_vol_id}, {params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, {transpose_out}, {act_in_bfp16}, {act_out_bfp}, {gelu_fuse}, 0)",
                                              f"gemm_params_bfp16(0, 1, {sub_vol_id}, {params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, {transpose_out}, {act_in_bfp16}, {act_out_bfp}, {gelu_fuse}, 0)",
                                              f"gemm_params_bfp16(1, 1, {sub_vol_id}, {params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, {transpose_out}, {act_in_bfp16}, {act_out_bfp}, {gelu_fuse}, 1)"]
        self.gen_shim_act_params(params)
        self.gen_mem_act_params( params)
        self.gen_wgt_params(params)
        self.gen_out_params(params)

    def initialize_params(self, buff, buff_prm):
        params = {}
        self.update_buffer_alloc_to_params(params, buff)
        return params

    #main param function
    def gen_dma_params(self, _pipeline_data):
        buff = _pipeline_data.info.get('BuffAllocator')
        buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
        params = self.initialize_params(buff, buff_prm)
        self.overlay_param(params)
        self.set_spatialsplit_params(params)
        self.calc_rep_params(params)
        self.gen_dma_tile_params(params)
        return params

class Gemm_base(params_funcs, BaseTemplate):
    def __init__(self, _data):
        super().__init__()
        self.data = _data

    def gen_instr(self, params):
        logging.info(f"Genrate code for matmul operation")

        #generate dma data flow
        dma_data_flow = self.dma_pattern_code(params)

        #overlay_func_str = f"self.gen_{params['template_meta_data']['overlay']}_overlay()"
        data_flow = ""
        if params['actxact'] and params['dims'].act_bits == 16:
            data_flow += self.gen_headers(params['template_meta_data']['overlay'], 2)
        elif params['actxact'] and params['dims'].act_bits == 8:
            data_flow += self.gen_headers(params['template_meta_data']['overlay'], 3)
        elif params['dims'].act_bits == 8:
            data_flow += self.gen_headers(params['template_meta_data']['overlay'], 3)
        else:
            data_flow += self.gen_headers(params['template_meta_data']['overlay'], 1)
        if params.get('CoreQdqAddr') is not None:
            if params['actxact'] and params['dims'].act_bits == 16:
                data_flow += self.params_instruction_actxact()
            elif params['actxact'] and params['dims'].act_bits == 8:
                data_flow += self.params_instruction_a8w8()
            elif params['dims'].act_bits == 8:
                data_flow += self.params_instruction_a8w8()
            else:
                data_flow += self.params_instruction()
            if  params.get('is_pwla_fused', False):
                data_flow += self.params_instruction_LUTOPs()
            if  params.get('is_rope_fused', False):
                data_flow += self.params_instruction_RoPE()
            if  params.get('is_elew_fused', False):
                data_flow += self.params_instruction_elew()
        else:
            data_flow += self.params_instruction_bfp()
        data_flow += self.gen_helper_func_instr('pack_transfers' in dma_data_flow, 'generate_packed_shim_data_transfer' in dma_data_flow)
        data_flow += self.gen_dataflow()
        data_flow += dma_data_flow
        data_flow += self.gen_kernel_name(params['op_name'], params['op_ver'])
        data_flow += self.gen_connection_run_compile(params['template_meta_data']['overlay'],params['dims'].aie_cols, params['dims'].aie_rows, 
                                                     params['CoreStackAddr'][0], params['param_channel_id'], params['task_queue_optimization'])
        data_flow += self.gen_main_func()
        
        return data_flow

    def params_instruction(self):
        return f"""
def gemm_qdq_params(
    input: Tuple[int, int],
    output: Tuple[int, int],
    zero_acc: bool,
    mode: int,
    qdq_addr: int,
    ifmsum_addr : int,
    tdm_1_addr: int,
    tdm_2_addr: int,
    int4_wgt: int = 0,
    wgt_unpack_addr: int = 0,
    qdq_mode:  OPMode = OPMode.OP_CONV_ASYM,
    sign_act: int = 0,
    sign_wgt: int = 0,
    sign_out: int = 0,
    wgt_bits: int = 8,
) -> bytes:
    Y_gran = 1
    X_gran = 8
    Co_gran = 8
    Ci_gran = 8
    size_bytes = 2
    stride_efficiency = 0.5
    mem_align = 64
    M, K = input
    K, N = output
    Ky, Kx = 1, 1
    Sy, Sx = 1, 1
    op_mode = qdq_mode
    params_blob = generate_layer_kernel_params(
        zero_acc,
        mode,
        qdq_addr,
        ifmsum_addr,
        tdm_1_addr,
        tdm_2_addr,
        GemmSubvDims(
            1,
            1, Y_gran,
            M, X_gran,
            N, Co_gran,
            K, Ci_gran,
            Ky, Kx,
            Sy, Sx,
            op_mode,
            size_bytes,
            stride_efficiency,
            mem_align,
            sign_act,
            sign_wgt,
            sign_out,
        ),
        int4_wgt,
        wgt_unpack_addr,
        wgt_bits
    )
    return params_blob
"""
    def params_instruction_a8w8(self):
        return f"""
def gen_a8w8_gemm_params(
    input : Tuple,
    output : Tuple,
    zero_init : int,
    final_tdm_iter : int,
    core_tdm1 : int,
    core_tdm2 : int,
    core_ifmsum : int,
    core_scratchbuf : int,
    core_qdq_addr : int,
    core_c0_addr : int,
    op_mode : int,
    mode : int,
    transposeA: int,
    transposeB: int
) -> bytes:
    Msubv, Ksubv = input
    _, Nsubv = output
    dummy = 0
    templates = {{
        "has_dwc": 0,
        "has_conv": 0,
        "has_sum": 0 if op_mode == 9 else 1,
        "has_vector_coeffs": 0
    }}
    parameters = {{
        "subvolume": {{
            "H": 1,
            "W": Msubv,
            "Co": Nsubv,
            "Ci": Ksubv,
            "Kh": 1,
            "Kw": 1,
            "Sh": 1,
            "Sw": 1,
            "Dh": 1,
            "Dw": 1
        }},
        "op_mode": "sym" if op_mode == 9 else "AxA" if op_mode == 27 else "asym",
        "dtype": {{
            "I0": "uint8",
            "I1": "uint8",
            "O0": "uint8"
        }},
        "transpose": {{
            "I0": transposeA,
            "I1": transposeB
        }},
        "quantization_coeffs": {{
            "shift_res": 0,
            "zp_wght" : 0,
            "vector_coeffs": -1 if op_mode == 27 else 0,
            "qdq_c0": 0,
            "qdq_c1": 0,
            "qdq_c2": 0,
            "qdq_c3": 0
        }}
    }}
    reserved = 0
    kernel_params = params(templates, parameters, 1, Msubv)
    kernel_param_padding = b'\\x00' * (166 - len(kernel_params))
    kernel_param_blob = kernel_params + kernel_param_padding
    layer_param_blob = (
        zero_init.to_bytes(length= 1, byteorder='little', signed=False)
        + final_tdm_iter.to_bytes(length=1, byteorder='little', signed=False)
        + (Ksubv * Nsubv).to_bytes(length=2, byteorder='little', signed=False)
        + core_tdm1.to_bytes(length = 2, byteorder='little', signed=False)
        + core_tdm2.to_bytes(length = 2, byteorder='little', signed=False)
        + core_ifmsum.to_bytes(length = 2, byteorder='little', signed=False)
        + core_scratchbuf.to_bytes(length = 2, byteorder='little', signed=False)
        + core_qdq_addr.to_bytes(length = 2, byteorder='little', signed=False)
        + core_c0_addr.to_bytes(length = 2, byteorder='little', signed=False)
        + op_mode.to_bytes(length = 2, byteorder='little', signed=False)
        + mode.to_bytes(length = 2, byteorder='little', signed=False)
        + reserved.to_bytes(length = 2, byteorder='little', signed=False)
    )
    bin_blob = layer_param_blob + kernel_param_blob
    return bin_blob
    
"""
    def params_instruction_RoPE(self):
        return f"""
def rope_layer_params(n: int, Msubv: int, Nsubv: int, qdq_addr: int, tdm_addr1: int, tdm_addr2: int, fused_op_flag: int, row: int, col: int):
        do_neg = 1 if col < 4 else 0
        return (n.to_bytes(length=2, byteorder='little', signed=False)
                +Msubv.to_bytes(length=2, byteorder='little', signed=False)
                +Nsubv.to_bytes(length=2, byteorder='little', signed=False)
                +qdq_addr.to_bytes(length=2, byteorder='little', signed=False)
                +tdm_addr1.to_bytes(length=2, byteorder='little', signed=False)
                +tdm_addr2.to_bytes(length=2, byteorder='little', signed=False)
                +fused_op_flag.to_bytes(length=2, byteorder='little', signed=False)
                +do_neg.to_bytes(length=2, byteorder='little', signed=False)
                )
"""
    def params_instruction_elew(self):
        return f"""
def matadd_params(n: int, Msubv: int, Nsubv: int, qdq_addr: int, tdm_addr1: int, tdm_addr2: int, fused_op_flag: int, row: int, col: int):
        do_neg = 0
        return (n.to_bytes(length=2, byteorder='little', signed=False)
                +Msubv.to_bytes(length=2, byteorder='little', signed=False)
                +Nsubv.to_bytes(length=2, byteorder='little', signed=False)
                +qdq_addr.to_bytes(length=2, byteorder='little', signed=False)
                +tdm_addr1.to_bytes(length=2, byteorder='little', signed=False)
                +tdm_addr2.to_bytes(length=2, byteorder='little', signed=False)
                +fused_op_flag.to_bytes(length=2, byteorder='little', signed=False)
                +do_neg.to_bytes(length=2, byteorder='little', signed=False)
                )
"""
    def params_instruction_LUTOPs(self):
        return f"""
def LUTOPs_qdq_params(qdq_prm_addr: int, lutad_addr: int, lutcd_addr: int, num_elements: int, tdm1_addr: int, tdm2_addr: int, fused_op: int, is_int16: int):
          return ( qdq_prm_addr.to_bytes(length=2, byteorder='little', signed=False)
                    + lutad_addr.to_bytes(length=2, byteorder='little', signed=False)
                    + lutcd_addr.to_bytes(length=2, byteorder='little', signed=False)
                    + num_elements.to_bytes(length=2, byteorder='little', signed=False)
                    + tdm1_addr.to_bytes(length=2, byteorder='little', signed=False)
                    + tdm2_addr.to_bytes(length=2, byteorder='little', signed=False)
                    + fused_op.to_bytes(length=2, byteorder='little', signed=False)
                    + is_int16.to_bytes(length=2, byteorder='little', signed=False)
                    )
"""
    def params_instruction_actxact(self):
        return """
def smxv_qdq_params(
    input: Tuple[int, int],
    output: Tuple[int, int],
    mha_mode: int,
    transpose_mode : int,
    col_id : int, 
    row_id : int,
    first_tdm_iter : int, 
    final_tdm_iter : int,
    unpack_matV : int,
    core_tdm1_addr : int,
    core_tdm2_addr : int,
    core_qdq_addr : int,
    core_act1_addr : int,
    core_act2_addr : int,
    core_C0_addr : int,
    core_scratch_addr : int,
    sign_act: int = 0,
    sign_wgt: int = 0,
    sign_out: int = 0,
) -> bytes:
    Y_gran = 1
    X_gran = 8
    Co_gran = 8
    Ci_gran = 8
    size_bytes = 2
    stride_efficiency = 0.5
    mem_align = 64
    M, K = input
    K, N = output
    Ky, Kx = 1, 1
    Sy, Sx = 1, 1
    op_mode = OPMode.OP_SUM
    params_blob = generate_layer_kernel_params(
        mha_mode,
        transpose_mode,
        col_id,
        row_id,
        first_tdm_iter,
        final_tdm_iter,
        unpack_matV,
        core_tdm1_addr,
        core_tdm2_addr,
        core_qdq_addr,
        core_act1_addr,
        core_act2_addr,
        core_C0_addr,
        core_scratch_addr,
        MhaSubvDims(
            1,
            1, Y_gran,
            M, X_gran,
            K, Ci_gran,
            N, Co_gran,
            Ky, Kx,
            Sy, Sx,
            op_mode,
            size_bytes,
            stride_efficiency,
            mem_align,
            sign_act,
            sign_wgt,
            sign_out,
        )
    )
    return params_blob
"""
    def params_instruction_bfp(self):
        return """
def gemm_params_bfp16(zero_acc: int, final_iter: int, sub_vol: int, CoreTdm1Addr:int, CoreTdm2Addr: int, transpose: int,
            act_in_bfp16: int, act_out_bfp: int, gelu_fuse: int, init_gemm_param:int):
        assert zero_acc in (0, 1)
        assert final_iter in (0, 1)
        assert transpose in (0, 1)
        assert act_in_bfp16 in (0, 1)
        assert act_out_bfp in (0, 1)
        assert gelu_fuse == 0
        #global init_gemm_param
        param_list = (
                zero_acc.to_bytes(length=2, byteorder='little', signed=False)
                + final_iter.to_bytes(length=2, byteorder='little', signed=False)
                + sub_vol.to_bytes(length=2, byteorder='little', signed=False)
                + CoreTdm1Addr.to_bytes(length=2, byteorder='little', signed=False)
                + CoreTdm2Addr.to_bytes(length=2, byteorder='little', signed=False)
                + transpose.to_bytes(length=2, byteorder='little', signed=False)
                + act_in_bfp16.to_bytes(length=2, byteorder='little', signed=False)
                + act_out_bfp.to_bytes(length=2, byteorder='little', signed=False)
                + init_gemm_param.to_bytes(length=2, byteorder='little', signed=False)
                + gelu_fuse.to_bytes(length=2, byteorder='little', signed=False)
            )
        #init_gemm_param = 0 
        return param_list 
"""

    def gen_helper_func_instr(self, packtransfer, packshimtransfer):
        code = ""
        if packtransfer:
            code += self.gen_pack_TransferParams()
        if packshimtransfer:
            code += self.generate_packed_shim_data_transfer()
        return code

    def split_share_ch_tiling(self, mm2s_list, ch):
        list_shape   = np.array(mm2s_list).shape
        flatten_list = np.array(mm2s_list).flatten().tolist()
        mod_list     = []
        for mm2s in flatten_list:
            mm2s_split = mm2s.split(' ')
            col_idx    = [1 if x.count(":")==2 and x.split(':')[0] == 'Cols' else 0 for x in mm2s_split].index(1) - len(mm2s_split)
            mm2s_col   = mm2s_split[col_idx]
            mm2s_col_split = mm2s_col.split(':')
            if ch == 0:
                mm2s_col_split[2] = str(int(mm2s_col_split[2]) - 2)
            elif ch == 1:
                mm2s_col_split[1] = str(int(mm2s_col_split[1]) + 2)
            mod_list.append(mm2s.replace(mm2s_col + ' ', ':'.join(mm2s_col_split) + ' '))
        return_list = np.array(mod_list).reshape(list_shape).tolist()
        return return_list

    def gen_memtile_instr(self, params):
        memtile_transfers  = "    #MEMTILE FLOW DEFINITION\n"
        memtile_transfers += self.memtile_stats(params)
        memtile_transfers += "    memtile_transfers = []\n"
        memtile_transfers += "    #MEMTILE PARAM TRANSFERS\n"
        memtile_transfers += self.memtile_prm_pattern(params)
        if params.get('CoreQdqSize') is not None:
            memtile_transfers += "    #MEMTILE QDQ PARAM TRANSFERS\n"
            memtile_transfers += self.memtile_qdq_pattern(params)

        memtile_transfers += f"    #MEMTILE ACTIVATION TRANSFERS, {params['ActDataFlow'].upper()}\n"
        if not params['MemtileActPongRepeat']:
            MemtileWgtRepeat = params['MemtileActPingRepeat']
        else:
            MemtileWgtRepeat = [sum(x) for x in zip(params['MemtileActPingRepeat'], params['MemtileActPongRepeat'])]
            
        enable_memtile_padding = params['dims'].Mpad_ifm > 0 or params['dims'].Kpad_ifm > 0
        mem_act = self.MemTransfers(params.get('MemtileWgtchid', 1) , MemtileWgtRepeat                 , params['MemtileActSize']    , 
                                    params['act_memtile_memory']    , params['act_memtile_mm2s']       , params['act_memtile_s2mm']  ,
                                    params['act_reuse_chain_length'], params['MemtileActReuseRatio']   , params['MemtileIfmPingAddr'],
                                    params['MemtileIfmPongAddr']    , params['ActPingPong']            , params['dims'].act_bits     ,
                                    'Default'                       , enable_memtile_padding           , params['ifm_tiling_iters'])
        if params['ifm_channel'] == 'broadcast':
            mem_act.ChannelId = params.get('MemtileWgtchid', 0)
            if not params['ActPingPong'] or params['ActDataFlow'] in ['stream']: #non ifm pingpong or streaming mode
                memtile_transfers += mem_act.dma_broadcast_pattern(params['ifm_idx'], params['row_list']) 
            else: #ifm pin pingpong
                mem_act.PingPongEnable = False
                mem_act.Repeat = params['MemtileActPingRepeat']
                memtile_transfers += mem_act.dma_broadcast_pattern(params['ifm_idx'], params['row_list'])
                if sum(params['MemtileActPongRepeat'])>0:
                    mem_act.Repeat = params['MemtileActPongRepeat']
                    mem_act.PingAddr = params['MemtileIfmPongAddr']
                    memtile_transfers += mem_act.dma_broadcast_pattern(params['ifm_idx'], params['row_list'])
        elif params['ifm_channel'] == 'unicast':
            if not params['ActPingPong'] or params['ActDataFlow'] in ['stream']: #non ifm pingpong or streaming mode
                mem_act.sync_strategy = params.get('sync_strategy', 'SyncStrategy.Parallel_1_to_N')
                memtile_transfers += mem_act.dma_unicast_pattern(params['ifm_idx'], params['row_list'])
            else: #ifm pin pingpong
                mem_act.PingPongEnable = False
                mem_act.sync_strategy = params.get('sync_strategy', 'SyncStrategy.Parallel_1_to_N')
                mem_act.Repeat = params['MemtileActPingRepeat']
                memtile_transfers += mem_act.dma_unicast_pattern(params['ifm_idx'], params['row_list'])
                if params['MemtileActPongRepeat'] != None and sum(params['MemtileActPongRepeat'])>0:
                    mem_act.Repeat = params['MemtileActPongRepeat']
                    mem_act.PingAddr = params['MemtileIfmPongAddr']
                    memtile_transfers += mem_act.dma_unicast_pattern(params['ifm_idx'], params['row_list'])
        
        enable_memtile_padding = params['dims'].Kpad_wgt > 0 or params['dims'].Npad_wgt > 0
        memtile_transfers += f"    #MEMTILE WEIGHT TRANSFERS, {params['WgtDataFlow'].upper()}\n"
        mem_wgt = self.MemTransfers(params.get('MemtileActchid', 0) , params['MemtileWgtRepeat']       , params['MemtileWgtSize']    , 
                                    params['wgt_memtile_memory']    , params['wgt_memtile_mm2s']       , params['wgt_memtile_s2mm']  ,
                                    params['wgt_reuse_chain_length'], params['MemtileWgtReuseRatio']   , params['MemtileWgtPingAddr'],
                                    params['MemtileWgtPongAddr']    , params['WgtPingPong']           , params['dims'].wgt_bits     ,
                                    'Default'                       , enable_memtile_padding           , params['wgt_tiling_iters'])
        if params['wgt_channel'] == 'broadcast':
            memtile_transfers += mem_wgt.dma_broadcast_pattern(params['wgt_idx'], params['row_list'])
        elif  params['wgt_channel'] == 'unicast':
            mem_wgt.sync_strategy = 'SyncStrategy.Parallel_1_to_N'
            if params['SHARE_CH_MODE']:
                assert params['actxact'] == False, "Shared channel mode not supported for actxact"
                mem_wgt.ChannelId = params.get('MemtileWgtchid', 0) 
                mem_wgt.S2MM_Format = self.split_share_ch_tiling(params['wgt_memtile_s2mm'], 0)               
                memtile_transfers += mem_wgt.dma_unicast_pattern(params['wgt_idx'], params['row_list'][:2])
                mem_wgt.ChannelId = params.get('MemtileWgtchid', 1)
                mem_wgt.S2MM_Format = self.split_share_ch_tiling(params['wgt_memtile_s2mm'], 1)  
                memtile_transfers += mem_wgt.dma_unicast_pattern(params['wgt_idx'], params['row_list'][2:])
            else:
                mem_wgt.ChannelId = params.get('MemtileWgtchid', 1)
                memtile_transfers += mem_wgt.dma_unicast_pattern(params['wgt_idx'], params['row_list'])

        if params['ifm_channel'] == params['wgt_channel'] == 'unicast_headpercore':
            if params['ActPingPong'] and params['ActPingPong']:
                params['MemtileIfmPongAddr'] = params['MemtileIfmPingAddr'] + params['MemtileActSize'] + params['MemtileWgtSize']
 
            mem_act = self.MemTransfers(params.get('MemtileWgtchid', 1) , MemtileWgtRepeat                 , params['MemtileActSize']    , 
                                        params['act_memtile_memory'] + params['wgt_memtile_memory'], 
                                        params['act_memtile_mm2s'] + params['wgt_memtile_mm2s'],
                                        params['act_memtile_s2mm'] + params['wgt_memtile_s2mm'],
                                        params['act_reuse_chain_length'], params['MemtileActReuseRatio']   , params['MemtileIfmPingAddr'],
                                        params['MemtileIfmPongAddr']    , params['ActPingPong']      , params['dims'].act_bits     ,
                                        'Default'                       , True                             , params['ifm_tiling_iters']  ,
                                        buffer_offset = [0, params['MemtileActSize']],
                                        bd_chain_length = 2)
            memtile_transfers += mem_act.dma_unicast_pattern(params['ifm_idx'], params['row_list'])

        memtile_transfers += "    #MEMTILE OFM TRANSFERS\n"
        mem_ofm = self.MemTransfers(5                              , params['MemtileOutRepeat']       , params['MemtileOutSize']    , 
                                    params['out_memtile_memory']   , params['out_memtile_mm2s']       , params['out_memtile_s2mm']  ,
                                    1                              , 1                                , params['MemtileOfmPingAddr'],
                                    params['MemtileOfmPongAddr']   , False                            , params['dims'].out_bits     ,
                                    'Default' if params.get('is_rope_fused', False) else 'SyncStrategy.Parallel_N_to_1' , False                            , params['out_tiling_iters'])

        memtile_transfers += mem_ofm.dma_unicast_out_pattern(params['ofm_idx'], params['row_list'])
        memtile_transfers += "\n"
        if  params.get('is_rope_fused', False):
            memtile_transfers += "    #MEMTILE SIN TRANSFERS\n"
            ofm_memory_fmt   = params['out_memtile_memory'][0].split()
            reversed_ofm_memory_fmt = ofm_memory_fmt[::-1]+['W:8']
            sin_cos_memory_fmt   = [' '.join(reversed_ofm_memory_fmt)]#params['out_memtile_memory']
            sin_cos_memtile_mm2s = [params['out_memtile_s2mm'][0]]*8
            sin_cos_memtile_s2mm = params['out_memtile_mm2s']
            mem_sin = self.MemTransfers(params.get('MemtileActchid', 1), (len(params['MemtileOutRepeat'])-2)*[0]+[1,0], params['MemtileRoPESize'] // 2    , 
                                        sin_cos_memory_fmt, sin_cos_memtile_mm2s , sin_cos_memtile_s2mm  ,
                                        1                 , 1                    , params['MemtileRoPEAddr'],
                                        None, False, params['dims'].out_bits     ,
                                       'Default' , False    , params['out_tiling_iters'])
            memtile_transfers += mem_sin.dma_unicast_pattern(params['rope_idx'], params['row_list'])
            memtile_transfers += "\n"
            memtile_transfers += "    #MEMTILE COS TRANSFERS\n"
            mem_cos = self.MemTransfers(params.get('MemtileActchid', 1)       , (len(params['MemtileOutRepeat'])-1)*[0]+[1]       , params['MemtileRoPESize'] // 2   , 
                                        sin_cos_memory_fmt, sin_cos_memtile_mm2s             , sin_cos_memtile_s2mm  ,
                                        1                 , 1                                , params['MemtileRoPEAddr'] + params['MemtileRoPESize'] //2,
                                        None, False,  params['dims'].out_bits,
                                        'Default' , params['ActPingPong']       , params['out_tiling_iters'])
            memtile_transfers += mem_cos.dma_unicast_pattern(params['rope_idx'], params['row_list'])
            memtile_transfers += "\n"
        elif  params.get('is_elew_fused', False):
            memtile_transfers += "    #MEMTILE IFMB TRANSFERS\n"
            ofm_memory_fmt   = params['out_memtile_memory'][0].split()
            reversed_ofm_memory_fmt = ofm_memory_fmt[::-1]+['W:8']
            ifmB_memory_fmt   = [' '.join(reversed_ofm_memory_fmt)]#params['out_memtile_memory']
            ifmB_memtile_mm2s = [params['out_memtile_s2mm'][0]]*8
            ifmB_memtile_s2mm = params['out_memtile_mm2s']
            mem_ifmB = self.MemTransfers(params.get('MemtileActchid', 1), (len(params['MemtileOutRepeat'])-1)*[0]+[1], params['MemtileifmBSize'] // 2    , 
                                        ifmB_memory_fmt, ifmB_memtile_mm2s , ifmB_memtile_s2mm  ,
                                        1                 , 1                    , params['MemtileifmBAddr'],
                                        None, False, params['dims'].out_bits     ,
                                       'Default' , False    , params['out_tiling_iters'])
            memtile_transfers += mem_ifmB.dma_unicast_pattern(params['elew_idx'], params['row_list'])
            memtile_transfers += "\n"

        return memtile_transfers

    def gen_shim_instr(self, params):
        shim_transfers  = "    #SHIM FLOW DEFINITION\n"
        shim_transfers += self.shim_stats(params)
        shim_transfers += "    shim_transfers = []\n"
        shim_transfers += "    #SHIM PARAM TRANSFERS\n"
        shim_param_ch   = params['unicast_idx'][0] if params.get('ShimParamMode', 'unicast') == 'unicast' else params['broadcast_idx'][0]
        shim_transfers += "".join([self.shim_prm_pattern(params, col) for col in shim_param_ch])
        if params['CoreQdqSize'] != None:
            shim_transfers += "    #SHIM QDQ PARAM TRANSFERS\n"
            shim_qdq_ch     = params['unicast_idx'][0] if params.get('ShimQdqMode', 'broadcast') == 'unicast' else params['broadcast_idx'][0]
            shim_transfers += "".join([self.shim_qdq_pattern(params, col) for col in shim_qdq_ch])

        shim_transfers += f"    #SHIM ACTIVATION TRANSFERS, {params['ActDataFlow'].upper()}\n"
        shim_act = self.ShimTransfers(0, 'MM2S', params['ShimActRepeat'], params['ShimActBufferIdx'], params['ShimIfmSize'], params['act_shim_memory'],  params['act_shim_mm2s'], 
                                      0, params['dims'].act_bits, params['ShimActOuterLoop'], params['ifm_tiling_iters'])
        shim_transfers += shim_act.dma_pattern(params['ifm_idx'])

        shim_transfers += f"    #SHIM WEIGHT TRANSFERS, {params['WgtDataFlow'].upper()}\n"
        wgt_buffer_idx = params['ShimActBufferIdx'] if params['actxact'] else params['ShimWgtBufferIdx']
        if params['SHARE_CH_MODE']:
            assert params['actxact'] == False, "Shared channel mode not supported for actxact"
            shared_ch0_mm2s = self.split_share_ch_tiling(params['wgt_shim_mm2s'], 0)
            shared_ch1_mm2s = self.split_share_ch_tiling(params['wgt_shim_mm2s'], 1)
            shim_wgt = self.ShimTransfers(0, 'MM2S', params['ShimWgtRepeat'], wgt_buffer_idx, params['ShimWgtSize'], params['wgt_shim_memory'],  shared_ch0_mm2s, 
                                        params.get('wgt_shim_offset', 0), params['dims'].wgt_bits, len(params['ShimWgtRepeat']), params['wgt_tiling_iters'])
            shim_transfers += shim_wgt.dma_pattern(params['wgt_idx'])
            shim_wgt = self.ShimTransfers(1, 'MM2S', params['ShimWgtRepeat'], wgt_buffer_idx, params['ShimWgtSize'], params['wgt_shim_memory'],  shared_ch1_mm2s, 
                                        params.get('wgt_shim_offset', 0), params['dims'].wgt_bits, len(params['ShimWgtRepeat']), params['wgt_tiling_iters'])
            shim_transfers += shim_wgt.dma_pattern(params['wgt_idx'])
        else:
            shim_wgt = self.ShimTransfers(1, 'MM2S', params['ShimWgtRepeat'], wgt_buffer_idx, params['ShimWgtSize'], params['wgt_shim_memory'],  params['wgt_shim_mm2s'], 
                                        params.get('wgt_shim_offset', 0), params['dims'].wgt_bits, len(params['ShimWgtRepeat']), params['wgt_tiling_iters'])
            shim_transfers += shim_wgt.dma_pattern(params['wgt_idx'])

        shim_transfers += "    #SHIM OFM TRANSFERS\n"
        shim_ofm = self.ShimTransfers(0, 'S2MM', params['ShimOutRepeat'], params['ShimOutBufferIdx'], params['ShimOfmSize'], params['out_shim_memory'],  params['out_shim_s2mm'], 
                                      0, params['dims'].out_bits, 1, params['out_tiling_iters'])
        shim_transfers += shim_ofm.dma_pattern(params['ofm_idx'])
        
        if  params.get('is_rope_fused', False):
            qdq_size_multiple = 0 if params['is_fused_rope_actxact'] else 2 #for rope and matmul's qdq
            sin_cos_offset = params['rope_shim_offset']+params['rope_qdq_size']*qdq_size_multiple
            shim_transfers += "    #SHIM SIN TRANSFERS IFM, STREAM\n"
            RoPE_Buffer_Idx =  params['ShimActBufferIdx'] if params['is_fused_rope_actxact'] else params['ShimWgtBufferIdx']
            shim_sin = self.ShimTransfers(1, 'MM2S', (len(params['ShimOutRepeat'])-2)*[0]+[1,0], RoPE_Buffer_Idx, params['ShimRoPESize'], params['out_shim_memory'],  params['out_shim_s2mm'], 
                                          sin_cos_offset, params['dims'].out_bits, 1, params['out_tiling_iters'])
            # Function to shuffle the entries
            def shuffle_entries(entries):
                half = len(entries) // 2
                return entries[half:] + entries[:half]

            # Shuffle the entries in params['rope_idx']
            sin_rope_idx = [shuffle_entries(entry) for entry in params['rope_idx']]
                
            shim_transfers += shim_sin.dma_pattern(sin_rope_idx)
            
            shim_transfers += "    #SHIM COS TRANSFERS IFM, STREAM\n"
            shim_cos = self.ShimTransfers(1, 'MM2S', (len(params['ShimOutRepeat'])-1)*[0]+[1], RoPE_Buffer_Idx, params['ShimRoPESize'], params['out_shim_memory'],  params['out_shim_s2mm'], 
                                          sin_cos_offset+params['rope_total_size'], params['dims'].out_bits, 1, params['out_tiling_iters'])
            shim_transfers += shim_cos.dma_pattern(params['rope_idx'])

        elif  params.get('is_elew_fused', False):
            ifmB_offset = params['elew_ifmB_shim_offset']
            shim_transfers += "    #SHIM IFMB TRANSFERS IFM, STREAM\n"
            Elew_Buffer_Idx =  params['ShimActBufferIdx'] 
            shim_ifmB = self.ShimTransfers(1, 'MM2S', (len(params['ShimOutRepeat'])-1)*[0]+[1], Elew_Buffer_Idx, params['ShimElewSize'], params['out_shim_memory'],  params['out_shim_s2mm'], 
                                          ifmB_offset, params['dims'].out_bits, 1, params['out_tiling_iters'])
            shim_transfers += shim_ifmB.dma_pattern(params['elew_idx'])
            
        return shim_transfers
        
    def dma_pattern_code(self, params):
        core_transers  = self.gen_core_instr(params)
        memtile_transfers = self.gen_memtile_instr(params)
        shim_transfers = self.gen_shim_instr(params)
        return core_transers + memtile_transfers + shim_transfers
