import sys
import os
infra_path = os.path.dirname(os.path.abspath(__file__)) + "/infra/"
sys.path.append(infra_path)
import logging
import math
import numpy as np
import copy
from enum import Enum
import re

import const
import dataclasses
from template_base import BaseTemplate
import gen_kernel_param

import scheduler_utils as utils


@dataclasses.dataclass(slots=True)
class MataddDims():
    M_subv : int
    aie_rows : int
    aie_cols : int
    aie_arrays : int
    outer_loop : int
    inner_loop : int
    ifmA_bits  : int
    ifmB_bits  : int
    ofm_bits   : int
    qdq_bytes  : int
    
def getByteCount(dtype):
    if dtype in ["uint16", "int16", "float32"]: #f32 should be considered as bf16
        return 2
    elif dtype in ["uint32", "int32"]:
        return 4
    elif dtype in ["uint8", "int8"]:
        return 1

def decompose(X):
    for pwr in range(math.ceil(math.log(X, 2))):
        N = 2**pwr
        M = X // N
        Prod = M * N
        if Prod == X and M < 128 and N < 128 and (N % 32 == 0):
            return M, N
    raise RuntimeError(
        "Unable to find M and N value to recreate needed subvolume size"
    )

def equalize_loop_length(X, fix_len = 0):
    if fix_len == 0:
        max_len = max(len(a) for a in X)
    else:
        max_len = fix_len
    padded  = [a + [0] * (max_len - len(a)) for a in X]
    return max_len, padded

def check_limit_loop_length(X, divisor=0):
    padded = []
    max_loop_len = 1024
    if divisor != 0:
        max_loop_len = int((np.floor(max_loop_len / divisor)) * divisor)

    for sub in X:
        val = sub[0]
        rest = sub[1:]

        chunks = [max_loop_len] * (val // max_loop_len)
        remainder = val % max_loop_len

        if remainder:
            if rest:
                chunks.append(remainder)
                chunks.append(rest[0])
            else:
                chunks.append(remainder)
        elif rest:
            chunks.append(rest[0])
        
        padded.append(chunks)
    return padded


class Add_SupportedOps(Enum):
    # Layer parameter takes in uint16 data
    # First nibble op types - 0 - Add ; 1 - Mul ; 2 - Sub ; 3 - Div
    # Second nibble check if it is broadcast, innermost broadcast / elw
    # Third nibble onwards - fusion (cascade)
    # 5th bit is set       - 0b0001 00xx - broadcast
    # 6th and 5th bit set  - 0b0011 00xx - innermost broadcast
    # 9th bit is set       - 0b0001 xxxx 00xx - Cascading
    ELW_ADD               = 0    #0x0000
    ELW_MUL               = 1    #0x0001
    ELW_SUB               = 2    #0x0002 #Not supported yet
    ELW_DIV               = 3    #0x0003 #Not supported yet
    
    BCAST_ADD             = 16   #0x0010
    BCAST_MUL             = 17   #0x0011
    BCAST_SUB             = 18   #0x0012 #Not supported yet
    BCAST_DIV             = 19   #0x0013 #Not supported yet

    INNERMOST_BCAST_ADD   = 48   #0x0030
    INNERMOST_BCAST_MUL   = 49   #0x0031
    INNERMOST_BCAST_SUB   = 50   #0x0032 #Not supported yet
    INNERMOST_BCAST_DIV   = 51   #0x0033 #Not supported yet
    
    CASCADE_ADD           = 256  #0x0100
    #CASCADE_MUL          = 257  #0x0101 #Not supported yet
    #CASCADE_SUB          = 258  #0x0102 #Not supported yet
    #CASCADE_DIV          = 259  #0x0103 #Not supported yet

class BroadCast():
    def __init__(self):
        pass
    
    def gen_core_instrs_func(self):
        code = f"""
def gen_core_instrs(op, Msubv, Nsubv, qdq_addr, qdq_bytes, ifma_addr, ifma_size, ifmb_addr, ifmb_size, ofm_addr, ofm_size, outer_loop, inner_loop):
    return [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), qdq_addr, None, qdq_bytes), #CoreQdqAddr, QdqSize
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), ofm_addr[0], ofm_addr[1], ofm_size), #CoreOutPingAddr, CoreOutPongAddr, CoreOutSize
        Loop(outer_loop, [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), ifma_addr[0], ifma_addr[1], ifma_size),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), ifmb_addr[0], ifmb_addr[1], ifmb_size),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
            CallKernel('{self.kernel_name}', kernel_params=matadd_params(op, Msubv, Nsubv, qdq_addr, ifma_addr[0], 0, 0, 0, 0)),
            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            Loop(inner_loop - 1, [
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                CallKernel('{self.kernel_name}', kernel_params=matadd_params(op, Msubv, Nsubv, qdq_addr, ifma_addr[0], 0, 0, 0, 1)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ]) if inner_loop - 1 > 0 else Loop(1, [])
        ])
    ] 
"""
        code += f"""
def gen_empty_core_instrs(qdq_addr, qdq_bytes):
    return [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), qdq_addr, None, qdq_bytes), #CoreQdqAddr, QdqSize
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1))
    ]
    
"""
        return code

    def gen_dma_transfers(self, params):
        if "uint8" in params["op_ver"]:
            self.kernel_name = 'run_a8a8_matadd_qdq'
            self.bcast_dtype = 'a8a8'
        else:
            self.kernel_name = 'run_a16a16_matadd_qdq'
            self.bcast_dtype = 'a16a16'
        dims                 = params['dims']
        shimActAsize         = params['ShimActSizeIfmA'] // getByteCount(params['WgtDtype'])
        shimActBsize         = params['ShimActSizeIfmB'] // getByteCount(params['ActDtype'])
        shimOfmsize          = params['ShimOutSize'] // getByteCount(params['OfmDtype'])
        
        memtileActA          = params['MemtileActSizeIfmA'] // getByteCount(params['WgtDtype'])
        memtileActB          = params['MemtileActSizeIfmB'] // getByteCount(params['ActDtype'])
        memtileOfm           = params['MemtileOutSize'] // getByteCount(params['OfmDtype'])
        
        coretileActA         = (params['CoreIfmaSize'] // getByteCount(params['WgtDtype']))
        coretileActB         = (params['CoreIfmbSize'] // getByteCount(params['ActDtype']))
        coretileOfm          = (params['CoreOfmSize'] // getByteCount(params['OfmDtype']))
        
        assert(coretileActA % params['ifmA_scale_factor'] == 0)
        assert(coretileActB % params['ifmB_scale_factor'] == 0)
        assert(coretileOfm  % params['ofm_scale_factor'] == 0)

        coretileActA         = coretileActA // params['ifmA_scale_factor']
        coretileActB         = coretileActB // params['ifmB_scale_factor']
        coretileOfm          = coretileOfm  // params['ofm_scale_factor']

        params['reenqueues'] = False
        params['bd_chain']   = False
        params["ping_pong_enable"]["ifmA"] = False
        self.multi_ch_batch_bcast    = params['multi_ch_batch_bcast']

        if params['ifmA_param_type'] == 'act':
            self.ifmA_const = 0
        elif params['ifmA_param_type'] == 'const':
            self.ifmA_const = 1

        if params['ifmB_param_type'] == 'act':
            self.ifmB_const = 0
        elif params['ifmB_param_type'] == 'const':
            self.ifmB_const = 1

        if params["op_name"] == "Add":
            self.op_select = Add_SupportedOps.BCAST_ADD
            if params["inner_dim_is_1"]:
                self.op_select = Add_SupportedOps.INNERMOST_BCAST_ADD
        elif params["op_name"] == "Mul":
            self.op_select = Add_SupportedOps.BCAST_MUL
            if params["inner_dim_is_1"]:
                self.op_select = Add_SupportedOps.INNERMOST_BCAST_MUL
        elif params["op_name"] == "Sub":
            self.op_select = Add_SupportedOps.BCAST_SUB
        elif params["op_name"] == "Div":
            self.op_select = Add_SupportedOps.BCAST_DIV
        else:
            raise RuntimeError("Unsupported op type!")

        self.act_shim_mm2s1      = [[] for _ in range(dims.aie_cols)]
        self.act_shim_mm2s2      = [[] for _ in range(dims.aie_cols)]
        self.out_shim_s2mm       = [[] for _ in range(dims.aie_cols)]
        self.act_memtile_s2mm1   = [[] for _ in range(dims.aie_cols)]
        self.act_memtile_s2mm2   = [[] for _ in range(dims.aie_cols)]
        self.out_memtile_mm2s    = [[] for _ in range(dims.aie_cols)]
        self.core_loops          = [[] for _ in range(dims.aie_cols)]
        self.inner_loop          = [[] for _ in range(dims.aie_cols)]
        self.outer_loop          = [[] for _ in range(dims.aie_cols)]
        
        outer_loop               = dims.outer_loop
        inner_loop               = dims.inner_loop

        ifmA_orig                = np.prod(params['in_ifmA_shape']).astype(int)
        ifmB_orig                = np.prod(params['in_ifmB_shape']).astype(int)

        ifm_mode                 = params['ifmA_mode']
        batch_transfer           = f"N:0:{outer_loop}:1"

        self.valid_cols          = []
        self.reenqueue_cols      = []

        ifm_mm2s1                = ""
        if params["inner_dim_is_1"]:
            if ifm_mode == 'pin':
                padding               = 8
                ifmA_act_H            = ifmA_orig // padding
                ifmB_act_W            = ifmB_orig // ifmA_act_H

                self.act_shim_memory1 = f"N:1 H:{ifmA_act_H} W:8"
                self.act_shim_memory2 = f"N:1 H:{ifmA_act_H} W:{ifmB_act_W}"
                self.out_shim_memory  = f"N:1 H:{ifmA_act_H} W:{ifmB_act_W}"

                ifm_transfer_H        = memtileActA  // padding
                ifmB_core_H           = coretileActA // padding
                ifmB_core_W           = coretileActB // ifmB_core_H

                if inner_loop == outer_loop:
                    params['bd_chain']          = True
                    self.act_memtile_memory1    = f"H:{ifm_transfer_H} W:{ifmB_core_W}"
                    self.ifmB_buff_offset       = ifm_transfer_H * ifmB_core_W * getByteCount(params['WgtDtype'])
                else:
                    self.act_memtile_memory1    = f"H:{ifm_transfer_H} W:8"
                self.act_memtile_memory2    = f"H:{ifm_transfer_H} W:{ifmB_core_W}"
                self.out_memtile_memory     = f"H:{ifm_transfer_H} W:{ifmB_core_W}"

                for col in range(dims.aie_cols):
                    self.valid_cols.append(col)
                    self.inner_loop[col]   = [inner_loop]
                    start                  = col * shimActAsize // padding // dims.aie_cols
                    end                    = (col + 1) * shimActAsize // padding // dims.aie_cols
                    if start >= ifmA_act_H:
                        self.valid_cols      = self.valid_cols[:-1]
                        self.core_loops[col] = 0
                        start                = 0
                        end                  = 0
                    if end > ifmA_act_H:
                        end = ifmA_act_H

                    self.act_shim_mm2s1[col] = [f"N:0:1 H:{start}:{end} W:0:8"]
                    self.act_shim_mm2s2[col] = [f"N:0:1 H:{start}:{end} W:0:{ifmB_act_W}"]
                    self.out_shim_s2mm[col]  = [f"N:0:1 H:{start}:{end} W:0:{ifmB_act_W}"]
                    if inner_loop == outer_loop:
                        self.act_memtile_s2mm1[col] = [f"H:0:{ifm_transfer_H} W:0:8"]
                        self.ifmB_buff_offset       = ifm_transfer_H * ifmB_core_W * getByteCount(params['WgtDtype'])
                    else:
                        self.act_memtile_s2mm1[col] = [f"H:0:{ifm_transfer_H} W:0:8"]
                    self.act_memtile_s2mm2[col] = [f"H:0:{ifm_transfer_H} W:0:{ifmB_act_W}"]
                    self.out_memtile_mm2s[col]  = [f"H:0:{ifm_transfer_H} W:0:{ifmB_act_W}"]
                    self.inner_loop[col] = [inner_loop]
                    self.outer_loop[col] = [outer_loop]
                    if ((end - start) % outer_loop) != 0 or ((end - start) < (ifm_transfer_H * outer_loop)):
                        self.inner_loop[col]        = [((end-start) // ifm_transfer_H), 1]
                        self.outer_loop[col]        = [((end-start) // ifm_transfer_H), 1]
                        pending_transfer_memtile_H  = (end - start) - (((end-start) // ifm_transfer_H) * ifm_transfer_H)
                        self.act_memtile_s2mm1[col] = [f"H:0:{ifm_transfer_H} W:0:8", f"H:0:{pending_transfer_memtile_H} W:0:8"]
                        self.act_memtile_s2mm2[col] = [f"H:0:{ifm_transfer_H} W:0:{ifmB_act_W}", f"H:0:{pending_transfer_memtile_H} W:0:{ifmB_act_W}"]
                        self.out_memtile_mm2s[col]  = [f"H:0:{ifm_transfer_H} W:0:{ifmB_act_W}", f"H:0:{pending_transfer_memtile_H} W:0:{ifmB_act_W}"]
                        self.reenqueue_cols.append(col)
                        params['reenqueues']        = True
                self.act_memtile_mm2s1              = [
                    [
                        f"H:{row * ifmB_core_H}:{(row+1) * ifmB_core_H} W:0:8"
                    ] for row in range(dims.aie_rows)
                ]
                self.act_memtile_mm2s2              = [
                    [
                        f"H:{(row * ifmB_core_H)}:{((row+1) * ifmB_core_H)} W:0:{ifmB_core_W}"
                    ] for row in range(dims.aie_rows)
                ]
                self.out_memtile_s2mm               = [
                    [
                        f"H:{row * ifmB_core_H}:{(row+1) * ifmB_core_H} W:0:{ifmB_core_W}"
                    ] for row in range(dims.aie_rows)
                ]

                self.core_inner_loop            = [1 for _ in self.inner_loop]
                self.core_outer_loop            = [np.sum(x) * params['core_iters'][1] for x in self.outer_loop]

                self.Msubv                      = ifmB_core_H // params['core_iters'][1]
                self.Nsubv                      = ifmB_core_W

            else:
                padding               = 8
                ifmA_act_H            = ifmA_orig // padding
                ifmB_act_W            = ifmB_orig // ifmA_act_H

                self.act_shim_memory1 = f"N:1 H:{ifmA_act_H} W:8"
                self.act_shim_memory2 = f"N:1 H:{ifmA_act_H} W:{ifmB_act_W}"
                self.out_shim_memory  = f"N:1 H:{ifmA_act_H} W:{ifmB_act_W}"

                ifm_transfer_H        = memtileActA  // padding
                ifmB_core_H           = coretileActA // padding
                ifmB_core_W           = coretileActB // ifmB_core_H

                ifmB_transfer_W       = memtileActB // ifm_transfer_H

                if inner_loop == outer_loop:
                    params['bd_chain']          = True
                    self.act_memtile_memory1    = f"H:{ifm_transfer_H} W:{ifmB_transfer_W}"
                    self.ifmB_buff_offset       = ifm_transfer_H * ifmB_transfer_W * getByteCount(params['WgtDtype'])
                else:
                    self.act_memtile_memory1    = f"H:{ifm_transfer_H} W:8"
                self.act_memtile_memory2    = f"H:{ifm_transfer_H} W:{ifmB_transfer_W}"
                self.out_memtile_memory     = f"H:{ifm_transfer_H} W:{ifmB_transfer_W}"

                for col in range(dims.aie_cols):
                    self.valid_cols.append(col)
                    self.inner_loop[col]   = [inner_loop]
                    start                  = col * ifmB_transfer_W
                    end                    = (col + 1) * ifmB_transfer_W
                    if start >= ifmB_act_W:
                        self.valid_cols      = self.valid_cols[:-1]
                        self.core_loops[col] = 0
                        start                = 0
                        end                  = 0
                    if end > ifmB_act_W:
                        end = ifmB_act_W
                    
                    self.act_shim_mm2s1[col] = [f"N:0:1 H:0:{ifmA_act_H} W:0:8"]
                    self.act_shim_mm2s2[col] = [f"N:0:1 H:0:{ifmA_act_H} W:{start}:{end}"]
                    self.out_shim_s2mm[col]  = [f"N:0:1 H:0:{ifmA_act_H} W:{start}:{end}"]
                    self.act_memtile_s2mm1[col] = [f"H:0:{ifm_transfer_H} W:0:8"]
                    self.act_memtile_s2mm2[col] = [f"H:0:{ifm_transfer_H} W:0:{end - start}"]
                    self.out_memtile_mm2s[col]  = [f"H:0:{ifm_transfer_H} W:0:{end - start}"]
                    self.inner_loop[col] = [inner_loop]
                    self.outer_loop[col] = [outer_loop]
                    if (ifm_transfer_H * outer_loop) > ifmA_act_H:
                        self.inner_loop[col] = [ifmA_act_H // ifm_transfer_H]
                        self.outer_loop[col] = [ifmA_act_H // ifm_transfer_H]
                        if ifmA_act_H % ifm_transfer_H != 0:
                            self.inner_loop[col] = [ifmA_act_H // ifm_transfer_H, 1]
                            self.outer_loop[col] = [ifmA_act_H // ifm_transfer_H, 1]

                            pending_transfer_H   = ifmA_act_H % ifm_transfer_H

                            self.act_memtile_s2mm1[col] = [f"H:0:{ifm_transfer_H} W:0:8", f"H:0:{pending_transfer_H} W:0:8"]
                            self.act_memtile_s2mm2[col] = [f"H:0:{ifm_transfer_H} W:0:{end - start}", f"H:0:{pending_transfer_H} W:0:{end - start}"]
                            self.out_memtile_mm2s[col]  = [f"H:0:{ifm_transfer_H} W:0:{end - start}", f"H:0:{pending_transfer_H} W:0:{end - start}"]

                            self.reenqueue_cols.append(col)
                            params['reenqueues']        = True
                self.act_memtile_mm2s1              = [
                    [
                        f"H:0:{ifm_transfer_H} W:0:8"
                    ] for row in range(dims.aie_rows)
                ]
                self.act_memtile_mm2s2              = [
                    [
                        f"H:0:{ifm_transfer_H} W:{row * ifmB_core_W}:{(row + 1) * ifmB_core_W}"
                    ] for row in range(dims.aie_rows)
                ]
                self.out_memtile_s2mm               = [
                    [
                        f"H:0:{ifm_transfer_H} W:{row * ifmB_core_W}:{(row + 1) * ifmB_core_W}"
                    ] for row in range(dims.aie_rows)
                ]

                self.core_inner_loop            = [1 for _ in self.inner_loop]
                self.core_outer_loop            = [np.sum(x) * params['core_iters'][1] for x in self.outer_loop]

                self.Msubv                      = ifm_transfer_H // params['core_iters'][1]
                self.Nsubv                      = ifmB_core_W

            #Dividing by scale factor in the core instruction, so multiplying the scale factor here
            params['CoreIfmaSize']          = self.Msubv * 8 * getByteCount(params['WgtDtype']) * params['ifmA_scale_factor']
            params['CoreIfmbSize']          = self.Nsubv * self.Msubv * getByteCount(params['ActDtype']) * params['ifmB_scale_factor']
            params['CoreOfmSize']           = self.Nsubv * self.Msubv * getByteCount(params['OfmDtype']) * params['ofm_scale_factor']
        else:
            self.act_shim_memory1    = f"N:{outer_loop} W:{ifmA_orig // outer_loop}"
            if ifm_mode == 'pin':
                self.act_shim_memory2    = f"N:{outer_loop} H:{ifmB_orig // ifmA_orig} W:{ifmA_orig // outer_loop}"
                self.out_shim_memory     = copy.copy(self.act_shim_memory2)
            else:
                self.act_shim_memory2    = f"N:{outer_loop} H:{shimActBsize // shimActAsize} W:{ifmA_orig // outer_loop}"
                self.out_shim_memory     = copy.copy(self.act_shim_memory2)

            self.act_memtile_memory1     = f"H:1 W:{memtileActA}"
            self.act_memtile_memory2     = f"H:{memtileActB // memtileActA} W:{memtileActA}"
            self.out_memtile_memory      = f"H:{memtileActB // memtileActA} W:{memtileActA}"

            divide_loop                  = inner_loop
            if outer_loop == inner_loop and params['multi_ch_batch_bcast']:
                self.multi_ch_batch_bcast = False
                divide_loop               = 1
            for col in range(dims.aie_cols):
                self.valid_cols.append(col)
                self.outer_loop[col] = [outer_loop]
                if ifm_mode == 'pin':
                    ifmB_col             = ifmA_orig // outer_loop
                    ifmB_row             = ifmB_orig // ifmA_orig
                    cols_memory          = f"H:{ifmB_row}"
                    rows_memory          = f"W:{ifmB_col}"
                    padded_B_rows        = shimActBsize // shimActAsize
                    ifmA_transfer        = ifmA_orig // outer_loop
                    ifmB_start_H         = col * padded_B_rows
                    ifmB_end_H           = (col + 1) * padded_B_rows
                    if ifmB_end_H > ifmB_row:
                        ifmB_end_H = ifmB_row
                    ifm_mm2s1            = f"W:0:{ifmA_transfer}"
                    if ifmB_start_H > ifmB_row:
                        ifmB_start_H     = 0
                        ifmB_end_H       = 0
                        ifmA_transfer    = 0
                        batch_transfer   = f"N:0:0"
                        self.valid_cols  = self.valid_cols[:-1]
                        self.core_loops[col] = 0
                    loop_cnt             = (padded_B_rows * outer_loop) // (memtileActB // memtileActA)
                    self.inner_loop[col] = [loop_cnt]
                    self.core_loops[col] = divide_loop
                    self.act_memtile_s2mm1[col] = [f"W:0:{ifmA_transfer}"]
                    if (ifmB_end_H - ifmB_start_H) % divide_loop != 0:
                        memtile_H        = memtileActB // memtileActA
                        loop_mulitplier  = 1 if not self.multi_ch_batch_bcast else outer_loop
                        odd_split        = False
                        if (ifmB_end_H - ifmB_start_H) % memtile_H != 0:
                            odd_split    = True
                            while ((ifmB_end_H - ifmB_start_H) % memtile_H != 0):
                                memtile_H -= 1
                        if memtile_H == 1 and odd_split:
                            memtile_H                   = memtileActB // memtileActA
                            self.core_loops[col]        = ((ifmB_end_H - ifmB_start_H) // memtile_H) + 1
                            self.inner_loop[col]        = [((ifmB_end_H - ifmB_start_H) // memtile_H) * loop_mulitplier, loop_mulitplier]
                            pending_transfer_memtile_H  = ((ifmB_end_H - ifmB_start_H) % memtile_H)
                            self.act_memtile_s2mm2[col] = [f"H:0:{memtile_H} W:0:{ifmA_transfer}", f"H:0:{pending_transfer_memtile_H} W:0:{ifmA_transfer}"]
                            self.out_memtile_mm2s[col]  = [f"H:0:{memtile_H} W:0:{ifmA_transfer}", f"H:0:{pending_transfer_memtile_H} W:0:{ifmA_transfer}"]
                            self.reenqueue_cols.append(col)
                            params['reenqueues']        = True
                        else:
                            self.core_loops[col]        = ((ifmB_end_H - ifmB_start_H) // memtile_H)
                            self.inner_loop[col]        = [self.core_loops[col] * loop_mulitplier]
                            self.act_memtile_s2mm2[col] = [f"H:0:{memtile_H} W:0:{ifmA_transfer}"]
                            self.out_memtile_mm2s[col]  = [f"H:0:{memtile_H} W:0:{ifmA_transfer}"]
                        if inner_loop == outer_loop:
                            params['bd_chain']          = True
                            self.act_memtile_s2mm1[col] = [f"H:0:1 W:0:{ifmA_transfer}"]
                            self.act_memtile_s2mm2[col] = [f"H:0:{memtile_H} W:0:{ifmA_transfer}"]
                            self.act_memtile_memory1    = f"H:1 W:{memtileActA}"
                            self.act_memtile_memory2    = f"H:{(memtileActB // memtileActA)} W:{memtileActA}"
                            self.ifmB_buff_offset       = memtileActA * getByteCount(params['WgtDtype'])
                    else:
                        self.act_memtile_s2mm2[col]     = [f"H:0:{(ifmB_end_H - ifmB_start_H) // divide_loop} W:0:{ifmA_transfer}"]
                        self.out_memtile_mm2s[col]      = [f"H:0:{(ifmB_end_H - ifmB_start_H) // divide_loop} W:0:{ifmA_transfer}"]
                        if inner_loop == outer_loop:
                            params['bd_chain']          = True
                            self.act_memtile_s2mm1[col] = [f"H:0:1 W:0:{ifmA_transfer}"]
                            self.act_memtile_s2mm2[col] = [f"H:0:{((ifmB_end_H - ifmB_start_H) // divide_loop)} W:0:{ifmA_transfer}"]
                            self.act_memtile_memory1    = f"H:1 W:{memtileActA}"
                            self.act_memtile_memory2    = f"H:{(memtileActB // memtileActA)} W:{memtileActA}"
                            self.ifmB_buff_offset       = memtileActA * getByteCount(params['WgtDtype'])
                    self.act_memtile_mm2s1              = [
                        [
                            f"H:0:1 W:0:{coretileActA}" if outer_loop == inner_loop else f"W:0:{coretileActA}"
                        ] for row in range(dims.aie_rows)
                    ]
                    memtile_H = (coretileActB // coretileActA) * params['core_iters'][1]
                    self.act_memtile_mm2s2              = [
                        [
                            f"H:{row * memtile_H}:{(row+1) * memtile_H} W:0:{coretileActA}"
                        ] for row in range(dims.aie_rows)
                    ]
                    self.out_memtile_s2mm               = [
                        [
                            f"H:{row * memtile_H}:{(row+1) * memtile_H} W:0:{coretileActA}"
                        ] for row in range(dims.aie_rows)
                    ]
                elif ifm_mode == 'stream':
                    startIfmA            = col * shimActAsize
                    endIfmA              = (col + 1) * shimActAsize
                    ifmB_start_H         = 0
                    ifmB_end_H           = shimActBsize // shimActAsize
                    cols_memory          = f"H:{ifmB_end_H}"
                    rows_memory          = f"W:{ifmA_orig // outer_loop}"
                    if startIfmA > ifmA_orig // outer_loop:
                        ifm_mm2s1        = "W:0:0"
                        ifmB_end_H       = 0
                        batch_transfer   = "N:0:0"
                        ifmA_transfer    = 0
                        self.valid_cols  = self.valid_cols[:-1]
                        self.core_loops[col] = 0
                    else:
                        if endIfmA > ifmA_orig // outer_loop:
                            endIfmA      = ifmA_orig // outer_loop
                        ifm_mm2s1        = f"W:{startIfmA}:{endIfmA}"
                        ifmA_transfer    = endIfmA - startIfmA
                    loop_cnt             = ((shimActBsize // shimActAsize) * outer_loop) // (memtileActB // memtileActA)
                    self.inner_loop[col] = [loop_cnt]
                    self.act_memtile_s2mm1[col] = [f"W:0:{ifmA_transfer}"]
                    self.act_memtile_s2mm2[col] = [f"H:0:{(memtileActB // memtileActA)} W:0:{ifmA_transfer}"]
                    self.out_memtile_mm2s[col]  = [f"H:0:{(memtileActB // memtileActA)} W:0:{ifmA_transfer}"]
                    if outer_loop == inner_loop:
                        params['bd_chain']      = True
                        self.act_memtile_s2mm1[col] = [f"H:0:1 W:0:{ifmA_transfer}"]
                        self.act_memtile_s2mm2[col] = [f"H:0:{(memtileActB // memtileActA)} W:0:{ifmA_transfer}"]
                        self.act_memtile_memory1    = f"H:1 W:{memtileActA}"
                        self.act_memtile_memory2    = f"H:{(memtileActB // memtileActA)} W:{memtileActA}"
                        self.ifmB_buff_offset       = memtileActA * getByteCount(params['WgtDtype'])
                    self.act_memtile_mm2s1          = [
                        [
                            f"H:0:1 W:{row * coretileActA}:{(row + 1) * coretileActA}"
                        ] for row in range(dims.aie_rows)
                    ]
                    self.act_memtile_mm2s2          = [
                        [
                            f"H:0:{memtileActB // memtileActA} W:{row * coretileActA}:{(row + 1) * coretileActA}"
                        ] for row in range(dims.aie_rows)
                    ]
                    self.out_memtile_s2mm           = [
                        [
                            f"H:0:{memtileActB // memtileActA} W:{row * coretileActA}:{(row + 1) * coretileActA}"
                        ] for row in range(dims.aie_rows)
                    ]
                self.act_shim_mm2s1[col] = [f"{batch_transfer} {ifm_mm2s1}"]
                self.act_shim_mm2s2[col] = [f"{batch_transfer} H:{ifmB_start_H}:{ifmB_end_H} {ifm_mm2s1}"]
                self.out_shim_s2mm[col]  = [f"{batch_transfer} H:{ifmB_start_H}:{ifmB_end_H} {ifm_mm2s1}"]
                if self.multi_ch_batch_bcast:
                    self.act_shim_mm2s1[col]  = ["f\'N:{start}:{start+1}" + f" {ifm_mm2s1}\'" + f" for start in range({outer_loop})"]
                    self.act_shim_mm2s2[col]  = ["f\'N:{start}:{start+1}" + f" H:{ifmB_start_H}:{ifmB_end_H} {ifm_mm2s1}\'" + f" for start in range({outer_loop})"]
                    self.out_shim_s2mm[col]   = ["f\'N:{start}:{start+1}" + f" H:{ifmB_start_H}:{ifmB_end_H} {ifm_mm2s1}\'" + f" for start in range({outer_loop})"]

            self.core_inner_loop = [[(np.sum(a) * params['core_iters'][1]) // outer_loop] for a in self.inner_loop]
            self.core_outer_loop = self.outer_loop * params['core_iters'][0]

            self.Msubv            = coretileActB // coretileActA
            self.Nsubv            = coretileActA

            if self.multi_ch_batch_bcast:
                params['reenqueues'] = True
                self.inner_loop = [[a[0] // outer_loop] * outer_loop for a in self.inner_loop]
                self.outer_loop = [[a[0] // outer_loop] * outer_loop for a in self.outer_loop]

        self.max_loop_len, self.inner_loop  = equalize_loop_length(self.inner_loop)
        _, self.outer_loop                  = equalize_loop_length(self.outer_loop, self.max_loop_len)
        self.ShimActRepeat    = [1] + [0] * (self.max_loop_len - 1)
        self.ShimOutRepeat    = [1] + [0] * (self.max_loop_len - 1)
        self.ShimParamRepeat  = [1] + [0] * (self.max_loop_len - 1)
        self.ShimQdqPrmRepeat = [1] + [0] * (self.max_loop_len - 1)

        
    
    def gen_core_instrs(self, params):
        code  = "    core_instrs = {}\n"
        for col in range(params['dims'].aie_cols):
            for row in range(params['dims'].aie_rows):
                if col not in self.valid_cols:
                    code += "".join([f"    core_instrs[AieTile(TileType.Core, {col}, {row})] = gen_empty_core_instrs({params['CoreQdqAddr'][0]}, {params['dims'].qdq_bytes})\n" ])
                else:
                    code += "".join([f"    core_instrs[AieTile(TileType.Core, {col}, {row})] = gen_core_instrs({self.op_select.value}, {self.Msubv}, {self.Nsubv}, {params['CoreQdqAddr'][0]}, {params['dims'].qdq_bytes}, {params['CoreIfmaAddr']}, {params['CoreIfmaSize'] // params['ifmA_scale_factor']}, {params['CoreIfmbAddr']}, {params['CoreIfmbSize'] // params['ifmB_scale_factor']}, {params['CoreOfmAddr']}, {params['CoreOfmSize'] // params['ofm_scale_factor']}, {np.sum(self.core_outer_loop[col])}, {np.sum(self.core_inner_loop[col])})\n"])
        return code
    
    #Shim transfers
    def gen_shim_prm_pattern(self, params):
        shim_prm_bytes = params['ParamSize'] * params['dims'].aie_cols * params['dims'].aie_rows
        shim_prm_size  = params['ParamSize']
        offset         = params['ParamSize']
        code = ""
        for col in range(params['dims'].aie_cols):
            code += f"    shim_transfers += [DataTransfer({self.ShimParamRepeat}, AieTile(TileType.Shim, {col}), [{params['ShimPrmBufferIdx']}], {shim_prm_bytes}, [], [TransferParams(shim_dma({col}, DmaDir.MM2S, {params.get('ShimParamChannelId', 0)}),{shim_prm_size}, offset={col * offset})])]\n"
        return code
    
    def gen_shim_qdq_prm_pattern(self, params):
        code = ""
        for col in range(0, params['dims'].aie_cols, 2):
            code += f"    shim_transfers += [DataTransfer({self.ShimParamRepeat}, AieTile(TileType.Shim, {col}), [{params['ShimWgtBufferIdx']}], {params.get('ShimWgtSize',0) + params['dims'].qdq_bytes}, [], [TransferParams(shim_dma({col}, DmaDir.MM2S, 0), {params['dims'].qdq_bytes // 4}, offset={params.get('ShimWgtSize', 0) // 4})])]\n"
        return code
    
    def gen_shim_ifm(self, params):
        code         = "    #SHIM ACTIVATION TRANSFERS IFM1\n"
        offset_ifmA  = 0
        offset_ifmB  = np.prod(params['in_ifmA_shape']).astype(int) * getByteCount(params["WgtDtype"])
        outer_loop   = params['dims'].outer_loop
        ifmA_buff_idx = params['ShimActBufferIdx']
        ifmB_buff_idx = params['ShimActBufferIdx']

        if self.ifmA_const:
            ifmA_buff_idx = params['ShimWgtBufferIdx']
            offset_ifmA   = params.get('ShimWgtSize',0) + params['dims'].qdq_bytes
            offset_ifmB   = 0
        elif self.ifmB_const:
            ifmB_buff_idx = params['ShimWgtBufferIdx']
            offset_ifmA   = 0
            offset_ifmB   = params.get('ShimWgtSize',0) + params['dims'].qdq_bytes

        for col in self.valid_cols:
            if self.multi_ch_batch_bcast:
                code += f"    shim_transfers += [generate_packed_shim_data_transfer([1] * {outer_loop}, shim_dma({col}, DmaDir.MM2S, 0), {ifmA_buff_idx}, ['{self.act_shim_memory1}'] * {outer_loop}, [{self.act_shim_mm2s1[col][0]}], [1] * {outer_loop}, {params['dims'].ifmA_bits}"
            else:
                code += f"    shim_transfers += [generate_shim_data_transfer({self.ShimActRepeat}, shim_dma({col}, DmaDir.MM2S, 0), {ifmA_buff_idx}, '{self.act_shim_memory1}', '{self.act_shim_mm2s1[col][0]}', {params['dims'].ifmA_bits}"
            if offset_ifmA == 0:
                code += ")]\n"
            else:
                code += f", buffer_offset={offset_ifmA})]\n"
        code += "    #SHIM ACTIVATION TRANSFERS IFM2\n"
        for col in self.valid_cols:
            if self.multi_ch_batch_bcast:
                code += f"    shim_transfers += [generate_packed_shim_data_transfer([1] * {outer_loop}, shim_dma({col}, DmaDir.MM2S, 1), {ifmB_buff_idx}, ['{self.act_shim_memory2}'] * {outer_loop}, [{self.act_shim_mm2s2[col][0]}], [1] * {outer_loop}, {params['dims'].ifmB_bits}"
            else:
                code += f"    shim_transfers += [generate_shim_data_transfer({self.ShimActRepeat}, shim_dma({col}, DmaDir.MM2S, 1), {ifmB_buff_idx}, '{self.act_shim_memory2}', '{self.act_shim_mm2s2[col][0]}', {params['dims'].ifmB_bits}"
            if offset_ifmB == 0:
                code += ")]\n"
            else:
                code += f", buffer_offset={offset_ifmB})]\n"        
        return code
    
    def gen_shim_ofm(self, params):
        code = ""
        outer_loop   = params['dims'].outer_loop
        for col in self.valid_cols:
            if self.multi_ch_batch_bcast:
                code += f"    shim_transfers += [generate_packed_shim_data_transfer([1] * {outer_loop}, shim_dma({col}, DmaDir.S2MM, 0), {params['ShimOutBufferIdx']}, ['{self.out_shim_memory}'] * {outer_loop}, [{self.out_shim_s2mm[col][0]}], [1] * {outer_loop}, {params['dims'].ofm_bits})]\n"
            else:
                code += f"    shim_transfers += [generate_shim_data_transfer({self.ShimOutRepeat}, shim_dma({col}, DmaDir.S2MM, 0), {params['ShimOutBufferIdx']}, '{self.out_shim_memory}', '{self.out_shim_s2mm[col][0]}', {params['dims'].ofm_bits})]\n"
        return code
    #Shim transfers--------------------->>>

    #Mem transfers
    def gen_memtile_prm_pattern(self, params):
        code = f"""    memtile_transfers += [
            DataTransfer({self.ShimParamRepeat}, AieTile(TileType.Memtile, col), [{params['MemtilePrmAddr']}], {params['ParamSize'] * 4},
                [TransferParams(memtile_dma(col, DmaDir.S2MM, {params.get('ShimParamChannelId', 0)}), {params['ParamSize']})],
                [TransferParams(memtile_dma(col, DmaDir.MM2S, row), {params['ParamSize'] // 4}, offset=row * {params['ParamSize'] // 4})
                for row in range({params['dims'].aie_rows})]
            ) for col in range({params['dims'].aie_cols})
        ]
"""     
        code += f"""    memtile_transfers += [
            DataTransfer({self.ShimQdqPrmRepeat}, AieTile(TileType.Memtile, col), [{params['MemtileQdqAddr']}], {params['dims'].qdq_bytes},
                [TransferParams(memtile_dma(col, DmaDir.S2MM, {params.get('ShimQdqChannelId',0)}), {params['dims'].qdq_bytes//4})],
                [TransferParams(memtile_dma(col, DmaDir.MM2S, 4), {params['dims'].qdq_bytes//4})]
            ) for col in range(0, {params['dims'].aie_cols}, 2)
        ]
"""
        return code
    
    def gen_memtile_ifm(self, params):
        IfmAAddr = (
            [params["MemtileIfmaPingAddr"], params["MemtileIfmaPongAddr"]]
            if params["ping_pong_enable"]["ifmA"]
            else [params["MemtileIfmaPingAddr"]]
        )
        IfmBAddr = (
            [params["MemtileIfmbPingAddr"], params["MemtileIfmbPongAddr"]]
            if params["ping_pong_enable"]["ifmB"]
            else [params["MemtileIfmbPingAddr"]]
        )
        code = ""
        for col in self.valid_cols:
            if params['bd_chain']:
                if col in self.reenqueue_cols:
                    s2mm_transfer  = f"pack_transfers(memtile_dma({col}, DmaDir.S2MM, 0), ['{self.act_memtile_memory1}', '{self.act_memtile_memory1}'], ['{self.act_memtile_s2mm1[col][0]}', '{self.act_memtile_s2mm1[col][1]}'], [1, 1], {params['dims'].ifmA_bits}),\n"
                    s2mm_transfer += f"                     pack_transfers(memtile_dma({col}, DmaDir.S2MM, 1), ['{self.act_memtile_memory2}', '{self.act_memtile_memory2}'], ['{self.act_memtile_s2mm2[col][0]}', '{self.act_memtile_s2mm2[col][1]}'], [1, 1], {params['dims'].ifmB_bits}, buffer_offset={self.ifmB_buff_offset})"
                else:
                    s2mm_transfer  = f"generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 0), '{self.act_memtile_memory1}', '{self.act_memtile_s2mm1[col][0]}', {params['dims'].ifmA_bits}),\n"
                    s2mm_transfer += f"                     generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 1), '{self.act_memtile_memory2}', '{self.act_memtile_s2mm2[col][0]}', {params['dims'].ifmB_bits}, buffer_offset={self.ifmB_buff_offset})"
                mm2s_transfer1      = "\n                     ".join([f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), '{self.act_memtile_memory1}', '{self.act_memtile_mm2s1[row][0]}', {params['dims'].ifmA_bits})," for row in range(params['dims'].aie_rows)])
                mm2s_transfer2      = "\n                     ".join([f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), '{self.act_memtile_memory2}', '{self.act_memtile_mm2s2[row][0]}', {params['dims'].ifmB_bits}, buffer_offset={self.ifmB_buff_offset})," for row in range(params['dims'].aie_rows)])

                mm2s_transfer       = mm2s_transfer1 + "\n                     " + mm2s_transfer2

                code += f"""    memtile_transfers += [
                DataTransfer({self.inner_loop[col]}, AieTile(TileType.Memtile, {col}), {IfmAAddr}, {params['MemtileActSizeIfmA'] + params['MemtileActSizeIfmB']},
                    [{s2mm_transfer}],
                    [{mm2s_transfer}]
                )]
"""
            else:
                s2mm2_transfer = f"generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 1), '{self.act_memtile_memory2}', '{self.act_memtile_s2mm2[col][0]}', {params['dims'].ifmB_bits})"
                if col in self.reenqueue_cols:
                    s2mm2_transfer  = f"pack_transfers(memtile_dma({col}, DmaDir.S2MM, 1), ['{self.act_memtile_memory2}', '{self.act_memtile_memory2}'], ['{self.act_memtile_s2mm2[col][0]}', '{self.act_memtile_s2mm2[col][1]}'], [1, 1], {params['dims'].ifmB_bits})"
                mm2s_transfer_1 = ",\n                     ".join([f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), '{self.act_memtile_memory1}', '{self.act_memtile_mm2s1[row][0]}', {params['dims'].ifmA_bits})"
                    for row in range(params["dims"].aie_rows)])
                code += f"""    memtile_transfers += [
                DataTransfer({self.outer_loop[col]}, AieTile(TileType.Memtile, {col}), {IfmAAddr}, {params['MemtileActSizeIfmA']},
                    [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 0), '{self.act_memtile_memory1}', '{self.act_memtile_s2mm1[col][0]}', {params['dims'].ifmA_bits})],
                    [{mm2s_transfer_1}],
                    sync_strategy=SyncStrategy.Parallel_1_to_N)]\n"""
                
                mm2s_transfer_2 = ",\n                     ".join([f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), '{self.act_memtile_memory2}', '{self.act_memtile_mm2s2[row][0]}', {params['dims'].ifmB_bits})"
                    for row in range(params["dims"].aie_rows)])
                code += f"""    memtile_transfers += [
                DataTransfer({self.inner_loop[col]}, AieTile(TileType.Memtile, {col}), {IfmBAddr}, {params['MemtileActSizeIfmB']},
                    [{s2mm2_transfer}],
                    [{mm2s_transfer_2}],
                    sync_strategy=SyncStrategy.Parallel_1_to_N)]\n"""

        return code

    def gen_memtile_ofm(self, params):
        ofm_Addr = [params['MemtileOfmPingAddr']]
        code = ""
        for col in self.valid_cols:
            if col in self.reenqueue_cols:
                mm2s_transfer = f"pack_transfers(memtile_dma({col}, DmaDir.MM2S, 5), ['{self.out_memtile_memory}', '{self.out_memtile_memory}'], ['{self.out_memtile_mm2s[col][0]}', '{self.out_memtile_mm2s[col][1]}'], [1, 1], {params['dims'].ofm_bits})"
            else:
                mm2s_transfer = f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, 5), '{self.out_memtile_memory}', '{self.out_memtile_mm2s[col][0]}', {params['dims'].ofm_bits})"
            s2mm_transfer = "\n                     ".join([f"generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {2 + row}), '{self.out_memtile_memory}', '{self.out_memtile_s2mm[row][0]}', {params['dims'].ofm_bits})," for row in range(params['dims'].aie_rows)])

            code += f"""    memtile_transfers += [
                DataTransfer({self.inner_loop[col]}, AieTile(TileType.Memtile, {col}), {ofm_Addr}, {params["MemtileOutSize"]},
                    [{s2mm_transfer}],
                    [{mm2s_transfer}],
                    sync_strategy=SyncStrategy.Parallel_N_to_1
                )]
"""
        return code

    #Mem transfers--------------------->>>

    #get combined shim transfers
    def gen_shim_transfers(self, params):
        code  = "\n    shim_transfers = []\n"
        code += "    #SHIM PARAM TRANSFERS\n"
        code += self.gen_shim_prm_pattern(params)
        code += "    #SHIM QDQ PARAM TRANSFERS\n"
        code += self.gen_shim_qdq_prm_pattern(params)
        code += self.gen_shim_ifm(params)
        code += "    #SHIM ACTIVATION TRANSFERS OFM\n"
        code += self.gen_shim_ofm(params)
        return code
    
    #get combined mem transfers
    def gen_mem_transfers(self, params):
        code  = "    memtile_transfers = []\n"
        code += "    #MEMTILE PARAM TRANSFERS\n"
        code += self.gen_memtile_prm_pattern(params)
        code += "    #MEMTILE ACTIVATION TRANSFERS IFMs\n"
        code += self.gen_memtile_ifm(params)
        code += "    #MEMTILE ACTIVATION TRANSFERS OFMs\n"
        code += self.gen_memtile_ofm(params)
        return code

    def gen_dma_pattern_code(self, params):
        code  = self.gen_core_instrs(params)
        code += "\n"
        code += self.gen_mem_transfers(params)
        code += "\n"
        code += self.gen_shim_transfers(params)
        return code 

class EleWise():
    def __init__(self):
        pass

    def elw_core_instr(self):
        return f"""
def gen_core_instrs(op_select, Msubv, Nsubv, qdq_addr, qdq_bytes, ifm_addr, ifm_size, ofm_addr, ofm_size, loop_cnt):
    return [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), qdq_addr, None, qdq_bytes), #CoreQdqAddr, QdqSize
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), ifm_addr[0], ifm_addr[1], ifm_size), #CoreActPingAddr, CoreActPongAddr, CoreActSize
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), ofm_addr[0], ofm_addr[1], ofm_size), #CoreOutPingAddr, CoreOutPongAddr, CoreOutSize
        Loop(loop_cnt, [
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
            CallKernel('{self.kernel_name}', kernel_params=matadd_params(op_select, Msubv, Nsubv, qdq_addr, ifm_addr[0], 0, 0, 0, 0)),
            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        ])
    ]
"""
    
    def cascade_core_instr(self):
        return f"""
def gen_core_instrs(op_select, Msubv, Nsubv, qdq_addr, qdq_bytes, ifm_addr, ifm_size, ofm_addr, ofm_size, outer_loop, inner_loop):
    return [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), qdq_addr, None, qdq_bytes), #CoreQdqAddr, QdqSize
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), ofm_addr[0], ofm_addr[1], ofm_size), #CoreOutPingAddr, CoreOutPongAddr, CoreOutSize
        Loop(outer_loop, [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), ifm_addr[0], None, ifm_size * 2), #CoreActPingAddr, None, CoreActSize
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
            CallKernel('{self.kernel_name}', kernel_params=matadd_params(op_select, Msubv, Nsubv, qdq_addr, ifm_addr[0], 0, 0, 0, 0)), # 2 batch data
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), ifm_addr[0], ifm_addr[1], ifm_size), #CoreActPingAddr, CoreActPongAddr, CoreActSize
            Loop(inner_loop - 3, [
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                CallKernel('{self.kernel_name}', kernel_params=matadd_params(op_select, Msubv, Nsubv, qdq_addr, ifm_addr[0], 0, 0, 0, 1)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ]),
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            CallKernel('{self.kernel_name}', kernel_params=matadd_params(op_select, Msubv, Nsubv, qdq_addr, ifm_addr[0], 0, 0, 0, 2)), # last batch data
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            RelBuffer(DmaChannel(DmaDir.MM2S, 0))
        ])
    ]
"""
    
    def empty_core_instr(self):
        return f"""
def gen_empty_core_instrs(qdq_addr, qdq_bytes):
    return [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), qdq_addr, None, qdq_bytes), #CoreQdqAddr, QdqSize
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1))
    ]
"""
    
    def gen_core_instrs_func(self):
            code = ""
            if self.op_select != Add_SupportedOps.CASCADE_ADD:
                code = self.elw_core_instr()
            else:
                assert (self.elw_dtype == "a16a16"), "The A8A8 Cascade Add is not supported in the kernel now"
                code = self.cascade_core_instr()
            code += self.empty_core_instr()
            return code

    def gen_dma_transfers(self, params):
        if "uint8" in params["op_ver"]:
            self.kernel_name = 'run_a8a8_matadd_qdq'
            self.elw_dtype   = 'a8a8'
        else:
            self.kernel_name = 'run_a16a16_matadd_qdq'
            self.elw_dtype   = 'a16a16'
        dims                 = params['dims']
        shimActAsize         = params['ShimActSizeIfmA'] // getByteCount(params['WgtDtype'])
        shimActBsize         = params['ShimActSizeIfmB'] // getByteCount(params['ActDtype'])
        shimOfmsize          = params['ShimOutSize'] // getByteCount(params['OfmDtype'])
        
        memtileActA          = params['MemtileActSizeIfmA'] // getByteCount(params['WgtDtype'])
        memtileActB          = params['MemtileActSizeIfmB'] // getByteCount(params['ActDtype'])
        memtileOfm           = params['MemtileOutSize'] // getByteCount(params['OfmDtype'])
        
        coretileActA         = (params['CoreIfmaSize'] // getByteCount(params['WgtDtype']))
        coretileActB         = (params['CoreIfmbSize'] // getByteCount(params['ActDtype']))
        coretileOfm          = (params['CoreOfmSize'] // getByteCount(params['OfmDtype']))

        assert(coretileActA % params['ifmA_scale_factor'] == 0)
        assert(coretileActB % params['ifmB_scale_factor'] == 0)
        assert(coretileOfm  % params['ofm_scale_factor'] == 0)

        coretileActA         = coretileActA // params['ifmA_scale_factor']
        coretileActB         = coretileActB // params['ifmB_scale_factor']
        coretileOfm          = coretileOfm  // params['ofm_scale_factor']

        params['reenqueues'] = False

        self.Msubv, self.Nsubv                 = decompose(dims.M_subv)

        if params['ifmA_param_type'] == 'act':
            self.ifmA_const = 0
        elif params['ifmA_param_type'] == 'const':
            self.ifmA_const = 1

        if params['ifmB_param_type'] == 'act':
            self.ifmB_const = 0
        elif params['ifmB_param_type'] == 'const':
            self.ifmB_const = 1

        if params["op_name"] == "Add":
            if params["numBatches"] != None:
                self.op_select = Add_SupportedOps.CASCADE_ADD
            else:
                self.op_select = Add_SupportedOps.ELW_ADD
        elif params["op_name"] == "Mul":
            self.op_select = Add_SupportedOps.ELW_MUL
        elif params["op_name"] == "Sub":
            self.op_select = Add_SupportedOps.ELW_SUB
        elif params["op_name"] == "Div":
            self.op_select = Add_SupportedOps.ELW_DIV
        else:
            raise RuntimeError("Unsupported op type!")

        self.act_shim_mm2s1      = [[] for _ in range(dims.aie_cols)]
        self.act_shim_mm2s2      = [[] for _ in range(dims.aie_cols)]
        self.out_shim_s2mm       = [[] for _ in range(dims.aie_cols)]
        self.act_memtile_s2mm1   = [[] for _ in range(dims.aie_cols)]
        self.act_memtile_s2mm2   = [[] for _ in range(dims.aie_cols)]
        self.out_memtile_mm2s    = [[] for _ in range(dims.aie_cols)]
        self.core_loops          = [[] for _ in range(dims.aie_cols)]
        self.mem_loops           = [[] for _ in range(dims.aie_cols)]
        
        outer_loop               = dims.outer_loop
        inner_loop               = dims.inner_loop

        if self.op_select != Add_SupportedOps.CASCADE_ADD:
            assert(outer_loop == inner_loop) #Elwadd both loops should be same
        else:
            outer_loop           = inner_loop
            self.casc_inner_loop = np.prod(params['numBatches'])


        ifmA_orig    = np.prod(params['in_ifmA_shape']).astype(int)
        ifmB_orig    = np.prod(params['in_ifmB_shape']).astype(int) if params["numBatches"] == None else np.prod(params['in_ifmB_shape']).astype(int) // params['numBatches'][0]
        ofm_orig     = np.prod(params['out_ofm_shape']).astype(int)
        
        self.act_shim_memory1 = f"W:{ifmA_orig}"
        self.act_shim_memory2 = f"W:{ifmB_orig}" if params["numBatches"] == None else f"N:{params['numBatches'][0]} W:{ifmB_orig}"
        self.out_shim_memory  = f"W:{ofm_orig}"

        self.act_memtile_memory1 = f"H:1 W:{memtileActA}"
        self.act_memtile_memory2 = f"H:1 W:{memtileActB}" 
        self.out_memtile_memory  = f"W:{memtileOfm}"

        self.ifmB_buff_offset    = memtileActA * getByteCount(params['WgtDtype']) if self.op_select != Add_SupportedOps.CASCADE_ADD else 0

        self.valid_cols          = []
        self.reenqueue_cols      = []
        
        if self.op_select != Add_SupportedOps.CASCADE_ADD:
            assert(ifmA_orig == ifmB_orig)
        
        for col in range(dims.aie_cols):
            startIfm     = col * shimActBsize
            endIfm       = (col+1) * shimActBsize

            self.mem_loops[col]             = [outer_loop]
            self.core_loops[col]            = outer_loop

            self.valid_cols.append(col)

            if startIfm >= ifmB_orig:
                startIfm = 0
                endIfm   = 0
                self.core_loops[col] = 0
                self.valid_cols = self.valid_cols[:-1]

            if endIfm > ifmB_orig:
                endIfm = ifmB_orig
            
            if self.op_select == Add_SupportedOps.CASCADE_ADD:
                self.act_shim_mm2s1[col]    = [f"W:0:0"]
                self.act_shim_mm2s2[col]    = [f"N:0:{params['numBatches'][0]}:1 W:{startIfm}:{endIfm}"]
                self.out_shim_s2mm[col]     = [f"W:{startIfm}:{endIfm}"]
            else:
                self.act_shim_mm2s1[col]    = [f"W:{startIfm}:{endIfm}"]
                self.act_shim_mm2s2[col]    = [f"W:{startIfm}:{endIfm}"]
                self.out_shim_s2mm[col]     = [f"W:{startIfm}:{endIfm}"]
            
            self.act_memtile_s2mm1[col]     = [f"H:0:1 W:0:{(endIfm - startIfm) // outer_loop}"] if self.op_select != Add_SupportedOps.CASCADE_ADD else [f"W:0:0"]
            self.act_memtile_s2mm2[col]     = [f"H:0:1 W:0:{(endIfm - startIfm) // outer_loop}"]
            self.out_memtile_mm2s[col]      = [f"W:0:{(endIfm - startIfm) // outer_loop}"]
            
            if (endIfm - startIfm) % outer_loop != 0:
                params['reenqueues']        = True
                self.reenqueue_cols.append(col)
                self.mem_loops[col]         = [(endIfm - startIfm) // memtileActB, 1]
                self.core_loops[col]        = np.sum(self.mem_loops[col])
                pending_transfer            = (endIfm - startIfm) % memtileActB
                actual_transfer             = ((shimActBsize * (col + 1)) - startIfm) // outer_loop
                if pending_transfer == 0:
                    self.mem_loops[col]         = [(endIfm - startIfm) // memtileActB]
                    self.core_loops[col]        = np.sum(self.mem_loops[col])
                    self.act_memtile_s2mm1[col] = [f"H:0:1 W:0:{actual_transfer}"] if self.op_select != Add_SupportedOps.CASCADE_ADD else [f"W:0:0"]
                    self.act_memtile_s2mm2[col] = [f"H:0:1 W:0:{actual_transfer}"]
                    self.out_memtile_mm2s[col]  = [f"W:0:{actual_transfer}"]
                    params['reenqueues']        = False
                    self.reenqueue_cols         = self.reenqueue_cols[:-1]
                else:
                    self.act_memtile_s2mm1[col] = [f"H:0:1 W:0:{actual_transfer}", f"H:0:1 W:0:{pending_transfer}"] if self.op_select != Add_SupportedOps.CASCADE_ADD else [f"W:0:0"]
                    self.act_memtile_s2mm2[col] = [f"H:0:1 W:0:{actual_transfer}", f"H:0:1 W:0:{pending_transfer}"]
                    self.out_memtile_mm2s[col]  = [f"W:0:{actual_transfer}", f"W:0:{pending_transfer}"]
        
        if self.op_select == Add_SupportedOps.CASCADE_ADD:
            self.casc_outer_loop   = [np.sum(x) for x in self.mem_loops] 
            self.mem_loops         = [[x * np.prod(params['numBatches']) for x in sublist] for sublist in self.mem_loops]
        else:
            self.act_memtile_mm2s1 = [[f"H:0:1 W:{row * coretileActA}:{(row+1) * coretileActA}"] for row in range(dims.aie_rows)]
        self.act_memtile_mm2s2     = [[f"H:0:1 W:{row * coretileActB}:{(row+1) * coretileActB}"] for row in range(dims.aie_rows)]
        self.out_memtile_s2mm      = [[f"W:{row * coretileOfm}:{(row+1) * coretileOfm}"] for row in range(dims.aie_rows)]

        if self.op_select == Add_SupportedOps.CASCADE_ADD:
            self.mem_loops             = check_limit_loop_length(self.mem_loops, params['numBatches'][0])
        else:
            self.mem_loops             = check_limit_loop_length(self.mem_loops)
        
        self.max_loop_len, self.mem_loops = equalize_loop_length(self.mem_loops)

        # memtile s2mm if reenqueue is enabled always has 2 transfers [init, last], 
        # if memloops has length more than 2, we need to increase the len to match memloops, to use it in the 'pack_transfers' API
        # equalize tranfers is intended to match the length of the mem loops length to the s2mm/mm2s transfers
        def equalize_transfers(s2mm, loop):
            if len(s2mm) != len(loop):
                last_transfer = s2mm[1]
                init_transfer = s2mm[0]
                transfers     = [init_transfer for _ in range(len(loop) - 1)]
                transfers    += last_transfer
                return transfers
            return s2mm
        
        for col in self.reenqueue_cols:
            if self.op_select != Add_SupportedOps.CASCADE_ADD:
                self.act_memtile_s2mm1[col] = equalize_transfers(self.act_memtile_s2mm1[col], self.mem_loops[col])
            self.act_memtile_s2mm2[col]     = equalize_transfers(self.act_memtile_s2mm2[col], self.mem_loops[col])
            self.out_memtile_mm2s[col]      = equalize_transfers(self.out_memtile_mm2s[col], self.mem_loops[col])

        self.ShimActRepeat    = [1] + [0] * (self.max_loop_len - 1)
        self.ShimOutRepeat    = [1] + [0] * (self.max_loop_len - 1)
        self.ShimParamRepeat  = [1] + [0] * (self.max_loop_len - 1)
        self.ShimQdqPrmRepeat = [1] + [0] * (self.max_loop_len - 1)

    def gen_core_instrs(self, params):
        code  = "    core_instrs = {}\n"
        for col in range(params['dims'].aie_cols):
            for row in range(params['dims'].aie_rows):
                if self.core_loops[col] == 0:
                    code += "".join([f"    core_instrs[AieTile(TileType.Core, {col}, {row})] = gen_empty_core_instrs({params['CoreQdqAddr'][0]}, {params['dims'].qdq_bytes})\n" ])
                else:
                    if self.op_select == Add_SupportedOps.CASCADE_ADD:
                        code += "".join([f"    core_instrs[AieTile(TileType.Core, {col}, {row})] = gen_core_instrs({self.op_select.value}, {self.Msubv}, {self.Nsubv}, {params['CoreQdqAddr'][0]}, {params['dims'].qdq_bytes}, {params['CoreIfmaAddr']}, {params['CoreIfmbSize'] // params['ifmA_scale_factor']}, {params['CoreOfmAddr']}, {params['CoreOfmSize'] // params['ofm_scale_factor']}, {self.casc_outer_loop[col]}, {self.casc_inner_loop})\n"])
                    else:
                        code += "".join([f"    core_instrs[AieTile(TileType.Core, {col}, {row})] = gen_core_instrs({self.op_select.value}, {self.Msubv}, {self.Nsubv}, {params['CoreQdqAddr'][0]}, {params['dims'].qdq_bytes}, {params['CoreIfmaAddr']}, {params['CoreIfmaSize'] // params['ifmA_scale_factor']} + {params['CoreIfmbSize'] // params['ifmB_scale_factor']}, {params['CoreOfmAddr']}, {params['CoreOfmSize'] // params['ofm_scale_factor']}, {self.core_loops[col]})\n"])
        return code
    
    #Shim transfers
    def gen_shim_prm_pattern(self, params):
        shim_prm_bytes = params['ParamSize'] * params['dims'].aie_cols * params['dims'].aie_rows
        shim_prm_size  = params['ParamSize']
        offset         = params['ParamSize']
        code = ""
        for col in range(params['dims'].aie_cols):
            code += f"    shim_transfers += [DataTransfer({self.ShimParamRepeat}, AieTile(TileType.Shim, {col}), [{params['ShimPrmBufferIdx']}], {shim_prm_bytes}, [], [TransferParams(shim_dma({col}, DmaDir.MM2S, {params.get('ShimParamChannelId', 0)}),{shim_prm_size}, offset={col * offset})])]\n"
        return code
    
    def gen_shim_qdq_prm_pattern(self, params):
        code = ""
        for col in range(0, params['dims'].aie_cols, 2):
            code += f"    shim_transfers += [DataTransfer({self.ShimParamRepeat}, AieTile(TileType.Shim, {col}), [{params['ShimWgtBufferIdx']}], {params.get('ShimWgtSize',0) + params['dims'].qdq_bytes}, [], [TransferParams(shim_dma({col}, DmaDir.MM2S, 0), {params['dims'].qdq_bytes // 4}, offset={params.get('ShimWgtSize', 0) // 4})])]\n"
        return code
    
    def gen_shim_ifm(self, params):
        code = ""
        offset_ifmA  = 0
        offset_ifmB  = np.prod(params['in_ifmA_shape']).astype(int) * getByteCount(params["WgtDtype"])
        ifmA_buff_idx = params['ShimActBufferIdx']
        ifmB_buff_idx = params['ShimActBufferIdx']
        if self.ifmA_const:
            ifmA_buff_idx = params['ShimWgtBufferIdx']
            offset_ifmA   = params.get('ShimWgtSize',0) + params['dims'].qdq_bytes
            offset_ifmB   = 0
        elif self.ifmB_const:
            ifmB_buff_idx = params['ShimWgtBufferIdx']
            offset_ifmA   = 0
            offset_ifmB   = params.get('ShimWgtSize',0) + params['dims'].qdq_bytes

        if self.op_select != Add_SupportedOps.CASCADE_ADD:
            code       += "    #SHIM ACTIVATION TRANSFERS IFM1\n"
            for col in self.valid_cols:
                code += f"    shim_transfers += [generate_shim_data_transfer({self.ShimActRepeat}, shim_dma({col}, DmaDir.MM2S, 0), {ifmA_buff_idx}, '{self.act_shim_memory1}', '{self.act_shim_mm2s1[col][0]}', {params['dims'].ifmA_bits}"
                if offset_ifmA == 0:
                    code += ")]\n"
                else:
                    code += f", buffer_offset={offset_ifmA})]\n"
        code += "    #SHIM ACTIVATION TRANSFERS IFM2\n"
        for col in self.valid_cols:
            code += f"    shim_transfers += [generate_shim_data_transfer({self.ShimActRepeat}, shim_dma({col}, DmaDir.MM2S, 1), {ifmB_buff_idx}, '{self.act_shim_memory2}', '{self.act_shim_mm2s2[col][0]}', {params['dims'].ifmB_bits}"
            if offset_ifmB == 0:
                code += ")]\n"
            else:
                code += f", buffer_offset={offset_ifmB})]\n"        
        return code
    
    def gen_shim_ofm(self, params):
        code = ""
        for col in self.valid_cols:
            code += f"    shim_transfers += [generate_shim_data_transfer({self.ShimOutRepeat}, shim_dma({col}, DmaDir.S2MM, 0), {params['ShimOutBufferIdx']}, '{self.out_shim_memory}', '{self.out_shim_s2mm[col][0]}', {params['dims'].ofm_bits})]\n"
        return code
    #Shim transfers--------------------->>>
    #Mem transfers
    def gen_memtile_prm_pattern(self, params):
        code = f"""    memtile_transfers += [
            DataTransfer({self.ShimParamRepeat}, AieTile(TileType.Memtile, col), [{params['MemtilePrmAddr']}], {params['ParamSize'] * 4},
                [TransferParams(memtile_dma(col, DmaDir.S2MM, {params.get('ShimParamChannelId', 0)}), {params['ParamSize']})],
                [TransferParams(memtile_dma(col, DmaDir.MM2S, row), {params['ParamSize'] // 4}, offset=row * {params['ParamSize'] // 4})
                for row in range({params['dims'].aie_rows})]
            ) for col in range({params['dims'].aie_cols})
        ]
"""     
        code += f"""    memtile_transfers += [
            DataTransfer({self.ShimQdqPrmRepeat}, AieTile(TileType.Memtile, col), [{params['MemtileQdqAddr']}], {params['dims'].qdq_bytes},
                [TransferParams(memtile_dma(col, DmaDir.S2MM, {params.get('ShimQdqChannelId',0)}), {params['dims'].qdq_bytes//4})],
                [TransferParams(memtile_dma(col, DmaDir.MM2S, 4), {params['dims'].qdq_bytes//4})]
            ) for col in range(0, {params['dims'].aie_cols}, 2)
        ]
"""
        return code
    
    def gen_memtile_ifm(self, params):
        pingAddr = params["MemtileIfmaPingAddr"]
        # since merging IFMA and IFMB, recalculating the pongAddr based on size
        # Buffer allocater provides 
        # PingA | PongA | PingB | PongB  #provided alloc by buffallocater continuous memory in memtile
        # PingA | PingB | PongA | PongB  #Reallocating since enough memory space is available
        pongAddr = pingAddr + params["MemtileActSizeIfmB"] + params["MemtileActSizeIfmA"]
        IfmAAddr = (
            [pingAddr, pongAddr]
            if params["ping_pong_enable"]["ifmA"] and params['numBatches'] == None
            else [params["MemtileIfmaPingAddr"]]
        )
        code = ""
        if self.op_select == Add_SupportedOps.CASCADE_ADD:
            for col in self.valid_cols:
                if col in self.reenqueue_cols:
                    loop_len = len(self.mem_loops[col])
                    memtile_s2mm      = [f'{s}' for s in self.act_memtile_s2mm2[col]]
                    s2mm_transfer  = f"pack_transfers(memtile_dma({col}, DmaDir.S2MM, 1), ['{self.act_memtile_memory}'] * {loop_len}, {memtile_s2mm} , [1] * {loop_len}, {params['dims'].ifmB_bits})"
                else:
                    s2mm_transfer  = f"generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 1), '{self.act_memtile_memory2}', '{self.act_memtile_s2mm2[col][0]}', {params['dims'].ifmB_bits})"
                mm2s_transfer = "\n                 ".join([f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), '{self.act_memtile_memory2}', '{self.act_memtile_mm2s2[row][0]}', {params['dims'].ifmB_bits})," for row in range(params['dims'].aie_rows)])
                code += f"""    memtile_transfers += [
                DataTransfer({self.mem_loops[col]}, AieTile(TileType.Memtile, {col}), {IfmAAddr}, {params["MemtileActSizeIfmB"] + params["MemtileActSizeIfmA"]},
                    [{s2mm_transfer}],
                    [{mm2s_transfer}],
                    sync_strategy=SyncStrategy.Parallel_1_to_N
                )]
"""
        else:
            for col in self.valid_cols:
                if col in self.reenqueue_cols:
                    loop_len = len(self.mem_loops[col])
                    memtile_s2mm1   = [f'{s}' for s in self.act_memtile_s2mm1[col]]
                    memtile_s2mm2   = [f'{s}' for s in self.act_memtile_s2mm2[col]]
                    s2mm_transfer  = f"pack_transfers(memtile_dma({col}, DmaDir.S2MM, 0), ['{self.act_memtile_memory1}'] * {loop_len}, {memtile_s2mm1}, [1] * {loop_len}, {params['dims'].ifmA_bits}),\n"
                    s2mm_transfer += f"                 pack_transfers(memtile_dma({col}, DmaDir.S2MM, 1), ['{self.act_memtile_memory2}'] * {loop_len}, {memtile_s2mm2}, [1] * {loop_len}, {params['dims'].ifmB_bits}, buffer_offset={self.ifmB_buff_offset})"
                else:
                    s2mm_transfer  = f"generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 0), '{self.act_memtile_memory1}', '{self.act_memtile_s2mm1[col][0]}', {params['dims'].ifmA_bits}),\n"
                    s2mm_transfer += f"                     generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 1), '{self.act_memtile_memory2}', '{self.act_memtile_s2mm2[col][0]}', {params['dims'].ifmB_bits}, buffer_offset={self.ifmB_buff_offset})"
                mm2s_transfer = "\n                     ".join([f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), '{self.act_memtile_memory1}', '{self.act_memtile_mm2s1[row][0]}', {params['dims'].ifmA_bits})," for row in range(params['dims'].aie_rows)])
                mm2s_transfer += "\n                     "
                mm2s_transfer += "\n                     ".join([f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), '{self.act_memtile_memory2}', '{self.act_memtile_mm2s2[row][0]}', {params['dims'].ifmB_bits}, buffer_offset={self.ifmB_buff_offset})," for row in range(params['dims'].aie_rows)])

                code += f"""    memtile_transfers += [
                DataTransfer({self.mem_loops[col]}, AieTile(TileType.Memtile, {col}), {IfmAAddr}, {params["MemtileActSizeIfmB"] + params["MemtileActSizeIfmA"]},
                    [{s2mm_transfer}],
                    [{mm2s_transfer}]
                )]
"""
        return code

    def gen_memtile_ofm(self, params):
        ofm_Addr = [params['MemtileOfmPingAddr']]
        code = ""
        for col in self.valid_cols:
            loop = self.mem_loops[col] if self.op_select != Add_SupportedOps.CASCADE_ADD else [x // params['numBatches'][0] for x in self.mem_loops[col]]
            if col in self.reenqueue_cols:
                loop_len = len(self.mem_loops[col])
                memtile_mm2s      = [f'{s}' for s in self.out_memtile_mm2s[col]]
                mm2s_transfer = f"pack_transfers(memtile_dma({col}, DmaDir.MM2S, 5), ['{self.out_memtile_memory}'] * {loop_len}, {memtile_mm2s}, [1] * {loop_len}, {params['dims'].ofm_bits})"
            else:
                mm2s_transfer = f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, 5), '{self.out_memtile_memory}', '{self.out_memtile_mm2s[col][0]}', {params['dims'].ofm_bits})"
            s2mm_transfer = "\n                     ".join([f"generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {2 + row}), '{self.out_memtile_memory}', '{self.out_memtile_s2mm[row][0]}', {params['dims'].ofm_bits})," for row in range(params['dims'].aie_rows)])

            code += f"""    memtile_transfers += [
                DataTransfer({loop}, AieTile(TileType.Memtile, {col}), {ofm_Addr}, {params["MemtileOutSize"]},
                    [{s2mm_transfer}],
                    [{mm2s_transfer}],
                    sync_strategy=SyncStrategy.Parallel_N_to_1
                )]
"""
        return code

    #Mem transfers--------------------->>>
    
    #get combined shim transfers
    def gen_shim_transfers(self, params):
        code  = "\n    shim_transfers = []\n"
        code += "    #SHIM PARAM TRANSFERS\n"
        code += self.gen_shim_prm_pattern(params)
        code += "    #SHIM QDQ PARAM TRANSFERS\n"
        code += self.gen_shim_qdq_prm_pattern(params)
        code += self.gen_shim_ifm(params)
        code += "    #SHIM ACTIVATION TRANSFERS OFM\n"
        code += self.gen_shim_ofm(params)
        return code
    
    #get combined mem transfers
    def gen_mem_transfers(self, params):
        code  = "    memtile_transfers = []\n"
        code += "    #MEMTILE PARAM TRANSFERS\n"
        code += self.gen_memtile_prm_pattern(params)
        code += "    #MEMTILE ACTIVATION TRANSFERS IFMs\n"
        code += self.gen_memtile_ifm(params)
        code += "    #MEMTILE ACTIVATION TRANSFERS OFMs\n"
        code += self.gen_memtile_ofm(params)
        return code

    def gen_dma_pattern_code(self, params):
        code  = self.gen_core_instrs(params)
        code += "\n"
        code += self.gen_mem_transfers(params)
        code += "\n"
        code += self.gen_shim_transfers(params)
        return code 

operations = {
    "EleWise": EleWise(),
    "BroadCast": BroadCast()
}

class MatAdd_Base(BaseTemplate):
    def __init__(self, _data=None):
        super().__init__()
        self.data        = _data
        self.helper_func = self.helper_func()
        self.prm         = self.params()
    
    def gen_code(self, params):
        logging.info("Generate dataflow for Elw/Broadcast operation")
        ops           = params['op_ver'].split("_")[2]
        data_flow     = ""
        data_flow    += self.gen_headers(params['template_meta_data']['overlay'])
        if params['reenqueues']:
            data_flow += self.gen_pack_TransferParams()
            data_flow += self.generate_packed_shim_data_transfer()
        data_flow    += self.gen_layer_params_generator()
        data_flow    += operations[ops].gen_core_instrs_func() + "\n"
        data_flow    += self.gen_dataflow(params)
        data_flow    += operations[ops].gen_dma_pattern_code(params)
        data_flow += self.gen_kernel_name(params["op_name"], params["op_ver"])
        data_flow += self.gen_connection_run_compile(
            params["template_meta_data"]["overlay"],
            params["dims"].aie_cols,
            params["dims"].aie_rows,
            params["CoreStackAddr"][0],
        )
        data_flow += self.gen_main_func()

        return data_flow
    
    def gen_layer_params_generator(self):
        return f"""
def matadd_params(n: int, Msubv: int, Nsubv: int, qdq_addr: int, matA_addr: int, matB_addr: int, isFused: int, do_neg: int, itr: int):
    return (n.to_bytes(length=2, byteorder='little', signed=False)
            +Msubv.to_bytes(length=2, byteorder='little', signed=False)
            +Nsubv.to_bytes(length=2, byteorder='little', signed=False)
            +qdq_addr.to_bytes(length=2, byteorder='little', signed=False)
            +matA_addr.to_bytes(length=2, byteorder='little', signed=False)
            +matB_addr.to_bytes(length=2, byteorder='little', signed=False)
            +isFused.to_bytes(length=2, byteorder='little', signed=False)
            +do_neg.to_bytes(length=2, byteorder='little', signed=False)
            +itr.to_bytes(length=2, byteorder='little', signed=False)
            )
"""
    class helper_func:
        def col_index(self, dims: MataddDims, array: int, col: int) -> int:
            assert 0 <= col < dims.aie_cols
            assert 0 <= array < dims.aie_arrays
            return (array * dims.aie_cols) + col

        def reuse_chain_length(
            self, reuse_ratio: int, num_consumers: int, max_chain_length: int = 4
        ) -> int:
            max_lock_value = 63
            for i in range(1, max_chain_length + 1):
                is_valid = ((reuse_ratio % i) == 0) and (
                    ((reuse_ratio // i) * num_consumers) <= max_lock_value
                )
                if is_valid:
                    return i
            raise RuntimeError("Failed to allocate reuse chain!")

        def help_info(self):
            code = """
--------------------------------------------------------
| QDQ Add    |             |            IFM           | 
-------------------------------------------------------
|            |             | Pin    | Stream | Full   |
| No Padding | No PingPong | No     | Yes    | No     |
|            | Ping Pong   | No     | Yes    | No     |
| M Padding  | No PingPong | No     | No     | No     |
|            | Ping Pong   | No     | No     | No     |
| K Padding  | No PingPong | No     | No     | No     |
|            | Ping Pong   | No     | No     | No     |
-------------------------------------------------------
"""

    class params(helper_func):
        def gen_dma_params(self, _pipeline_data):
            def update_buffer_alloc_to_params(params, buff):
                buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
                params["ParamSize"] = const.PARAM_SIZE
                params["CoreBankSize"] = const.CORE_BANK_SIZE
                params["ShimOutBufferIdx"] = 0
                params["ShimActBufferIdx"] = 1
                params["ShimWgtBufferIdx"] = 2
                params["ShimPrmBufferIdx"] = 3
                params.update(buff[const.BufAllocator_Idx.CORE_TILE_ADDR_IDX.value])
                params.update(buff[const.BufAllocator_Idx.CORE_TILE_SIZE_IDX.value])
                params.update(buff[const.BufAllocator_Idx.MEM_TILE_ADDR_IDX.value])
                params["ping_pong_enable"] = buff_prm["sch_attr"].ping_pong_enable
                params["ActDataFlow"] = buff_prm["sch_attr"].dataflow_mode["ifmB"]
                params["ActPingPong"] = buff_prm["sch_attr"].ping_pong_enable["ifmB"]
                params["OfmPingPong"] = buff_prm["sch_attr"].ping_pong_enable["ofm"]
                params["ShimActSizeIfmA"] = buff_prm["ifmA_shim_tile_size"]
                params["ShimActSizeIfmB"] = buff_prm["ifmB_shim_tile_size"]
                params["ShimOutSize"] = buff_prm["ofm_shim_tile_size"]
                params["MemtileActSizeIfmA"] = buff_prm["ifmA_mem_tile_size"]
                params["MemtileActSizeIfmB"] = buff_prm["ifmB_mem_tile_size"]
                params["MemtileOutSize"] = buff_prm["ofm_mem_tile_size"]
                params["WgtDtype"] = buff_prm["ifmA"].dtype
                params["ActDtype"] = buff_prm["ifmB"].dtype
                params["OfmDtype"] = buff_prm["ofm"].dtype
                params["op_name"] = buff_prm["orig_op_type"]
                params["op_ver"] = buff_prm["op_type"]
                params["template_meta_data"] = {
                    "op_type": buff_prm["orig_op_type"],
                    "mode": buff_prm["mode"],
                    "overlay": buff_prm["overlay"],
                }
                params['ShimParamChannelId'] = 1
                params['param_channel_id']   = 0
                params['ShimQdqChannelId']   = 0
                params['CoreQdqChId']        = 1
                params['ifmA_param_type'] = buff_prm['ifmA_param_type']
                params['ifmB_param_type'] = buff_prm['ifmB_param_type']
                params["padding_ifmA"] = sum(buff_prm["padding"][0]["pad_ifmA_x"])
                params["padding_ifmB"] = sum(buff_prm["padding"][1]["pad_ifmB_x"])
                params['numBatches'] = buff_prm["num_batches"]
                params['ifmA_mode'] = buff_prm['ifmA_mode']
                params['ifmB_mode'] = buff_prm['ifmB_mode']
                params['ofm_mode'] = buff_prm['ofm_mode']
                params['in_ifmA_shape'] = buff_prm['in_ifmA_shape']
                params['in_ifmB_shape'] = buff_prm['in_ifmB_shape']
                params['out_ofm_shape'] = buff_prm['out_ofm_shape']
                params['multi_ch_batch_bcast'] = buff_prm['multi_ch_batch_bcast']
                params['core_iters'] = [buff_prm['core_outer_loop'][0], buff_prm['core_inner_loop'][0]]
                params['inner_dim_is_1'] = buff_prm['inner_dim_is_1']
                params['ifmA_scale_factor'] = buff_prm['ifmA_scale_factor']
                params['ifmB_scale_factor'] = buff_prm['ifmB_scale_factor']
                params['ofm_scale_factor'] = buff_prm['ofm_scale_factor']


            buff      = _pipeline_data.info.get("BuffAllocator")
            buff_prm  = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
            split = [int(x) for x in re.findall(r'\d+', str(buff_prm['mode']))]
            base_Data = (
                buff_prm["M_subV"],
                buff_prm["aie_rows"],
                buff_prm["aie_cols"],
                buff_prm["aie_arrays"],
                buff_prm["outer_loop"],
                buff_prm["inner_loop"],
                buff_prm["ifmA_bits"],
                buff_prm["ifmB_bits"],
                buff_prm["ofm_bits"],
                buff_prm["qdq_bytes"]
            )
            dims   = MataddDims(*base_Data)
            params = {"dims": dims}
            update_buffer_alloc_to_params(params, buff)
            ops    = params['op_ver'].split("_")[2]
            operations[ops].gen_dma_transfers(params)
            return params


class M1N16(MatAdd_Base):
    def __init__(self, _data=None):
        super().__init__()
        self.data = _data
        # self.helper_func = self.helper_func()
        self.prm = self.params()


class M1N32(MatAdd_Base):
    def __init__(self, _data=None):
        super().__init__()
        self.data = _data
        # self.helper_func = self.helper_func()
        self.prm = self.params()
