import sys
import os
import pdb
#infra_path = (os.path.dirname(os.path.abspath(__file__))+"infra/")
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/../infra/")
sys.path.append(infra_path)
import logging
import math
import textwrap
from enum import Enum

import const
import scheduler_utils as utils
from template_base import BaseTemplate, BaseDims
from gen_kernel_param import return_kernel_path, gen_blob

class M4N4(BaseTemplate):
    def __init__(self, _op_code, _ver, data):
        super().__init__()
        self.op_code = _op_code
        self.ver = _ver
        self.data = data
        if self.op_code == "PWLA":
            self.num_tensor = 1
            self.var = "Lut"
        elif _op_code in ["MatAdd"]:
            utils.sanity_check(False,"Not tested data flow. Check again")
            self.num_tensor = 2
            self.var = "QdqPrm"
            self.kernel_name = ['run_a16a16_matadd_qdq']
            self.kernel_inc = [
                'super.hh',
                'qdq/wrapper_qdq.cc',
                'matadd/matadd_32x64_kernel.c',
            ]

    def gen_code(self,op_params):
        logging.info(f"Genrate code for {self.op_code} operation")
        data_flow = ""
        data_flow += self.gen_op_header(op_params['overlay'])
        data_flow += self.params_instruction_gelu()
        data_flow += self.gen_dataflow(op_params)
        data_flow += self.gen_main_func()

        return data_flow

    def gen_op_header(self, overlay):
        code = self.gen_headers(overlay)  #Base Class Impl
        return code

    def gen_dataflow(self, op_params):
        code = f"def generate_dataflow(back_end: BackEnd):"
        code += self.gen_print_attr(op_params)
        code += self.gen_core_instr(op_params)
        code += self.gen_memtile_instr(op_params)
        code += self.gen_shimtile_instr(op_params)
        code += self.gen_kernel_name("PWLA", self.ver)
        code += self.gen_connection_run_compile(op_params['overlay'], op_params['aie_cols'], op_params['aie_rows'], op_params['CoreStackAddr'])
        return code

    def gen_silu_params(self, n):
        return (n).to_bytes(length=1, byteorder='little', signed=False)

    def gen_params(self, core_act_addr: int, lutad_addr: int, lutcd_addr: int,
                                subv_rows: int, subv_cols: int) -> bytes:
        input_arg_list = [core_act_addr, lutad_addr, lutcd_addr, subv_rows, subv_cols, 0, 0] #fused ops and scratch addr not needed in standalone LUT OPs
        print(f"input_arg_list :{input_arg_list}")
        kernel_params = gen_blob(input_arg_list, "PWLA", self.ver)
        return (kernel_params)

    def params_instruction_gelu(self):
        return """
def gelu_qdq_params(qdq_prm_addr: int, lutad_addr: int, lutcd_addr: int, num_elements: int, tdm1_addr: int, tdm2_addr: int, fused_op: int, is_in_int16: int, is_out_int16: int):
    return ( qdq_prm_addr.to_bytes(length=2, byteorder='little', signed=False)
           + lutad_addr.to_bytes(length=2, byteorder='little', signed=False)
           + lutcd_addr.to_bytes(length=2, byteorder='little', signed=False)
           + num_elements.to_bytes(length=2, byteorder='little', signed=False)
           + tdm1_addr.to_bytes(length=2, byteorder='little', signed=False)
           + tdm2_addr.to_bytes(length=2, byteorder='little', signed=False)
           + fused_op.to_bytes(length=2, byteorder='little', signed=False)
           + is_in_int16.to_bytes(length=2, byteorder='little', signed=False)
           + is_out_int16.to_bytes(length=2, byteorder='little', signed=False)
    )
"""

    def gen_print_attr(self, op_params):
        return f"""
    #Print OPs params
    #Mgemm        = {op_params['Mgemm']}
    #Ngemm        = {op_params['Ngemm']}
    #Msubv        = {op_params['Msubv']}
    #Nsubv        = {op_params['Nsubv']}
    #ifm_byte_len = {op_params['ifm_byte_len']}
    #ofm_byte_len = {op_params['ofm_byte_len']}
    #backend_type = {op_params['backend_type']} """

    def gen_core_instr(self, op_params):
        meta_kernel_list = list((return_kernel_path("PWLA",self.ver)[0]).keys())
        lutab_Addr = 0 # initializing the var to resolve pylint
        lutcd_Addr = 0
        if self.op_code == "PWLA":
            lutab_Addr = op_params['CoreLutabAddr'][0]
            lutcd_Addr = op_params['CoreLutcdAddr'][0]
        else: 
            utils.sanity_check(False,f"Incorrect opcode: {self.op_code}")

        return f"""
    #Core Instr
    #CoreActPingAddr = {op_params['CoreActPingAddr']}
    #CoreActPongAddr = {op_params['CoreActPongAddr']}
    #{f'Core{self.var}PingAddr'} = {op_params['CoreLutabAddr'][0]}
    #lutab_Addr      = {lutab_Addr}
    #lutcd_Addr      = {lutcd_Addr}
    #CoreQdqAddr     = {op_params['CoreQdqAddr']}
    #CoreTdmAddr0    = {op_params['CoreTdmAddr'][0]}
    #CoreTdmAddr1    = {op_params['CoreTdmAddr'][1]}
    #CoreOutPingAddr = {op_params['CoreOutPingAddr']}
    #CoreQdqSize     = {op_params['CoreQdqSize']}
    #CoreActSize     = {op_params['CoreActSize']}
    #{f'Core{self.var}Size'}     = {op_params['CoreLutabSize']+op_params['CoreLutcdSize']}
    #CoreOutSize     = {op_params['CoreOutSize']}
    Core_Tn          = {op_params['Core_Tn'] if op_params['aie_cols'] == len(op_params['active_col']) else op_params['Core_Tn'] * op_params['Core_Tn']}
    #num_elements    = {op_params['Msubv'] * op_params['Nsubv']}
    core_instrs = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), {op_params['CoreActPingAddr']}, {op_params['CoreActPongAddr']}, {op_params['CoreActSize']}),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), {op_params['CoreLutabAddr'][0]}, None, {op_params['CoreLutabSize']+op_params['CoreLutcdSize']+op_params['CoreQdqSize']}),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), {op_params['CoreOutPingAddr']}, None, {op_params['CoreOutSize']}),
        Loop(Core_Tn, [
            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
            CallKernel('{meta_kernel_list[0]}', kernel_params=gelu_qdq_params({op_params['CoreQdqAddr']}, {lutab_Addr}, {lutcd_Addr}, {op_params['Msubv'] * op_params['Nsubv']}, {op_params['CoreTdmAddr'][0]}, {op_params['CoreTdmAddr'][1]}, 0, {int(op_params['ifm_byte_len'] == 2)}, {int(op_params['ofm_byte_len'] == 2)})),
            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        ]),
    ]"""
    
    def gen_mem_tiling_expr(self, tiling_expr, num_cols, enable_col_list, num_rows, size):
        #Note - IT IS TO BE USED ONLY FOR UNICAST CHANNEL
        tiling_expr_list = []
        start_pos = tiling_expr.find("W:")
        end_pos = tiling_expr.find("H:")
        for col_id in range(num_cols):
           if col_id in enable_col_list: 
              new_expr = f"H:0:{num_rows} B:0:{size}"
           else:
              new_expr = f"H:{num_rows}:{num_rows} B:{size}:{size}"
           tiling_expr_list.append(f"{new_expr}")
        return tiling_expr_list


    def gen_memtile_instr(self, op_params):
        ifm_shard_size = op_params['CoreActSize']
        ofm_shard_size = op_params['CoreOutSize']
        #Prm
        s2mm_param_mem_fmt, s2mm_param_tile_fmt = f"\'Byte:{op_params['CorePrmSize']}\'", f"\'Byte:0:{op_params['CorePrmSize']}\'"
        mm2s_param_mem_fmt, mm2s_param_tile_fmt = s2mm_param_mem_fmt, s2mm_param_tile_fmt
        #LUT/QDQ Prm
        str_var = f"{op_params['CoreLutSize']+op_params['CoreQdqSize']}"
        s2mm_lut_qdq_mem_fmt, s2mm_lut_qdq_tile_fmt = f"\'Byte:{str_var}\'", f"\'Byte:0:{str_var}\'"
        mm2s_lut_qdq_mem_fmt, mm2s_lut_qdq_tile_fmt = s2mm_lut_qdq_mem_fmt, s2mm_lut_qdq_tile_fmt 
        #Act
        act_mem_fmt       = f"H:{op_params['aie_rows']} B:{ifm_shard_size}"
        s2mm_act_tile_fmt = f"H:0:{op_params['aie_rows']} B:0:{ifm_shard_size}"
        mm2s_act_tile_fmt = f"H:{{row}}:{{row+1}} B:0:{ifm_shard_size}"
        #Out
        out_mem_fmt       = f"H:{op_params['aie_rows']} B:{ofm_shard_size}"
        mm2s_out_tile_fmt = f"H:0:{op_params['aie_rows']} B:0:{ofm_shard_size}"
        s2mm_out_tile_fmt = f"H:{{row}}:{{row+1}} B:0:{ofm_shard_size}"

        if op_params['MemtileActPongAddr'] is not None:
            act_addr = [op_params['MemtileActPingAddr'], op_params['MemtileActPongAddr']]
        else:
            act_addr = [op_params['MemtileActPingAddr']]

        if op_params['MemtileOutPongAddr'] is not None:
            out_addr = [op_params['MemtileOutPingAddr'], op_params['MemtileOutPongAddr']]
        else:
            out_addr = [op_params['MemtileOutPingAddr']]

        broadcast_itr_expn = 'AieCols' if op_params['overlay'] == '4x4' else '0,AieCols,2'


        return f"""
    #Memtile Instr
    AieRows            = {op_params['aie_rows']}
    AieCols            = {op_params['aie_cols']}
    unicast_itr        = {op_params['Mem_Tn'] if op_params['aie_cols'] == len(op_params['active_col']) else op_params['Mem_Tn'] * op_params['Mem_Tn']}
    mem_act_tiling_expr = {self.gen_mem_tiling_expr(s2mm_act_tile_fmt, op_params['aie_cols'], op_params['active_col'], op_params['aie_rows'], ifm_shard_size)}
    mem_out_tiling_expr = {self.gen_mem_tiling_expr(mm2s_out_tile_fmt, op_params['aie_cols'], op_params['active_col'], op_params['aie_rows'], ofm_shard_size)}
    #Prm
    #MemtilePrmPingAddr = {op_params['MemtilePrmPingAddr']}
    #CorePrmSize        = {op_params['CorePrmSize']}
    #LUT
    #MemtileQdqAddr     = {op_params.get('MemtileQdqAddr')}
    #{f'Memtile{self.var}Addr'}     = {op_params['MemtileLutAddr']}
    #CoreLutSize        = {op_params['CoreLutSize']}
    #CoreQdqSize        = {op_params['CoreQdqSize']}
    #Act
    #Mem_Tn             = [{op_params['Mem_Tn']}]
    #MemtileActPingAddr = {op_params['MemtileActPingAddr']}
    #MemtileActPongAddr = {op_params['MemtileActPongAddr']}
    #MemtileActSize     = {op_params['MemtileActSize']}
    #Out
    #MemtileOutSize     = {op_params['MemtileOutSize']}
    #MemtileOutPingAddr = {op_params['MemtileOutPingAddr']}
    #MemtileOutPongAddr = {op_params['MemtileOutPongAddr']}
    
    memtile_transfers = [
    DataTransfer(
            [1],
            AieTile(TileType.Memtile, col, 0), [{op_params['MemtilePrmPingAddr']}], {op_params['CorePrmSize']},
            [generate_transfer_params(memtile_dma(col, DmaDir.S2MM, 1), {s2mm_param_mem_fmt}, {s2mm_param_tile_fmt}, 8)],
            [generate_transfer_params(memtile_dma(col, DmaDir.MM2S, row), {mm2s_param_mem_fmt}, {mm2s_param_tile_fmt}, 8)
        for row in range(AieRows)],
            sync_strategy=SyncStrategy.Parallel_1_to_N
    ) for col in range(AieCols)
    ] + [
    DataTransfer(
            [1],
            AieTile(TileType.Memtile, col, 0), [{op_params['MemtileLutAddr']}], {op_params['CoreLutSize']+op_params['CoreQdqSize']},
            [generate_transfer_params(memtile_dma(col, DmaDir.S2MM, 0), {s2mm_lut_qdq_mem_fmt}, {s2mm_lut_qdq_tile_fmt}, 8)],
            [generate_transfer_params(memtile_dma(col, DmaDir.MM2S, 4), {mm2s_lut_qdq_mem_fmt}, {mm2s_lut_qdq_tile_fmt}, 8)]
    ) for col in range({broadcast_itr_expn})
    ] + [
    DataTransfer(
                [unicast_itr],
                AieTile(TileType.Memtile, col, 0), {act_addr}, {op_params['MemtileActSize']},
                [generate_transfer_params(memtile_dma(col, DmaDir.S2MM, 1), f'{act_mem_fmt}', mem_act_tiling_expr[col], 8)],
                [generate_transfer_params(memtile_dma(col, DmaDir.MM2S, row), f'{act_mem_fmt}', f'{mm2s_act_tile_fmt}', 8)
            for row in range(AieRows)],
        ) for col in range(AieCols)
    ] + [
    DataTransfer(
                [unicast_itr],
                AieTile(TileType.Memtile, col, 0), {out_addr}, {op_params['MemtileOutSize']},
                [generate_transfer_params(memtile_dma(col, DmaDir.S2MM, 2 + row), f'{out_mem_fmt}', f'{s2mm_out_tile_fmt}', 8)
            for row in range(AieRows)],
                [generate_transfer_params(memtile_dma(col, DmaDir.MM2S, 5), f'{out_mem_fmt}', mem_out_tiling_expr[col], 8)],
            ) for col in range(AieCols)
    ]"""

    def gen_shim_act_transfer_param(self,op_params):
        mm2s_act_mem_fmt, mm2s_act_tile_fmt = 'f\'H:1 W:{ShimActShardSize*AieCols}\'', 'f\'H:0:1 W:{col*ShimActShardSize}:{col*ShimActShardSize+ShimActShardSize}\''
        mm2s_act_mem_fmt2, mm2s_act_tile_fmt2 = 'f\'H:1 W:{ShimActShardSize*AieCols}\'', 'f\'H:0:1 W:{col*ShimActShardSize}:{col*ShimActShardSize+ShimActShardSize}\''
        if self.op_code == "PWLA":
            return f'[generate_transfer_params(shim_dma(col, DmaDir.MM2S, 1), {mm2s_act_mem_fmt}, {mm2s_act_tile_fmt}, 8)]'
        elif self.op_code in ["MatAdd"]:
            return f'[generate_transfer_params(shim_dma(col, DmaDir.MM2S, 1), {mm2s_act_mem_fmt}, {mm2s_act_tile_fmt}, 8),\
                            generate_transfer_params(shim_dma(col, DmaDir.MM2S, 1), {mm2s_act_mem_fmt2}, {mm2s_act_tile_fmt2}, 8)]'


    def gen_shim_tiling_expr(self, tiling_expr, num_cols, enable_col_list):
        #Note - IT IS TO BE USED ONLY FOR UNICAST CHANNEL
        tiling_expr_list = []
        start_pos = tiling_expr.find("W:")
        end_pos = tiling_expr.find("H:")
        for col_id in range(num_cols):
           if col_id in enable_col_list: 
              new_expr = f"W:{col_id}:{col_id+1} "
           else:
              new_expr = f"W:{num_cols}:{col_id+1} "
           tiling_expr_list.append(f"{tiling_expr[:start_pos]}{new_expr}{tiling_expr[end_pos:]}")
        return tiling_expr_list

    def gen_shimtile_instr(self, op_params):
        ifm_shard_size = op_params['CoreActSize'] if op_params['aie_cols'] == len(op_params['active_col']) else op_params['CoreActSize'] * op_params['Mem_Tn']
        ofm_shard_size = op_params['CoreOutSize'] if op_params['aie_cols'] == len(op_params['active_col']) else op_params['CoreOutSize'] * op_params['Mem_Tn']
        #Prm
        param_mem_fmt  = f"\'Byte:{op_params['CorePrmSize']}\'"
        mm2s_param_tile_fmt = f"\'Byte:0:{op_params['CorePrmSize']}\'"
        #LUT/QDQ Prm
        str_var = f"{op_params['CoreLutSize']+op_params['CoreQdqSize']}"
        lut_qdq_mem_fmt       = f"\'Byte:{str_var}\'"
        mm2s_lut_qdq_tile_fmt = f"\'Byte:0:{str_var}\'"
        #Act
        act_mem_fmt  = f"R:{op_params['Core_Tn']} W:{op_params['aie_cols']} H:{op_params['aie_rows']} B:{ifm_shard_size}"
        mm2s_act_tile_fmt = f"R:0:{op_params['Core_Tn']} W:{{col}}:{{col+1}} H:0:{op_params['aie_rows']} B:0:{ifm_shard_size}"
        #Out
        out_mem_fmt  = f"R:{op_params['Core_Tn']} W:{op_params['aie_cols']} H:{op_params['aie_rows']} B:{ofm_shard_size}"
        s2mm_out_tile_fmt = f"R:0:{op_params['Core_Tn']} W:{{col}}:{{col+1}} H:0:{op_params['aie_rows']} B:0:{ofm_shard_size}"

        broadcast_itr_expn = 'AieCols' if op_params['overlay'] == '4x4' else '0,AieCols,2'

        ifm_transfers = [f"generate_transfer_params(shim_dma(col, DmaDir.MM2S, 1), f'{act_mem_fmt}', f'{mm2s_act_tile_fmt}', 8)"]

        return f"""
    #Shimtile Instr
    AieRows             = {op_params['aie_rows']}
    AieCols             = {op_params['aie_cols']}
    shim_act_tiling_expr = {self.gen_shim_tiling_expr(mm2s_act_tile_fmt, op_params['aie_cols'], op_params['active_col'])}
    shim_out_tiling_expr = {self.gen_shim_tiling_expr(s2mm_out_tile_fmt, op_params['aie_cols'], op_params['active_col'])}
    #CorePrmSize        = {op_params['CorePrmSize']}
    #CoreLutSize        = {op_params['CoreLutSize']}
    #CoreQdqSize        = {op_params['CoreQdqSize']}
    #ShimActSize        = {op_params['ShimActSize']}
    #ShimOutSize        = {op_params['ShimOutSize']}

    shim_transfers = [
        DataTransfer([1], AieTile(TileType.Shim, col), [3], {op_params['CorePrmSize']}, [],
            [generate_transfer_params(shim_dma(col, DmaDir.MM2S, 1), {param_mem_fmt}, {mm2s_param_tile_fmt}, 8)]
        ) for col in range(AieCols)
    ] + [
        DataTransfer([1], AieTile(TileType.Shim, col), [2], {str_var},[],
            [generate_transfer_params(shim_dma(col, DmaDir.MM2S, 0), {lut_qdq_mem_fmt}, {mm2s_lut_qdq_tile_fmt}, 8)]
        ) for col in range({broadcast_itr_expn})
    ] + [
        DataTransfer([1], AieTile(TileType.Shim, col), [1], {op_params['ShimActSize']}, [],
            [generate_transfer_params(shim_dma(col, DmaDir.MM2S, 1), f'{act_mem_fmt}', shim_act_tiling_expr[col], 8)],
        ) for col in range(AieCols)
    ] + [
        DataTransfer([1], AieTile(TileType.Shim, col), [0], {op_params['ShimOutSize']},
            [generate_transfer_params(shim_dma(col, DmaDir.S2MM, 0), f'{out_mem_fmt}', shim_out_tiling_expr[col], 8)],
            []
        ) for col in range(AieCols)
    ]"""

    def gen_dataflow_footer(self, overlay, col, row, core_stack_addr):
        if self.op_code != "PWLA":
            utils.sanity_check(False,"Unsupported OP")
        
        #SILU and GELU both are defined in same yaml file
        _list, _include = return_kernel_path("PWLA",self.ver)
        compile_code = self.gen_connection_run_compile(overlay, col, row, core_stack_addr)
        cleaned_lines = "\n".join([line.strip() for line in compile_code.splitlines()])
        indent_code = textwrap.indent(cleaned_lines,'\t\t')

        return f"""
        kernel_names = {_list}
        kernel_includes = {_include}
        {indent_code}
        """

'''
#For debug purpose
if __name__ == "__main__":
    params = {
            'ParamSize': 200,
            'Mgemm': 0,
            'Ngemm': 0,
            'Msubv': 0,
            'Nsubv': 0,
            'ifm_byte_len': 0,
            'ofm_byte_len': 0,
            'AieCols': 0,
            'AieRows': 0,
            'backend_type': 0,
            }

    debug = M4N4()
    print(debug.gen_code(params))
'''
