import argparse
import logging
import onnx
from onnx.checker import check_model

from OGOAT.src.L1_fusion.L1_utils.utils import (
    construct_dict,
)

def graph_sort(model, order=0):

    ini_dict, nodes_dict, in_nodes_dict, out_nodes_dict, value_info_dict, input_names, output_names = construct_dict(model)

    #remove all nodes from graph
    for node in nodes_dict:
        model.graph.node.remove(nodes_dict[node])

    #find output
    outnames = list(output_names.keys())
    last_node_list = []
    for node in nodes_dict:
        for outname in outnames:
            if outname in nodes_dict[node].output:
                flag = all(elem in last_node_list for elem in out_nodes_dict[node])
                if flag:
                    last_node_list.append(node)

    node_list = last_node_list
    #start backward node sort
    while len(last_node_list) > 0:
        nx_node_list = []
        for node in last_node_list:
            for in_node in in_nodes_dict[node]:
                if in_node not in node_list:
                    #check if node outputs are already in list, skip if that's not the case
                    flag = all(elem in node_list for elem in out_nodes_dict[in_node])
                    if flag:
                        if in_node not in nx_node_list:
                            nx_node_list.append(in_node)
                        if nx_node_list[0] == 'LayerNormalization_fused_ReduceMean_0LayerNorm_uint16_cstm':
                            a=0
        last_node_list = nx_node_list
        node_list.extend(nx_node_list)

    #add nodes_list to model.
    if order == 0: #topo order
        for node_name in node_list[::-1]:
            model.graph.node.append(nodes_dict[node_name])
    else:
        for node_name in node_list:
            model.graph.node.append(nodes_dict[node_name])

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--debug", help="Print lots of debugging statements", action="store_const", dest="loglevel", const=logging.DEBUG)
    parser.add_argument("-mp", "--model_path", help="path to onnx model and output destination.Required Field")

    args = parser.parse_args()
    if not args.model_path:
        parser.error("Please pass path/to/onnx/model using -mp or --model_path flags.\npython3 parse_onnx_model.py --help\n\t\t\tfor further info.")
    logging.basicConfig(level=args.loglevel)
    logging.debug("Debug mode is enabled!")

    model_path = args.model_path
    out_model_path = model_path[:-5] + '_sorted.onnx'

    #load model
    model = onnx.load(model_path)

    graph_sort(model, 0) #order 0: topo order, 1: reserve order

    check_model(model)

    #save model
    onnx.save(model, out_model_path)
