import os
import glob
import yaml
import json
import numpy
import argparse
import logging
import onnx
from collections import defaultdict
from onnx.helper import make_attribute

from kernel_func_list import kernel_func_list


def get_kernel_dict():
    kernel_dict = {}
    parent_dir = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    )
    yaml_files = glob.glob(
        os.path.join(parent_dir, "Collaterals/*_kernel_metadata.yaml")
    )

    for file in yaml_files:
        with open(file) as f:
            kd = yaml.safe_load(f)
            kernel_dict.update(kd)
    return kernel_dict


def get_kernel_pm_dict(is_target_procyon: bool = False):
    kernel_pm_dict = {}
    parent_dir = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    )
    json_file = os.path.join(parent_dir, "Collaterals/pm_kernel_map.json")

    with open(json_file) as f:
        kd = json.load(f)

    for id in list(kd.keys()):
        if type(kd[id]) is dict and "kernel_list" in kd[id]:
            for pm_kernel in list(kd[id]["kernel_list"].keys()):
                if pm_kernel in kernel_pm_dict:
                    kernel_pm_dict[pm_kernel].append(id)
                else:
                    kernel_pm_dict[pm_kernel] = [id]

    uniq_pm_ids = unique_list = list(
        set(item for sublist in kernel_pm_dict.values() for item in sublist)
    )
    proc_uniq_pm_ids = [
        pm_id
        for pm_id in uniq_pm_ids
        if any("proc_opt" in inc for inc in kd[pm_id]["kernel_include"])
    ]
    pristine_kernel_pm_dict = {
        key: [int(pm_id[3:]) for pm_id in value if pm_id not in proc_uniq_pm_ids]
        for key, value in kernel_pm_dict.items()
    }
    kernel_pm_dict = (
        {
            key: (
                [int(pm_id[3:]) for pm_id in value if pm_id in proc_uniq_pm_ids]
                if set(proc_uniq_pm_ids).issubset(value)
                else [int(pm_id[3:]) for pm_id in value]
            )
            for key, value in kernel_pm_dict.items()
        }
        if is_target_procyon
        else pristine_kernel_pm_dict
    )
    if is_target_procyon and proc_uniq_pm_ids:
        # Force proc_uniq_pm_ids for 'run_slice_a8' since it doesn't require actual kernel anyways
        kernel_pm_dict["run_slice_a8"] = [int(pm_id[3:]) for pm_id in proc_uniq_pm_ids]

    return kernel_pm_dict, pristine_kernel_pm_dict


def dataflow_op_kernel_names(tilings_json_path):
    node_to_kernels = {}

    with open(tilings_json_path, "r") as f:
        tilings_data = json.load(f)

    # Iterate through all operations in tilings.json
    for op_id, op_data in tilings_data.items():
        if not isinstance(op_data, dict):
            continue

        # Read kernel_names list
        kernel_names = op_data.get("kernel_names", [])

        # Read nodenames from layer_info
        layer_info = op_data.get("layer_info", {})
        nodenames = layer_info.get("nodenames", [])

        # Add same kernel_names list for every node_name in nodenames
        if kernel_names and nodenames:
            for node_name in nodenames:
                node_to_kernels[node_name] = kernel_names

    return node_to_kernels


def get_intersection(list1, list2):
    return [x for x in list1 if x in list2]


def static_pm_partition_graph(
    model,
    model_path,
    fast_pm_enable,
    prebuilt_mladf_mha,
    is_target_procyon,
    tiling_json_file=None,
    IR_json_file=None,
    save_model=False,
):
    kernel_dict = get_kernel_dict()
    dataflow_ops_kernel_map = (
        dataflow_op_kernel_names(tiling_json_file) if tiling_json_file else {}
    )
    model_nodes = model.graph.node
    # kernel_list contains run_gemm_a8w8, set a variable
    is_gemm_a8w8 = any(
        "run_gemm_a8w8"
        in kernel_dict.get(node.op_type, {})
        .get("kernel_path", {})
        .get("kernel_list", [])
        for node in model_nodes
    )

    current_kernel_list = []
    current_pm_list = []
    kernel_pm_dict, pristine_kernel_pm_dict = get_kernel_pm_dict(
        is_target_procyon and is_gemm_a8w8
    )
    output_kernel_list = []
    output_pm_list = []
    start = 0
    node_idx_pm_list = []

    if is_target_procyon:
        # create copy of kernel_pm_dict
        kernel_pm_dict_bkup = kernel_pm_dict.copy()

    for i in range(len(model_nodes)):
        if is_target_procyon:
            # restore kernel_pm_dict back from its copy
            kernel_pm_dict = kernel_pm_dict_bkup.copy()

        # getting the op_type of the node
        node = model_nodes[i]
        node_op = node.op_type
        for attr in node.attribute:
            if attr.name == "kernel_op_type":
                node_op = attr.s.decode("utf-8")

        new_pm_list = []
        # checking if the op is a non NPU op
        if (
            prebuilt_mladf_mha
            and "MHA_3p0_1col_Transpose_qdq_" in node_op
        ):
            new_pm_list.extend([99])
        elif not (node_op in kernel_dict or node.name in dataflow_ops_kernel_map):
            continue
        else:
            # getting the list of kernels for the current node
            if node.name in dataflow_ops_kernel_map:
                # print("Using dataflow op kernel mapping for node:", node.name)
                node_kernels = dataflow_ops_kernel_map[node.name]
            else:
                node_kernels = kernel_dict[node_op]["kernel_path"]["kernel_list"]
            node_kernels = [node_kernel for node_kernel in node_kernels if node_kernel]
            new_kernels = []
            for node_kernel in node_kernels:
                if node_kernel not in current_kernel_list:
                    new_kernels.append(node_kernel)

            if is_target_procyon:
                # check for intesection of pm_list for all kernels of the node
                pm_list_for_node = []
                for node_kernel in node_kernels:
                    if not pm_list_for_node:
                        pm_list_for_node.extend(kernel_pm_dict[node_kernel])
                    else:
                        pm_list_for_node = get_intersection(
                            pm_list_for_node, kernel_pm_dict[node_kernel]
                        )

                # if all kernels of the node don't have any common pm bin then restore their pm_bin list to original
                if not pm_list_for_node:
                    for node_kernel in node_kernels:
                        kernel_pm_dict[node_kernel] = pristine_kernel_pm_dict[
                            node_kernel
                        ]

            # getting the pm list of the current node
            for node_kernel in node_kernels:
                if not new_pm_list:
                    new_pm_list.extend(kernel_pm_dict[node_kernel])
                else:
                    new_pm_list = get_intersection(
                        new_pm_list, kernel_pm_dict[node_kernel]
                    )

        # pm swap algorithm
        if not current_pm_list:
            common_pm_list = new_pm_list.copy()
        else:
            common_pm_list = get_intersection(current_pm_list, new_pm_list)
        if not common_pm_list:
            output_kernel_list.append(current_kernel_list)
            output_pm_list.append(current_pm_list[0])
            node_idx_pm_list.append((start, i))
            start = i
            current_pm_list = new_pm_list.copy()
            current_kernel_list = node_kernels.copy()
        else:
            current_pm_list = common_pm_list.copy()
            current_kernel_list.extend(new_kernels)
    output_kernel_list.append(current_kernel_list)
    if current_pm_list:
        output_pm_list.append(current_pm_list[0])
    node_idx_pm_list.append((start, len(model_nodes)))

    # adding pm_id attribute to nodes
    node_name_to_pm_id = {}  # Track node names to pm_id for JSON update
    for i in range(len(node_idx_pm_list)):
        start, end = node_idx_pm_list[i]
        for j in range(start, end):
            node = model_nodes[j]
            node_op = node.op_type
            for attr in node.attribute:
                if attr.name == "kernel_op_type":
                    node_op = attr.s.decode("utf-8")
            if node_op in kernel_dict:
                node.attribute.append(make_attribute("pm_id", output_pm_list[i]))
                node_name_to_pm_id[node.name] = output_pm_list[i]
            elif (
                prebuilt_mladf_mha
                and "MHA_3p0_1col_Transpose_qdq_" in node_op
            ):
                node.attribute.append(make_attribute("pm_id", output_pm_list[i]))
                node_name_to_pm_id[node.name] = output_pm_list[i]
            elif "noop" in node_op or "runtime" in node_op:
                node.attribute.append(make_attribute("pm_id", output_pm_list[i]))
                node_name_to_pm_id[node.name] = output_pm_list[i]

    # Save modified model if requested
    if save_model:
        onnx.save(model, model_path)

    # Update IR JSON file if requested
    if IR_json_file is not None:
        json_path = IR_json_file
        if os.path.exists(json_path):
            with open(json_path, "r") as f:
                ir_data = json.load(f)

            for node_name, node_data in ir_data.items():
                if isinstance(node_data, dict):
                    # Check if this node has a pm_id assigned
                    if node_name in node_name_to_pm_id:
                        pm_id = node_name_to_pm_id[node_name]

                        # Add pm_id to attributes
                        if "attributes" not in node_data:
                            node_data["attributes"] = {}
                        node_data["attributes"]["pm_id"] = [pm_id]

            # Save updated JSON
            with open(json_path, "w") as f:
                json.dump(ir_data, f, indent=4)

    parent_dir = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    )


def main(args):
    model_path = args["model_path"]
    load_data = int(args["load_data"])
    fast_pm_enable = True
    prebuilt_mladf_mha = False
    is_target_procyon = False

    # load model
    model = onnx.load_model(model_path, load_external_data=load_data)

    static_pm_partition_graph(
        model, model_path, fast_pm_enable, prebuilt_mladf_mha, is_target_procyon
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-d",
        "--debug",
        help="Print lots of debugging statements",
        action="store_const",
        dest="loglevel",
        const=logging.DEBUG,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        help="path to onnx model and output destination.Required Field",
    )
    parser.add_argument(
        "-ld",
        "--load_data",
        help="path to additional model data file for large models. Optional Field. Default value = 0",
        default="0",
    )

    args = parser.parse_args()
    if not args.model_path:
        parser.error(
            "Please pass path/to/onnx/model using -mp or --model_path flags.\npython3 parse_onnx_model.py --help\n\t\t\tfor further info."
        )
    logging.basicConfig(level=args.loglevel)
    logging.debug("Debug mode is enabled!")

    main(vars(args))
