import yaml
from typing import List, Dict
import os
import struct
import math
import logging 
import pdb
import re
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/../infra/")
collatoral_path = (os.path.dirname(os.path.abspath(__file__))+"/../../../Collaterals/")
L1_path = (os.path.dirname(os.path.abspath(__file__))+"/../../L1_fusion/")
import sys
sys.path.append(infra_path)
sys.path.append(collatoral_path)
sys.path.append(L1_path)
import scheduler_utils as utils
import struct

from kernel_func_list import kernel_func_list

def return_kernel_path(op_name, op_ver):
	# yaml_name = "LUT" if op_name in ["Silu", "Gelu"] else op_name
	if op_name == "PWLA":
		yaml_name = "LUT"
	elif op_name in ["Add", "Mul"]:
		yaml_name = "Elemwise" if "EleWise" in op_ver else "Broadcast"
	else:
		yaml_name = op_name
	
	file_path = collatoral_path+f"{yaml_name}_kernel_metadata.yaml"
	#print(f"Opening {file_path} for {op_ver} kernel parsing")
	with open (file_path, 'r') as file:
		data = yaml.safe_load(file)
		kernel_list , kernel_include = data[op_ver]['kernel_path']['kernel_list'], data[op_ver]['kernel_path']['kernel_include']

	kernel_dict = {}

	for k in kernel_list:
		kernel_dict[k] = kernel_func_list.index(k)

	return kernel_dict , kernel_include

def load_kernel_metadata(op_name, op_ver):
	if op_name == "PWLA":
		yaml_name = "LUT"
	elif op_name in ["Add", "Mul"]:
		yaml_name = "Elemwise" if "EleWise" in op_ver else "Broadcast"
	else:
		yaml_name = op_name
	file_path = collatoral_path+f"{yaml_name}_kernel_metadata.yaml"
	#print(f"Opening {file_path} for kernel parsing")
	with open (file_path, 'r') as file:
		data = yaml.safe_load(file)
		modified_data = substitute_val(data[op_ver], data)
		return data[op_ver], modified_data

def find_dict_by_name(name, context):
	if isinstance(context, dict):
			if name in context:
					return context[name]
			for k, v in context.items():
					found = find_dict_by_name(name, v)
					if found:
							return found
	elif isinstance(context, list):
			for item in context:
					found = find_dict_by_name(name, item)
					if found:
							return found
	return None

def get_nested_val(d, keys):
	for k in keys:
			if isinstance(d, dict):
					d = d.get(k)
			elif isinstance(d, list) and k.isdigit():
					idx = int(k)
					if idx < len(d):
							d = d[idx]
					else:
							return None
			else:
					return None
	return d

def substitute_val(d, context):
	if isinstance(d, dict):
			return {k: substitute_val(v, context) for k, v in d.items()}
	elif isinstance(d, str):
			matches = re.findall(r'(\w+)(\[".*?"\])(\[".*?"\])?', d)
			for match in matches:
				dict_name = match[0]
				keys = [match[1].strip('[]').replace('"', '')]
				if match[2]:
					keys.append(match[2].strip('[]').replace('"', ''))
				sub_dict = find_dict_by_name(dict_name, context)
				if sub_dict:
					nested_val = get_nested_val(sub_dict, keys)
					if nested_val is not None:
						logging.info(f"Substituting {d} with {nested_val}")
						return nested_val
			return d
	return d

def eval_eqn(equations, context):
	evaluated_equations = {}
	for key, eqn in equations.items():
			try:
				output = eval(eqn, {"math": math}, context)
				context[key] = output #If in case used in next expressions
				evaluated_equations[key] = output
			except Exception as e:
				evaluated_equations[key] = str(e)
	#print(evaluated_equations)
	return evaluated_equations

def prepare_kernel_blob(_val, struct_fmt):
	byte_val = {}
	for key, value in _val.items():
				byte_val[key] = value.to_bytes(2, byteorder='little', signed=False)

	utils.sanity_check(len(struct_fmt['type']) == len(struct_fmt['name']), f"Invalid struct fmt. len(type): {len(struct_fmt['type'])} and len(name): {len(struct_fmt['name'])}")
	# Define the mapping from data types to struct format characters
	dtype_to_fmt = {
			'uint8_t': 'B',
			'int8_t': 'b',
			'uint16_t': 'H',
			'int16_t': 'h',
			'uint32_t': 'I',
			'int32_t': 'i',
			'int': 'i',  # Assuming 'int' is equivalent to 'int32_t'
			'bool': 'i'  # Assuming 'bool' is stored as an int
	}
	# Generate the format string
	fmt_list = ['<' + dtype_to_fmt[dtype] for dtype in struct_fmt['type']]

	# Calculate the total size needed for the blob
	total_size = 0
	offsets = []
	for fmt in fmt_list:
			size = struct.calcsize(fmt)
			# Align the current offset to the size of the current type
			if total_size % size != 0:
					total_size += size - (total_size % size)
			offsets.append(total_size)
			total_size += size

	# Prepare the blob with the calculated total size
	blob = bytearray(total_size)

	# Fill the blob with data
	for offset, (name, dtype) in zip(offsets, zip(struct_fmt['name'], fmt_list)):
			struct.pack_into(dtype, blob, offset, _val[name])

	inline_param = bytes(blob)
	#print(inline_param)
	#print(list(inline_param))
	return inline_param

def gen_input_args(op_code, input_list):
	if op_code == "MatMul" :
		utils.sanity_check(len(input_list) == 10,f"Invalid input list opcode: {op_code}, expect 10 input args")
		return {
		'M_SUBV': input_list[0],
		'K_SUBV': input_list[1],
		'N_SUBV': input_list[2],
		'zero_init' : input_list[3],
		'qdq_param_addr' : input_list[4],
		'gemm_wgt_size' : input_list[5],
		'ifm_sum' : input_list[6],
		'tdm1' : input_list[7],
		'tdm2' : input_list[8],
		'isMatA' : input_list[9],
		}
	elif op_code == "Conv" :
		utils.sanity_check(len(input_list) == 5,f"Invalid input list opcode: {op_code}, expect 3 input args")
		return {
		'M_SUBV': input_list[0],
		'K_SUBV': input_list[1],
		'N_SUBV': input_list[2],
		'Sx': input_list[3],
		'Sy': input_list[4]
		}
	elif op_code == "PWLA":
		utils.sanity_check(len(input_list) == 7,f"Invalid input list opcode: {op_code}, expect 3 input args")
		return {
		'core_act_ping_addr': input_list[0],
		'lutab_addr': input_list[1],
		'lutcd_addr': input_list[2],
		'subv_rows': input_list[3],
		'subv_cols': input_list[4],
		'scratch_addr': input_list[5],
		'fused_ops': input_list[6],
		}
	elif op_code == "GPN":
		utils.sanity_check(len(input_list) == 3,f"Invalid input list opcode: {op_code}, expect 3 input args")
		return {
		'Nsubv': input_list[0],
		'LRN_EN': input_list[1],
		'Nparam': input_list[2]
		}
	elif op_code == "LRN":
		utils.sanity_check(len(input_list) == 4,f"Invalid input list opcode: {op_code}, expect 3 input args")
		return {
		'Nsubv': input_list[0],
		'LRN_EN': input_list[1],
		'aie_rows': input_list[2],
		'aie_cols': input_list[3]
		}
	else:
		utils.sanity_check(False,f"Invalid opcode: {op_code}")

def gen_blob(input_list, op_code, op_ver):
	logging.info(f"==============Input Expressions for {op_ver}==================")
	input_args = gen_input_args(op_code, input_list)
	orig_data, meta_data = load_kernel_metadata(op_code, op_ver)
	args_int = {}
	if 'args' in meta_data['kernel_param'].keys():
		logging.info(f"meta_data['kernel_param']['args']: {meta_data['kernel_param']['args']}")
		args_int = utils.convert_values_to_int(meta_data['kernel_param']['args'])
	input_vars = {**args_int, **input_args}
	val = eval_eqn(meta_data['kernel_param']['equations'], input_vars)
	packed_data = prepare_kernel_blob(val, meta_data['kernel_param']['struct_fmt'])
	logging.info(f"meta_data['kernel_param']['equations']: {meta_data['kernel_param']['equations']}")
	logging.info(f"==========Kernel blob for {op_ver}=============")
	logging.info(f"Packed Data: {packed_data}")
	
	#print(op_code, op_ver)
	print(f"OP: {op_ver}, Kernel data: {val}, param (in bytes): {packed_data}'")
	#print(packed_data)
	
	return packed_data

#For debug purpose
'''
if __name__ == "__main__":
	input_list = [32, 0, 64] #[32, 1, 4, 4]
	op_code =  'GPN' #'LRN'
	op_ver = 'GPN_bf16' #'LRN_bf16' 
	gen_blob(input_list, op_code, op_ver)
'''
