import onnx
import sys
import os
import json

import argparse
import logging
from L1_utils.model_IR_utils import (
    get_unique_nodes_wrt_shapes_dtypes_attrs,
    get_kernels_for_model,
)
from OGOAT.src.L1_fusion.L1_utils.safe_runner import SafeRunner
from OGOAT.src.L1_fusion.graph_info import GraphInfo
from OGOAT.src.utils.context import Context, Logger
from kernel_func_list import kernel_func_list
from OGOAT.src.L1_fusion.graph_info_utils import (
    GraphInfoParams,
    get_top_level_details_from_model,
)
from OGOAT.src.L1_fusion.L1_utils.utils import remove_additional_attributes


class ParseOnnxModel:
    def __init__(self, main_params, logger: Logger, model: onnx.ModelProto, runner: SafeRunner):
        self.runner = runner
        self.logger = logger
        self.graph_info_params = GraphInfoParams(
            assign_new_dtypes=int(main_params["assign_new_dtypes"]),
            low_precision_act_dtype=str(main_params["low_precision_act_dtype"]),
            high_precision_act_dtype=str(
                main_params["high_precision_act_dtype"]
            ),
            low_precision_wgt_dtype=str(main_params["low_precision_wgt_dtype"]),
            high_precision_wgt_dtype=str(
                main_params["high_precision_wgt_dtype"]
            ),
            no_dtype_downcast=bool(main_params["no_dtype_downcast"]),
            device=str(main_params["device"])
        )
        # user configurations
        self.model = model
        # onnx.checker.check_model(model)
        # onnx.checker.check_model(model_path) #if model>2GB

        # get basic details of the model
        all_ops, total_nodes, g_inputs, g_outputs = (
            get_top_level_details_from_model(self.model)
        )
        print("Total number of nodes:", total_nodes)
        print("Operators in the model: ", all_ops)
        # print("Inputs of the model: ", g_inputs)
        # print("Outputs of the model: ", g_outputs)

    def parse_fused_model_to_ir(self, model_IR_path):
        """
        Capture node inputs and outputs and iterate over nodes and Initialize
        required values for layer object"""
        graph_info = GraphInfo(self.model, self.graph_info_params, self.runner)
        graph_ = self.runner.run(graph_info.get_graph_info)
        if graph_:
            # save the graph representation in json file
            with open(model_IR_path, "w") as f:
                # Change the standard output to the file we created.
                sys.stdout = f
                json.dump(graph_, f, indent=4)
                # original_stdout. Reset the standard output to its original value
                sys.stdout = sys.__stdout__

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-d",
        "--debug",
        help="Print lots of debugging statements",
        action="store_const",
        dest="loglevel",
        const=logging.DEBUG,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        help="path to onnx model and output destination.Required Field",
    )
    parser.add_argument(
        "-ld",
        "--load_data",
        help="path to additional model data file for large models. Optional Field. Default value = 0",
        default="0",
    )

    parser.add_argument(
        "-new_dtypes",
        "--assign_new_dtypes",
        help="flag to set when assigning datatypes based on a heuristic for AIE. Optional Field. Default value = 0",
        default="0",
    )
    parser.add_argument(
        "-act_dtype_low",
        "--low_precision_act_dtype",
        help="low precision activation dtype for tensor datatype assignment. Optional Field. Default value = 'uint16'",
        default="uint16",
    )
    parser.add_argument(
        "-act_dtype_high",
        "--high_precision_act_dtype",
        help="high precision activation dtype for tensor datatype assignment. Optional Field. Default value = 'uint16'",
        default="uint16",
    )
    parser.add_argument(
        "-wgt_dtype_low",
        "--low_precision_wgt_dtype",
        help="low precision weights dtype for tensor datatype assignment. Optional Field. Default value = 'uint8'",
        default="uint8",
    )
    parser.add_argument(
        "-wgt_dtype_high",
        "--high_precision_wgt_dtype",
        help="high precision weights dtype for tensor datatype assignment. Optional Field. Default value = 'uint16'",
        default="uint16",
    )

    parser.add_argument(
        "-in",
        "--input_names",
        required=False,
        nargs="+",
        help="Names of inputs if model has dynamic shape inputs. Optional Field. Default value = ''",
        default="",
    )
    parser.add_argument(
        "-dims",
        "--input_dims",
        required=False,
        nargs="+",
        help="Shapes of inputs if model has dynamic shape inputs. Optional Field. Default value = ''",
        default="",
    )
    parser.add_argument(
        "-shape_params",
        "--in_shape_params",
        required=False,
        type=str,
        help="Dynamic shape parameters for inputs. Optional Field. Default value = '{}'",
        default="{}",
    )
    parser.add_argument(
        "-no_dtype_downcast",
        "--no_dtype_downcast",
        action="store_true",
        help="Disable dtype downcasting during L1 fusion",
    )
    parser.add_argument(
        "-d",
        "--device",
        help="Name of device to run; e.g., strix or phoenix etc. Default = 'strix'",
        choices=["strix", "med", "swv"],
        default="strix",
    )

    args = parser.parse_args()
    if not args.model_path:
        parser.error(
            "Please pass path/to/onnx/model using -mp or --model_path flags.\npython3 parse_onnx_model.py --help\n\t\t\tfor further info."
        )
    logging.basicConfig(level=args.loglevel)
    logging.debug("Debug mode is enabled!")
    main_params = vars(args)
    context = Context(output_dir="WAIC_Outputs", debug=args.get("debug", False))
    logger = Logger(name="L1_fusion", context=context)
    logger.info("Start L1 fusion stage")
    runner = SafeRunner(
        logger=logger,
        output_dir_path="WAIC_Outputs",
        summary_file_name="fusion_error_summary.txt",
    )
    load_data = int(main_params["load_data"])
    model = onnx.load_model(
        args.model_path, load_external_data=load_data
    )
    parse_onnx_model = ParseOnnxModel(main_params, logger, model, runner)
    model_IR_path = os.path.splitext(args.model_path)[0] + "_IR.json"
    runner.run(parse_onnx_model.parse_fused_model_to_ir, model_IR_path)
    if len(runner.errors_occured) != 0:
        logger.info(
            f"Some errors occured during fusion. For more information you can look at the summary file {runner.summary_file_path}"
        )
        runner.dump_error_summary()

    # generate uniqge nodes
    get_unique_nodes_wrt_shapes_dtypes_attrs(model_IR_path)
    get_kernels_for_model(
        model_IR_path,
        "../../Collaterals/",
        kernel_func_list,
        print_to_file=True,
    )

    ## TODO: move these functions out of first pass of parser> shape format change, Add op type change, layerNorm fusion type
