import os
import ctypes
import json
import argparse
import logging
import shutil
import onnx
import graph_surgery
import add_onnx_tensor_shapes
from io import BytesIO
from pathlib import Path

from OGOAT.src.L1_fusion.L1_utils.model_IR_utils import (
    get_unique_nodes_wrt_shapes_dtypes_attrs,
)
from OGOAT.src.utils.context import Logger, Context
from OGOAT.src.L1_fusion.L1_utils.safe_runner import SafeRunner
from OGOAT.src.L1_fusion.L1_utils.utils import remove_additional_attributes, save_model
from OGOAT.src.L1_fusion.parse_onnx_model import ParseOnnxModel
from OGOAT.src.L1_fusion.L1_utils.utils import find_invalid_graph_output

default_shape_params_values = os.path.join(
    os.path.dirname(__file__), "default_shape_params_values.yml"
)


def uses_external_data(file_name: str) -> tuple[bool, onnx.ModelProto]:
    """
    Returns True if an initializer has the external data flag set
    """
    model = onnx.load(file_name, load_external_data=False)
    return (
        any(
            t.data_location == onnx.TensorProto.EXTERNAL
            for t in model.graph.initializer
        ),
        model,
    )


def get_external_data_locations(file_name: str) -> set[str]:
    locations: set[str] = set()
    model = onnx.load(file_name, load_external_data=False)
    return set(
        [
            entry.value
            for t in model.graph.initializer
            for entry in t.external_data
            if t.data_location == onnx.TensorProto.EXTERNAL and entry.key == "location"
        ]
    )


def as_protobuf(model) -> bytes:
    f = BytesIO()
    onnx.save_model(model, f)
    return f.getvalue()


def vaiml_compile(
    model_path: str,
    output_dir: str,
    vaiml_shared_lib_path: str,
    optimize_level: int,
    prebuilt_mladf_mha: bool,
):
    if not os.path.exists(vaiml_shared_lib_path):
        raise RuntimeError(
            f"Could not find vaiml library at the provided path {vaiml_shared_lib_path}"
        )

    # Load the shared library and use it via the c interface
    flexml_lib = ctypes.cdll.LoadLibrary(vaiml_shared_lib_path)

    # Load the onnx model as bytes object and extract its bytes size
    # Don't load external data when calling vaiml, it loads on it's own
    model = onnx.load_model(model_path, load_external_data=False)
    protobuf_model = as_protobuf(model)
    onnx_protobuf_size = len(protobuf_model)

    # Create the context instance needed by the vaiml interface
    flexml_lib.vaiml_context_create.restype = ctypes.POINTER(ctypes.c_void_p)
    context = flexml_lib.vaiml_context_create()

    # Encode the strings into bytes
    onnx_external_data_dir = os.path.dirname(model_path).encode(encoding="utf-8")
    output_dir = output_dir.encode(encoding="utf-8")

    # Config dictionanry selects the DMAC recipe and ask the front-end to produce a frontend-mlir
    # result only
    config = json.dumps(
        {
            "recipes": "DMAC",
            "keep_outputs": True,
            "output_type": "frontend-mlir",
            "optimize_level": optimize_level,
            "prebuilt_mladf_mha": prebuilt_mladf_mha,
        }
    ).encode(encoding="utf-8")

    # Set the type conversion function that needs to be called for each parameters
    flexml_lib.vaiml_compile_v3.argtypes = [
        ctypes.c_void_p,  # context
        ctypes.c_char_p,  # protobuf_model
        ctypes.c_int,  # onnx_protobuf_size
        ctypes.c_char_p,  # onnx_external_data_dir
        ctypes.c_char_p,  # output_dir
        ctypes.c_char_p,  # config
        ctypes.c_int,  # ai_analyzer_profiling
        ctypes.c_int,  # ai_analyzer_visualization
        ctypes.c_char_p,  # logging_level
    ]

    # call the compile interface
    flexml_lib.vaiml_compile_v3(
        context,
        protobuf_model,
        onnx_protobuf_size,
        onnx_external_data_dir,
        output_dir,
        config,
        False,
        False,
        b"info",
    )

    # Free the context
    flexml_lib.vaiml_context_destroy(context)


def run_VAIML_fusion(args):
    assert args.cpp_fe, "path to the vaiml shared library is required"

    print(
        "VAIML front end selected from TA, start compilation from flexmlcompile interface"
    )

    vaiml_compile(
        args.model_path,
        args.output_dir,
        args.cpp_fe,
        int(args.optimization_level),
        bool(args.prebuilt_mladf_mha),
    )

    # FIXME: Currently the cpp F.E produces an empty directory named '0' that is
    # that is picked up by the M.E and which causes a crash. For now we remove
    # it if it exist and is empty. A ticket is open to fix the F.E and not generate
    # it when 'recipe="DMAC"'.
    empty_dir = os.path.join(args.output_dir, "0")
    if os.path.exists(empty_dir) and len(os.listdir(empty_dir)) == 0:
        os.removedirs(empty_dir)

    # Find the json file generated by the vaiml compile command and rename it
    # in order for the Tiler to pick it up
    json_result_path = os.path.join(args.output_dir, "unique_nodes.json")
    if not os.path.exists(json_result_path):
        raise RuntimeError(
            f"Could not find the unique nodes json in path {json_result_path}"
        )

    model_name_without_ext = os.path.splitext(os.path.basename(args.model_path))[0]
    expected_json_result_name = (
        model_name_without_ext + "_mod_nhwc_fused_IR_unique_nodes.json"
    )
    expected_json_result_path = os.path.join(args.output_dir, expected_json_result_name)

    print(f"Renaming {json_result_path} into {expected_json_result_path}")
    os.rename(json_result_path, expected_json_result_path)


def main(args):

    output_dir = args.get("output_dir")
    model_path = args.get("model_path")
    model_name_without_ext = os.path.splitext(os.path.basename(model_path))[0]
    use_external_data, model = uses_external_data(model_path)
    invalid_graph_output = find_invalid_graph_output(model)
    if invalid_graph_output:
        raise RuntimeError(
            f"model {model_path} is not valid because some graph global output {invalid_graph_output} are not produced by any node."
        )
    args["load_data"] = use_external_data
    if args["load_data"]:
        data_locations = get_external_data_locations(model_path)
        for location in data_locations:
            shutil.copy(Path(model_path).parent / location, output_dir)

    context = Context(output_dir=output_dir, debug=args.get("debug", False))
    logger = Logger(name="L1_fusion", context=context)
    logger.info("Start L1 fusion stage")

    runner = SafeRunner(
        logger=logger,
        output_dir_path=output_dir,
        summary_file_name="fusion_error_summary.txt",
    )

    # FIXME: ideally the shape inference should run in the L1 Fusion class,
    # just before the const cleanup
    # infer shapes
    if "1.0" not in args.get("skip_step", []):
        runner.run(add_onnx_tensor_shapes.add_model_info, args)
        # add_onnx_tensor_shapes.add_model_info(args)
        if runner.has_failed:
            logger.warning(
                "Error in the shape inference. Continuing fusion without it."
            )
            if not os.path.exists(model_path.replace(".onnx", "_mod.onnx")):
                shutil.copy(
                    model_path,
                    os.path.join(output_dir, model_name_without_ext + "_mod.onnx"),
                )
            else:
                shutil.copy(model_path.replace(".onnx", "_mod.onnx"), output_dir)
    else:
        logger.info(
            f"Skipping L1 shape inference, appending _mod.onnx to model path for L1 fusion"
        )
        if not os.path.exists(model_path.replace(".onnx", "_mod.onnx")):
            raise FileNotFoundError(
                f"ONNX model file '{model_path.replace('.onnx', '_mod.onnx')}' does not exist."
            )
        shutil.copy(model_path.replace(".onnx", "_mod.onnx"), output_dir)
        if args.get("load_data"):
            if not os.path.exists(model_path.replace(".onnx", "_mod.onnx.data")):
                raise FileNotFoundError(
                    f"Data file '{model_path.replace('.onnx', '_mod.onnx.data')}' does not exist."
                )
            shutil.copy(model_path.replace(".onnx", "_mod.onnx.data"), output_dir)

    # graph surgery
    gs_args = args.copy()
    gs_args["model_name"] = model_name_without_ext
    gs_args["model_path"] = os.path.join(
        output_dir, f"{model_name_without_ext}_mod.onnx"
    )
    fused_model = graph_surgery.main(logger, runner, **gs_args)

    # FIXME: unnecessary load of the fused graph. It was created in the function graph_surgery main function
    # we should return it and use it here.
    # parse fused graph
    ps_args = args.copy()
    fused_model_path = gs_args["model_path"].replace(".onnx", "_nhwc_fused.onnx")
    parse_onnx_model = ParseOnnxModel(ps_args, logger, fused_model, runner)
    model_IR_path = fused_model_path.replace(".onnx", "_IR.json")
    runner.run(parse_onnx_model.parse_fused_model_to_ir, model_IR_path)
    if len(runner.errors_occured) != 0:
        logger.info(
            f"Some errors occured during fusion. For more information you can look at the summary file {runner.summary_file_path}"
        )
        runner.dump_error_summary()

    # generate unique IR json
    get_unique_nodes_wrt_shapes_dtypes_attrs(model_IR_path)

    for node in fused_model.graph.node:
        for attr in node.attribute[:]:
            if attr.name == "pm_id":
                node.attribute.remove(attr)

    if os.path.exists(fused_model_path):
        os.remove(fused_model_path)
    remove_additional_attributes(fused_model)
    # save final model
    save_model(fused_model, fused_model_path, args["load_data"])

    # collatoeral_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'Collaterals')
    # fast_pm_enable = args.get('fast_pm')
    # parse_onnx_model.get_kernels_for_model(model_IR_path, collatoeral_path, kernel_func_list, print_to_file = True, fast_pm=fast_pm_enable)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="L1 fusion",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )

    # required knobs
    parser.add_argument(
        "-mp",
        "--model_path",
        help="path to onnx model and output destination.Required Field",
    )
    # debug/profile knobs
    parser.add_argument(
        "-skip",
        "--skip_step",
        nargs="+",
        default=[],
        action="extend",
        help="Skip WAIC step, none, 1.0: skip L1 shape inference, 2+: skip everything after L1 fusion",
    )
    parser.add_argument(
        "-d",
        "--debug",
        help="Print lots of debugging statements",
        action="store_const",
        dest="loglevel",
        const=logging.DEBUG,
    )
    parser.add_argument(
        "-df",
        "--debug_file_name",
        help="Debug log file name",
        default="dbg_log.txt",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        choices=["debug", "info", "error"],
        help="Verbosity for debug logs",
        default="debug",
    )
    parser.add_argument(
        "-in",
        "--input_names",
        required=False,
        nargs="+",
        help="Names of inputs if model has dynamic shape inputs. Optional Field. Default value = ''",
        default="",
    )
    parser.add_argument(
        "-dims",
        "--input_dims",
        required=False,
        nargs="+",
        help="Shapes of inputs if model has dynamic shape inputs. Optional Field. Default value = ''",
        default="",
    )
    parser.add_argument(
        "-method",
        "--shape_infer_method",
        help="method to get tensor shapes from the model. Optional Field. Default value = 'onnx_shape_infer'",
        default="onnx_shape_infer",
    )
    parser.add_argument(
        "-new_dtypes",
        "--assign_new_dtypes",
        help="flag to set when assigning datatypes based on a heuristic for AIE. Optional Field. Default value = 0",
        default="0",
    )
    parser.add_argument(
        "-act_dtype_low",
        "--low_precision_act_dtype",
        help="low precision activation dtype for tensor datatype assignment. Optional Field. Default value = 'uint16'",
        default="uint16",
    )
    parser.add_argument(
        "-act_dtype_high",
        "--high_precision_act_dtype",
        help="high precision activation dtype for tensor datatype assignment. Optional Field. Default value = 'uint16'",
        default="uint16",
    )
    parser.add_argument(
        "-wgt_dtype_low",
        "--low_precision_wgt_dtype",
        help="low precision weights dtype for tensor datatype assignment. Optional Field. Default value = 'uint8'",
        default="uint8",
    )
    parser.add_argument(
        "-wgt_dtype_high",
        "--high_precision_wgt_dtype",
        help="high precision weights dtype for tensor datatype assignment. Optional Field. Default value = 'uint16'",
        default="uint16",
    )
    parser.add_argument(
        "-shape_params",
        "--in_shape_params",
        required=False,
        type=str,
        help="Dynamic shape parameters for inputs as a JSON string. Optional Field. Default value = '{}'",
        default="{}",
    )
    parser.add_argument(
        "--fixed_input_values",
        required=False,
        type=str,
        help="Fixed input values to the neural network. JSON syntax: input name -> value. Optional Field. Default value = '{}'",
        default="{}",
    )
    parser.add_argument(
        "--default_shape_params_values",
        required=False,
        type=str,
        help=f"YML file specifying default shape parameters and graph input values. Default: {default_shape_params_values}",
        default=default_shape_params_values,
    )
    parser.add_argument(
        "--qdq_optimization",
        type=int,
        choices=[0, 1],
        default=0,
        help="Enable QDQ optimization at end of L1 fusion. Default is 0 (disabled).",
    )
    parser.add_argument(
        "--qdq_int16_cleanup",
        type=int,
        choices=[0, 1],
        default=1,
        help="Enable int16 QDQ cleanup pass before NHWC conversion (AIESW-18212). Removes int16 QDQ pairs adjacent to int8 QDQ. Default is 1 (enabled).",
    )

    args = parser.parse_args()
    if not args.model_path:
        parser.error(
            "Please pass path/to/onnx/model using -mp or --model_path flags.\npython3 parse_onnx_model.py --help\n\t\t\tfor further info."
        )
    logging.basicConfig(level=args.loglevel)
    logging.debug("Debug mode is enabled!")
    # Run with out profiling
    main(vars(args))
