import sys
import os
import pdb
import textwrap
from os import path
#infra_path = (os.path.dirname(os.path.abspath(__file__))+"infra/")
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/../infra/")
sys.path.append(infra_path)

import logging
import math
from enum import Enum

import scheduler_utils as utils

import const
#from template_base import BaseTemplate, BaseDims
from template_base import BaseTemplate
from template_base_norm import BaseNormTemplate
#from Scheduling_Engine.code_gen.template_base_norm__ import BaseNormTemplate
import gen_kernel_param

class M4N4(BaseNormTemplate):
	def __init__(self, _data=None):
		super().__init__()
		self.data = _data

		# buff = _data.info.get('BuffAllocator')
		# buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
		# self.params = LRNDims(buff_prm) 

	def gen_print_attr(self, op_params):
		return f"""
		#Print OPs params
		#Mlrn        = {op_params['Mlrn']}
		#Nlrn        = {op_params['Nlrn']}
		#Msubv        = {op_params['Msubv']}
		#Nsubv        = {op_params['Nsubv']}
		#ifm_byte_len = {op_params['ifm_byte_len']}
		#ofm_byte_len = {op_params['ofm_byte_len']}
		#backend_type = {op_params['backend_type']} """
	def gen_code(self,op_params):
		logging.info(f"Genrate code for Group Norm operation")
		data_flow = ""
		data_flow += self.gen_op_header(op_params)                               #Base Class Impl
		data_flow += self.gen_dataflow(op_params)                       #Derived Class Impl
		data_flow += self.gen_main_func()                               #Base Class impl
		return data_flow

	def gen_op_header(self, op_params):
		overlay = op_params['Overlay']
		code = self.gen_headers(overlay)  #Base Class Impl
		code += "\n"
		code += "from math import floor, sqrt \n"
		return code
	
	def gen_dataflow(self, op_params):
            _adf = "BackEnd.Adf"
            curdir = path.dirname(__file__)
            fname = path.abspath(path.join(curdir, 'group_norm.py'))
            with open(fname) as f:
            	code = f.read()                           #Base Class Impl	
            code += self.gen_norm_kernel_includes(op_params)
            code += f"{self.gen_dataflow_footer(op_params, {_adf})}" #TODO
            return code

	def gen_dataflow_footer(self, op_params, _back_end):
		Overlay = op_params['Overlay']
		aie_inst = op_params['AieInst']
		col = op_params['AieCols']
		row = op_params['AieRows']
		overlay_function =  "overlay_{}_dma_connections()".format(Overlay)
		core_stream_connections =  "overlay_{}_core_stream_bdcast()".format(Overlay)
		return f"""
def generate_dataflow(code_backend):
            AieInst			= {aie_inst}
            AieRows         = {row}
            AieCols         = {col//aie_inst}
            Mlrn 			= {op_params['Mlrn']}
            Nlrn 			= {op_params['Nlrn']}
            TdimLayer 		= {op_params['TdimLayer']}
            CoreMsubv 		= {op_params['CoreMsubv']}
            CoreNsubv 		= {op_params['CoreNsubv']}
            CoreTsubv 		= {op_params['CoreTsubv']}
            CoreMsubvNorm 	= {op_params['CoreMsubvNorm']}
            CoreNsubvNorm 	= {op_params['CoreNsubvNorm']}
            CoreTsubvNorm 	= {op_params['CoreTsubvNorm']}
            MemMsubv 		= {op_params['MemMsubv']}
            MemNsubv 		= {op_params['MemNsubv']}
            InOutBytesPerElem = {[op_params['IfmBytesPerElem'], op_params['OfmBytesPerElem']]}
            kernel_names, kernel_includes = get_kernel_includes() 
            params = group_norm(AieInst,
								AieRows,
								AieCols, 
                                                                kernel_names,
		                        Mlrn, 
		                        Nlrn,
		                        TdimLayer,
                                        InOutBytesPerElem, 
								CoreMsubv,
								CoreNsubv,
								CoreTsubv,
								CoreMsubvNorm,
								CoreNsubvNorm,
								CoreTsubvNorm,
								MemMsubv,
								MemNsubv)
            overlay_shape = OverlayShape({col}, {row})
            core_instrs_array = []
            for col in range(params.AieCols*params.AieInst):
                for row in range(2,params.AieRows+2):
                    core_instrs_array.append(get_core_instrs(params, col, row))
            instr_dict = {{}}
            for col in range(AieCols*AieInst):
                for row in range(AieRows):
                    instr_dict[AieTile(TileType.Core, col, row)] = core_instrs_array[col*AieRows+row]
            memtile_transfers = get_memtile_transfers(params)  
            shim_transfers    = get_shim_transfers(params)
            dma_connections = {overlay_function}
            run_layer_compilation(overlay_shape,
		                          kernel_names,
		                          kernel_includes,
		                          instr_dict,
		                          memtile_transfers,
		                          shim_transfers,
		                          dma_connections,
		                          code_backend,
                                          param_channel_id=0,
                                          casc_dir=CascDir.Vertical,
                                          core_connections={core_stream_connections})
		"""
'''
#For debug purpose
if __name__ == "__main__":
	params = {
			'ParamSize': 200,
			'Mlrn': 0,
			'Nlrn': 0,
			'Msubv': 0,
			'Nsubv': 0,
			'ifm_byte_len': 0,
			'ofm_byte_len': 0,
			'AieCols': 0,
			'AieRows': 0,
			'backend_type': 0,
			}

	debug = M4N4()
	print(debug.gen_code(params))
	'''
