import sys
import os
import logging
import json
import OGOAT.src.Scheduling_Engine.infra.scheduler_utils as utils
from OGOAT.src.Scheduling_Engine.infra import const
import copy
import dataclasses
import numpy as np
import math

from OGOAT.src.Scheduling_Engine.code_gen import gen_kernel_param

@dataclasses.dataclass(slots=True)
class BaseDims():
    split: list
    B: int
    M: int
    K: int
    N: int
    B_subv: int
    M_subv: int
    K_subv: int
    N_subv: int
    aie_rows: int
    aie_cols: int
    aie_arrays: int
    act_bits: int
    out_bits: int
    qdq_bytes: int
    # The Tm/Tn/Tk refers to the trip-count of M/N/K-loop of the loop-nest
    # in super-kernel.
    Tm: int
    Tn: int
    Tk: int
    wgt_subv_rows: int
    wgt_subv_cols: int
    act_subv_bytes: int
    wgt_subv_bytes: int
    out_subv_bytes: int
    sum_subv_bytes: int
    tdm_subv_bytes: int
    wgt_bits: int
    bits_per_byte: int
    bias_bits: int
    tdm_bits: int

class BaseTemplate():
    def gen_instr(self, params=None):
        utils.sanity_check(False,"Needs to be implemented in dervied class")

    def gen_print_attr(self, op_params):
        utils.sanity_check(False,"Expect derived class to implement this")

    def helper_func(self):
        pass

    def shorten_val(self, a, ref_in=[], elements=1):
        def append_val(val, cnt, elements):
            if cnt > 1:
                return [f"{val.tolist()}*{cnt}"]
            else:
                if elements>1:
                    r =[f"['{x}']" for x in val]
                    return r
                return [f"{val.tolist()}"]
        val  = None
        cnt  = 0
        out  = []
        in_list   = np.array(a).reshape(-1,elements)
        for idx, x in enumerate(in_list):
            if idx==0:
                val=x
                cnt=1
            elif all(x==val):
                cnt+=1
            else:
                out.extend(append_val(val, cnt, elements))
                cnt=1
                val=x
            if idx==len(in_list)-1:
                apv = append_val(val, cnt, elements)
                out.extend(apv)
        if ref_in!=[] and len(out)>=len(ref_in):
            out = ref_in
        is_2x = len(a)%(2*elements) == 0
        if elements<=8 and elements<(len(a) if isinstance(a, list) else 1) and is_2x:
            return BaseTemplate.shorten_val(self, a, out, 2*elements)
        else:
            return ' + '.join(out)
    
    def gen_main_func(self):
        return """
def main(backend=BackEnd.Adf):
    generate_dataflow(backend)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-b','--backend', required=False, choices=['Adf', 'TxnHostPatch'], default='Adf')
    args = parser.parse_args()
    main(eval(f"BackEnd.{args.backend}"))

"""

    def gen_kernel_name(self,op_code, op_ver):
        #self.data.info.get('BuffAllocator')

        if not self.data.info.get('BuffAllocator')[const.KERNEL_IDX].get('program_args').get('combine_kernels'):
            _list, _include = gen_kernel_param.return_kernel_path(op_code, op_ver)
        else:
            with open(self.data.info.get('BuffAllocator')[const.KERNEL_IDX].get('program_args').get('combine_kernels')) as f:
                kernel_dict = json.load(f)
            _list = kernel_dict['kernel_list']
            _include = kernel_dict['kernel_include']
        return f"""
    kernel_names = {_list}

    kernel_includes = {_include}"""

    def gen_headers(self, overlay, gemm_prms=0):
        code = f"""import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
import argparse
sys.path.append(os.path.join(CURRDIR, '..', '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))

from typing import Tuple, List, Type

from os import path, getcwd

from dmacompiler import \\
    OverlayShape, DataTransfer, TransferParams, SyncStrategy, BackEnd, \\
    DmaChannel, DmaDir, AieDma, AieTile, TileType, DmaConnection, \\
    memtile_dma, shim_dma, core_dma, \\
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, CoreInstr, \\
    compute_buffer_size, \\
    generate_transfer_params, \\
    generate_shim_data_transfer, \\
    run_layer_compilation, \\
    set_dev_gen, DevGen, config

from dataflow.dataflow_common import overlay_{overlay}_dma_connections
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

"""
        if self.data.info['fast_pm'] == 0:
            code+= f"config.ENABLE_FAST_PM = False\n"
        if gemm_prms==1:
            code +="from kernels.conv.gemm_params import generate_layer_kernel_params, GemmSubvDims, OPMode\n"
        elif gemm_prms==2:
            code +="from kernels.mha_qdq.mha_params_smxv import generate_layer_kernel_params, MhaSubvDims, OPMode\n"
        elif gemm_prms==3:
            code +="from kernels.conv.direct_conv_int8x8_generic.direct_conv_int8x8_generic_params import params\n"
        return code

    def gen_connection_run_compile(self, overlay_name, aie_cols, aie_rows, CoreStackAddr, param_channel_id=0, enable_task_queue_optimization=True):
        overlay_function =  "overlay_{}_dma_connections()".format(overlay_name)
        return f"""
    dma_connections = {overlay_function}
    overlay_shape = OverlayShape({aie_cols}, {aie_rows})
    run_layer_compilation(
        overlay_shape,
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        dma_connections,
        back_end,
        core_stack_addr = {CoreStackAddr},
        param_channel_id = {param_channel_id},
        enable_task_queue_optimization = {enable_task_queue_optimization})
"""

    def gen_pack_TransferParams(self, OpDims=0, ifm_repeat_counts=0):
        return f"""
def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
    buffer_offset: int = 0,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
            buffer_offset = buffer_offset,
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param

"""

    def generate_packed_shim_data_transfer(self):
        return """
def generate_packed_shim_data_transfer(
    repeat_counts: List[int],
    dma: AieDma,
    shim_buffer_idx: int,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iter_nums: List[int],
    bits_per_elem: int,
    max_chain_length: int = 4,
    buffer_offset: int = 0
) -> DataTransfer:
    '''
    Reconfigures a BD with different transfer
    params at the shim for poll and re-enqueue
    '''
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iter_nums)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iter_nums)
        res = []
        for item, num in zip(items, tiling_iter_nums):
            res += [item] * num
        return res
    def try_or(list_, idx, val, default=None, expected_exc=(Exception,)):
        try:
            return eval(f'list_[idx].{val}')
        except expected_exc:
            return eval(f'default.{val}')
    num_fmts = len(tiling_fmts)
    params = []
    packed_params = []
    repeat_coeff_iter = [0] * sum(tiling_iter_nums)
    empty_bd = generate_transfer_params(dma, 'Bytes:32', 'Bytes:0:0', bits_per_elem)
    try:
        #print("Inside Try block without iter_step")
        '''
        Try to generate transfer params without iter_step, this is the fast path for simple cases
        '''
        adr = 0
        for i in range(num_fmts):
            transfer_chain = generate_transfer_params(
                    dma,
                    memory_fmts[i],
                    tiling_fmts[i],
                    bits_per_block=bits_per_elem,
                    enable_padding=False,
                    use_iter_step=False,
                    max_chain_length=max_chain_length,
                    buffer_offset= buffer_offset
            )
            repeat_coeff_iter[adr:adr+tiling_iter_nums[i]] = [1]*tiling_iter_nums[i]
            adr += tiling_iter_nums[i]
            params.append(transfer_chain)
        packed_params.append(TransferParams(
            dma,
            pack([param.length_i(0) for param in params]),
            offset=pack([param.offset_i(0) for param in params]),
            step=pack([param.step_i(0) for param in params]),
            wrap=pack([param.wrap_i(0) for param in params]),
            padding=pack([param.padding_i(0) for param in params]),
            iter_step=pack([param.iter_step_i(0) for param in params]),
            iter_wrap=pack([param.iter_wrap_i(0) for param in params]),
        ))
    except:
        #print("Failed to generate transfer params without iter_step, falling back to iter_step")
        adr = 0
        for i in range(num_fmts):
            repeat_coeff, transfer_chain = generate_transfer_params(
                    dma,
                    memory_fmts[i],
                    tiling_fmts[i],
                    bits_per_block=bits_per_elem,
                    enable_padding=False,
                    use_iter_step=True,
                    max_chain_length=max_chain_length,
                    buffer_offset=buffer_offset
            )
            repeat_coeff_iter[adr:adr+tiling_iter_nums[i]] = [repeat_coeff for _ in range(tiling_iter_nums[i])]
            adr += tiling_iter_nums[i]
            params.append(transfer_chain)
        for chain_idx in range(max([len(p) for p in params])):
            '''
            For each transfer in the chain, pack the transfer params again, this time using iter_step
            '''
            packed_params.append(TransferParams(
                dma,
                pack([try_or(param, chain_idx, 'length_i(0)', empty_bd) for param in params]),
                offset=pack([try_or(param, chain_idx, 'offset_i(0)', empty_bd) for param in params]),
                step=pack([try_or(param, chain_idx, 'step_i(0)', empty_bd) for param in params]),
                wrap=pack([try_or(param, chain_idx, 'wrap_i(0)', empty_bd) for param in params]),
                padding=pack([try_or(param, chain_idx, 'padding_i(0)', empty_bd) for param in params]),
                iter_step=pack([try_or(param, chain_idx, 'iter_step_i(0)', empty_bd) for param in params]),
                iter_wrap=pack([try_or(param, chain_idx, 'iter_wrap_i(0)', empty_bd) for param in params]),
            ))
    buffer_size = compute_buffer_size(memory_fmts[0], bits_per_elem)
    if dma.channel.dir == DmaDir.S2MM:
        write_params = packed_params
        read_params = []
    else:
        read_params = packed_params
        write_params = []            
    for idx, count in enumerate(repeat_counts):
        repeat_counts[idx] = count * repeat_coeff_iter[idx]
    return DataTransfer(
        repeat_counts,
        dma.tile, [shim_buffer_idx], buffer_size,
        write_params,
        read_params
    )

"""
    
    def DataTransfer(self, repeat_counts, tile, buffer_addrs, buffer_size, write_params, read_params, reuse_ratio=1, sync_strategy='SyncStrategy.Default', buffer_split=1):
        return f"DataTransfer({repeat_counts}, {tile}, {buffer_addrs}, {buffer_size}, {write_params}, {read_params}, {reuse_ratio}, {sync_strategy})"

    def gen_dataflow(self, enable_packtransfer=0):
        code = f"def generate_dataflow(back_end: BackEnd):\n"
        return code

    def gen_gelu_params(self, gemm_qdq_prm_addr, ofm_dim): 
         #NOTE- TEMP since we change the way on how to pass param
         gelu_qdq_prm_addr = gemm_qdq_prm_addr + 64 #16*4 
         lutab_addr        = gelu_qdq_prm_addr + 64
         lutcd_addr        = lutab_addr + 2560
         input_arg_list = [gelu_qdq_prm_addr, lutab_addr, lutcd_addr, ofm_dim[0], ofm_dim[1]]
         #print(f"input_arg_list :{input_arg_list}")
         #NOTE- Temp workaround
         kernel_params = gen_kernel_param.gen_blob(input_arg_list, "PWLA", "Gelu_qdq_bf16")
         return kernel_params

    def gen_core_instr(self, params):
        mode = params.get('in_ch_mode', 0)
        ifm_ch_id = 1 if mode==0 else 0
        wgt_ch_id = 0 if mode==0 else 1

        meta_kernel_list = list((gen_kernel_param.return_kernel_path(params['op_name'], params['op_ver'])[0]).keys())
        if params.get('CoreQdqSize') is not None:
            if params['actxact'] and params['dims'].act_bits == 16:
                kernel_names = ['run_vxsmt', 'run_vxsmt']
            elif params['actxact'] and params['dims'].act_bits == 8:
                kernel_names = ['run_gemm_a8w8', 'run_gemm_a8w8']
            elif params['dims'].act_bits == 8:
                kernel_names = ['run_gemm_a8w8', 'run_gemm_a8w8']
            else:
                kernel_names = ['run_a16w8_gemm_tdm', 'run_a16w8_gemm_qdq']
        else: 
            kernel_names = ['run_gemm', 'run_gemm']

        if params.get('is_pwla_fused'):
          #In case of fused op, matmul qdq size is increased from 64 to 5248 (2*64 + 2*2560) to accommodate the LUT and gelu/silu qdq params
          matmul_qdq_size = 64
          gelu_lutab_size = 2560
          gelu_lutcd_size = 2560
          gelu_qdq_size   = 64
          expected_qdq_size = matmul_qdq_size + gelu_lutab_size + gelu_lutcd_size + gelu_qdq_size
          utils.sanity_check(expected_qdq_size == params.get('CoreQdqSize'),f"CoreQdq size not matching the expectation for fused matmul. expected: {expected_qdq_size}, actual: {params.get('CoreQdqSize')}")
          gelu_lutab_addr = params['CoreQdqAddr'][0] + matmul_qdq_size
          gelu_lutcd_addr = gelu_lutab_addr + gelu_lutab_size
          gelu_qdq_addr   = gelu_lutcd_addr + gelu_lutcd_size
        
        if params.get('is_rope_fused'):
          #In case of fused rope op, matmul qdq size is increased from 64 to 128 (64x2) to accommodate the RoPE qdq params
          rope_qdq_addr = params['CoreQdqAddr'][0] + params.get('CoreQdqSize') // 2
        if params.get('is_elew_fused'):
          #In case of fused elew op, matmul qdq size is increased from 64 to 128 (64x2) to accommodate the Elew qdq params
          elew_qdq_addr = params['CoreQdqAddr'][0] + params.get('CoreQdqSize') // 2
        addr_prms = [x for x in params if x[0:4]=="Core" and x[-4:]=="Addr"]
        addr_dict = {}
        for addrs in addr_prms:
            if len(params[addrs])==1:
                addr_dict[addrs] = [params[addrs][0], addrs.replace('Addr', 'Size')]
            elif len(params[addrs])==2:
                addr_dict[addrs[0:-4]+'0'+addrs[-4:]] = [params[addrs][0], addrs.replace('Addr', 'Size')]
                addr_dict[addrs[0:-4]+'1'+addrs[-4:]] = [params[addrs][1], addrs.replace('Addr', 'Size')]
        new_addr_dict = {k: v for k, v in addr_dict.items() if v[0] is not None}
        sorted_addr_dict = dict(sorted(new_addr_dict.items(), key=lambda item: item[1]))
        addr_string = "".join(["    #{:<15} = {:<10} {:<15} {}\n".format(k, sorted_addr_dict[k][0], sorted_addr_dict[k][1], params[sorted_addr_dict[k][1]]) for k in sorted_addr_dict])

        meta_kernel_list = [key for key in meta_kernel_list] #NOTE- SCH only care for kernel name
        kernel_idx = [meta_kernel_list.index(kname) for kname in kernel_names]

        if params['dims'].B_itr*params['dims'].Tm*params['dims'].Tn*params['dims'].Tk > 1024 and params['dims'].Tm*params['dims'].Tn*params['dims'].Tk <= 1024:
            cfgbuff_location = 1
        elif params['dims'].Tm*params['dims'].Tn*params['dims'].Tk > 1024 and params['dims'].Tn*params['dims'].Tk <= 1024:
            cfgbuff_location = 2
        elif params['dims'].Tn*params['dims'].Tk > 1024 and params['dims'].Tk <= 1024:
            cfgbuff_location = 3
        else:
            cfgbuff_location = 0

        if params['Wgtbits'] == 8 and params['Actbits'] == 16 and params['actxact']:
            params['CoreWgtSize'] = params['CoreWgtSize'] // 2
        code =  f"""
    #CORE FLOW DEFINITION
{addr_string}
    def gen_instr(row, col):
        return  ["""

        code += f"""
        ConfigBuffer(DmaChannel(DmaDir.S2MM, {params.get('CoreQdqChId',1)}), {params['CoreQdqAddr'][0]}, None, {params['dims'].qdq_bytes}), #CoreQdqAddr, QdqSize
        AcqBuffer(DmaChannel(DmaDir.S2MM, {params.get('CoreQdqChId',1)})),
        RelBuffer(DmaChannel(DmaDir.S2MM, {params.get('CoreQdqChId',1)})),"""

        if params['dims'].split[0]==32:
            code += f"""
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), {params['CoreOfmAddr'][0]}, {params['CoreOfmAddr'][1]}, {params['CoreOfmSize']}), #CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize"""
        elif cfgbuff_location==0:
            code += f"""
        ConfigBuffer(DmaChannel(DmaDir.S2MM, {ifm_ch_id}), {params['CoreIfmAddr'][0]}, {params['CoreIfmAddr'][1]}, {params['CoreIfmSize']}), #CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize
        ConfigBuffer(DmaChannel(DmaDir.S2MM, {wgt_ch_id}), {params['CoreWgtAddr'][0]}, {params['CoreWgtAddr'][1]}, {params['CoreWgtSize']}), #CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), {params['CoreOfmAddr'][0]}, {params['CoreOfmAddr'][1]}, {params['CoreOfmSize']}), #CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize"""

        code += f"""
        Loop({params['dims'].B_itr}, ["""

        if cfgbuff_location==1:
            code += f"""
            ConfigBuffer(DmaChannel(DmaDir.S2MM, {ifm_ch_id}), {params['CoreIfmAddr'][0]}, {params['CoreIfmAddr'][1]}, {params['CoreIfmSize']}), #CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize
            ConfigBuffer(DmaChannel(DmaDir.S2MM, {wgt_ch_id}), {params['CoreWgtAddr'][0]}, {params['CoreWgtAddr'][1]}, {params['CoreWgtSize']}), #CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize
            ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), {params['CoreOfmAddr'][0]}, {params['CoreOfmAddr'][1]}, {params['CoreOfmSize']}), #CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize"""

        code += f"""
            Loop({params['dims'].Tm}, ["""

        if cfgbuff_location==2:
            code += f"""
                ConfigBuffer(DmaChannel(DmaDir.S2MM, {ifm_ch_id}), {params['CoreIfmAddr'][0]}, {params['CoreIfmAddr'][1]}, {params['CoreIfmSize']}), #CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize
                ConfigBuffer(DmaChannel(DmaDir.S2MM, {wgt_ch_id}), {params['CoreWgtAddr'][0]}, {params['CoreWgtAddr'][1]}, {params['CoreWgtSize']}), #CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize
                ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), {params['CoreOfmAddr'][0]}, {params['CoreOfmAddr'][1]}, {params['CoreOfmSize']}), #CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize"""

        code += f"""
                Loop({params['dims'].Tn}, ["""
        
        if cfgbuff_location==3:
            code += f"""
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, {ifm_ch_id}), {params['CoreIfmAddr'][0]}, {params['CoreIfmAddr'][1]}, {params['CoreIfmSize']}), #CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, {wgt_ch_id}), {params['CoreWgtAddr'][0]}, {params['CoreWgtAddr'][1]}, {params['CoreWgtSize']}), #CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), {params['CoreOfmAddr'][0]}, {params['CoreOfmAddr'][1]}, {params['CoreOfmSize']}), #CoreOfmPingAddr, CoreOfmPongAddr, CoreOfmSize"""

        if params['dims'].split[0]==32:
            code += f"""
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, {wgt_ch_id}), {params['CoreWgtAddr'][0]}, {params['CoreWgtAddr'][1]}, 0), #CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, {ifm_ch_id}), {params['CoreIfmAddr'][0]}, {params['CoreIfmAddr'][1]}, {params['CoreIfmSize']}), #CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, {ifm_ch_id}), {params['CoreWgtAddr'][0]}, {params['CoreWgtAddr'][1]}, {params['CoreWgtSize']}), #CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, {ifm_ch_id}), {params['CoreIfmAddr'][0]}, {params['CoreIfmAddr'][1]}, 0), #CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize"""

        if params['dims'].Tk == 1:
                code += f"""
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    CallKernel('{meta_kernel_list[kernel_idx[1]]}', kernel_params={params['gemm_params'][3]}),
"""
        elif params['dims'].Tk == 2:
                code += f"""
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    CallKernel('{meta_kernel_list[kernel_idx[0]]}', kernel_params={params['gemm_params'][0]}),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    CallKernel('{meta_kernel_list[kernel_idx[1]]}', kernel_params={params['gemm_params'][2]}),
"""
        else:
                code += f"""
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    CallKernel('{meta_kernel_list[kernel_idx[0]]}', kernel_params={params['gemm_params'][0]}),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    Loop({params['dims'].Tk-2}, [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        CallKernel('{meta_kernel_list[kernel_idx[0]]}', kernel_params={params['gemm_params'][1]}),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]),
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    CallKernel('{meta_kernel_list[kernel_idx[1]]}', kernel_params={params['gemm_params'][2]}),
"""

        if params.get('is_pwla_fused'):
            code += f"                    CallKernel('{meta_kernel_list[2] if params['dims'].act_bits == 16 else meta_kernel_list[1]}', kernel_params=LUTOPs_qdq_params({gelu_qdq_addr}, {gelu_lutab_addr}, {gelu_lutcd_addr}, {params['dims'].M_subv * params['dims'].N_subv}, {params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, 1, {int(params['dims'].act_bits == 16)})),\n"
        elif params.get('is_rope_fused'):
            code += f"                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),   #release wgt ping-pong to fetch sin/cos data\n"
            #NOTE- Assuiming sin/cos size is same as ofm (since it's elew)
            code += f"                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), {params['CoreWgtAddr'][0]}, {params['CoreWgtAddr'][1]}, {params['CoreOfmSize']}+{params['CoreOfmSize']}), #CoreSinCosPingAddr, CoreSinCosPongddr, CoreSinCosSize  #Reused CoreWgtAddr\n"
            code += f"                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),   #fetch sin/cos data\n"
            code += f"                    CallKernel('{meta_kernel_list[2] if params['dims'].act_bits == 16 else meta_kernel_list[1]}', kernel_params=rope_layer_params(0, {params['dims'].M_subv}, {params['dims'].N_subv}, {rope_qdq_addr}, {params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, 1, row, col)),\n"
        elif params.get('is_elew_fused'):
            code += f"                     RelBuffer(DmaChannel(DmaDir.S2MM, 0)),   #release tdm ping-pong to fetch ifmB\n"
             #NOTE- Assuiming sin/cos size is same as ofm (since it's elew)
            code += f"                     ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), {params['CoreTdmAddr'][0]}, {params['CoreWgtAddr'][1]}, {params['CoreOfmSize']}), #CoreIfmBPingAddr, CoreIfmBPongAddr, CoreIfmBSIze #Reused CoretdmAddr\n"
            code += f"                     AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),   #fetch sin/cos data\n"
            code += f"                    CallKernel('{meta_kernel_list[2] if params['dims'].act_bits == 16 else meta_kernel_list[1]}', kernel_params=matadd_params(0, {params['dims'].M_subv}, {params['dims'].N_subv}, {elew_qdq_addr}, {params['CoreTdmAddr'][0]}, {params['CoreTdmAddr'][1]}, 1, row, col)),\n"
        else:
            #No OPs fusion
            pass
        code += f"""                    RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                ]),
            ]),
        ]
    )]
"""
        code += "    core_instrs = {}\n"
        code +=f"    for col in range({params['dims'].aie_cols}):\n"
        code +=f"       for row in range({params['dims'].aie_rows}): \n"
        code += "           core_instrs[AieTile(TileType.Core, col, row)] =  gen_instr(row, col)\n"
        return code

    def memtile_stats(self, params):
        dims = params['dims']
        return f"""    #Split: {params['template_meta_data']['mode']}
    #M,K,K_ifmB, N      =({dims.M}, {dims.K}, {getattr(dims, 'K_ifmB', -1)}, {dims.N})
    #Msubv,Ksubv,Nsubv  =({dims.M_subv}, {dims.K_subv}, {dims.N_subv})
    #Mpad_ifm,Kpad_ifm  = {dims.Mpad_ifm}, {dims.Kpad_ifm}
    #Kpad_wgt,Npad_wgt  = {dims.Kpad_wgt}, {dims.Npad_wgt}
    #Mpad_ofm,Npad_ofm  = {dims.Munpad}, {dims.Nunpad}
    #Tm,Tk,Tn           = {dims.Tm}, {dims.Tk}, {dims.Tn}
    #ParamSize          = {params['ParamSize']}
    #MemtileActSize     = {params['MemtileActSize']}
    #MemtileWgtSize     = {params.get('MemtileWgtSize')}
    #MemtileOutSize     = {params['MemtileOutSize']}
    #MemtilePrmAddr     = {params['MemtilePrmAddr']}
    #MemtileQdqAddr     = {params['MemtileQdqAddr']} 
    #MemtileIfmPingAddr = {params['MemtileIfmPingAddr']}
    #MemtileIfmPongAddr = {params['MemtileIfmPongAddr']}
    #MemtileWgtPingAddr = {params.get('MemtileWgtPingAddr')}
    #MemtileWgtPongAddr = {params.get('MemtileWgtPongAddr')}
    #MemtileOfmPingAddr = {params['MemtileOfmPingAddr']}
    #MemtileOfmPongAddr = {params['MemtileOfmPongAddr']}
    #MemtileOfmPongAddr = {params['MemtileOfmPongAddr']}
    #MemtileRoPEAddr = {params.get('MemtileRoPEAddr')}
    #MemtileRoPESize = {params.get('MemtileRoPESize')}
    #MemtileElewifmBAddr = {params.get('MemtileifmBAddr')}
    #MemtileElewifmBSize = {params.get('MemtileifmBSize')}
    #Total Memtile Memory Utilization: {params['MemtileTotal']} Bytes

"""

    def memtile_prm_pattern(self, params):
        if params.get('ShimParamMode', 'unicast') == 'unicast':
            return f"""    memtile_transfers += [
        DataTransfer({BaseTemplate.shorten_val(self, params['MemtileParamRepeat'])}, AieTile(TileType.Memtile, col), [{params['MemtilePrmAddr']}], {params['ParamSize']},
            [TransferParams(memtile_dma(col, DmaDir.S2MM, {params.get('ShimParamChannelId',1)}), {params['ParamSize']//4})],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, row), {params['ParamSize']//4})
             for row in range({params['dims'].aie_rows})]
        ) for col in range({params['dims'].aie_cols})
    ]
"""
        else:
            code = f"""    memtile_transfers += [
        DataTransfer(
            {BaseTemplate.shorten_val(self, params['MemtileParamRepeat'])}, AieTile(TileType.Memtile, col), [{params['MemtilePrmAddr']}], {params['ParamSize']},
            [TransferParams(memtile_dma(col, DmaDir.S2MM, {params.get('ShimParamChannelId',0)}), {params['ParamSize']//4})],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, 4), {params['ParamSize']//4})],
"""
            if params['template_meta_data']['overlay'] == '4x4':
                code += f"        ) for col in range({params['dims'].aie_cols})\n    ]\n"
            elif params['template_meta_data']['overlay'] == '8x4':
                code += f"        ) for col in range(0, {params['dims'].aie_cols}, 2)\n    ]\n"
            return code

    def memtile_qdq_pattern(self, params):
        if params.get('ShimQdqMode', 'unicast') == 'broadcast':
            code = f"""    memtile_transfers += [
        DataTransfer(
            {BaseTemplate.shorten_val(self, params['MemtileQdqPrmRepeat'])}, AieTile(TileType.Memtile, col), [{params['MemtileQdqAddr']}], {params['dims'].qdq_bytes},
            [TransferParams(memtile_dma(col, DmaDir.S2MM, {params.get('ShimQdqChannelId',0)}), {params['dims'].qdq_bytes//4})],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, 4), {params['dims'].qdq_bytes//4})],
"""
            if params['template_meta_data']['overlay'] == '4x4':
                code += f"        ) for col in range({params['dims'].aie_cols})\n    ]\n"
            elif params['template_meta_data']['overlay'] == '8x4':
                code += f"        ) for col in range(0, {params['dims'].aie_cols}, 2)\n    ]\n"
            return code
        else:
            return f"""    memtile_transfers += [
        DataTransfer({BaseTemplate.shorten_val(self, params['MemtileQdqPrmRepeat'])}, AieTile(TileType.Memtile, col), [{params['MemtileQdqAddr']}], {params['dims'].qdq_bytes},
            [TransferParams(memtile_dma(col, DmaDir.S2MM, {params.get('ShimQdqChannelId',1)}), {params['dims'].qdq_bytes//4})],
            [TransferParams(memtile_dma(col, DmaDir.MM2S, row), {params['dims'].qdq_bytes//4})
             for row in range({params['dims'].aie_rows})]
        ) for col in range({params['dims'].aie_cols})
    ]
"""
    def memtile_act_pingpong_pattern(self, params, mem, col):
        memtile_s2mm_id = params.get('MemtileActchid', 0)
        code = f"""    memtile_transfers += [
        DataTransfer({params['MemtileActPingRepeat']}, AieTile(TileType.Memtile, {col}), [{params['MemtileIfmPingAddr']}, {params['MemtileIfmPongAddr']}], {params['MemtileActSize']},
            [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {memtile_s2mm_id}), '{mem['act_memtile_memory']}', '{mem['act_memtile_s2mm']}', {params['dims'].act_bits})],
            [generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, 4), '{mem['act_memtile_memory']}', '{mem['act_memtile_mm2s']}', {params['dims'].act_bits}"""
        code += ")],\n" if mem['enable_padding'] == 0 else f", enable_padding=True)],\n"
        code += " "*12+")]\n" if params['MemtileActReuseRatio'] == 1 else " "*12+f"reuse_ratio={params['MemtileActReuseRatio']})]\n"
        assert params['MemtileActReuseRatio']==1, 'memtile act reuse ratio must equal to 1 in inline pingpong buffer mode'
        return code

    def memtile_act_ping_pattern(self, params, mem, col):
        memtile_s2mm_id = params.get('MemtileActchid', 0)
        code = f"""    memtile_transfers += [
        DataTransfer({params['MemtileActPingRepeat']}, AieTile(TileType.Memtile, {col}), [{params['MemtileIfmPingAddr']}], {params['MemtileActSize']},
            [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {memtile_s2mm_id}), '{mem['act_memtile_memory']}', '{mem['act_memtile_s2mm']}', {params['dims'].act_bits})],
            [generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, 4), '{mem['act_memtile_memory']}', '{mem['act_memtile_mm2s']}', {params['dims'].act_bits}"""
        code += ")],\n" if mem['enable_padding'] == 0 else f", enable_padding=True)],\n"
        code += " "*12+")]\n" if params['MemtileActReuseRatio'] == 1 else " "*12+f"reuse_ratio={params['MemtileActReuseRatio']})]\n"
        return code

    def memtile_act_pong_pattern(self, params, mem, col):
        memtile_s2mm_id = params.get('MemtileActchid', 0)
        if sum(params['MemtileActPongRepeat']) > 0:
            code = f"""    memtile_transfers += [
            DataTransfer({params['MemtileActPongRepeat']}, AieTile(TileType.Memtile, {col}), [{params['MemtileIfmPongAddr']}], {params['MemtileActSize']},
                [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {memtile_s2mm_id}), '{mem['act_memtile_memory']}', '{mem['act_memtile_s2mm']}', {params['dims'].act_bits})],
                [generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, 4), '{mem['act_memtile_memory']}', '{mem['act_memtile_mm2s']}', {params['dims'].act_bits}"""
            code += ")],\n" if mem['enable_padding'] == 0 else f", enable_padding=True)],\n"
            code += " "*12+")]\n" if params['MemtileActReuseRatio'] == 1 else " "*12+f"reuse_ratio={params['MemtileActReuseRatio']})]\n"
            return code
        else:
            return ''

    def memtile_wgt_pattern(self, params, col):
        memtile_s2mm_id = params.get('MemtileWgtchid', 1)
        mm2s_transfers = [f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), '{params['wgt_memtile_memory']}', '{params['wgt_memtile_mm2s'][row]}', {params['dims'].wgt_bits})"
                                for row in range(params['dims'].aie_rows)
                                for _ in range(params['reuse_chain_length'])]
        mm2s_transfers_join = ",\n             ".join(mm2s_transfers)
        WgtAddr = [params['MemtileWgtPingAddr'], params['MemtileWgtPongAddr']] if params['WgtPingPong'] else [params['MemtileWgtPingAddr']]
        code = f"""    memtile_transfers += [
        DataTransfer({params['MemtileWgtRepeat']}, AieTile(TileType.Memtile, {col}), {WgtAddr}, {params['MemtileWgtSize']},
            [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {memtile_s2mm_id}), '{params['wgt_memtile_memory']}', '{params['wgt_memtile_s2mm']}', {params['dims'].wgt_bits})],
            [{mm2s_transfers_join}],\n"""
        code += "" if params['sync_strategy']=='Default' else f"            sync_strategy={params['sync_strategy']},\n"
        code += " "*12+")]\n" if params['MemtileWgtReuseRatio'] // params['reuse_chain_length'] == 1 else " "*12+f"reuse_ratio={params['MemtileWgtReuseRatio'] // params['reuse_chain_length']})]\n"
        return code
    
    def memtile_ofm_pattern(self, params, col):
        #MEMTILE OFM TRANSFERS
        s2mm_transfers = [f"generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {2+row}), '{params['out_memtile_memory']}', '{params['out_memtile_s2mm'][row]}', {params['dims'].out_bits})"
                         for row in range(params['dims'].aie_rows)]
        s2mm_transfers_join = ",\n             ".join(s2mm_transfers)
        return f"""    memtile_transfers += [
        DataTransfer({params['MemtileOutRepeat']}, AieTile(TileType.Memtile, {col}), [{params['MemtileOfmPingAddr']}], {params['MemtileOutSize']},
            [{s2mm_transfers_join}],
            [generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, 5), '{params['out_memtile_memory']}', '{params['out_memtile_mm2s']}', {params['dims'].out_bits})],
            sync_strategy=SyncStrategy.Parallel_N_to_1,)]
"""

    def shim_stats(self, params):
        dims = params['dims']
        return f"""    #Split: {params['template_meta_data']['mode']}
    #M,K,K_ifmB, N      =({dims.M}, {dims.K}, {getattr(dims, 'K_ifmB', -1)}, {dims.N})
    #Msubv,Ksubv,Nsubv  =({dims.M_subv}, {dims.K_subv}, {dims.N_subv})
    #Mpad_ifm,Kpad_ifm  = {dims.Mpad_ifm}, {dims.Kpad_ifm}
    #Kpad_wgt,Npad_wgt  = {dims.Kpad_wgt}, {dims.Npad_wgt}
    #Mpad_ofm,Npad_ofm  = {dims.Munpad}, {dims.Nunpad}
    #Tm,Tk,Tn           = {dims.Tm}, {dims.Tk}, {dims.Tn}
    #ParamSize          = {const.PARAM_SIZE}
    #ShimIfmSize        = {params.get('ShimIfmSize')}
    #ShimWgtSize        = {params.get('ShimWgtSize')}
    #ShimOfmSize        = {params.get('ShimOfmSize')}
"""

    def shim_prm_pattern(self, params, col):
        return f"    shim_transfers += [DataTransfer({self.shorten_val(params['ShimParamRepeat'])}, AieTile(TileType.Shim, {col}), [{params['ShimPrmBufferIdx']}], \
{params['ParamSize']}, [], [TransferParams(shim_dma({col}, DmaDir.MM2S, {params.get('ShimParamChannelId', 0)}), {params['ParamSize']//4})])]\n"

    def shim_qdq_pattern(self, params, col):
        code = f"    shim_transfers += [DataTransfer({self.shorten_val(params['ShimQdqPrmRepeat'])}, AieTile(TileType.Shim, {col}), [{params['ShimWgtBufferIdx']}], \
{(params.get('ShimWgtSize',0) if not params.get('actxact') else 0) + params['dims'].qdq_bytes}, [], [TransferParams(shim_dma({col}, DmaDir.MM2S, {params.get('ShimQdqChannelId',0)}), {params['dims'].qdq_bytes//4}"
        code += f", offset={params.get('qdq_offset',0)//4})])]\n" if params.get('qdq_offset',0) != 0 else f")])]\n"
        return code

    def shim_act_pattern_pack(self, params, col, itr):
        if not isinstance(params['act_shim_memory'], str) and len(params['act_shim_memory']) > 1:
            mem_format = params['act_shim_memory'][0][0]
        else:
            mem_format = params['act_shim_memory']
        return f"    shim_transfers += [generate_shim_data_transfer({params['ShimActRepeat']}, shim_dma({col}, DmaDir.MM2S, 0), {params['ShimActBufferIdx']},\
 '{mem_format}', '{params['act_shim_mm2s'][col][itr]}',{params['dims'].act_bits})]\n"

    def shim_wgt_pattern_pack(self, params, col):
        code =  f"    shim_transfers += [generate_shim_data_transfer({params['ShimWgtRepeat']}, shim_dma({col}, DmaDir.MM2S, 1), {params['ShimWgtBufferIdx']},\
 '{params['wgt_shim_memory']}', '{params['wgt_shim_mm2s'][col]}',{params['dims'].wgt_bits}"
        code += f", buffer_offset={params.get('wgt_shim_offset',0)})]\n" if params['wgt_shim_offset']!=0 else ")]\n"
        return code

    def shim_ofm_pattern(self, params, col):
        return f"    shim_transfers += [generate_shim_data_transfer({params['ShimOutRepeat']}, shim_dma({col}, DmaDir.S2MM, 0), {params['ShimOutBufferIdx']},\
 '{params['out_shim_memory']}', '{params['out_shim_s2mm'][col]}',{params['dims'].out_bits})]\n"

    class ShimTransfers():
        def __init__(
            self,
            ChannelId,
            DmaDir,
            Repeat,
            BufferIdx,
            BufferSize,
            MemoryFormat,
            TilingFormat,
            Offset,
            BitsPerCycle,
            ReEnqueueCnt,
            tiling_iters,
        ):
            self.ChannelId    = copy.deepcopy(ChannelId)
            self.DmaDir       = copy.deepcopy(DmaDir)
            self.Repeat       = copy.deepcopy(Repeat)
            self.BufferIdx    = copy.deepcopy(BufferIdx)
            self.BufferSize   = copy.deepcopy(BufferSize)
            self.MemoryFormat = copy.deepcopy(MemoryFormat)
            self.TilingFormat = copy.deepcopy(TilingFormat)
            self.Offset       = copy.deepcopy(Offset)
            self.BitsPerCycle = copy.deepcopy(BitsPerCycle)
            self.ReEnqueueCnt = copy.deepcopy(ReEnqueueCnt)
            self.tiling_iters = copy.deepcopy(tiling_iters)

        def check_tiling_dimension(self, TilingFmt):
            fmt = copy.deepcopy(TilingFmt)
            max_dim = 0
            flat_fmt = np.array(fmt).reshape(-1).tolist()
            for fmt in flat_fmt:
                dim = 0
                for t in fmt.split(' '):
                    val = t.split(':')
                    if val[-2]!=val[-1]:
                        dim+=1
                max_dim = max(dim, max_dim)
            return max_dim

        def dma_pattern(self, col_list):
            repeat = self.Repeat if isinstance(self.Repeat[0], list) else [self.Repeat] * self.ReEnqueueCnt
            memlen = len(self.MemoryFormat)
            code = ''
            if memlen==1:
                for itr in range(self.ReEnqueueCnt):
                    for col, col_idx in zip(col_list[0], col_list[1]):
                        code +=  f"    shim_transfers += [generate_shim_data_transfer({BaseTemplate.shorten_val(self, repeat[itr])}, " + \
                                 f"shim_dma({col}, DmaDir.{self.DmaDir}, {self.ChannelId}), {self.BufferIdx}, " + \
                                 f"'{self.MemoryFormat[0]}', '{self.TilingFormat[0][col_idx][itr]}', {self.BitsPerCycle}"
                        code += f", buffer_offset={self.Offset})]\n" if self.Offset!=0 else ")]\n"
                return code
            else:
                tiling_dim = self.check_tiling_dimension(self.TilingFormat)
                if tiling_dim <= 3:
                    for itr in range(self.ReEnqueueCnt):
                        for col, col_idx in zip(col_list[0], col_list[1]):
                            if self.DmaDir == 'S2MM':
                                code +=  f"    shim_transfers += [DataTransfer({BaseTemplate.shorten_val(self, repeat[itr])}, " + \
                                         f"AieTile(TileType.Shim, {col}), [{self.BufferIdx}], {self.BufferSize}, " + \
                                         f"[pack_transfers(shim_dma({col}, DmaDir.{self.DmaDir}, {self.ChannelId}), " + \
                                         f"{BaseTemplate.shorten_val(self, self.MemoryFormat)}, " + \
                                         f"{BaseTemplate.shorten_val(self, np.array(self.TilingFormat)[:,col_idx,itr].tolist())}, " + \
                                         f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, {self.BitsPerCycle}"
                                code += f", buffer_offset={self.Offset})], [])]\n" if self.Offset!=0 else ")], [])]\n"
                            else:
                                code +=  f"    shim_transfers += [DataTransfer({BaseTemplate.shorten_val(self, repeat[itr])}, " + \
                                         f"AieTile(TileType.Shim, {col}), [{self.BufferIdx}], {self.BufferSize}, [], " + \
                                         f"[pack_transfers(shim_dma({col}, DmaDir.{self.DmaDir}, {self.ChannelId}), " + \
                                         f"{BaseTemplate.shorten_val(self, self.MemoryFormat)}, " + \
                                         f"{BaseTemplate.shorten_val(self, np.array(self.TilingFormat)[:,col_idx,itr].tolist())}, " + \
                                         f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, {self.BitsPerCycle}"
                                code += f", buffer_offset={self.Offset})])]\n" if self.Offset!=0 else ")])]\n"
                else:
                    for itr in range(self.ReEnqueueCnt):
                        for col, col_idx in zip(col_list[0], col_list[1]):
                            code +=  f"    shim_transfers += [generate_packed_shim_data_transfer({BaseTemplate.shorten_val(self, repeat[itr])}, " + \
                                     f"shim_dma({col}, DmaDir.{self.DmaDir}, {self.ChannelId}), {self.BufferIdx}, " + \
                                     f"{BaseTemplate.shorten_val(self, self.MemoryFormat)}, " + \
                                     f"{BaseTemplate.shorten_val(self, np.array(self.TilingFormat)[:,col_idx,itr].tolist())}, " + \
                                     f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, {self.BitsPerCycle}, 4, buffer_offset={self.Offset})]\n"
                return code



    class MemTransfers():
        def __init__(
            self,
            ChannelId,
            Repeat,
            BufferSize,
            MemoryFormat,
            MM2S_Format,
            S2MM_Format,
            reuse_chain_length,
            ReuseRatio,
            PingAddr,
            PongAddr,
            PingPongEnable,
            BitsPerCycle,
            sync_strategy,
            enable_padding,
            tiling_iters,
            buffer_offset=0,
            bd_chain_length=1,
        ):
            self.ChannelId          = copy.deepcopy(ChannelId)
            self.Repeat             = copy.deepcopy(Repeat)
            self.BufferSize         = copy.deepcopy(BufferSize)
            self.MemoryFormat       = copy.deepcopy(MemoryFormat)
            self.MM2S_Format        = copy.deepcopy(MM2S_Format)
            self.S2MM_Format        = copy.deepcopy(S2MM_Format)
            self.reuse_chain_length = copy.deepcopy(reuse_chain_length)
            self.ReuseRatio         = copy.deepcopy(ReuseRatio)
            self.PingAddr           = copy.deepcopy(PingAddr)
            self.PongAddr           = copy.deepcopy(PongAddr)
            self.PingPongEnable     = copy.deepcopy(PingPongEnable)
            self.BitsPerCycle       = copy.deepcopy(BitsPerCycle)
            self.sync_strategy      = copy.deepcopy(sync_strategy)
            self.enable_padding     = copy.deepcopy(enable_padding)
            self.tiling_iters       = copy.deepcopy(tiling_iters)
            self.buffer_offset      = copy.deepcopy(buffer_offset)
            self.bd_chain_length    = copy.deepcopy(bd_chain_length)

            if self.ReuseRatio // self.reuse_chain_length == 0:
                self.reuse_chain_length = self.ReuseRatio 

        def dma_broadcast_pattern(self, col_list, row_list):
            if np.array(self.MemoryFormat).shape[0]==1:
                return self.dma_broadcast_pattern_1d(col_list, row_list)
            else:
                return self.dma_broadcast_pattern_2d(col_list, row_list)

        def dma_broadcast_pattern_1d(self, col_list, row_list):
            code = ''
            for col, col_idx in zip(col_list[0], col_list[1]):
                WgtAddr = [self.PingAddr, self.PongAddr] if self.PingPongEnable else [self.PingAddr]
                mm2s_transfers = []
                MemoryFormat = self.MemoryFormat[0][col_idx] if isinstance(self.MemoryFormat[0], list) else self.MemoryFormat[0]
                for _ in range(self.reuse_chain_length):
                    if isinstance(self.MM2S_Format[0][0][0], list):
                        mm2s_str = []
                        for itr in range(len(self.MM2S_Format[0][0][0])):
                            mm2s_str_itr = f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, 4), " + \
                                           f"'{self, MemoryFormat}', " + \
                                           f"'{self.MM2S_Format[0][col_idx][0][itr]}', " + \
                                           f"{self.BitsPerCycle}"
                            mm2s_str_itr  += ")" if self.enable_padding == 0 else f", enable_padding=True)"
                            mm2s_str.append(mm2s_str_itr)
                        mm2s_transfers.append(',\n             '.join(mm2s_str))
                    else:
                        mm2s_transfers.append(f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, 4), " + \
                                              f"'{MemoryFormat}', " + \
                                              f"'{self.MM2S_Format[0][col_idx][0]}', " + \
                                              f"{self.BitsPerCycle}")
                        mm2s_transfers[-1]  += ")" if self.enable_padding == 0 else f", enable_padding=True)"
                mm2s_transfers_join = ",\n             ".join(mm2s_transfers)
                code += f"    memtile_transfers += [\n" + \
                        f"        DataTransfer({BaseTemplate.shorten_val(self, self.Repeat)}, AieTile(TileType.Memtile, {col}), {WgtAddr}, {self.BufferSize},\n" + \
                        f"            [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {self.ChannelId}), " + \
                        f"'{MemoryFormat}', " + \
                        f"'{self.S2MM_Format[0][col_idx]}', " + \
                        f"{self.BitsPerCycle})],\n" + \
                        f"            [{mm2s_transfers_join}],\n"
                code += "" if self.sync_strategy=='Default' else f"            sync_strategy={self.sync_strategy},\n"
                code += " "*12+")]\n" if self.ReuseRatio // self.reuse_chain_length == 1 else " "*12+f"reuse_ratio={self.ReuseRatio // self.reuse_chain_length})]\n"
            return code

        def dma_broadcast_pattern_2d(self, col_list, row_list):
            code = ''
            for col, col_idx in zip(col_list[0], col_list[1]):
                WgtAddr = [self.PingAddr, self.PongAddr] if self.PingPongEnable else [self.PingAddr]
                mm2s_transfers = []
                for _ in range(self.reuse_chain_length):
                    if isinstance(self.MM2S_Format[0][0][0], list):
                        mm2s_str = []
                        for itr in range(len(self.MM2S_Format[0][0][0])):
                            mm2s_str.append(f"pack_transfers(memtile_dma({col}, DmaDir.MM2S, 4), " + \
                                            f"{BaseTemplate.shorten_val(self, np.array(self.MemoryFormat)[:,col_idx].tolist())}, " + \
                                            f"{BaseTemplate.shorten_val(self, np.array(self.MM2S_Format)[:,col_idx,0, itr].tolist())}, " + \
                                            f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, {self.BitsPerCycle})")
                        mm2s_transfers.append(',\n             '.join(mm2s_str))
                    else:
                        mm2s_transfers.append(f"pack_transfers(memtile_dma({col}, DmaDir.MM2S, 4), " + \
                                              f"{BaseTemplate.shorten_val(self, np.array(self.MemoryFormat)[:,col_idx].tolist())}, " + \
                                              f"{BaseTemplate.shorten_val(self, np.array(self.MM2S_Format)[:,col_idx,0].tolist())}, " + \
                                              f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, {self.BitsPerCycle})")
                mm2s_transfers_join = ",\n             ".join(mm2s_transfers)
                code += f"    memtile_transfers += [\n" + \
                        f"        DataTransfer({BaseTemplate.shorten_val(self, self.Repeat)}, AieTile(TileType.Memtile, {col}), {WgtAddr}, {self.BufferSize},\n" + \
                        f"            [pack_transfers(memtile_dma({col}, DmaDir.S2MM, {self.ChannelId}), " + \
                        f"{BaseTemplate.shorten_val(self, np.array(self.MemoryFormat)[:,col_idx].tolist())}, " + \
                        f"{BaseTemplate.shorten_val(self, np.array(self.S2MM_Format)[:,col_idx].tolist())}, " + \
                        f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, " + \
                        f"{self.BitsPerCycle})],\n" + \
                        f"            [{mm2s_transfers_join}],\n"
                code += "" if self.sync_strategy=='Default' else f"            sync_strategy={self.sync_strategy},\n"
                code += " "*12+")]\n" if self.ReuseRatio // self.reuse_chain_length == 1 else " "*12+f"reuse_ratio={self.ReuseRatio // self.reuse_chain_length})]\n"
            return code

        def dma_unicast_pattern(self, col_list, row_list):
            if self.bd_chain_length == 2:
                if np.array(self.MemoryFormat).shape[0]//2==1:
                    return self.dma_unicast_unicast_pattern_1d(col_list, row_list)
                else:
                    return self.dma_unicast_unicast_pattern_2d(col_list, row_list)
            elif np.array(self.MemoryFormat).shape[0]==1:
                return self.dma_unicast_pattern_1d(col_list, row_list)
            else:
                return self.dma_unicast_pattern_2d(col_list, row_list)

        def dma_unicast_pattern_1d(self, col_list, row_list):
            code = ''
            for col, col_idx in zip(col_list[0], col_list[1]):
                mm2s_transfers = []
                MemoryFormat = self.MemoryFormat[0][col_idx] if isinstance(self.MemoryFormat[0], list) else self.MemoryFormat[0]
                for row in row_list:
                    for _ in range(self.reuse_chain_length):
                        mm2s_transfers.append(f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), " + \
                                              f"'{MemoryFormat}', " + \
                                              f"'{self.MM2S_Format[0][col_idx][row]}', " + \
                                              f"{self.BitsPerCycle}")
                        mm2s_transfers[-1]  += ")" if self.enable_padding == 0 else f", enable_padding=True)"
                mm2s_transfers_join = ",\n             ".join(mm2s_transfers)
                WgtAddr = [self.PingAddr, self.PongAddr] if self.PingPongEnable else [self.PingAddr]
                code += f"    memtile_transfers += [\n" + \
                f"        DataTransfer({BaseTemplate.shorten_val(self, self.Repeat)}, AieTile(TileType.Memtile, {col}), {WgtAddr}, {self.BufferSize},\n" + \
                f"            [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {self.ChannelId}), " + \
                f"'{MemoryFormat}', "+ \
                f"'{self.S2MM_Format[0][col_idx]}', " + \
                f"{self.BitsPerCycle})],\n" + \
                f"            [{mm2s_transfers_join}],\n"
                code += "" if self.sync_strategy=='Default' else f"            sync_strategy={self.sync_strategy},\n"
                code += " "*12+")]\n" if self.ReuseRatio // self.reuse_chain_length == 1 else " "*12+f"reuse_ratio={self.ReuseRatio // self.reuse_chain_length})]\n"
            return code

        def dma_unicast_pattern_2d(self, col_list, row_list):
            code = ''
            for col, col_idx in zip(col_list[0], col_list[1]):
                mm2s_transfers = []
                for row in row_list:
                    for _ in range(self.reuse_chain_length):
                        mm2s_transfers.append(f"pack_transfers(memtile_dma({col}, DmaDir.MM2S, {row}), " + \
                                              f"{BaseTemplate.shorten_val(self, np.array(self.MemoryFormat)[:,col_idx].tolist())}, " + \
                                              f"{BaseTemplate.shorten_val(self, np.array(self.MM2S_Format)[:,col_idx,row].tolist())}, " + \
                                              f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, {self.BitsPerCycle})")
                mm2s_transfers_join = ",\n             ".join(mm2s_transfers)
                WgtAddr = [self.PingAddr, self.PongAddr] if self.PingPongEnable else [self.PingAddr]
                code += f"    memtile_transfers += [\n" + \
                        f"        DataTransfer({BaseTemplate.shorten_val(self, self.Repeat)}, AieTile(TileType.Memtile, {col}), {WgtAddr}, {self.BufferSize},\n" + \
                        f"            [pack_transfers(memtile_dma({col}, DmaDir.S2MM, {self.ChannelId}), " + \
                        f"{BaseTemplate.shorten_val(self, np.array(self.MemoryFormat)[:,col_idx].tolist())}, " + \
                        f"{BaseTemplate.shorten_val(self, np.array(self.S2MM_Format)[:,col_idx].tolist())}, " + \
                        f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, " + \
                        f"{self.BitsPerCycle})],\n" + \
                        f"            [{mm2s_transfers_join}],\n"
                code += "" if self.sync_strategy=='Default' else f"            sync_strategy={self.sync_strategy},\n"
                code += " "*12+")]\n" if self.ReuseRatio // self.reuse_chain_length == 1 else " "*12+f"reuse_ratio={self.ReuseRatio // self.reuse_chain_length})]\n"
            return code


        def dma_unicast_unicast_pattern_1d(self, col_list, row_list):
            code = ''
            for col, col_idx in zip(col_list[0], col_list[1]):
                mm2s_transfers = []
                MemoryFormat = self.MemoryFormat[0][col_idx] if isinstance(self.MemoryFormat[0], list) else self.MemoryFormat[0]
                for chain in range(self.bd_chain_length):
                    MemoryFormat = self.MemoryFormat[chain][col_idx] if isinstance(self.MemoryFormat[0], list) else self.MemoryFormat[0]
                    for row in row_list:
                        mm2s_transfers.append(f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}), " + \
                                              f"'{MemoryFormat}', " + \
                                              f"'{self.MM2S_Format[chain][col_idx][row]}', " + \
                                              f"{self.BitsPerCycle}, buffer_offset={self.buffer_offset[chain]}")
                        mm2s_transfers[-1]  += ")" if self.enable_padding == 0 else f", enable_padding=True)"
                mm2s_transfers_join = ",\n             ".join(mm2s_transfers)

                s2mm_transfers = []
                for chain in range(self.bd_chain_length):
                    MemoryFormat = self.MemoryFormat[chain][col_idx] if isinstance(self.MemoryFormat[0], list) else self.MemoryFormat[0]
                    s2mm_transfers.append(f"generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {self.ChannelId + chain}), " + \
                                          f"'{MemoryFormat}', "+ \
                                          f"'{self.S2MM_Format[chain][col_idx]}', " + \
                                          f"{self.BitsPerCycle}, buffer_offset={self.buffer_offset[chain]})")
                s2mm_transfers_join = ",\n             ".join(s2mm_transfers)

                WgtAddr = [self.PingAddr, self.PongAddr] if self.PingPongEnable else [self.PingAddr]
                code += f"    memtile_transfers += [\n" + \
                f"        DataTransfer({BaseTemplate.shorten_val(self, self.Repeat)}, AieTile(TileType.Memtile, {col}), {WgtAddr}, {self.BufferSize},\n" + \
                f"            [{s2mm_transfers_join}],\n" + \
                f"            [{mm2s_transfers_join}],\n"
                code += "" if self.sync_strategy=='Default' else f"            sync_strategy={self.sync_strategy},\n"
                code += " "*12+")]\n" if self.ReuseRatio // self.reuse_chain_length == 1 else " "*12+f"reuse_ratio={self.ReuseRatio // self.reuse_chain_length})]\n"
            return code

        def dma_unicast_unicast_pattern_2d(self, col_list, row_list):
            code = ''
            for col, col_idx in zip(col_list[0], col_list[1]):
                mm2s_transfers = []
                for chain in range(self.bd_chain_length):
                    for row in row_list:
                        for _ in range(self.reuse_chain_length):
                            start = int(((chain)/self.bd_chain_length)*len(self.MemoryFormat))
                            stop = int(((chain+1)/self.bd_chain_length)*len(self.MemoryFormat))
                            mm2s_transfers.append(f"pack_transfers(memtile_dma({col}, DmaDir.MM2S, {row}), " + \
                                                f"{BaseTemplate.shorten_val(self, np.array(self.MemoryFormat)[start:stop,col_idx].tolist())}, " + \
                                                f"{BaseTemplate.shorten_val(self, np.array(self.MM2S_Format)[start:stop,col_idx,row].tolist())}, " + \
                                                f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, {self.BitsPerCycle}, " + \
                                                f"buffer_offset={self.buffer_offset[chain]})")
                mm2s_transfers_join = ",\n             ".join(mm2s_transfers)
                s2mm_transfers = []
                for chain in range(self.bd_chain_length):
                    start = int(((chain)/self.bd_chain_length)*len(self.MemoryFormat))
                    stop = int(((chain+1)/self.bd_chain_length)*len(self.MemoryFormat))
                    s2mm_transfers.append(f"pack_transfers(memtile_dma({col}, DmaDir.S2MM, {self.ChannelId+chain}), " + \
                                                f"{BaseTemplate.shorten_val(self, np.array(self.MemoryFormat)[start:stop,col_idx].tolist())}, " + \
                                                f"{BaseTemplate.shorten_val(self, np.array(self.S2MM_Format)[start:stop,col_idx].tolist())}, " + \
                                                f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, {self.BitsPerCycle}, " + \
                                                f"buffer_offset={self.buffer_offset[chain]})")
                s2mm_transfers_join = ",\n             ".join(s2mm_transfers)
                WgtAddr = [self.PingAddr, self.PongAddr] if self.PingPongEnable else [self.PingAddr]
                code += f"    memtile_transfers += [\n" + \
                        f"        DataTransfer({BaseTemplate.shorten_val(self, self.Repeat)}, AieTile(TileType.Memtile, {col}), {WgtAddr}, {self.BufferSize},\n" + \
                        f"            [{s2mm_transfers_join}],\n" + \
                        f"            [{mm2s_transfers_join}],\n"
                code += "" if self.sync_strategy=='Default' else f"            sync_strategy={self.sync_strategy},\n"
                code += " "*12+")]\n" if self.ReuseRatio // self.reuse_chain_length == 1 else " "*12+f"reuse_ratio={self.ReuseRatio // self.reuse_chain_length})]\n"
            return code

        def dma_unicast_out_pattern(self, col_list, row_list):
            if len(self.MemoryFormat)==1:
                return self.dma_unicast_out_pattern_1d(col_list, row_list)
            else:
                return self.dma_unicast_out_pattern_2d(col_list, row_list)

        def dma_unicast_out_pattern_1d(self, col_list, row_list):
            code = ''
            for col, col_idx in zip(col_list[0], col_list[1]):
                s2mm_transfers = [f"generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, {row+self.ChannelId-3}), " + \
                                  f"'{self.MemoryFormat[0]}', " + \
                                  f"'{self.S2MM_Format[0][col_idx][row]}', " + \
                                  f"{self.BitsPerCycle})"
                                        for row in row_list
                                        for _ in range(self.reuse_chain_length)]
                
                s2mm_transfers_join = ",\n             ".join(s2mm_transfers)
                WgtAddr = [self.PingAddr, self.PongAddr] if self.PingPongEnable else [self.PingAddr]
                code += f"    memtile_transfers += [\n" + \
                        f"        DataTransfer({BaseTemplate.shorten_val(self, self.Repeat)}, AieTile(TileType.Memtile, {col}), {WgtAddr}, {self.BufferSize},\n" + \
                        f"            [{s2mm_transfers_join}],\n" + \
                        f"            [generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {self.ChannelId}), " + \
                        f"'{self.MemoryFormat[0]}', " + \
                        f"'{self.MM2S_Format[0][col_idx]}', " + \
                        f"{self.BitsPerCycle})],\n"
                code += "" if self.sync_strategy=='Default' else f"            sync_strategy={self.sync_strategy},"
                code += ")]\n" if self.ReuseRatio // self.reuse_chain_length == 1 else f"reuse_ratio={self.ReuseRatio // self.reuse_chain_length})]\n"
            return code
        
        def dma_unicast_out_pattern_2d(self, col_list, row_list):
            code = ''
            for col, col_idx in zip(col_list[0], col_list[1]):
                s2mm_transfers = [f"pack_transfers(memtile_dma({col}, DmaDir.S2MM, {row+self.ChannelId-3}), " + \
                                  f"{BaseTemplate.shorten_val(self, self.MemoryFormat)}, " + \
                                  f"{BaseTemplate.shorten_val(self, np.array(self.S2MM_Format)[:,col_idx,row].tolist())}, " + \
                                  f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, {self.BitsPerCycle})"
                                        for row in row_list
                                        for _ in range(self.reuse_chain_length)]
                
                s2mm_transfers_join = ",\n             ".join(s2mm_transfers)
                WgtAddr = [self.PingAddr, self.PongAddr] if self.PingPongEnable else [self.PingAddr]
                code += f"    memtile_transfers += [\n" + \
                        f"        DataTransfer({BaseTemplate.shorten_val(self, self.Repeat)}, AieTile(TileType.Memtile, {col}), {WgtAddr}, {self.BufferSize},\n" + \
                        f"            [{s2mm_transfers_join}],\n" + \
                        f"            [pack_transfers(memtile_dma({col}, DmaDir.MM2S, {self.ChannelId}), " + \
                        f"{BaseTemplate.shorten_val(self, self.MemoryFormat)}, " + \
                        f"{BaseTemplate.shorten_val(self, np.array(self.MM2S_Format)[:,col_idx].tolist())}, " + \
                        f"{BaseTemplate.shorten_val(self, self.tiling_iters)}, " + \
                        f"{self.BitsPerCycle})],\n"
                code += "" if self.sync_strategy=='Default' else f"            sync_strategy={self.sync_strategy},"
                code += ")]\n" if self.ReuseRatio // self.reuse_chain_length == 1 else f"reuse_ratio={self.ReuseRatio // self.reuse_chain_length})]\n"
            return code

