import sys
import os
import subprocess
import argparse
import json
import onnx
import re
from OGOAT.src.Ort.onnx_graph_partitioner import get_graph_channels

red_code = "\033[91m"
yellow_code = "\033[93m"
reset_code = "\033[0m"


def check_input_output_in_orig_onnx(fld, original_model):
    #check the input and output from config (context.json) name matches with the original onnx graph
    context = os.path.join(fld, "cut_graphs/context.json")
    with open(context) as f:
        io_node_dict = json.load(f)
    orig_model = onnx.load(original_model)
    orig_channels = get_graph_channels(orig_model)
    fail = False
    for subgraph, graph in io_node_dict.items():
        input_channels = graph["inputs"]
        output_channels = graph["outputs"]
        for input in input_channels:
            if input not in orig_channels:
                fail = True
                print(f"{yellow_code} Cannot find", input, "in original onnx model")
        for output in output_channels:
            if output not in orig_channels:
                fail = True
                print(f"{yellow_code} Cannot find", output, "in original onnx graph")

    if not fail:
        print("PASSED SUBGRAPH I/O CHECK")
    else:
        print("FAILED SUBGRAPH I/O CHECK")

def check_dupe_merge_dicts(dict_a, dict_b):
    common_keys = set(dict_a.keys()) & set(dict_b.keys())
    if common_keys:
        print(f"\033[91mWarning: Common nodes are present in subgraphs: {common_keys}\033[00m")
    else:
        merged_dict = {**dict_a, **dict_b}
        return merged_dict

def check_dupe_and_leftover_ops(fld, fused_model_f):
    sub_dir = os.path.join(fld, "cut_graphs")
    fused_model = onnx.load(fused_model_f)
    file_list = os.listdir(sub_dir)

    all_nodes = [node.name for node in fused_model.graph.node]

    pattern = r'.*\.onnx'
    # Collect all the nodes from all the cut subgraphs and sanity check there are no duplicating nodes among cut subgraphs
    merged_dict = {}
    for f in file_list:
        if re.match(pattern, f) and f.endswith(".onnx"):
            subgraph_onnx = os.path.join(fld, "cut_graphs/"+f)
            # with open(subgraph_onnx, 'r') as file:
            #     subgraph_dict = json.load(file)
            onnx_model = onnx.load(subgraph_onnx)
            subgraph_dict = {}
            for node in onnx_model.graph.node:
                node_name = node.name
                subgraph_dict[node_name] = node  # we can choose to store any value related to the node.
            temp_merged_dict = check_dupe_merge_dicts(merged_dict, subgraph_dict)
            if temp_merged_dict:
                merged_dict = temp_merged_dict
            else:
                print(f"{red_code}Couldn't merge the dictionary from {subgraph_onnx} due to node conflicts.{reset_code}")
                break

    print("FINISHED SUBGRAPH OVERLAPPING CHECK. THE FOLLOWING LEFTOVER NODES WERE NOT LABELED SKIP_OP IN THE GRAPH PARTITIONER")

    skipped_nodes = os.path.join(fld, "skipped_nodes.json")
    skipped_nodes_dict ={}
    with open(skipped_nodes) as f:
        skipped_nodes_dict = json.load(f)

    failed = False
    # Check that all the nodes {that are in the original fused graph but are not included in any subgraphs} are CPU ops
    for node in all_nodes:
        if (node not in merged_dict.keys()): # nodes that do not appear in any subgraphs
            if (node not in skipped_nodes_dict.keys()): #nodes that were not labeled skip_op in partitioner
                print(f"{yellow_code}{node} is not included in any of the AIE subgraphs even though it was not labeled skipped.{reset_code}")
                failed = True

    if failed:
        print("FAILED CPU NODE CHECK")
    else:
        print("PASSED CPU NODE CHECK")


def main(args):
    if args.out_dir is not None:
        fld = args.out_dir
    else:
        fld = os.path.dirname(args.model_path)
    check_input_output_in_orig_onnx(fld, args.orig_model_path)

    check_dupe_and_leftover_ops(fld, (args.model_path))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Automated cutter of onnx graph into NPU and CPU chunks",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        required=True,
        help="Path to fused onnx model",
    )
    parser.add_argument(
        "-omp",
        "--orig_model_path",
        required=True,
        help="Path to original onnx model",
    )

    parser.add_argument(
        "-c",
        "--context",
        required=False,
        help="path to context.json file"
    )
    parser.add_argument(
        "--out_dir",
        required=False,
        help="Directory to store cut graphs",
    )
    args = parser.parse_args()

    main(args)
