import os
import sys
import copy
import logging
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
sys.path.append(infra_path)
import const
from template_base_gemm import Gemm_base, GemmDims

class template_class_name(Gemm_base):
    #access buffer allocator paramters
    def gen_dma_params(self, _pipeline_data):
        buff = _pipeline_data.info.get('BuffAllocator')
        buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
        op_Data = sum(buff_prm['padding'][0]['pad_ifm_x']), sum(buff_prm['padding'][0]['pad_ifm_y'])
        dims = GemmDims(buff_prm['M']+sum(buff_prm['padding'][0]['pad_ifm_x']),
                     buff_prm['K'], 
                     buff_prm['N'],
                     buff_prm['M_subV'], buff_prm['K_subV'], buff_prm['N_subV'],
                     buff_prm['aie_rows'], buff_prm['aie_cols'], buff_prm['aie_arrays'],
                     buff_prm['ifm_bits'], buff_prm['ofm_bits'], buff_prm['qdq_bytes'],
                     buff_prm['outer_loop'], buff_prm['inner_loop'], buff_prm['acc_loop'],
                     buff_prm['wgt_subv_rows'], buff_prm['wgt_subv_cols'],
                     buff_prm['ifm_core_subv_bytes'], buff_prm['wgt_core_subv_bytes'], 
                     buff_prm['ofm_core_subv_bytes'],
                     buff_prm['sum_core_subv_bytes'], buff_prm['tdm_core_subv_bytes'],
                     buff_prm['wgt_bits'] if buff_prm['actxact'] else const.BITS_PER_BYTE,
                     const.BITS_PER_BYTE, buff_prm['bias_bits'],
                     buff_prm['tdm_bits'], 
                     sum(buff_prm['padding'][0]['pad_ifm_x']), sum(buff_prm['padding'][0]['pad_ifm_y']),
                     sum(buff_prm['padding'][2]['pad_ofm_x']), sum(buff_prm['padding'][2]['pad_ofm_y']))
        params = {'dims': dims}
        self.update_buffer_alloc_to_params(params, buff)
        self.calc_rep_params(params)
        self.set_kernelmode_prm(params)
        self.gen_dma_tile_params(params)
        return params

    #top level function call to generate data_flow.py string
    def gen_instr(self, params):
        logging.info(f"Genrate code for matmul operation")

        data_flow = ""
        #header import statements
        data_flow += self.gen_headers(params['template_meta_data']['overlay'], 1)
        #kernel call function, ex. gemm_qdq_params
        data_flow += self.params_instruction()
        #include helper functions, ex. gen_pack_shim_transfer_with_iter_step
        data_flow += self.gen_helper_func_instr(params['enable_packtransfer'], params['enable_pack_shim'], params['dims'])
        #include overlay mapping definition
        data_flow += self.gen_8x4_overlay()
        #data flow function definition 
        data_flow += self.gen_dataflow(params['enable_packtransfer'])
        #create actual dma transfers
        data_flow += self.dma_pattern_code(params)
        #kernel name and kernel file include
        data_flow += self.gen_kernel_name(params['op_name'], params['op_ver'])
        #run compile function
        data_flow += self.gen_connection_run_compile(params['template_meta_data']['overlay'],params['dims'].aie_cols, params['dims'].aie_rows, params['CoreStackAddr'][0])
        #over_flow.py main function
        data_flow += self.gen_main_func()
        
        return data_flow

    def dma_pattern_code(self, params):
        #core transfers
        core_transers  = self.gen_core_instr(params)
        #memtile transfers
        memtile_transfers = self.gen_memtile_instr(params)
        #shim transfers
        shim_transfers = self.gen_shim_instr(params)
        return core_transers + memtile_transfers + shim_transfers

    def gen_memtile_instr(self, params):
        memtile_transfers  = "    #MEMTILE FLOW DEFINITION\n"
        memtile_transfers += self.memtile_stats(params)
        memtile_transfers += "    memtile_transfers = []\n"
        memtile_transfers += "    #MEMTILE PARAM TRANSFERS\n"
        memtile_transfers += self.memtile_prm_pattern(params)
        if params.get('CoreQdqSize') is not None:
            memtile_transfers += "    #MEMTILE QDQ PARAM TRANSFERS\n"
            memtile_transfers += self.memtile_qdq_pattern(params)
        memtile_transfers += f"    #MEMTILE ACTIVATION TRANSFERS, {params['ActDataFlow'].upper()}\n"
        for col in range(params['dims'].aie_cols):
            act_params = {'wgt_memtile_memory': copy.deepcopy(params['act_memtile_memory'][col][0]),
                        'wgt_memtile_mm2s'    : copy.deepcopy(params['act_memtile_mm2s'][col]*params['dims'].aie_rows),
                        'wgt_memtile_s2mm'    : copy.deepcopy(params['act_memtile_s2mm'][col][0]),
                        'dims'                : copy.deepcopy(params['dims']),
                        'reuse_chain_length'  : copy.deepcopy(params['reuse_chain_length']),
                        'MemtileWgtPingAddr'  : copy.deepcopy(params['MemtileIfmPingAddr']),
                        'MemtileWgtPongAddr'  : copy.deepcopy(params['MemtileIfmPongAddr']),
                        'ping_pong_enable'    : copy.deepcopy({'wgt':params['ping_pong_enable']['ifm']}),
                        'MemtileWgtRepeat'    : copy.deepcopy(params['MemtileActPingRepeat']),
                        'MemtileWgtSize'      : copy.deepcopy(params['MemtileActSize']),
                        'sync_strategy'       : copy.deepcopy('Default'), #params['sync_strategy'],
                        'MemtileWgtReuseRatio': copy.deepcopy(params['MemtileActReuseRatio']),
                        'MemtileWgtchid'      : copy.deepcopy(params.get('MemtileWgtchid', 1))
                        }
            act_params['dims'].wgt_bits = params['dims'].act_bits
            memtile_transfers += self.memtile_wgt_pattern(act_params, col)
        memtile_transfers += f"    #MEMTILE WEIGHT TRANSFERS, {params['WgtDataFlow'].upper()}\n"
        wgt_params = {'act_memtile_memory'  : copy.deepcopy([[params['wgt_memtile_memory']]]*params['dims'].aie_cols),
                      'act_memtile_s2mm'    : copy.deepcopy([[params['wgt_memtile_s2mm']]]*params['dims'].aie_cols),
                      'act_memtile_mm2s'    : copy.deepcopy([[params['wgt_memtile_mm2s'][x//2]] for x in range(len(params['wgt_memtile_mm2s'])*2)]),
                      'enable_padding'      : copy.deepcopy(False),
                      'MemtileActPingRepeat': copy.deepcopy(params['MemtileWgtRepeat']),
                      'MemtileActPongRepeat': copy.deepcopy(params['MemtileWgtRepeat']),
                      'MemtileIfmPingAddr'  : copy.deepcopy(params['MemtileWgtPingAddr']),
                      'MemtileIfmPongAddr'  : copy.deepcopy(params['MemtileWgtPongAddr']),
                      'MemtileActSize'      : copy.deepcopy(params['MemtileWgtSize']),
                      'dims'                : copy.deepcopy((params['dims'])),
                      'MemtileActReuseRatio': copy.deepcopy(params['MemtileWgtReuseRatio']),
                      'ActPingPong'         : copy.deepcopy(params['WgtPingPong']),
                      'MemtileActchid'      : copy.deepcopy(params.get('MemtileActchid', 0)),
                      'ActDataFlow'         : copy.deepcopy(params['WgtDataFlow']),
                      }
        wgt_params['dims'].act_bits = params['dims'].wgt_bits
        memtile_transfers += self.gen_memtile_act_patten(wgt_params)
        memtile_transfers += "    #MEMTILE OFM TRANSFERS\n"
        memtile_transfers += "".join([self.memtile_ofm_pattern(params, col) for col in range(params['dims'].aie_cols)])
        memtile_transfers += "\n"
        return memtile_transfers
        
    def gen_shim_instr(self, params):
        shim_transfers  = "    #SHIM FLOW DEFINITION\n"
        shim_transfers += self.shim_stats(params)
        shim_transfers += "    shim_transfers = []\n"
        shim_transfers += "    #SHIM PARAM TRANSFERS\n"
        shim_transfers += "".join([self.shim_prm_pattern(params, col) for col in range(params['dims'].aie_cols)])
        if params['CoreQdqSize'] != None:
            shim_transfers += "    #SHIM QDQ PARAM TRANSFERS\n"
            shim_transfers += "".join([self.shim_qdq_pattern(params, col) for col in range(0, params['dims'].aie_cols, 2)])
        shim_transfers += f"    #SHIM ACTIVATION TRANSFERS, {params['ActDataFlow'].upper()}\n"
        for itr in range(params['ShimActOuterLoop']):
            for col in range(params['dims'].aie_cols):
                act_params = {'ShimActRepeat':    copy.deepcopy(params['ShimActRepeat'][itr]),
                              'ShimActBufferIdx': copy.deepcopy(params['ShimActBufferIdx']),
                              'ShimIfmSize':      copy.deepcopy(params['ShimIfmSize']),
                              'act_shim_memory':  copy.deepcopy(params['act_shim_memory']),
                              'act_shim_mm2s':    copy.deepcopy(params['act_shim_mm2s']),
                              'dims':             copy.deepcopy(params['dims'])}
                shim_transfers += self.shim_act_pattern_pack(act_params, col, itr)
        shim_transfers += f"    #SHIM WEIGHT TRANSFERS, {params['WgtDataFlow'].upper()}\n"
        for itr in range(len(params['ShimWgtRepeat'])):
            for col in range(0, params['dims'].aie_cols, 2):
                wgt_params = {'ShimWgtRepeat':    copy.deepcopy(params['ShimWgtRepeat'][itr]),
                              'ShimWgtBufferIdx': copy.deepcopy(params['ShimWgtBufferIdx']),
                              'wgt_shim_memory':  copy.deepcopy(params['wgt_shim_memory']),
                              'wgt_shim_mm2s':    copy.deepcopy([params['wgt_shim_mm2s'][x//2][itr] for x in range(params['dims'].aie_cols)]),
                              'dims'           :  copy.deepcopy(params['dims'])}
                shim_transfers += self.shim_wgt_pattern_pack(wgt_params, col)
        shim_transfers += "    #SHIM OFM TRANSFERS\n"
        shim_transfers += "".join([self.shim_ofm_pattern(params, col) for col in range(params['dims'].aie_cols)])
        return shim_transfers