"""
some helper functions to extract more information from model IR generated by parse_onnx_model.py
"""

import sys
import yaml
import glob
from collections import defaultdict

import os

import onnxruntime as ort
import json
import yaml
import csv
import numpy as np
import ast

def sanity_check(condition, message, level="Error"):
	if not condition:
		if level == "Error":
			raise AssertionError("\033[91m"+message+"\033[0m")
		elif level == "Warning":
			print("\033[33mWarning: "+message+"\033[0m")
		elif level == "Message":
			print("Message: "+message)

def add_unique_nodenames(unique_configs, all_nodes_dict, pm_id_dict):
    for unique_config in unique_configs.values():
        config_to_check = unique_config.copy()
        config_to_check.pop("Frequency")

        unique_config["nodenames"] = []
        attr = ast.literal_eval(unique_config["attributes"]) if unique_config["attributes"] != None else None
        if attr is None:
            attr = {}
        attr["pm_id"] = []
        for node_name, node_config in all_nodes_dict.items():
            if node_config == config_to_check:
                unique_config["nodenames"].append(node_name)
                if node_name in pm_id_dict:
                    attr["pm_id"].extend(pm_id_dict[node_name])
                attr["pm_id"] = sorted(list(set(attr["pm_id"])))
        if len(attr["pm_id"]) == 0:
            attr.pop("pm_id")
        if len(attr) == 0:
            unique_config["attributes"] = None
        else:
            unique_config["attributes"] = str(attr)

        # print(unique_config["Frequency"], len(unique_config["nodenames"]))

    return unique_configs

def get_unique_nodes_wrt_shapes_dtypes_attrs(model_IR_path:str):
    # generate uniqge nodes
    f = open(model_IR_path)
    model_IR = json.load(f)
    # iterate over nodes in model IR
    in_shapes_list = []
    all_node_dict = {}
    pm_id_dict = {}
    for node_name, node in model_IR.items():
        temp_dict = {}
        temp_dict["op_type"] = node["op_type"]

        # add inputs and outputs info
        temp_dict["inputs"] = node["inputs"]
        temp_dict["outputs"] = node["outputs"]
        for dict_elem in temp_dict["inputs"]:
            dict_elem.pop("name")
        for dict_elem in temp_dict["outputs"]:
            dict_elem.pop("name")
        temp_dict["inputs"] = str(temp_dict["inputs"])
        temp_dict["outputs"] = str(temp_dict["outputs"])

        temp_dict["in_act_shape"] = str(node["in_act_shape"]) if "in_act_shape" in node.keys() else None
        temp_dict["in_wgt_shape"] = str(node["in_wgt_shape"]) if "in_wgt_shape" in node.keys() else None
        temp_dict["in_wgt1_shape"] = str(node["in_wgt1_shape"]) if "in_wgt1_shape" in node.keys() else None
        temp_dict["out_act_shape"] = str(node["out_act_shape"]) if "out_act_shape" in node.keys() else None

        temp_dict["in_datatype"] = str(node["in_datatype"]) if "in_datatype" in node.keys() else None
        temp_dict["wgt_datatype"] = str(node["wgt_datatype"]) if "wgt_datatype" in node.keys() else None
        temp_dict["wgt1_datatype"] = str(node["wgt1_datatype"]) if "wgt1_datatype" in node.keys() else None
        temp_dict["out_datatype"] = str(node["out_datatype"]) if "out_datatype" in node.keys() else None

        temp_dict["in_bytes"] = str(node["in_bytes"]) if "in_bytes" in node.keys() else None
        temp_dict["wgt_bytes"] = str(node["wgt_bytes"]) if "wgt_bytes" in node.keys() else None
        temp_dict["wgt1_bytes"] = str(node["wgt1_bytes"]) if "wgt1_bytes" in node.keys() else None
        temp_dict["out_bytes"] = str(node["out_bytes"]) if "out_bytes" in node.keys() else None
        
        if "attributes" in node.keys() and "pm_id" in node["attributes"]:
            pm_id_dict[node_name] = node["attributes"]["pm_id"]
            node["attributes"].pop("pm_id")
        if "attributes" in node.keys() and "const_padding_value" in node["attributes"]:
            node["attributes"].pop("const_padding_value")

        temp_dict["attributes"] = str(node["attributes"]) if "attributes" in node.keys() else None

        temp_dict["qdq_symmetry"] = str(node["qdq_symmetry"]) if "qdq_symmetry" in node.keys() else None
        temp_dict["coeff_shape"] = str(node["coeff_shape"]) if "coeff_shape" in node.keys() else None

        # check for residency information
        temp_dict["in_act_residency"] = str(node["in_act_residency"]) if "in_act_residency" in node.keys() else None
        temp_dict["out_act_residency"] = str(node["out_act_residency"]) if "out_act_residency" in node.keys() else None

        in_shapes_list.append(temp_dict)
        all_node_dict[node_name] = temp_dict

    unique_configs = {}
    for item in in_shapes_list:
        item_tuple = str(tuple(item.items()))
        if item_tuple in unique_configs:
            unique_configs[item_tuple]["Frequency"] += 1
        else:
            unique_configs[item_tuple] = item.copy()
            unique_configs[item_tuple]["Frequency"] = 1
    unique_configs = dict(sorted(unique_configs.items()))

    # Create a new dictionary with keys based on the first part of "op_type" field
    unique_configs_with_op_type_keys = defaultdict(dict)
    op_type_count = defaultdict(int)

    for key, value in unique_configs.items():
        op_type = value["op_type"].split("_")[0]
        new_key = f"{op_type}_{op_type_count[op_type]}"
        unique_configs_with_op_type_keys[new_key] = value
        op_type_count[op_type] += 1

    final_dict = add_unique_nodenames(unique_configs_with_op_type_keys, all_node_dict, pm_id_dict)

    for op_config in final_dict.values():
        op_config["in_act_shape"] = ast.literal_eval(op_config["in_act_shape"]) if op_config["in_act_shape"] != None else None
        op_config["in_wgt_shape"] = ast.literal_eval(op_config["in_wgt_shape"]) if op_config["in_wgt_shape"] != None else None
        op_config["in_wgt1_shape"] = ast.literal_eval(op_config["in_wgt1_shape"]) if op_config["in_wgt1_shape"] != None else None
        op_config["out_act_shape"] = ast.literal_eval(op_config["out_act_shape"]) if op_config["out_act_shape"] != None else None
        op_config["in_bytes"] = ast.literal_eval(op_config["in_bytes"]) if op_config["in_bytes"] != None else None
        op_config["wgt_bytes"] = ast.literal_eval(op_config["wgt_bytes"]) if op_config["wgt_bytes"] != None else None
        op_config["wgt1_bytes"] = ast.literal_eval(op_config["wgt1_bytes"]) if op_config["wgt1_bytes"] != None else None
        op_config["out_bytes"] = ast.literal_eval(op_config["out_bytes"]) if op_config["out_bytes"] != None else None
        op_config["attributes"] = ast.literal_eval(op_config["attributes"]) if op_config["attributes"] != None else None
        op_config["coeff_shape"] = ast.literal_eval(op_config["coeff_shape"]) if op_config["coeff_shape"] != None else None
        op_config["qdq_symmetry"] = ast.literal_eval(op_config["qdq_symmetry"]) if op_config["qdq_symmetry"] != None else None

    with open(model_IR_path.split(".json")[0] + "_unique_nodes.json", "w") as f:
        sys.stdout = f  # Change the standard output to the file we created.
        json.dump(final_dict, f, indent=4)
        sys.stdout = sys.__stdout__  # original_stdout # Reset the standard output to its original value

def get_node_count_wrt_shapes_attrs(model_IR, model_IR_path, print_to_csv = True, print_to_json = False):
        # iterate over nodes in model IR
    in_shapes_list = []
    for node_name, node in model_IR.items():

        temp_dict = {}
        temp_dict["op_type"] = node["op_type"]

        temp_dict["in_act_shape"] = str(node["in_act_shape"]) if "in_act_shape" in node.keys() else None
        temp_dict["in_wgt_shape"] = str(node["in_wgt_shape"]) if "in_wgt_shape" in node.keys() else None
        temp_dict["in_wgt1_shape"] = str(node["in_wgt1_shape"]) if "in_wgt1_shape" in node.keys() else None
        temp_dict["out_act_shape"] = str(node["out_act_shape"]) if "out_act_shape" in node.keys() else None

        temp_dict["attributes"] = str(node["attributes"]) if "attributes" in node.keys() else None

        in_shapes_list.append(temp_dict)

    unique_configs = {}
    for item in in_shapes_list:
        item_tuple = str(tuple(item.items()))
        if item_tuple in unique_configs:
            unique_configs[item_tuple]["Frequency"] += 1
        else:
            unique_configs[item_tuple] = item.copy()
            unique_configs[item_tuple]["Frequency"] = 1
    unique_configs = dict(sorted(unique_configs.items()))

    # Create a new dictionary with integer keys
    unique_configs_with_int_keys = {}
    for i, (key, value) in enumerate(unique_configs.items()):
        unique_configs_with_int_keys[i] = value

    if print_to_csv:
        with open(model_IR_path.split(".json")[0] + "_count.csv", "w", newline='') as f:
            writer = csv.writer(f)
            writer.writerow(["op_type", "in_act_shape", "in_wgt_shape", "in_wgt1_shape", "out_act_shape", "attributes", "Frequency"])
            for key, value in unique_configs_with_int_keys.items():
                writer.writerow([value['op_type'], value['in_act_shape'], value['in_wgt_shape'], value['in_wgt1_shape'], value['out_act_shape'], value['attributes'], value['Frequency']])

    if print_to_json:
        with open(model_IR_path.split(".json")[0] + "_count.json", "w") as f:
            json.dump(unique_configs_with_int_keys, f, indent=4)



def get_all_nodes_df(model_IR, model_IR_path):
    all_nodes_list = []
    for node_name, node in model_IR.items():

        temp_dict = {}
        temp_dict["node_name"] = node["node_name"]
        temp_dict["op_type"] = node["op_type"]

        temp_dict["num_inputs"] = len(node["inputs"])


        temp_dict["in_act_shape"] = str(node["in_act_shape"]) if "in_act_shape" in node.keys() else None
        temp_dict["in_wgt_shape"] = str(node["in_wgt_shape"]) if "in_wgt_shape" in node.keys() else None
        temp_dict["in_wgt1_shape"] = str(node["in_wgt1_shape"]) if "in_wgt1_shape" in node.keys() else None
        temp_dict["out_act_shape"] = str(node["out_act_shape"]) if "out_act_shape" in node.keys() else None

        temp_dict["in_datatype"] = str(node["in_datatype"]) if "in_datatype" in node.keys() else None
        temp_dict["wgt_datatype"] = str(node["wgt_datatype"]) if "wgt_datatype" in node.keys() else None
        temp_dict["wgt1_datatype"] = str(node["wgt1_datatype"]) if "wgt1_datatype" in node.keys() else None
        temp_dict["out_datatype"] = str(node["out_datatype"]) if "out_datatype" in node.keys() else None

        temp_dict["attributes"] = str(node["attributes"]) if "attributes" in node.keys() else None

        # temp_dict["parent_op_types"] = node["parent_op_types"]
        # temp_dict["children_op_types"] = node["children_op_types"]

        all_nodes_list.append(temp_dict)

    with open(model_IR_path.split(".json")[0] + "_nodes.csv", "w", newline='') as csvfile:
        fieldnames = ["node_name", "op_type", "num_inputs", "in_act_shape", "in_wgt_shape", "in_wgt1_shape", "out_act_shape", "in_datatype", "wgt_datatype", "wgt1_datatype", "out_datatype", "attributes"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for node in all_nodes_list:
            writer.writerow(node)
        
def tag_layernorm_fusion_type(model_IR, metadata_path, device = "strix",  overlay = "4x4"):
    device_info = yaml.safe_load( open( os.path.join(metadata_path + "/device.yaml"), 'r') )
    aie_cols = int(overlay.split("x")[1]) #4
    aie_rows = int(overlay.split("x")[0]) #4
    total_cores = aie_rows*aie_cols
    data_mem_per_core = device_info[device]["core_data_memory"]*1024 #65536
    mem_tile_capacity = device_info[device]["memtile_capacity"]*1024 #512*1024
    for node_name, node_info in model_IR.items():
        if not node_info:
            continue
        #add fusion type for layer normalization
        if "layernorm" in node_info["op_type"].lower() and node_info["in_act_shape"] != ["NA"]:
            required_space = np.prod(node_info["in_act_shape"]) * node_info["in_bytes"] + np.prod(node_info["out_act_shape"]) * node_info["out_bytes"] + 8192 * 8

            if required_space < data_mem_per_core * total_cores:
                node_info["fusion_type"] = "L1_Fused"
            elif required_space < mem_tile_capacity * aie_cols * 0.9:
                node_info["fusion_type"] = "L2_Fused"
            else:
                node_info["fusion_type"] = "L3_Fused"

    return model_IR

def get_unique_ops_from_model_IR(model_IR):
    all_ops = set()
    # total_nodes = 0
    for node_name, node_info in model_IR.items():
        all_ops.add(node_info["op_type"])
        # total_nodes+=1
    # sort all_ops set when converting to list to make return value stable
    return sorted(all_ops)

def merge_dicts(dict1, dict2):
    for key, value in dict2.items():
        if key in dict1:
            if isinstance(dict1[key], dict) and isinstance(value, dict):
                merge_dicts(dict1[key], value)
            elif isinstance(dict1[key], list) and isinstance(value, list):
                dict1[key].extend(value)
            else:
                dict1[key] = [dict1[key], value] if not isinstance(dict1[key], list) else dict1[key] + [value]
        else:
            dict1[key] = value


def get_kernels_for_model(model_IR_path, metadata_path, kernel_func_list, print_to_file = False, fast_pm = False):
    f = open(model_IR_path)
    model_IR = json.load(f)
    yaml_files = glob.glob(os.path.join(metadata_path, "*kernel_metadata.yaml"))  # This will get all YAML files in the current directory
    merged_data = defaultdict(dict)

    for file_name in yaml_files:
        with open(file_name, 'r') as file:
            data = yaml.safe_load(file)
            merge_dicts(merged_data, data)
            if "base_kernel_list" in merged_data:
                del merged_data["base_kernel_list"]
            if "base_kernel_include" in merged_data:
                del merged_data["base_kernel_include"]

    all_kernel_list = []#merged_data["base_kernel_list"]
    all_kernel_includes_list = []#merged_data["base_kernel_include"]


    ops_list = get_unique_ops_from_model_IR(model_IR)
    # for debugging
    # ops_list.extend([
	# 	    "Add_qdq_int16xint8xint16",
    #         "MatMul_qdq_int16xint8xint16",
    #         "SiLu_qdq_bf16"
    # ])
    gpn_nodes_in_model = False
    norm_nodes_in_model = False
    lrn_nodes_in_model = False

    for op in ops_list:
        if merged_data[op]:
            # print(op)
            # print(merged_data[op]["kernel_path"]["kernel_list"])
            # print(merged_data[op]["kernel_path"]["kernel_include"])
            all_kernel_list.extend(merged_data[op]["kernel_path"]["kernel_list"])
            all_kernel_includes_list.extend(merged_data[op]["kernel_path"]["kernel_include"])
        else:
            sanity_check(False,f"No AIE kernel implementation found for {op} operator", "Warning")

        if "GroupNormalization" in op:
            gpn_nodes_in_model = True

        unsupported_mode = ["Skip"]
        if "LayerNormalization" in op and all(um not in op for um in unsupported_mode):
            lrn_nodes_in_model = True

        if op.split("_")[0] in ["LayerNormalization", "GroupNormalization", "BatchNormalization", "InstanceNormalization"]:
            norm_nodes_in_model = True

    all_kernel_set = set(all_kernel_list)
    if gpn_nodes_in_model and not lrn_nodes_in_model:
        all_kernel_set.add("run_wrapper_gpn")
        if "run_wrapper_lrn" in all_kernel_set: all_kernel_set.remove("run_wrapper_lrn")
        if "run_wrapper_lrn_k_gpn" in all_kernel_set: all_kernel_set.remove("run_wrapper_lrn_k_gpn")
    elif not gpn_nodes_in_model and lrn_nodes_in_model:
        all_kernel_set.add("run_wrapper_lrn")
        if "run_wrapper_gpn" in all_kernel_set: all_kernel_set.remove("run_wrapper_gpn")
        if "run_wrapper_lrn_k_gpn" in all_kernel_set: all_kernel_set.remove("run_wrapper_lrn_k_gpn")
    elif gpn_nodes_in_model and lrn_nodes_in_model:
        all_kernel_set.add("run_wrapper_gpn")
        all_kernel_set.add("run_wrapper_lrn_k_gpn")
        if "run_wrapper_lrn" in all_kernel_set: all_kernel_set.remove("run_wrapper_lrn")
    else:
        pass

    all_kernel_includes_set = set(all_kernel_includes_list)

    final_dict = defaultdict(dict)
    #final_dict["kernel_list"] = [s for s in sorted(all_kernel_set) if s]
    
    kernel_names = {}
    for s in all_kernel_set:
        try:
            kidx = kernel_func_list.index(s)
            kernel_names[s] = kidx
        except ValueError:
            print(f"Error: '{s}' not found in the kernel fun list!")
    
    final_dict["kernel_list"] =  dict(sorted(kernel_names.items(), key=lambda item: item[1]))
    
    final_dict["kernel_include"] = [s for s in sorted(all_kernel_includes_set) if s]
    if 'super.hh' in final_dict["kernel_include"]:
        final_dict["kernel_include"].remove('super.hh')
        final_dict["kernel_include"].insert(0, 'super.hh')

    #find biggest n for norm
    # if norm_nodes_in_model:
    #     max_N = 0
    #     for node_name, node_info in model_IR.items():
    #         if node_info["op_type"].split("_")[0] in ["LayerNormalization", "GroupNormalization", "BatchNormalization", "InstanceNormalization"]:
    #             max_N = int(max(max_N, np.prod(node_info["inputs"][1]["shape"])))
    # else:
    #     max_N = "NA"

    final_dict["group_norm_in_model"] = gpn_nodes_in_model
    # final_dict["biggest_N"] = max_N

    final_dict["disable_fast_pm"] = not fast_pm

    if print_to_file:
        print("Saving kernels list for the model")
        # Save the merged data to a new YAML file
        with open(os.path.splitext(os.path.normpath(model_IR_path))[0]+'_kernel_list.json', 'w') as file:
            # import pdb
            # pdb.set_trace()
            # yaml.safe_dump(dict(final_dict), file)
            sys.stdout = file # Change the standard output to the file we created.
            json.dump(dict(final_dict), file, indent = 4)
            sys.stdout = sys.__stdout__ #original_stdout # Reset the standard output to its original value


    return list(all_kernel_set), list(all_kernel_includes_set)


if __name__ == "__main__":
    model_IR_path = sys.argv[1]
    f = open(model_IR_path)
    model_IR = json.load(f) 
    
    ## for debugging individual functions
    # get_all_nodes_df(model_IR,model_IR_path)
    get_node_count_wrt_shapes_attrs(model_IR, model_IR_path, print_to_csv = True, print_to_json = True)
    # get_unique_nodes_wrt_shapes_dtypes_attrs(model_IR, model_IR_path)
    # tag_layernorm_fusion_type(model_IR_path, "../../../Collaterals/", device = "strix", overlay = "4x4")
    # model_kernels, model_kernel_includes = get_kernels_for_model(model_IR, "../../../Collaterals/", print_to_file = True)
