from typing import List, Dict, Optional, Tuple, Callable, Type, Deque, Union
from enum import Enum
import sys
import os
import logging
import pdb
from string import Template
import scheduler_utils as utils
import OGOAT.src.Scheduling_Engine.schedules.scheduler as schedule
import template_selector
from BufferAllocatorResult import get_buffer_allocations, BufferAllocations
import const

infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
code_gen_path = (os.path.dirname(os.path.abspath(__file__))+"/code_gen/")
sys.path.append(infra_path)
sys.path.append(code_gen_path)

class AccessPattern(schedule.Stage):
    def __init__(self, artifacts_dict):
        self._program_args     = artifacts_dict['program_arg_obj']
        self._overlay_params   = artifacts_dict['overlay_info_obj']
        self._layer_params     = artifacts_dict['layer_info_obj']
        self._kernel_params    = artifacts_dict['kernel_info_obj']
        self._schedule_params  = artifacts_dict['scheduling_obj']
        self._mem_tile_params  = artifacts_dict['mem_tile_params_obj']
        self._core_tile_params = artifacts_dict['core_tile_params_obj']
        self._shim_tile_params  = artifacts_dict['shim_tile_params_obj']
        self._padding_params   = artifacts_dict['layer_padding_obj']
        self.param_ptr         = {
                                    'orig_op_type': self._layer_params.get_value('orig_op_type'),
                                    'op_type': self._layer_params.get_value('op_type')
                                  }

        self.template_meta_data = {
                                    'op_type': self._layer_params.get_value('orig_op_type'),
                                    'ver': self._layer_params.get_value('op_type'),
                                    'mode': self._overlay_params.get_value('mode'),
                                    'overlay': self._overlay_params.get_value('overlay')
                                  }
        
        self.fast_pm            = self._program_args["fast_pm"]

    def execute(self, _pipeline_data):
        _pipeline_data.info['fast_pm'] = self.fast_pm
        op_name = self.param_ptr['orig_op_type'].split('_')[0]
        if op_name == "PWLA":
            self.execute_SiluGelu(_pipeline_data)
        elif op_name in ["MatMul"]:
            self.execute_MatMul(_pipeline_data)
        elif op_name in ["Add", "Mul"]:
            self.execute_MatAdd(_pipeline_data)
        elif op_name in ["RoPE"]:
            self.execute_RoPE(_pipeline_data)
        elif op_name in ["Conv"]:
            self.execute_Conv(_pipeline_data)
        elif op_name in ["LayerNormalization", "Softmax"]:
            self.execute_layernorm(_pipeline_data)
        elif op_name in ["GroupNormalization"]:
            self.execute_groupnorm(_pipeline_data)
        elif op_name in ["LpNormalization"]:
            self.execute_lpnorm(_pipeline_data)
        elif op_name == "MHA":
            self.execute_MHAHead(_pipeline_data)
        else:
            utils.sanity_check(False,f"Invalid opcode: {op_name}")

    def print_data_flow(self, instr):
        if not self._program_args.get("call_DMAC", False):
            with open(self._program_args['output_dir']+"/data_flow.py", "w") as text_file:
                text_file.write(instr)
        else:
            import types
            original_dir = os.getcwd()
            os.chdir(self._program_args['output_dir'])
            mod = types.ModuleType("data_flow")
            mod.__dict__['__file__'] = "data_flow.py"
            exec(instr, mod.__dict__)
            mod.main()
            os.chdir(original_dir)

    def execute_MHAHead(self, _pipeline_data):
        # Convert the _pipeline_data object into a dataclass containing the
        # buffer allocation results and dump them for the dataflow stage.
        buffer_allocs: BufferAllocations = get_buffer_allocations(_pipeline_data)
        buffer_allocs.dump(self._program_args["output_dir"])

    def execute_RoPE(self, _pipeline_data):
        self.code_template = template_selector.get_template_obj(self.template_meta_data, _pipeline_data)
        params             = self.code_template.prm.gen_dma_params(_pipeline_data)
        data_flow          = self.code_template.gen_code(params)
        self.print_data_flow(data_flow)

    def execute_MatAdd(self, _pipeline_data):
        #logging.info("Executing Access Pattern stage. %s", _pipeline_data.info.get('BuffAllocator'))
        self.code_template = template_selector.get_template_obj(self.template_meta_data, _pipeline_data)
        params             = self.code_template.prm.gen_dma_params(_pipeline_data)
        data_flow          = self.code_template.gen_code(params)
        self.print_data_flow(data_flow)

    def execute_SiluGelu(self, _pipeline_data):
        self.code_template = template_selector.get_template_obj(self.template_meta_data, _pipeline_data)
        buff           = _pipeline_data.info.get('BuffAllocator')
        core_tile_addr = buff[const.BufAllocator_Idx.CORE_TILE_ADDR_IDX.value]
        core_tile_size = buff[const.BufAllocator_Idx.CORE_TILE_SIZE_IDX.value]
        mem_tile_addr  = buff[const.BufAllocator_Idx.MEM_TILE_ADDR_IDX.value]
        buff_prm       = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
        params = {
							#For Debug
							'ParamSize': const.PARAM_SIZE,
							'Mgemm': buff_prm['M'],
							'Ngemm': buff_prm['K'],
							'Msubv': buff_prm['M_subV'],
							'Nsubv': buff_prm['K_subV'],
							'ifm_byte_len': buff_prm['ifm'].bytes,
							'ofm_byte_len': buff_prm['ofm'].bytes,
							#For Core tile
							'CoreActPingAddr': core_tile_addr['CoreIfmAddr'][0],
							'CoreActPongAddr': None if len(core_tile_addr['CoreIfmAddr']) < 2 else core_tile_addr['CoreIfmAddr'][1],
							'CoreLutabAddr': core_tile_addr['CoreLutabAddr'],
							'CoreLutcdAddr': core_tile_addr['CoreLutcdAddr'],
                            'CoreQdqAddr': core_tile_addr['CoreQdqAddr'][0],
                            'CoreTdmAddr': core_tile_addr['CoreTdmAddr'],
							'CoreLutSize': core_tile_size['CoreLutabSize'] + core_tile_size['CoreLutcdSize'],
							'CoreOutPingAddr': core_tile_addr['CoreOfmAddr'][0],
							'CoreOutPongAddr': None if len(core_tile_addr['CoreOfmAddr']) < 2 else core_tile_addr['CoreOfmAddr'][1],
							'CoreActSize':  core_tile_size['CoreIfmSize'],
							'CoreLutabSize':  core_tile_size['CoreLutabSize'],
							'CoreLutcdSize':  core_tile_size['CoreLutcdSize'],
                            'CoreTdmSize': core_tile_size['CoreTdmSize'],
                            'CoreQdqSize': core_tile_size['CoreQdqSize'],
							'CoreOutSize':  core_tile_size['CoreOfmSize'],
							'Core_Tn':      buff_prm['mem_tile'].itr['ifm'][0],
							#For Memtile
							'MemtilePrmPingAddr': mem_tile_addr['MemtilePrmAddr'],
							'CorePrmSize'       : const.PARAM_SIZE,
							'CorePrmWords'      : const.PARAM_SIZE // 4, 
							'MemtileLutAddr'    : mem_tile_addr['MemtileLutAddr'],
                            'MemtileQdqAddr'    : mem_tile_addr['MemtileQdqAddr'],
							'MemtileLutWords'   : (core_tile_size['CoreLutabSize'] + core_tile_size['CoreLutcdSize']) // 4,
							'Mem_Tn':        buff_prm['mem_tile'].itr['ifm'][0],
							'MemtileActPingAddr'  : mem_tile_addr['MemtileIfmPingAddr'],
							'MemtileActPongAddr'  : mem_tile_addr['MemtileIfmPongAddr'],
							'MemtileActSize'      : buff_prm['ifm_mem_tile_size'],
							'MemtileActWords'     : buff_prm['ifm_mem_tile_size'] // 4,
							'MemtileActmm2swords' : (buff_prm['ifm_mem_tile_size'] // 4) // buff_prm['aie_rows'],
							'MemtileOutPingAddr'  :  mem_tile_addr['MemtileOfmPingAddr'],
							'MemtileOutPongAddr'  :  mem_tile_addr['MemtileOfmPongAddr'],
							'CoreOutWords'        :  core_tile_size['CoreOfmSize'] // 4, #what about unequal split
							'MemtileOutSize'      :  buff_prm['ofm_mem_tile_size'],
							'MemtileOutWords'     :  buff_prm['ofm_mem_tile_size'] // 4,
							#For Shimtile
							'Shim_Tn':           buff_prm['mem_tile'].itr['ifm'][0],
							'ShimLutWords':      (core_tile_size['CoreLutabSize']+ core_tile_size['CoreLutcdSize']) // 4,
							'ShimActSize':       buff_prm['ifm_shim_tile_size'],
							'ShimActWords':      buff_prm['ifm_shim_tile_size'] // 4,
							'ShimOutSize':       buff_prm['ofm_shim_tile_size'],
							'ShimOutWords':      buff_prm['ofm_shim_tile_size'] // 4,
							#Misc
							'aie_cols':  buff_prm['aie_cols'],
							'aie_rows': buff_prm['aie_rows'],
							'overlay': buff_prm['overlay'],
							'active_col': buff_prm['active_col'],
							'backend_type': 'BackEnd.Adf',
							'CoreStackAddr': core_tile_addr['CoreStackAddr'][0],
				}
        data_flow          = self.code_template.gen_code(params)
        self.print_data_flow(data_flow)


    def execute_layernorm(self, _pipeline_data):
        self.code_template = template_selector.get_template_obj(self.template_meta_data, _pipeline_data)
        buff           = _pipeline_data.info.get('BuffAllocator')
        buff_prm       = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
        aie_inst = self._overlay_params.get_value('aieinst')
        overlay = self._overlay_params.get_value('overlay')
        op_name = self.param_ptr['orig_op_type'].split('_')[0]
        if(op_name == 'LayerNormalization'):
            opc = 'LRN'
        else:
            opc = op_name

        params = {
                    'op_code' : opc,
                    'op_ver'   : self.param_ptr['op_type'],
                    'ParamSize': const.PARAM_SIZE,
                    'Mlrn': buff_prm['ifm'].dim[0], 
                    'Nlrn': buff_prm['ifm'].dim[1], 
                    'TdimLayer': self._kernel_params.get_value('tdim_layer'),
                    'IfmBytesPerElem': self._layer_params.get_value('in_ifm_bytes'),
                    'OfmBytesPerElem': self._layer_params.get_value('out_ofm_bytes'),
                    'Kgran': self._kernel_params.get_value('kernel_gran'),
                    'CoreTilings': self._core_tile_params.get_value('subvols')['ifm'],
                    'MemTilings_ifm_s2mm': self._mem_tile_params.get_value('subvols')['ifm']['s2mm'],
                    'MemTilings_ifm_mm2s': self._mem_tile_params.get_value('subvols')['ifm']['mm2s'],
                    'ShimTilings': self._shim_tile_params.get_value('subvols')['ifm'],
                    'SplitType': self._kernel_params.get_value('split_type'),
                    'Overlay' : overlay,
                    'AieInst' : aie_inst,
                    'AieCols' : self._overlay_params.get_value('shape')['col'],
                    'AieRows' : self._overlay_params.get_value('shape')['row']

                }
        data_flow          = self.code_template.gen_code(params)
        self.print_data_flow(data_flow)
    
    def execute_groupnorm(self, _pipeline_data):
        self.code_template = template_selector.get_template_obj(self.template_meta_data, _pipeline_data)
        buff           = _pipeline_data.info.get('BuffAllocator')
        buff_prm       = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
        overlay = self._overlay_params.get_value('overlay')
        if(self._overlay_params.get_value('shape')['col'] ==8 and self._overlay_params.get_value('shape')['row'] == 4):
            aie_inst = 2
        else:
            aie_inst = 1
        
        params = {
                    'op_code' : 'GPN',
                    'op_ver'   : self.param_ptr['op_type'],
                    'ParamSize': const.PARAM_SIZE,
                    'Mlrn': buff_prm['ifm'].dim[0],
                    'Nlrn': buff_prm['ifm'].dim[1],
                    'TdimLayer': self._kernel_params.get_value('tdim_layer'),
                    'IfmBytesPerElem': self._layer_params.get_value('in_ifm_bytes'),
                    'OfmBytesPerElem': self._layer_params.get_value('out_ofm_bytes'),
                    'CoreMsubv': self._core_tile_params.get_value('subvols')['ifm'][0],
                    'CoreNsubv': self._core_tile_params.get_value('subvols')['ifm'][1],
                    'CoreTsubv': self._core_tile_params.get_value('iters')['ifm'][0],
                    'CoreMsubvNorm': self._core_tile_params.get_value('subvols')['ofm'][0],
                    'CoreNsubvNorm': self._core_tile_params.get_value('subvols')['ofm'][1],
                    'CoreTsubvNorm': self._core_tile_params.get_value('iters')['ofm'][0],
                    'MemMsubv': self._mem_tile_params.get_value('subvols')['ifm'][0],
                    'MemNsubv': self._mem_tile_params.get_value('subvols')['ifm'][1],
                    'Overlay' : overlay,
                    'AieInst' : aie_inst,
                    'AieCols' : self._overlay_params.get_value('shape')['col'],
                    'AieRows' : self._overlay_params.get_value('shape')['row']
                }
        data_flow          = self.code_template.gen_code(params)
        self.print_data_flow(data_flow)
    
    def execute_lpnorm(self, _pipeline_data):
        self.code_template = template_selector.get_template_obj(self.template_meta_data, _pipeline_data)
        buff           = _pipeline_data.info.get('BuffAllocator')
        buff_prm       = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
        aie_inst = self._overlay_params.get_value('aieinst')
        overlay = self._overlay_params.get_value('overlay')

        params = {
                    'op_code' : 'LpNorm',
                    'op_ver'   : self.param_ptr['op_type'],
                    'ParamSize': const.PARAM_SIZE,
                    'Mlrn': buff_prm['ifm'].dim[0],
                    'Nlrn': buff_prm['ifm'].dim[1],
                    'TdimLayer': self._kernel_params.get_value('tdim_layer'),
                    'IfmBytesPerElem': self._layer_params.get_value('in_ifm_bytes'),
                    'OfmBytesPerElem': self._layer_params.get_value('out_ofm_bytes'),
                    'Kgran': self._kernel_params.get_value('kernel_gran'),
                    'CoreTilings': self._core_tile_params.get_value('subvols')['ifm'],
                    'MemTilings_ifm_s2mm': self._mem_tile_params.get_value('subvols')['ifm']['s2mm'],
                    'MemTilings_ifm_mm2s': self._mem_tile_params.get_value('subvols')['ifm']['mm2s'],
                    'ShimTilings': self._shim_tile_params.get_value('subvols')['ifm'],
                    'SplitType': self._kernel_params.get_value('split_type'),
                    'Overlay' : overlay,
                    'AieInst' : aie_inst,
                    'AieCols' : self._overlay_params.get_value('shape')['col'],
                    'AieRows' : self._overlay_params.get_value('shape')['row']

                }
        data_flow          = self.code_template.gen_code(params)
        self.print_data_flow(data_flow)


    def execute_MatMul(self, _pipeline_data):
        #logging.info("Executing Access Pattern stage. %s", _pipeline_data.info.get('BuffAllocator'))
        self.code_template = template_selector.get_template_obj(self.template_meta_data, _pipeline_data) #M4N4(_pipeline_data)
        params             = self.code_template.gen_dma_params(_pipeline_data)
        data_flow          = self.code_template.gen_instr(params)
        self.print_data_flow(data_flow)

    def execute_Conv(self, _pipeline_data):
        def update_buffer_alloc_to_params(params, buff):
            buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
            params['ParamSize']            = const.PARAM_SIZE
            params['CoreBankSize']         = const.CORE_BANK_SIZE
            params['ShimOutBufferIdx']     = 0
            params['ShimActBufferIdx']     = 1
            params['ShimWgtBufferIdx']     = 2
            params['ShimPrmBufferIdx']     = 3
            params.update(buff[const.BufAllocator_Idx.CORE_TILE_ADDR_IDX.value])
            params.update(buff[const.BufAllocator_Idx.CORE_TILE_SIZE_IDX.value])
            params.update(buff[const.BufAllocator_Idx.MEM_TILE_ADDR_IDX.value])
            params['ping_pong_enable']     = buff_prm['sch_attr'].ping_pong_enable
            params['ActDataFlow']          = buff_prm['sch_attr'].dataflow_mode['ifm']
            params['WgtDataFlow']          = buff_prm['sch_attr'].dataflow_mode['wgt']
            params['ActPingPong']          = buff_prm['sch_attr'].ping_pong_enable['ifm']
            params['WgtPingPong']          = buff_prm['sch_attr'].ping_pong_enable['wgt']
            params['OfmPingPong']          = buff_prm['sch_attr'].ping_pong_enable['ofm']
            params['ShimActSize']          = buff_prm['ifm_shim_tile_size']
            params['ShimWgtSize']          = buff_prm['wgt_shim_tile_size']
            params['MemtileActSize']       = buff_prm['ifm_mem_tile_size']
            params['MemtileWgtSize']       = buff_prm['wgt_mem_tile_size']
            params['MemtileOutSize']       = buff_prm['ofm_mem_tile_size']
            params['ActDtype']             = buff_prm['ifm'].dtype
            params['WgtDtype']             = buff_prm['wgt'].dtype
            params['OfmDtype']             = buff_prm['ofm'].dtype
            params['ActFormat']            = decode_data_format(buff_prm['ifm'].dtype)
            params['WgtFormat']            = decode_data_format(buff_prm['wgt'].dtype)
            params['OfmFormat']            = decode_data_format(buff_prm['ofm'].dtype)
            params['min_cols']             = 8
            params['min_rows']             = 8

        def decode_data_format(dtype):
            if dtype == "bfp16":
                return "BFP_A8"
            elif dtype == "bf16":
                return "BF16_A8"
            else:
                return "default"

        def calc_rep_params(params):
            def format_Prm_Qdq(params, MemPrmRep=[1], MemQdqRep=[1], ShimPrmRep=[1], ShimQdqRep=[1]):
                params['MemtileParamRepeat']   = MemPrmRep
                params['MemtileQdqPrmRepeat']  = MemQdqRep
                params['ShimParamRepeat']      = ShimPrmRep
                params['ShimQdqPrmRepeat']     = ShimQdqRep

            def format_Act_repeat(params, 
                                  MemActPingRep, MemActPongRep, MemActRR, 
                                  ShimActOLoop, ShimActRep, ShimActRR=1,
                                  ):
                params['MemtileActPingRepeat'] = MemActPingRep
                params['MemtileActPongRepeat'] = MemActPongRep
                params['MemtileActReuseRatio'] = MemActRR
                params['ShimActOuterLoop']     = ShimActOLoop
                params['ShimActRepeat']        = ShimActRep
                params['ShimActReuseRatio']    = ShimActRR
                utils.sanity_check((sum(ShimActRep)  <= 64), f"Max Shim BD repeat is 64 (ShimActRep = {ShimActRep})", "Message")

            def format_Wgt_repeat(params, MemWgtRep, MemWgtRR, ShimWgtRep, ShimWgtRR=1):
                params['MemtileWgtRepeat']     = MemWgtRep
                params['MemtileWgtReuseRatio'] = MemWgtRR
                params['ShimWgtRepeat']        = ShimWgtRep
                params['ShimWgtReuseRatio']    = ShimWgtRR
                utils.sanity_check((sum(ShimWgtRep)  <= 64), f"Max Shim BD repeat is 64 (ShimWgtRep = {ShimWgtRep})", "Message")

            if params['dims'].Mpad != 0 :
                params['enable_padding']       = True
                params['enable_pack_shim']     = True
                params['enable_packtransfer']  = True
            elif params['dims'].Kpad != 0:
                params['enable_padding']       = True
                params['enable_pack_shim']     = True
                params['enable_packtransfer']  = False
            else:
                params['enable_padding']       = False
                params['enable_pack_shim']     = True
                params['enable_packtransfer']  = False

            assert params['ActDataFlow']=="pin" or "stream" or "full", "Activation mode: pin, Stream"
            assert params['WgtDataFlow']=="pin" or "stream" or "full", "Weight mode: pin, Stream"
            
            Tm = params['dims'].outer_loop
            Tn = params['dims'].inner_loop
            Tk = params['dims'].acc_loop

            format_Prm_Qdq(params)

            #activation flow configuration
            if params['ActDataFlow']=="pin" and params['ActPingPong'] and params['dims'].Mpad > 0:
                repeat_ping = ([1, 0] * (Tm//2)) + ([1] * (Tm % 2))
                repeat_pong = ([0, 1] * (Tm//2)) + ([0] * (Tm % 2))
                format_Act_repeat(params, repeat_ping, repeat_pong, Tn, 1, [1])
            if params['ActDataFlow']=="pin" and params['ActPingPong'] and params['dims'].Mpad == 0:
                repeat_ping = ([1, 0] * (Tm//2)) + ([1] * (Tm % 2))
                repeat_pong = ([0, 1] * (Tm//2)) + ([0] * (Tm % 2))
                format_Act_repeat(params, repeat_ping, repeat_pong, Tn, 1, [1])
            if params['ActDataFlow']=="pin" and not params['ActPingPong'] and params['dims'].Mpad > 0:
                repeat_ping = [Tm-1] + [0] * (Tm-3) + [1]
                format_Act_repeat(params, repeat_ping, None, Tn, 1, [1])
            if params['ActDataFlow']=="pin" and not params['ActPingPong'] and params['dims'].Mpad == 0:
                format_Act_repeat(params, [Tm], None, Tn, 1, [1])
            if params['ActDataFlow']=="stream" and params['ActPingPong']:
                repeat_ping = ([1, 0] * ((Tm * Tk * Tn)//2)) + ([1] * ((Tm * Tk * Tn) % 2))
                repeat_pong = ([0, 1] * ((Tm * Tk * Tn)//2)) + ([0] * ((Tm * Tk * Tn) % 2))
                format_Act_repeat(params, repeat_ping, repeat_pong, 1, Tm, [Tn])
            if params['ActDataFlow']=="stream" and not params['ActPingPong'] and params['dims'].Mpad > 0:
                assert params['MInsert0']==0, "only supported appending 0 to M at this moment"
                repeat_ping = [Tm * Tk * Tn-1] + [1]
                format_Act_repeat(params, repeat_ping, None, 1, Tm, [Tn])
            if params['ActDataFlow']=="stream" and not params['ActPingPong'] and params['dims'].Mpad == 0:
                format_Act_repeat(params, [Tm * Tk * Tn], None, 1, Tm, [Tn])
            if params['ActDataFlow']=="full":
                raise Exception("Only Activation Pinning is supported")
                format_Act_repeat(params, [Tm], None, Tn, 1, [1])

            #weight flow configuration
            if params['WgtDataFlow']=="pin":
                raise Exception("Weight Pinning is not supported")
                format_Wgt_repeat(params, [Tm * Tk * Tn], 1, [Tm])
            if params['WgtDataFlow']=="stream":
                format_Wgt_repeat(params, [Tm * Tk * Tn], 1, [Tm])
            if params['WgtDataFlow']=="full" and not params['WgtPingPong']:
                format_Wgt_repeat(params, [1], Tm, [1])
            if params['WgtDataFlow']=="full" and params['WgtPingPong']:
                raise Exception("Wgt full mode with pingpong is not supported")

            #ofm flow configuration
            params['MemtileOutRepeat']     = [Tn * Tm]
            params['ShimOutRepeat']        = [1]
            utils.sanity_check((sum(params['ShimOutRepeat']) <= 64), f"Max Shim BD repeat is 64 (ShimOutRepeat = {params['ShimOutRepeat']})", "Message")

            #pad repeat list
            max_len = 0
            for _item in params:
                if _item[-6:] == 'Repeat' and params[_item] != None and max_len < len(params[_item]): 
                    max_len = len(params[_item])
            for _item in params:
                if _item[-6:] == 'Repeat' and params[_item] != None and len(params[_item]) < max_len:
                        params[_item] += [0] * (max_len - len(params[_item]))

        def gen_mem_act_transfers(self, params):
            helpfunc = self.code_template.helper_func
            dims     = params['dims']
            mode     = params['ActDataFlow']
            PingPong = params['ActPingPong']
            Mpad     = params['dims'].Mpad
            Kpad     = params['dims'].Kpad

            if mode == "pin" and (not PingPong or PingPong) and Mpad == 0 and Kpad > 0: #both pingpong and non-pingpong cases
                params['act_memtile_memory']   =  helpfunc.act_memtile_memory(params)
                params['act_memtile_s2mm']     =  helpfunc.act_memtile_s2mm(params)
                params['act_memtile_mm2s']     =  helpfunc.act_memtile_mm2s(params, Kpad=Kpad)
            elif mode == "stream" and (not PingPong or PingPong) and Mpad == 0 and Kpad > 0:
                params['act_memtile_memory']   =  helpfunc.act_memtile_memory(params)
                params['act_memtile_s2mm']     =  helpfunc.act_memtile_s2mm(params)
                params['act_memtile_mm2s']     =  helpfunc.act_memtile_mm2s(params,Kpad=Kpad)
            elif (mode == "pin" or mode == "stream") and not PingPong and Mpad > 0 and Kpad == 0: #not supported
                    full_transfers = dims.M // dims.M_subv
                    padded_itr = [(dims.M + params['dims'].Mpad) // dims.aie_cols // dims.M_subv] * dims.aie_cols
                    full_itr   = [(dims.M // dims.M_subv + dims.aie_cols -x ) // dims.aie_cols for x in range(1, dims.aie_cols+1)]
                    final_itr  = [x - y for x, y in zip (padded_itr, full_itr)]

                    Mpad                           = [params['dims'].Mpad % dims.M_subv if final_itr[col]==1 else 0                                 for col in range(dims.aie_cols)]
                    params['act_memtile_memory']   = [(helpfunc.act_memtile_memory(params), helpfunc.act_memtile_memory(params, Mpad[col]))         for col in range(dims.aie_cols)]
                    params['act_memtile_s2mm']     = [(helpfunc.act_memtile_s2mm(params), 
                                                       helpfunc.act_memtile_s2mm(params, Mpad[col]))                                                for col in range(dims.aie_cols)]
                    params['act_memtile_mm2s']     = [(helpfunc.act_memtile_mm2s(params), helpfunc.act_memtile_mm2s(params))                        for col in range(dims.aie_cols)]
                
            #Directly split the transfer when there are ping pong buffer
            elif (mode == "pin" or mode == "stream") and PingPong and Mpad > 0 and Kpad == 0: #not supported
                full_transfers = dims.M // dims.M_subv
                padded_itr = [(dims.M + params['dims'].Mpad) // dims.aie_cols // dims.M_subv // 2] * dims.aie_cols * 2
                full_itr   = [(full_transfers + dims.aie_cols * 2 - (x + 1)) // dims.aie_cols // 2 for x in range(dims.aie_cols * 2)]
                final_itr  = [x - y for x, y in zip (padded_itr, full_itr)]

                Mpad                           = [params['dims'].Mpad % dims.M_subv if final_itr[col]==1 else 0 
                                                            for col in range(dims.aie_cols*2)]
                params['act_memtile_memory']   = [((helpfunc.act_memtile_memory(params), helpfunc.act_memtile_memory(params, Mpad[2*col])),
                                                 (helpfunc.act_memtile_memory(params), helpfunc.act_memtile_memory(params, Mpad[2*col+1]))) 
                                                            for col in range(dims.aie_cols)]
                params['act_memtile_s2mm']     = [((helpfunc.act_memtile_s2mm(params), 
                                                    helpfunc.act_memtile_s2mm(params, Mpad[2*col])),
                                                   (helpfunc.act_memtile_s2mm(params), 
                                                    helpfunc.act_memtile_s2mm(params, Mpad[2*col+1]))) 
                                                            for col in range(dims.aie_cols)]
                params['act_memtile_mm2s']     = [((helpfunc.act_memtile_mm2s(params), helpfunc.act_memtile_mm2s(params)),
                                                 (helpfunc.act_memtile_mm2s(params), helpfunc.act_memtile_mm2s(params)))
                                                            for col in range(dims.aie_cols)]
            #non-padding cases
            elif not params['enable_padding']:
                params['act_memtile_memory']   =  helpfunc.act_memtile_memory(params)
                params['act_memtile_s2mm']     =  helpfunc.act_memtile_s2mm(params)
                params['act_memtile_mm2s']     =  helpfunc.act_memtile_mm2s(params)
            else:
                raise Exception("Reach unsupported case setting")

        def gen_shim_act_transfers(self, params):
            helpfunc = self.code_template.helper_func
            dims     = params['dims']
            mode     = params['ActDataFlow']
            PingPong = params['ActPingPong']
            Mpad     = params['dims'].Mpad
            Kpad     = params['dims'].Kpad
            
            if mode == "pin" and not PingPong and Mpad == 0 and Kpad > 0:
                params['act_shim_memory'] = helpfunc.act_shim_memory(params)
                params['act_shim_mm2s']   = [[helpfunc.act_shim_mm2s(params, col, itr) 
                                                            for itr in range(params['ShimActOuterLoop'])]
                                                            for col in range(dims.aie_cols)]
            elif mode == "pin" and not PingPong and Mpad > 0 and Kpad == 0: #not supported
                if params['ActDataFlow']=="pin":
                    params['act_shim_memory']      = [(tuple([helpfunc.act_shim_memory(params, params['dims'].Mpad)]*2))                           for col in range(dims.aie_cols)]
                    params['act_shim_mm2s']        = [[(helpfunc.act_shim_mm2s_wpadding(dims, params['ActDataFlow'], col, itr, False), 
                                                        helpfunc.act_shim_mm2s_wpadding(dims, params['ActDataFlow'], col, itr, itr == params['ShimActOuterLoop']-1, Mpad[col]))
                                                                for itr in range(params['ShimActOuterLoop'])]
                                                                for col in range(dims.aie_cols)]
                if params['ActDataFlow']=="stream":
                    params['act_shim_memory']      =  helpfunc.act_shim_memory(params)
                    params['act_shim_mm2s']        = [[helpfunc.act_shim_mm2s(params, col, itr, itr == params['ShimActOuterLoop']-1, Mpad[col]) 
                                                                for itr in range(params['ShimActOuterLoop'])]
                                                                for col in range(dims.aie_cols)]
            elif mode == "pin" and PingPong and Mpad > 0 and Kpad == 0: #not supported
                if params['ActDataFlow']=="pin":
                    params['act_shim_memory']      = [(tuple([helpfunc.act_shim_memory(dims, params['ActFormat'], params['dims'].Mpad)]*2))                           for col in range(dims.aie_cols)]
                    params['act_shim_mm2s']        = [[(helpfunc.act_shim_mm2s_wpadding(dims, params['ActDataFlow'], col, itr, False), 
                                                        helpfunc.act_shim_mm2s_wpadding(dims, params['ActDataFlow'], col, itr, itr == params['ShimActOuterLoop']-1, Mpad[col]))
                                                                for itr in range(params['ShimActOuterLoop'])]
                                                                for col in range(dims.aie_cols)]
                if params['ActDataFlow']=="stream":
                    params['act_shim_memory']      =  helpfunc.act_shim_memory(params)
                    params['act_shim_mm2s']        = [[helpfunc.act_shim_mm2s(params, col, itr, itr == params['ShimActOuterLoop']-1, Mpad[col]) 
                                                                for itr in range(params['ShimActOuterLoop'])]
                                                                for col in range(dims.aie_cols)]
            elif Mpad == 0 and Kpad == 0:
                params['act_shim_memory'] =  helpfunc.act_shim_memory(params)
                params['act_shim_mm2s']   = [[helpfunc.act_shim_mm2s(params, col, itr) 
                                                            for itr in range(params['ShimActOuterLoop'])]
                                                            for col in range(dims.aie_cols)]
            else:
                params['act_shim_memory'] =  helpfunc.act_shim_memory(params)
                params['act_shim_mm2s']   = [[helpfunc.act_shim_mm2s(params, col, itr) 
                                                            for itr in range(params['ShimActOuterLoop'])]
                                                            for col in range(dims.aie_cols)]
                
        def gen_wgt_tranfsers(self, params):
            dims = params['dims']
            mode = params['WgtDataFlow']
            helpfunc = self.code_template.helper_func
            params['wgt_memtile_memory']   =  helpfunc.wgt_memtile_memory(params)
            params['wgt_memtile_s2mm']     =  helpfunc.wgt_memtile_s2mm(dims, mode)
            params['wgt_memtile_mm2s']     =  helpfunc.wgt_memtile_mm2s(params['dims'], mode)
            params['wgt_shim_memory']      =  helpfunc.wgt_shim_memory(params)
            params['wgt_shim_mm2s']        = [helpfunc.wgt_shim_mm2s(params, 0, col) for col in range(dims.aie_cols)]

        def gen_out_transfers(self, params):
            dims = params['dims']
            helpfunc = self.code_template.helper_func
            params['out_memtile_memory']   =  helpfunc.out_memtile_memory(params)
            params['out_memtile_s2mm']     = [helpfunc.out_memtile_s2mm(params, row) for row in range(dims.aie_rows)]
            params['out_memtile_mm2s']     =  helpfunc.out_memtile_mm2s(params)
            params['out_shim_memory']      =  helpfunc.out_shim_memory(params, params['dims'].Mpad)
            params['out_shim_s2mm']        = [helpfunc.out_shim_s2mm(params, 0, col, params['dims'].Mpad) for col in range(dims.aie_cols)]

        def gen_dma_transfers(self, params):
            dims = params['dims']
            helpfunc = self.code_template.helper_func
            params['reuse_chain_length']   =  helpfunc.reuse_chain_length(dims.outer_loop, dims.aie_rows)
            params['sync_strategy']        = 'SyncStrategy.Parallel_1_to_N' #'Default'

            if params['CoreQdqAddr'][0] != None: #qdq
                params['gemm_params']          = [helpfunc.gemm_params_qdq(dims, 1, params['CoreQdqAddr'][0], params['CoreSumAddr'][0], params['CoreTdmAddr'][0], params['CoreTdmAddr'][1], self.param_ptr, True),
                                                  helpfunc.gemm_params_qdq(dims, 0, params['CoreQdqAddr'][0], params['CoreSumAddr'][0], params['CoreTdmAddr'][0], params['CoreTdmAddr'][1], self.param_ptr, True),
                                                  helpfunc.gemm_params_qdq(dims, 0, params['CoreQdqAddr'][0], params['CoreSumAddr'][0], params['CoreTdmAddr'][0], params['CoreTdmAddr'][1], self.param_ptr, True),
                                                  helpfunc.gemm_params_qdq(dims, 1, params['CoreQdqAddr'][0], params['CoreSumAddr'][0], params['CoreTdmAddr'][0], params['CoreTdmAddr'][1], self.param_ptr, True)]
            else: #bfp16
                supported_subv_list = [
                        (16, 80, 80, 0),
                        (16, 64, 80, 0),
                        (128, 64, 16, 0),
                        (32, 128, 32, 0),
                        (64, 80, 64, 0),
                        (64, 80, 80, 0),
                        (32, 128, 64, 0),
                        (128, 64, 32, 0),
                        (64, 128, 32, 0),
                        (32, 128, 80, 0),
                        (32, 128, 64, 1),
                        (64, 128, 32, 1),
                        (80, 128, 32, 1),
                        (64, 80, 64, 1),
]
                transpose_out = 0  #TODO: hardcoded for now
                act_in_bfp16 = params['ActDtype']=='bfp16'
                act_out_bfp  = params['OfmDtype'] =='bfp16'
                gelu_fuse    = 0
                sub_vol_id   = supported_subv_list.index((params['dims'].M_subv, params['dims'].K_subv, params['dims'].N_subv, transpose_out))
                sub_vol_id   = sub_vol_id + 14 if act_out_bfp else sub_vol_id  #TODO: suv_vol_id is offseted by 14 if bfp16 output is enabled

                params['gemm_params']          = [helpfunc.gemm_params_bfp16(1, 0, sub_vol_id, params['CoreTdmAddr'][0], params['CoreTdmAddr'][1], transpose_out, act_in_bfp16, act_out_bfp, gelu_fuse, 1),
                                                  helpfunc.gemm_params_bfp16(0, 0, sub_vol_id, params['CoreTdmAddr'][0], params['CoreTdmAddr'][1], transpose_out, act_in_bfp16, act_out_bfp, gelu_fuse, 0),
                                                  helpfunc.gemm_params_bfp16(0, 1, sub_vol_id, params['CoreTdmAddr'][0], params['CoreTdmAddr'][1], transpose_out, act_in_bfp16, act_out_bfp, gelu_fuse, 0),
                                                  helpfunc.gemm_params_bfp16(1, 1, sub_vol_id, params['CoreTdmAddr'][0], params['CoreTdmAddr'][1], transpose_out, act_in_bfp16, act_out_bfp, gelu_fuse, 1)]
            gen_shim_act_transfers(self, params)
            gen_mem_act_transfers(self, params)
            gen_wgt_tranfsers(self, params)
            gen_out_transfers(self, params)

        def gen_dma_params(_pipeline_data):
            buff = _pipeline_data.info.get('BuffAllocator')
            buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
            base_Data = (buff_prm['M']+sum(buff_prm['padding'][0]['pad_ifm_x']), 
                         buff_prm['K'], 
                         buff_prm['N'],
                         buff_prm['M_subV'], buff_prm['K_subV'], buff_prm['N_subV'],
                         buff_prm['aie_rows'], buff_prm['aie_cols'], buff_prm['aie_arrays'],
                         buff_prm['ifm_bits'], buff_prm['ofm_bits'], buff_prm['qdq_bytes'],
                         buff_prm['outer_loop'], buff_prm['inner_loop'], buff_prm['acc_loop'],
                         buff_prm['wgt_subv_rows'], buff_prm['wgt_subv_cols'],
                         buff_prm['ifm_core_subv_bytes'], buff_prm['wgt_core_subv_bytes'], 
                         buff_prm['ofm_core_subv_bytes'],
                         buff_prm['sum_core_subv_bytes'], buff_prm['tdm_core_subv_bytes'],
                         buff_prm['wgt_bits'], const.BITS_PER_BYTE, buff_prm['bias_bits'],
                         buff_prm['tdm_bits'], buff_prm['is_pwla_fused'], buff_prm['is_rope_fused'], buff_prm['is_fused_rope_actxact'], buff_prm['is_elew_fused'])
            op_Data = (sum(buff_prm['padding'][0]['pad_ifm_x']), sum(buff_prm['padding'][0]['pad_ifm_y']),)
            data = base_Data + op_Data
            dims = template_selector.get_template_define_struct(self.template_meta_data , data) 
            #dims = template_selector.get_template_define_struct("conv", data) 
            #assert dims.outer_loop == self._mem_tile_params.get_value('iters')['ifm'][0], f"dims.outer_loops {dims.outer_loop} must equal Tm {self._mem_tile_params.get_value('iters')['ifm'][0]}"
            #assert dims.inner_loop == self._mem_tile_params.get_value('iters')['wgt'][1], f"dims.inner_loops {dims.inner_loop} must equal Tn {self._mem_tile_params.get_value('iters')['wgt'][1]}"
            #assert dims.acc_loop   == self._mem_tile_params.get_value('iters')['wgt'][0], f"dims.acc_loop {dims.acc_loop} must equal Tk_core {self._mem_tile_params.get_value('iters')['wgt'][0]}"

            params = {'dims': dims}
            params['op_name'] = self.param_ptr['orig_op_type']
            params['op_ver'] = self.param_ptr['op_type']
            #params = {'op_ver': self.param_ptr['orig_op_type']}
            update_buffer_alloc_to_params(params, buff)
            calc_rep_params(params)
            gen_dma_transfers(self, params)
            return params
        
        #logging.info("Executing Access Pattern stage. %s", _pipeline_data.info.get('BuffAllocator'))
        self.code_template = template_selector.get_template_obj(self.template_meta_data, _pipeline_data) #M4N4(_pipeline_data)
        #self.code_template = template_selector.get_template_obj("conv", _pipeline_data) #M4N4(_pipeline_data)
        params             = gen_dma_params(_pipeline_data)
        data_flow          = self.code_template.gen_code(params)
        self.print_data_flow(data_flow)

