import sys
import os

infra_path = os.path.dirname(os.path.abspath(__file__)) + "/infra/"
sys.path.append(infra_path)
import logging
import pdb
import math
from enum import Enum
import re
import const
import dataclasses
from template_base import BaseTemplate, BaseDims
import gen_kernel_param

import scheduler_utils as utils

@dataclasses.dataclass(slots=True)
class GemmDims(BaseDims):
    Mpad: int
    Kpad: int


class RoPE_Base(BaseTemplate):
    def __init__(self, _data=None):
        super().__init__()
        self.data = _data
        self.helper_func = self.helper_func()
        self.prm = self.params()
    
    def gen_code(self, params):
        logging.info(f"Genrate code for RoPE operation")
        
        data_flow = ""
        data_flow += self.gen_headers(params["template_meta_data"]["overlay"])
        data_flow += self.gen_helper_func()
        data_flow += self.gen_dataflow()
        data_flow += self.dma_pattern_code(params)
        data_flow += self.gen_kernel_name(params["op_name"], params["op_ver"])
        data_flow += self.gen_connection_run_compile(
            params["template_meta_data"]["overlay"],
            params["dims"].aie_cols,
            params["dims"].aie_rows,
            params["CoreStackAddr"][0],
        )
        data_flow += self.gen_main_func()


        return data_flow

    def gen_rope_params(self):
        return """
def rope_layer_params(n: int, Msubv: int, Nsubv: int, qdq_addr: int, tdm_addr1: int, tdm_addr2: int, fused_op_flag: int):
    return (n.to_bytes(length=2, byteorder='little', signed=False)
            +Msubv.to_bytes(length=2, byteorder='little', signed=False)
            +Nsubv.to_bytes(length=2, byteorder='little', signed=False)
            +qdq_addr.to_bytes(length=2, byteorder='little', signed=False)
            +tdm_addr1.to_bytes(length=2, byteorder='little', signed=False)
            +tdm_addr2.to_bytes(length=2, byteorder='little', signed=False)
            +fused_op_flag.to_bytes(length=2, byteorder='little', signed=False)
            )
"""

    def gen_helper_func(self):
        return self.gen_rope_params()

    def gen_core_instr(self, params):
        meta_kernel_list = list((gen_kernel_param.return_kernel_path(
            params["op_name"], params["op_ver"]
        )[0]).keys())
        kernel_names = ["run_a16a16_rope_qdq"]
        kernel_idx = [meta_kernel_list.index(kname) for kname in kernel_names]
        code = f"""
    #CORE FLOW DEFINITION
    #ParamSize       = {const.PARAM_SIZE}
    #CoreBankSize    = {const.CORE_BANK_SIZE}
    #CoreQdqSize     = {params['CoreQdqSize']}
    #CoreActSize     = {params['CoreIfmSize']} + {params['CoreSinSize']} + {params['CoreCosSize']}
    #CoreOfmSize     = {params['CoreOfmSize']}
    #CoreStackSize   = {params['CoreStackSize']}
    #CoreQdqAddr     = {params['CoreQdqAddr'][0]}
    #CoreActAddr     = {params['CoreIfmAddr'][0]}, {params['CoreIfmAddr'][1]}
    #CoreSinAddr     = {params['CoreSinAddr'][0]}, {params['CoreSinAddr'][1]}
    #CoreCosAddr     = {params['CoreCosAddr'][0]}, {params['CoreCosAddr'][1]}
    #CoreOutAddr     = {params['CoreOfmAddr'][0]}, {params['CoreOfmAddr'][1]}
    #CoreStackAddr   = {params['CoreStackAddr'][0]}
    #CoreTdm1Addr    = {params['CoreTdm1Addr'][0]}
    #CoreTdm2Addr    = {params['CoreTdm2Addr'][0]}

    core_instrs = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), {params['CoreQdqAddr'][0]}, None, {params['dims'].qdq_bytes}), #CoreQdqAddr, QdqSize
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), {params['CoreIfmAddr'][0]}, {params['CoreIfmAddr'][1]}, {params['CoreIfmSize'] + params['CoreSinSize'] + params['CoreCosSize']}), #CoreActPingAddr, CoreActPongAddr, CoreActSize
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), {params['CoreOfmAddr'][0]}, {params['CoreOfmAddr'][1]}, {params['CoreOfmSize']}), #CoreOutPingAddr, CoreOutPongAddr, CoreOutSize

        Loop({params['dims'].Tm}, [
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
            CallKernel('{meta_kernel_list[kernel_idx[0]]}', kernel_params={params['gemm_params'][0]}),
            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        ]),
    ]
        """
        return code

    def gen_memtile_instr(self, params):
        memtile_transfers = f"""
    memtile_transfers = []
    #MEMTILE PARAM TRANSFERS\n"""
        params['MemtileParamRepeat'] = [params['MemtileParamRepeat'][0], 0, 0]
        memtile_transfers += self.memtile_prm_pattern(params)
        memtile_transfers += "    #MEMTILE QDQ PARAM TRANSFERS\n"
        params['MemtileQdqPrmRepeat'] = [params['MemtileQdqPrmRepeat'][0], 0, 0]
        memtile_transfers += self.memtile_qdq_pattern(params)
        memtile_transfers += (
            f"    #MEMTILE ACTIVATION TRANSFERS IFMs, {params['ActDataFlow'].upper()}\n"
        )
        memtile_transfers += "".join(
            [
                self.memtile_pattern(params, col, 'act')
                for col in range(params["dims"].aie_cols)
            ]
        )
        memtile_transfers += (
            f"    #MEMTILE SIN TRANSFERS IFMs \n"
        )
        memtile_transfers += "".join(
            [
                self.memtile_pattern(params, col, 'sin')
                for col in range(params["dims"].aie_cols)
            ]
        )
        memtile_transfers += (
            f"    #MEMTILE COS TRANSFERS IFMs \n"
        )
        memtile_transfers += "".join(
            [
                self.memtile_pattern(params, col, 'cos')
                for col in range(params["dims"].aie_cols)
            ]
        )
        memtile_transfers += "    #MEMTILE OFM TRANSFERS\n"
        params['MemtileOutRepeat'] = [0, 0, params['MemtileOutRepeat'][0]]
        memtile_transfers += "".join(
            [
                self.memtile_ofm_pattern(params, col)
                for col in range(params["dims"].aie_cols)
            ]
        )
        memtile_transfers += "\n"
        return memtile_transfers


    def memtile_pattern(self, params, col, memStr):
        if memStr == "act":
            mm2s_transfers = [
            f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}),  '{params['act_memtile_memory']}', '{params['act_memtile_mm2s'][row]}', {params['dims'].act_bits})"
            for row in range(params["dims"].aie_rows)
            for _ in range(params["reuse_chain_length"])
        ]
        elif memStr == "sin":
            mm2s_transfers = [
            f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}),  '{params['sin_memtile_memory']}', '{params['sin_memtile_mm2s'][row]}', {params['dims'].act_bits})"
            for row in range(params["dims"].aie_rows)
            for _ in range(params["reuse_chain_length"])
        ]
        elif memStr == "cos":
            mm2s_transfers = [
            f"generate_transfer_params(memtile_dma({col}, DmaDir.MM2S, {row}),  '{params['cos_memtile_memory']}', '{params['cos_memtile_mm2s'][row]}', {params['dims'].act_bits})"
            for row in range(params["dims"].aie_rows)
            for _ in range(params["reuse_chain_length"])
        ]
        else:
            mm2s_transsfer = []
            assert False, f'Invalid input: {memStr}'


        mm2s_transfers_join = ",\n             ".join(mm2s_transfers)

        if memStr == "act":
            repeatCnt = [params['MemtileActRepeatIfm'][0], 0, 0]
            memAddr = (
            [params["MemtileIfmPingAddr"], params["MemtileIfmPongAddr"]]
            if params["ping_pong_enable"]["ifm"]
            else [params["MemtileIfmPingAddr"]]
            )
            code = f"""    memtile_transfers += [
        DataTransfer({repeatCnt}, AieTile(TileType.Memtile, {col}), {memAddr}, {params["MemtileActSizeIfm"]},
        [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 0), '{params['act_memtile_memory']}', '{params['act_memtile_s2mm']}', {params['dims'].act_bits})],
        [{mm2s_transfers_join}],\n"""
        elif memStr == "sin":
            repeatCnt = [0, params['MemtileActRepeatIfm'][0], 0]
            memAddr = (
            [params["MemtileSinPingAddr"], params["MemtileSinPongAddr"]]
            if params["ping_pong_enable"]["sin"]
            else [params["MemtileSinPingAddr"]]
            )
            code = f"""    memtile_transfers += [
        DataTransfer({repeatCnt}, AieTile(TileType.Memtile, {col}), {memAddr}, {params["MemtileSinSize"]},
        [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 0), '{params['sin_memtile_memory']}', '{params['sin_memtile_s2mm']}', {params['dims'].act_bits})],
        [{mm2s_transfers_join}],\n"""
        elif memStr == "cos":
            repeatCnt = [0, 0, params['MemtileActRepeatIfm'][0]]
            memAddr = (
            [params["MemtileCosPingAddr"], params["MemtileCosPongAddr"]]
            if params["ping_pong_enable"]["sin"]
            else [params["MemtileCosPingAddr"]]
            )
            code = f"""    memtile_transfers += [
        DataTransfer({repeatCnt}, AieTile(TileType.Memtile, {col}), {memAddr}, {params["MemtileCosSize"]},
        [generate_transfer_params(memtile_dma({col}, DmaDir.S2MM, 0), '{params['cos_memtile_memory']}', '{params['cos_memtile_s2mm']}', {params['dims'].act_bits})],
        [{mm2s_transfers_join}],\n"""

        code += (
            " " * 12 + ")]\n"
            if params["MemtileActReuseRatio"] // params["reuse_chain_length"] == 1
            else " " * 12
            + f"reuse_ratio={params['MemtileActReuseRatio'] // params['reuse_chain_length']})]\n"
        )
        return code

    def shim_qdq_pattern(self, params):
        code = f"""    shim_transfers += [
        DataTransfer(
            {params['ShimQdqPrmRepeat']}, AieTile(TileType.Shim, col), [{params['ShimWgtBufferIdx']}], {params.get('ShimWgtSize',0) + params['dims'].qdq_bytes},
            [],
            [TransferParams(shim_dma(col, DmaDir.MM2S, 0), {params['dims'].qdq_bytes // 4}, offset={params.get('ShimWgtSize', 0) // 4})],
"""
        if params["template_meta_data"]["overlay"] == "4x4":
            code += f"        ) for col in range({params['dims'].aie_cols})\n    ]\n"
        elif params["template_meta_data"]["overlay"] == "8x4":
            code += (
                f"        ) for col in range(0, {params['dims'].aie_cols}, 2)\n    ]\n"
            )
        return code

    def shim_act_pattern(self, params, col, itr):
        code = f"    shim_transfers += [generate_shim_data_transfer([{params['ShimActRepeatIfm'][0]}, 0, 0], shim_dma({col}, DmaDir.MM2S, 0), {params['ShimActBufferIdx']}, '{params['act_shim_memory']}', '{params['act_shim_mm2s'][col][itr]}', {params['dims'].act_bits})]\n"
        return code

    def shim_sin_pattern(self, params, col, itr):
        if params['SinCosIdx'] == 1:
            shimTransferOffset = (params["ShimActSizeIfm"] * params['dims'].aie_cols) #size of A matrix
        else:
            shimTransferOffset = 64 #size of qdq size : ToDO change to params dict instead of hardcoding
        code = f"    shim_transfers += [generate_shim_data_transfer([0, {params['ShimActRepeatIfm'][0]}, 0], shim_dma({col}, DmaDir.MM2S, 0), {params['SinCosIdx']}, '{params['sin_shim_memory']}', '{params['sin_shim_mm2s'][col][itr]}', {params['dims'].act_bits}, buffer_offset={shimTransferOffset})]\n"
        return code

    def shim_cos_pattern(self, params, col, itr):
        if params['SinCosIdx'] == 1:
            shimTransferOffset = (params["ShimActSizeIfm"] * params['dims'].aie_cols)  + (params["ShimSinSize"] * params['dims'].aie_cols)  # size of mat A + mat sin size
        else:
            shimTransferOffset = 64 + (params["ShimSinSize"] * params['dims'].aie_cols) # change it to size of qdq + sin size : : ToDO change to params dict instead of hardcoding
        code = f"    shim_transfers += [generate_shim_data_transfer([0, 0, {params['ShimActRepeatIfm'][0]}], shim_dma({col}, DmaDir.MM2S, 0), {params['SinCosIdx']}, '{params['cos_shim_memory']}', '{params['cos_shim_mm2s'][col][itr]}', {params['dims'].act_bits}, buffer_offset={shimTransferOffset})]\n"
        return code


    def shim_ofm_pattern(self, params, col):
        code = f"    shim_transfers += [generate_shim_data_transfer([0, 0, {params['ShimOutRepeat'][0]}], shim_dma({col}, DmaDir.S2MM, 0), {params['ShimOutBufferIdx']},\
 '{params['out_shim_memory']}', '{params['out_shim_s2mm'][col]}',{params['dims'].out_bits})]\n"
        return code


    def gen_shim_instr(self, params):
        shim_transfers = "    #SHIM FLOW DEFINITION\n"
        shim_transfers += "    shim_transfers = []\n"
        shim_transfers += "    #SHIM PARAM TRANSFERS\n"
        params['ShimParamRepeat'] = [params['ShimParamRepeat'][0], 0, 0]
        shim_transfers += "".join(
            [
                self.shim_prm_pattern(params, col)
                for col in range(params["dims"].aie_cols)
            ]
        )
        shim_transfers += "    #SHIM QDQ PARAM TRANSFERS\n"
        params['ShimQdqPrmRepeat'] = [params['ShimQdqPrmRepeat'][0], 0, 0]
        shim_transfers += "".join(
            self.shim_qdq_pattern(params)
        )  # for col in range(params['dims'].aie_cols)])
        shim_transfers += (
            f"    #SHIM ACTIVATION TRANSFERS IFM, {params['ActDataFlow'].upper()}\n"
        )
        shim_transfers += "".join(
            [
                self.shim_act_pattern(params, col, itr)
                # for itr in range(params['ShimActOuterLoop'])
                for itr in range(1)
                for col in range(params["dims"].aie_cols)
            ]
        )
        shim_transfers += (
            f"    #SHIM SIN TRANSFERS IFM, {params['ActDataFlow'].upper()}\n"
        )
        shim_transfers += "".join(
            [
                self.shim_sin_pattern(params, col, itr)
                # for itr in range(params['ShimActOuterLoop'])
                for itr in range(1)
                for col in range(params["dims"].aie_cols)
            ]
        )
        shim_transfers += (
            f"    #SHIM COS TRANSFERS IFM, {params['ActDataFlow'].upper()}\n"
        )
        shim_transfers += "".join(
            [
                self.shim_cos_pattern(params, col, itr)
                # for itr in range(params['ShimActOuterLoop'])
                for itr in range(1)
                for col in range(params["dims"].aie_cols)
            ]
        )
        shim_transfers += "    #SHIM OFM TRANSFERS\n"
        shim_transfers += "".join(
            [
                self.shim_ofm_pattern(params, col)
                for col in range(params["dims"].aie_cols)
            ]
        )
        return shim_transfers

    def dma_pattern_code(self, params):
        core_transfers = self.gen_core_instr(params)
        memtile_transfers = self.gen_memtile_instr(params)
        shim_transfers = self.gen_shim_instr(params)
        return core_transfers + memtile_transfers + shim_transfers

    class helper_func:
        def col_index(self, dims: GemmDims, array: int, col: int) -> int:
            assert 0 <= col < dims.aie_cols
            assert 0 <= array < dims.aie_arrays
            return (array * dims.aie_cols) + col

        def reuse_chain_length(
            self, reuse_ratio: int, num_consumers: int, max_chain_length: int = 4
        ) -> int:
            max_lock_value = 63
            for i in range(1, max_chain_length + 1):
                is_valid = ((reuse_ratio % i) == 0) and (
                    ((reuse_ratio // i) * num_consumers) <= max_lock_value
                )
                if is_valid:
                    return i
            raise RuntimeError("Failed to allocate reuse chain!")

        def help_info(self):
            code = """
--------------------------------------------------------
| QDQ Add    |             |            IFM           | 
-------------------------------------------------------
|            |             | Pin    | Stream | Full   |
| No Padding | No PingPong | No     | Yes    | No     |
|            | Ping Pong   | No     | Yes    | No     |
| M Padding  | No PingPong | No     | No     | No     |
|            | Ping Pong   | No     | No     | No     |
| K Padding  | No PingPong | No     | No     | No     |
|            | Ping Pong   | No     | No     | No     |
-------------------------------------------------------
"""
    
    class params(helper_func):
        def gen_dma_params(self, _pipeline_data):
            def update_buffer_alloc_to_params(params, buff):
                buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
                params["ParamSize"] = const.PARAM_SIZE
                params["CoreBankSize"] = const.CORE_BANK_SIZE
                params["ShimOutBufferIdx"] = 0
                params["ShimActBufferIdx"] = 1
                params["ShimWgtBufferIdx"] = 2
                params["ShimPrmBufferIdx"] = 3
                params.update(buff[const.BufAllocator_Idx.CORE_TILE_ADDR_IDX.value])
                params.update(buff[const.BufAllocator_Idx.CORE_TILE_SIZE_IDX.value])
                params.update(buff[const.BufAllocator_Idx.MEM_TILE_ADDR_IDX.value])
                params["ping_pong_enable"] = buff_prm["sch_attr"].ping_pong_enable
                params["ActDataFlow"] = buff_prm["sch_attr"].dataflow_mode["ifm"]
                params["ActPingPong"] = buff_prm["sch_attr"].ping_pong_enable["ifm"]
                params["OfmPingPong"] = buff_prm["sch_attr"].ping_pong_enable["ofm"]
                params["ShimActSizeIfm"] = buff_prm["ifm_shim_tile_size"]
                params["ShimSinSize"] = buff_prm["sin_shim_tile_size"]
                params["ShimCosSize"] = buff_prm["cos_shim_tile_size"]
                params["ShimOutSize"] = buff_prm["ofm_shim_tile_size"]
                params["MemtileActSizeIfm"] = buff_prm["ifm_mem_tile_size"]
                params["MemtileSinSize"] = buff_prm["sin_mem_tile_size"]
                params["MemtileCosSize"] = buff_prm["cos_mem_tile_size"]
                params["MemtileOutSize"] = buff_prm["ofm_mem_tile_size"]
                params["ActDtype"] = buff_prm["ifm"].dtype
                params["SinDtype"] = buff_prm["sin"].dtype
                params["CosDtype"] = buff_prm["cos"].dtype
                params["OfmDtype"] = buff_prm["ofm"].dtype
                params["ActFormat"] = decode_data_format(buff_prm["ifm"].dtype)
                params["OfmFormat"] = decode_data_format(buff_prm["ofm"].dtype)
                params["min_cols"] = 8
                params["min_rows"] = 8
                params["op_name"] = buff_prm["orig_op_type"]
                params["op_ver"] = buff_prm["op_type"]
                params["param_ptr"] = {
                    "orig_op_type": buff_prm["orig_op_type"],
                    "op_type": buff_prm["op_type"],
                }
                params["template_meta_data"] = {
                    "op_type": buff_prm["orig_op_type"],
                    "mode": buff_prm["mode"],
                    "overlay": buff_prm["overlay"],
                }
                params["ioinfo"] = buff_prm["ioinfo"]
                params["enable_padding"] = False
                params["enable_pack_shim"] = True
                params["enable_packtransfer"] = False
                params['ShimParamMode']     = 'unicast'
                params['ShimQdqMode']       = 'broadcast'
                params['ShimParamChannelId'] = 1
                params['param_channel_id']   = 0
                params['ShimQdqChannelId']   = 0
                params['CoreQdqChId']        = 1
                
            def decode_data_format(dtype):
                return "default"

            def calc_rep_params(params):
                def format_Prm_Qdq(
                    params, MemPrmRep=[1], MemQdqRep=[1], ShimPrmRep=[1], ShimQdqRep=[1]
                ):
                    params["MemtileParamRepeat"] = MemPrmRep
                    params["MemtileQdqPrmRepeat"] = MemQdqRep
                    params["ShimParamRepeat"] = ShimPrmRep
                    params["ShimQdqPrmRepeat"] = ShimQdqRep

                def format_Act_repeat(
                    params,
                    MemActPingRep,
                    MemActPongRep,
                    MemActRR,
                    ShimActOLoop,
                    ShimActRep,
                    ShimActRR=1,
                ):
                    
                    params["MemtileActRepeatIfm"] = MemActPingRep
                    params["ShimActRepeatIfm"] = ShimActRep
                    params["MemtileActReuseRatio"] = MemActRR
                    params["ShimActOuterLoop"] = ShimActOLoop
                    params["ShimActReuseRatio"] = ShimActRR
                    utils.sanity_check(
                        (sum(ShimActRep) <= 64),
                        f"Max Shim BD repeat is 64 (ShimActRep = {ShimActRep})",
                        "Message",
                    )

                Tm, Tn, Tk = params["dims"].Tm, 1, 1

                format_Prm_Qdq(params)

                # activation flow configuration
                if params["ActDataFlow"] == "pin":
                    raise Exception("Activation Pinning mode is not supported")
                if params["ActDataFlow"] == "full":
                    raise Exception("Only Activation Full mode is not supported")
                if params["ActDataFlow"] == "stream" and params["ActPingPong"]:
                    repeat_ping = ([1, 0] * ((Tm * Tk * Tn) // 2)) + (
                        [1] * ((Tm * Tk * Tn) % 2)
                    )
                    repeat_pong = ([0, 1] * ((Tm * Tk * Tn) // 2)) + (
                        [0] * ((Tm * Tk * Tn) % 2)
                    )
                    format_Act_repeat(params, [Tm], [Tm], 1, Tm, [Tn])
                if params["ActDataFlow"] == "stream" and not params["ActPingPong"]:
                    format_Act_repeat(params, [Tm * Tk * Tn], None, 1, 1, [1])

                # ofm flow configuration
                params["MemtileOutRepeat"] = [Tn * Tm]
                params["ShimOutRepeat"] = [1]
                utils.sanity_check(
                    (sum(params["ShimOutRepeat"]) <= 64),
                    f"Max Shim BD repeat is 64 (ShimOutRepeat = {params['ShimOutRepeat']})",
                    "Message",
                )

                # pad repeat list
                max_len = 0
                for _item in params:
                    if (
                        _item[-6:] == "Repeat"
                        and params[_item] != None
                        and max_len < len(params[_item])
                    ):
                        max_len = len(params[_item])
                for _item in params:
                    if (
                        _item[-6:] == "Repeat"
                        and params[_item] != None
                        and len(params[_item]) < max_len
                    ):
                        params[_item] += [0] * (max_len - len(params[_item]))

            buff = _pipeline_data.info.get("BuffAllocator")
            buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
            split = [int(x) for x in re.findall(r'\d+', str(buff_prm['mode']))]
            base_Data = (
                split,
                [1],
                [
                    buff_prm["M"] + sum(buff_prm["padding"][0]["pad_ifm_x"]),
                ],
                [buff_prm["K"]],
                buff_prm["N"],
                [buff_prm["M_subV"]],
                [buff_prm["K_subV"]],
                buff_prm["N_subV"],
                buff_prm["aie_rows"],
                buff_prm["aie_cols"],
                buff_prm["aie_arrays"],
                buff_prm["ifm_bits"],
                buff_prm["ofm_bits"],
                buff_prm["qdq_bytes"],
                buff_prm["outer_loop"],
                buff_prm["inner_loop"],
                buff_prm["acc_loop"],
                buff_prm["wgt_subv_rows"],
                buff_prm["wgt_subv_cols"],
                buff_prm["ifm_core_subv_bytes"],
                buff_prm["wgt_core_subv_bytes"],
                buff_prm["ofm_core_subv_bytes"],
                buff_prm["sum_core_subv_bytes"],
                buff_prm["tdm_core_subv_bytes"],
                buff_prm["wgt_bits"],
                const.BITS_PER_BYTE,
                buff_prm["bias_bits"],
                buff_prm["tdm_bits"],
            )
            op_Data = (
                sum(buff_prm["padding"][0]["pad_ifm_x"]),
                sum(buff_prm["padding"][0]["pad_ifm_y"]),
            )
            data = base_Data + op_Data
            dims = GemmDims(*data)
            params = {"dims": dims}
            update_buffer_alloc_to_params(params, buff)
            calc_rep_params(params)
            self.gen_dma_transfers(params)
            return params

        def gen_dma_transfers(self, params):
            def decompose(X):
                for pwr in range(math.ceil(math.log(X, 2))):
                    N = 2**pwr
                    M = X // N
                    Prod = M * N
                    if Prod == X and M < 128 and N < 128 and (N % 32 == 0):
                        return M, N
                raise RuntimeError(
                    "Unable to find M and N value to recreate needed subvolume size"
                )

            def getByteCount(dtype):
                if dtype in ["uint16", "int16"]:
                    return 2
                elif dtype in ["uint32", "int32"]:
                    return 4
                elif dtype in ["uint8", "int8"]:
                    return 1

            dims = params["dims"]
            params["reuse_chain_length"] = 1
            params["sync_strategy"] = "SyncStrategy.Parallel_1_to_N"  #'Default'

            shimActsize = params["ShimActSizeIfm"] // getByteCount(params["ActDtype"])
            shimSinsize = params["ShimSinSize"] // getByteCount(params["SinDtype"])
            shimCossize = params["ShimCosSize"] // getByteCount(params["CosDtype"])
            shimOutsize = params["ShimOutSize"] // getByteCount(params["OfmDtype"])

            memtileAct = params["MemtileActSizeIfm"] // getByteCount(params["ActDtype"])
            memtileSin = params["MemtileSinSize"] // getByteCount(params["SinDtype"])
            memtileCos = params["MemtileCosSize"] // getByteCount(params["CosDtype"])
            memtileOut = params["MemtileOutSize"] // getByteCount(params["OfmDtype"])
            
            coretileAct = params["CoreIfmSize"] // getByteCount(params["ActDtype"])
            coretileSin = params["CoreSinSize"] // getByteCount(params["SinDtype"])
            coretileCos = params["CoreCosSize"] // getByteCount(params["CosDtype"])
            coretileOut = params["CoreOfmSize"] // getByteCount(params["OfmDtype"])


            if params["op_ver"] == "RoPE_qdq_uint16xuint16":
                params["SinCosIdx"] = params["ShimWgtBufferIdx"]
            else:
                params["SinCosIdx"] = params["ShimActBufferIdx"]
            fused_op_flag = 0
            #if params["op_ver"].split("_")[2] == "EleWise":
            M, N = decompose(dims.M_subv[0])
            params["gemm_params"] = [
                f"rope_layer_params(0, {M}, {N}, {params['CoreQdqAddr'][0]}, {params['CoreTdm1Addr'][0]}, {params['CoreTdm2Addr'][0]}, {fused_op_flag})"
            ] * 4

            params["act_shim_memory"] = f"W:{shimActsize * dims.aie_cols}"
            params["act_shim_mm2s"] = [
                [
                    (
                        f"W:0:{shimActsize * dims.aie_cols}:{memtileAct * dims.aie_cols} "
                        if memtileAct != shimActsize
                        else ""
                    ) + f"W:{col*memtileAct}:{(col+1)*memtileAct}"
                    for itr in range(params["ShimActOuterLoop"])
                ]
                for col in range(dims.aie_cols)
            ]

            params["sin_shim_memory"] = f"W:{shimSinsize * dims.aie_cols}"
            params["sin_shim_mm2s"] = [
                [
                    (
                        f"W:0:{shimSinsize * dims.aie_cols}:{memtileSin * dims.aie_cols} "
                        if memtileSin != shimSinsize
                        else ""
                    ) + f"W:{col*memtileSin}:{(col+1)*memtileSin}"
                    for itr in range(params["ShimActOuterLoop"])
                ]
                for col in range(dims.aie_cols)
            ]

            params["cos_shim_memory"] = f"W:{shimCossize * dims.aie_cols}"
            params["cos_shim_mm2s"] = [
                [
                    (
                        f"W:0:{shimCossize * dims.aie_cols}:{memtileCos * dims.aie_cols} "
                        if memtileCos != shimCossize
                        else ""
                    ) + f"W:{col*memtileCos}:{(col+1)*memtileCos}"
                    for itr in range(params["ShimActOuterLoop"])
                ]
                for col in range(dims.aie_cols)
            ]

            params["act_memtile_memory"] = f"W:{memtileAct}"
            params["act_memtile_s2mm"] = f"W:0:{memtileAct}"
            params["act_memtile_mm2s"] = [
                f"W:{row*coretileAct}:{(row+1)*coretileAct}"
                for row in range(dims.aie_rows)
            ]

            params["sin_memtile_memory"] = f"W:{memtileSin}"
            params["sin_memtile_s2mm"] = f"W:0:{memtileSin}"
            params["sin_memtile_mm2s"] = [
                f"W:{row*coretileSin}:{(row+1)*coretileSin}"
                for row in range(dims.aie_rows)
            ]

            params["cos_memtile_memory"] = f"W:{memtileCos}"
            params["cos_memtile_s2mm"] = f"W:0:{memtileCos}"
            params["cos_memtile_mm2s"] = [
                f"W:{row*coretileCos}:{(row+1)*coretileCos}"
                for row in range(dims.aie_rows)
            ]
            
            params["out_memtile_memory"] = f"W:{memtileOut}"
            params["out_memtile_s2mm"] = [
                f"W:{row*coretileOut}:{(row+1)*coretileOut}"
                for row in range(dims.aie_rows)
            ]
            params["out_memtile_mm2s"] = f"W:0:{memtileOut}"

            params["out_shim_memory"] = f"W:{shimOutsize  * dims.aie_cols}"
            params["out_shim_s2mm"] = [
                (
                    f"W:0:{shimOutsize * dims.aie_cols}:{memtileOut * dims.aie_cols} "
                    if memtileOut != shimOutsize
                    else "" 
                ) + f"W:{col*memtileOut}:{(col+1)*memtileOut}"
                for col in range(dims.aie_cols)
            ]


class M1N16(RoPE_Base):
    def __init__(self, _data=None):
        super().__init__()
        self.data = _data
        self.prm = self.params()


class M1N32(RoPE_Base):
    def __init__(self, _data=None):
        super().__init__()
        self.data = _data
        self.prm = self.params()
