import sys
import os
import argparse
import pdb
import json
import numpy as np
import yaml
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

supported_op = ['MatMul_qdq_slice_silu_uint16xint4xuint16','MatMul_qdq_slice_uint16xint4xuint16','MatMul_qdq_uint16xint4xuint16','MatMul_qdq_uint16xint8xuint16', 'MatMul_qdq_bias_uint16xint4xuint16', 'MatMul_qdq_silu_uint16xint4xuint16', 'MatMul_qdq_uint16xint4xuint16','MatMul_qdq_uint16xuint8xuint16']


class Gen_Tiling_Output():
	def __init__ (self, _data, _params):
		self.M_gran, self.K_gran, self.N_gran = (8, 8, 8)
		self.M, self.K, self.N = _data[0]
		self.M_subV, self.K_subV, self.N_subV = _data[1]
		self.M_padded, self.K_padded, self.N_padded = max(_data[0][0], self.M_subV), max(_data[0][1], self.K_subV),  max(_data[0][2], self.K_subV)
		self.ifm_mode,  self.wgt_mode = "pin", "stream"
		self.ifm_ping_pong,  self.wgt_ping_pong,  self.ofm_ping_pong = True, True, False
		self.rows, self.cols = _params['row'], _params['col']
		self.split = _params['split']
		self.op_type = _params['op_type']
		self.layer_obj = _params['layer_obj']
		self.in_dtype, self.wgt_dtype, self.out_dtype = _params['dtype'][0],  _params['dtype'][1],  _params['dtype'][2]
		self.bias_size = ((self.layer_obj["coeff_shape"][0] // self.N) * self.N_subV * ((self.K * self.N) // (self.K_subV * self.N_subV)))
		self.split_dict = self.get_split(self.split)
		self.vec_coeff = (self.layer_obj["coeff_shape"][0] // self.N)
	
	def get_split(self, split):
		with open(os.path.join(parent_dir, "Collaterals/overlays.yaml")) as f:
		    overlays_dict = yaml.safe_load(f)

		if self.op_type in supported_op:
		    return overlays_dict[f'{self.cols}x{self.rows}']['MatMul'][split]

		assert False, f'Unsupport OPs: {self.op_type}'

	def gen_core_tile_params(self):
		core_tile_obj = {
				"subvols": {
						"ifm": [
								self.M_subV,
								self.K_subV
						],
						"wgt": [
								self.K_subV,
								self.N_subV
						],
						"ofm": [
								self.M_subV,
								self.N_subV
						]
				},
				"iters": {
						"ifm": [
								1,
								1
						],
						"wgt": [
								1,
								1
						],
						"ofm": [
								1,
								1
						]
				}
		}
		return core_tile_obj

	def gen_mem_tile_params(self):
		bias_bytes = 8 * (self.layer_obj['coeff_shape'][0] // self.layer_obj['in_wgt_shape'][1]) if 'qdq' in self.op_type else self.layer_obj['wgt1_bytes']
		n_subv = self.N_subV * self.rows if self.split == "M1K1N32" else self.N_subV 
		ofm_m_subv = self.M_subV if self.split == "M1K1N32" else  self.M_subV * self.rows
		#TODO: Check on wgt size calculation for M4N8 split
		mem_subvols = {
						"ifm": [
								self.M_subV,
								self.K
						],
						"wgt": [
								self.K_subV,
								n_subv
						],
						"ofm": [
								ofm_m_subv,
								n_subv
						]
				}

		mem_tile_obj = {
				"subvols": mem_subvols,
				"iters": {
						"ifm": [
								1,
								1
						],
						"wgt": [
								self.K // self.K_subV,
								((self.N // self.split_dict['mem_splits']['wgt'][1]) // n_subv)
						],
						"ofm": [
								1,
								((self.N // self.split_dict['mem_splits']['ofm'][1]) // n_subv)
						]
				},
				"sizes": {
						"ifm":  int((np.product(mem_subvols["ifm"]) * self.in_dtype).astype(int)),
						"wgt":  int((np.product(mem_subvols["wgt"]) * self.wgt_dtype).astype(int) + self.N_subV * self.rows * bias_bytes),
						"ofm":  int((np.product(mem_subvols["ofm"]) * self.out_dtype).astype(int)),
				}
		}
		return mem_tile_obj

	def gen_shim_tile_params(self):
		shim_subvols = {
						"ifm": [
								max(self.M_padded // self.split_dict['mem_splits']['ifm'][0], self.M_gran),
								max(self.K // self.split_dict['mem_splits']['ifm'][1], self.K_gran)
						],
						"wgt": [
								max(self.K // self.split_dict['mem_splits']['wgt'][0], self.K_gran),
								max(self.N // self.split_dict['mem_splits']['wgt'][1], self.N_gran)
						],
						"ofm": [
								max(self.M_padded // self.split_dict['mem_splits']['ofm'][0], self.M_gran),
								max(self.N // self.split_dict['mem_splits']['ofm'][1], self.N_gran)
						]
				}
		shim_tile_obj = {
				"subvols": shim_subvols,
				"sizes": {
						"ifm": int((np.product(shim_subvols["ifm"]) * self.in_dtype).astype(int)),
						"wgt": int((np.product(shim_subvols["wgt"]) * self.wgt_dtype).astype(int) + self.bias_size),
						"ofm": int((np.product(shim_subvols["ofm"]) * self.out_dtype).astype(int))
				}
		}
		return shim_tile_obj

	def gen_dram_params(self):
		dram_bias_size  = self.bias_size * self.cols
		dram_params_obj = {
				"shapes": {
						"ifm": [
								self.M_padded,
								self.K
						],
						"wgt": [
								self.K,
								self.N
						],
						"ofm": [
								self.M_padded,
								self.N
						]
				},
				"sizes": {
						"ifm": self.in_dtype * self.M_padded * self.K,
						"wgt": int(((self.K // self.K_subV ) * self.N // self.N_subV) * ((self.K_subV*self.N_subV) * self.wgt_dtype + self.N_subV * 8 * self.vec_coeff)),
						"ofm": self.out_dtype * self.M_padded * self.N
				}
		}
		return dram_params_obj

	def gen_kernel_obj(self):
		qdq_bytes = (2*64+5120) if "gelu" in self.op_type or "silu" in self.op_type else 64
		qdq_bank  = "BANK0" if "gelu" in self.op_type or "silu" in self.op_type else "BANK3"
		bias_bytes = 8 * (self.layer_obj['coeff_shape'][0] // self.layer_obj['in_wgt_shape'][1]) if 'qdq' in self.op_type else self.layer_obj['wgt1_bytes']
		tdm_bytes = 4
		wgt_bank  = ["BANK0", "BANK1"] if "int8" in self.op_type else ["BANK1.0", "BANK1.1"]
		wgt_unpack_size  = 0 if "xuint8x" in self.op_type or  "xint8x" in self.op_type else self.K_subV* self.N_subV
                #TODO: Check how to find tdm_bytes and bias_bytes from unique node
		kernel_obj = {
			"placement_constraints": {
				"ifm" : {
					"BANK2" : self.M_subV * self.K_subV * self.in_dtype,
					"BANK3" : self.M_subV * self.K_subV * self.in_dtype
				},
				"wgt" : {
					f"{wgt_bank[0]}": int(self.K_subV * self.N_subV * self.wgt_dtype) + self.N_subV * bias_bytes,
					f"{wgt_bank[1]}": int(self.K_subV * self.N_subV * self.wgt_dtype) + self.N_subV * bias_bytes
				},
				"ofm" : {
					"BANK2" : self.M_subV * self.N_subV * self.out_dtype,
				},
				"qdq" : {
					f"{qdq_bank}" : qdq_bytes
				},
				"WgtUnpack" : {
					"BANK0" : wgt_unpack_size
				},
				"tdm" : {
					"BANK0": int(0.5 * self.M_subV * self.N_subV * tdm_bytes),
					"BANK1": int(0.5 * self.M_subV * self.N_subV * tdm_bytes)
				},
				"ifm_sum" : {
					"BANK0" : self.M_subV * tdm_bytes
				},
				"stack" : {
					"BANK3" : 5120
				},
			},
		}
		return kernel_obj

	def gen_overlay_info(self):
		overlay_obj = {
			"overlay": f"{self.cols}x{self.rows}",
			"mode": self.split,
			"shape": {
				"row": self.rows,
				"col": self.cols
				}
		}
		return overlay_obj

	def gen_scheduling(self):
		sch_obj = {
			"ifm": "pin",
			"wgt": "stream",
			"ifm_ping_pong": False,
			"wgt_ping_pong": True,
			"ofm_ping_pong": False,
		}
		return sch_obj 

	def gen_layer_padding(self):
		padding_obj = {
			"ifm": [[0, self.M_padded-self.M], [0, self.K_padded-self.K]],
			"wgt": [[0, self.K_padded-self.K], [0, self.N_padded-self.N]],
			"ofm": [[0, self.M_padded-self.M], [0, self.N_padded-self.N]]
		}
		return padding_obj  

	def gen_layer_info(self):
		layer_dict = self.layer_obj
		layer_dict['orig_op_type'] = "MatMul"
		layer_dict['in_ifm_shape'] = [self.M, self.K]
		layer_dict['out_ofm_shape'] = [self.M, self.N]

		layer_dict['in_ifm_datatype'] = self.layer_obj["in_datatype"]
		layer_dict['in_wgt_datatype'] = self.layer_obj["wgt_datatype"]
		layer_dict['in_wgt1_datatype'] = self.layer_obj["wgt1_datatype"]
		layer_dict['out_ofm_datatype'] = self.layer_obj["out_datatype"]
		
		layer_dict['in_ifm_bytes'] = self.layer_obj["in_bytes"]
		layer_dict['in_wgt_bytes'] = self.layer_obj["wgt_bytes"]
		layer_dict['in_wgt1_bytes'] = self.layer_obj["wgt1_bytes"]
		layer_dict['out_ofm_bytes'] = self.layer_obj["out_bytes"]
		return layer_dict

	def gen_testbench_args(self):
		tb_cpp  = "test/gemm_int16x8_unit_test/main_gemm.cpp" if "int8" in self.op_type else "test/gemm_int16x4_unit_test/main_gemm.cpp"
		dtype  = "A16W8" if "int8" in self.op_type else "A16W4"
		flag_dq = self.layer_obj['attributes']['disable_q'][0] if 'disable_q' in self.layer_obj['attributes'] else 0
		testbench_obj = {
			"HOST_NAME": f"{tb_cpp}",
			"COMPILE_FLAGS": {
				f"M_GEMM_{dtype}": self.M,
				f"K_GEMM_{dtype}": self.K,
				f"N_GEMM_{dtype}": self.N,
				f"M_GEMM_SUBV_{dtype}": self.M_subV,
				f"K_GEMM_SUBV_{dtype}": self.K_subV,
				f"N_GEMM_SUBV_{dtype}": self.N_subV,
				"GEMM_VEC_COEFFS": self.layer_obj["coeff_shape"][0] // self.N,
				"GEMM_GELU": 1 if "gelu" in self.op_type else 0,
				"GEMM_SILU": 1 if "silu" in self.op_type else 0,
				"GEMM_WGT_SIGN": 0 if 'uint' in self.op_type.split('_')[-1].split('x')[1] else 1,
				"QDQMODE": flag_dq,
				}
		}
		return testbench_obj  

	def gen_code(self):
		code = {
				"core_tile_params" : self.gen_core_tile_params(),
				"mem_tile_params" : self.gen_mem_tile_params(),
				"shim_tile_params" : self.gen_shim_tile_params(),
				"dram_params" : self.gen_dram_params(),
				"scheduling" : self.gen_scheduling(),
				"layer_padding": self.gen_layer_padding(),
				"kernel_info": self.gen_kernel_obj(),
				"overlay_info": self.gen_overlay_info(),
				"layer_info": self.gen_layer_info(),
				"testbench_args": self.gen_testbench_args()
			}
		return code

def print_code(instr, file_name="input.json"):
	with open(file_name, "w") as text_file:
		json.dump(instr, text_file, indent=4)

def find_subv_for_supported_shapes(dictionary, key):
	#Check if key exist
	if key in dictionary:
		return dictionary[key]
        #Check if key(M as none) exist
	key_with_none = (None,) + key[1:]
	return  dictionary.get(key_with_none, "Key not found")

def main(args):
	with open(args["ir_json"]) as f:
		data = json.load(f)
	
	node_names=[]
	shapes=[]
	for k, v in data.items():
		if any(op in v['op_type'] for op in supported_op):
			#print(f'Running node. {k}')
			#print(v['in_act_shape'])
			ifm_dim_len = len(v['in_act_shape'])
			M=v['in_act_shape'][ifm_dim_len - 2]
			K=v['in_act_shape'][ifm_dim_len - 1]
			N=v['in_wgt_shape'][1]
			params = {
				"row": 4,
				"col": 8,
				"op_type": v['op_type'],
				"dtype": [v['in_bytes'], v['wgt_bytes'], v['out_bytes']],
				"split": args["mode"],
				"layer_obj": v,
				}
			op_meta_data = ((M, K, N), params)
			shapes.append(op_meta_data)
			node_names.append(k)

	M_subv = 16 if args["mode"] == "M4K1N8" else 8
	supported_shapes = {
        (None, 3072, 8192) : (M_subv, 256, 32),   #PSU0
        (None, 3072, 3072) : (M_subv, 256, 48),   #PSU0
        (None, 3072, 9216) : (M_subv, 256, 48),   #PSU0
        (None, 8192, 3072) : (M_subv, 256, 48),   #PSU0
        (None, 3072, 3072) : (M_subv, 256, 48),   #PSU0
        (None, 8192, 3072) : (M_subv, 256, 48),   #PSU0
        (None, 1536, 1536) : (M_subv, 192, 48),   #DS1 Node#6
        (None, 1536, 256) :  (16, 256, 8),    #DS1 Node#7
        #(None, 1536, 8960) :  (M_subv, 192, 56),  #DS1 Node#8  -> is_fused_op_handling
        (None, 1536, 8960) :  (M_subv, 192, 56),  #DS1 Node#9
        (None, 8960, 1536) : (M_subv, 160, 48),   #DS1 Node#10
        (None, 1536, 1536) : (M_subv, 192, 48),   #DS1 Node#11
        #(None, 8960, 1536) : (M_subv, 280, 48),   #DS1 Node#12 -> is_int8_ds1_gemm
        }
	for node_name, meta_data in zip(node_names, shapes):
		shape = meta_data[0]
		params = meta_data[1]
		is_fused_op_handling = True if "silu" in params["op_type"] and meta_data[0][1] == 1536 and meta_data[0][2] == 8960 else False
		is_int8_ds1_gemm    = True if "int8" in params["op_type"] and meta_data[0][1] == 8960 and meta_data[0][2] == 1536 else False
		if is_fused_op_handling:
			#NOTE - Hacky way to handle this:
			subv = (M_subv, 192, 56)
		elif is_int8_ds1_gemm:
			#NOTE - Hacky way to handle this:
			subv = (M_subv, 280, 48)
		else :
			subv = find_subv_for_supported_shapes(supported_shapes, shape)
		assert subv != 'key not found', f"Unsupported shape: {shape}"
		tiling_input = (shape, subv)
		
		print("Tiling input: ", tiling_input)
		test = Gen_Tiling_Output(tiling_input, params)
		code = test.gen_code()
		output_dir = os.path.dirname(os.path.abspath(args["ir_json"]))
		newpath = output_dir+"/"+node_name
		if not os.path.exists(newpath):
			os.makedirs(newpath)
		output_file_name = newpath+"/"+node_name+".json"
		print_code(code, output_file_name)


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="Generate compiler tiling params for MLOps. NOTE- IT's DUMMY",
																	 usage='use "%(prog)s --help" for more info', 
																	 formatter_class=argparse.RawTextHelpFormatter)
	
	# required knobs
	parser.add_argument('-m','--mode', Choices=['M4K1N8', 'M1K1N32'], required=True, help="Tensor Split")
	parser.add_argument('-ir','--ir_json', type=str, required=True, help="Pass the supported shapes")
	args = parser.parse_args()

	
	main(vars(args))
