import sys
import os
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
sys.path.append(infra_path)
import logging
import scheduler_utils as utils
import template_gemm_4x4 as matmul_4x4
import template_gemm_8x4 as matmul_8x4
import template_conv as conv
import template_LUT_ops as LUT_ops
import template_add as add_op
import template_rope as rope_op
import template_layernorm as layernorm
import template_groupnorm as groupnorm
import template_lpnorm as lpnorm
import template_softmax as softmax
import pdb


def get_4x4_overlay_template_obj(template_meta_data, data):
	utils.sanity_check(template_meta_data['overlay'] == "4x4", f"Incorrect overlay. Please check! overlay: {template_meta_data['overlay']}")

	if template_meta_data['op_type'] == "MatMul":
		if template_meta_data['mode'] == "M4K1N4":
			return matmul_4x4.M4N4(data)
		elif template_meta_data['mode'] == "M1K1N16":
			return matmul_4x4.M1N16(data)
		elif template_meta_data['mode'] == "M16K1N1":
			return matmul_4x4.M16N1(data)
		else:
			utils.sanity_check(False, f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
	elif template_meta_data['op_type'] == "Conv":
		utils.sanity_check(template_meta_data['mode'] == "H4CO4", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return conv.M4N4(data)
	elif template_meta_data['op_type'] == "PWLA":
		utils.sanity_check(template_meta_data['mode'] == "N16", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return LUT_ops.M4N4(template_meta_data['op_type'],template_meta_data['ver'], data)
	elif template_meta_data['op_type'] in ["Add", "Mul"]:
		utils.sanity_check(template_meta_data['mode'] in ["N16", "M16N1"], f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return add_op.M1N16(data)
	elif template_meta_data['op_type'] == "LayerNormalization":
		utils.sanity_check(template_meta_data['mode'] == "N16", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return layernorm.M4N4(data)
	elif template_meta_data['op_type'] == "GroupNormalization":
		utils.sanity_check(template_meta_data['mode'] == "N16", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return groupnorm.M4N4(data)
	elif template_meta_data['op_type'] == "LpNormalization":
		utils.sanity_check(template_meta_data['mode'] == "N16", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return lpnorm.M4N4(data)
	elif template_meta_data['op_type'] == "Softmax":
		utils.sanity_check(template_meta_data['mode'] == "N16", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return softmax.M4N4(data)
	else:
		utils.sanity_check(False,f"Incorrect template op_code passed. op_code: {template_meta_data['op_type']}")


def get_8x4_overlay_template_obj(template_meta_data, data):
	utils.sanity_check(template_meta_data['overlay'] == "8x4", f"Incorrect overlay. Please check! overlay: {template_meta_data['overlay']}")
	if template_meta_data['op_type'] == "PWLA":
		utils.sanity_check(template_meta_data['mode'] == "N32", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return LUT_ops.M4N4(template_meta_data['op_type'],template_meta_data['ver'], data)
	elif template_meta_data['op_type'] == "LayerNormalization":
		utils.sanity_check(template_meta_data['mode'] == "N32", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return layernorm.M4N4(data)
	elif template_meta_data['op_type'] == "GroupNormalization":
		utils.sanity_check(template_meta_data['mode'] == "N32", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return groupnorm.M4N4(data)
	elif template_meta_data['op_type'] == "LpNormalization":
		utils.sanity_check(template_meta_data['mode'] == "N32", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return lpnorm.M4N4(data)
	elif template_meta_data['op_type'] == "Softmax":
		utils.sanity_check(template_meta_data['mode'] == "N32", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return softmax.M4N4(data)
	if template_meta_data['op_type'] == "MatMul":
		if template_meta_data['mode'] == "M4K1N8":
			return matmul_8x4.M4N8(data)
		elif template_meta_data['mode'] == "M8K1N4":
			return matmul_8x4.M8N4(data)
		elif template_meta_data['mode'] == "M1K1N32":
			return matmul_8x4.M1N32(data)
		elif template_meta_data['mode'] == "M32K1N1":
			return matmul_8x4.M32N1(data)
		elif template_meta_data['mode'] == "B4M8K1N1":
			return matmul_8x4.B4M8N1(data)
		elif template_meta_data['mode'] == "B32M1K1N1":
			return matmul_8x4.B32M1N1(data)
		else:
			utils.sanity_check(False, f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		
	elif template_meta_data['op_type'] in ["Add", "Mul"]:
		utils.sanity_check(template_meta_data['mode'] in ["N32","M32N1"], f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return add_op.M1N32(data)
	elif template_meta_data['op_type'] in ["RoPE"]:
		utils.sanity_check(template_meta_data['mode'] in ["N32","M32N1"], f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return rope_op.M1N32(data)
	else:
		utils.sanity_check(False,f"Incorrect template op_code passed. op_code: {template_meta_data['op_type']}")


def get_template_obj(template_meta_data, data):
	if template_meta_data['overlay'] == "4x4":
		return get_4x4_overlay_template_obj(template_meta_data, data)
	elif template_meta_data['overlay'] == "8x4":
		supported_ops = ['PWLA', 'MatMul', 'Add', 'Mul', 'LayerNormalization', 'GroupNormalization', 'LpNormalization', 'RoPE', 'Softmax']
		utils.sanity_check(template_meta_data['op_type'] in supported_ops, f"Op Code: {template_meta_data['op_type']} is not supported in 8x4 yet")
		return get_8x4_overlay_template_obj(template_meta_data, data)
	else:
		utils.sanity_check(False,f"Invalid overlay: {template_meta_data['overlay']}. Please check!")


def get_template_define_struct(template_meta_data, data):
	utils.sanity_check(template_meta_data['overlay'] == "4x4", f"Incorrect overlay. Please check! overlay: {template_meta_data['overlay']}")
	if template_meta_data['op_type'] == "MatMul":
		utils.sanity_check(template_meta_data['mode'] == "M4K1N4", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return  matmul_4x4.GemmDims(data)
	elif template_meta_data['op_type'] == "Conv":
		utils.sanity_check(template_meta_data['mode'] == "M4K1N4", f"Incorrect mode. Please check! mode: {template_meta_data['mode']}")
		return conv.ConvDims(data)
	else:
		utils.sanity_check(False,f"Incorrect call. OPs dim not needed for op_code: {template_meta_data['op_type']}")


