# fmt: on
import dataclasses
import os

import onnx

from onnx.helper import (
    make_opsetid,
)

from OGOAT.src.L1_fusion.py_match.nodes_tensors import OutputTensor
from OGOAT.src.L1_fusion.qdq_tagging import QdqTagging
from OGOAT.src.L1_fusion.graph_partitioning import partition_graph
from OGOAT.src.L1_fusion.static_pm_bin_selection import static_pm_partition_graph
from OGOAT.src.L1_fusion.topo_sort import graph_sort
from OGOAT.src.L1_fusion.L1_utils.ops_definition_utils import OnnxOpsWrapper
from OGOAT.src.L1_fusion.L1_utils.utils import save_model
from OGOAT.src.L1_fusion.py_match.const_eval import replace_const, remove_constants
from OGOAT.src.L1_fusion.py_match.helpers.fusion_configs import (
    FusionArguments,
    FusionConfigs,
)
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict
from OGOAT.src.L1_fusion.py_match.py_fuse import (
    Fuser,
    get_py_matcher,
)
from OGOAT.src.utils.context import Logger
from OGOAT.src.L1_fusion.L1_utils.utils import save_model
from OGOAT.src.L1_fusion.L1_utils.safe_runner import SafeRunner
from OGOAT.src.L1_fusion.py_match.fusion_frozen import (
    find_dtype_frozen_nodes,
    tag_l1_fusion_frozen,
)


class Fusion_L1:

    def __init__(
        self,
        fusionArgs: FusionArguments,
        GRAPH_SURGERY_SEQ,
        MODULE_OPSET,
        logger: Logger,
        runner: SafeRunner,
    ):
        # check if any patterns are hooked ('pattern_name*'), and removes '*'
        self.dbg_hook = FusionDebugHook.from_surgery_seq(GRAPH_SURGERY_SEQ).create(
            fusionArgs.model_name, fusionArgs.out_dir_path, fusionArgs.external_data
        )

        self.fusionArgs = fusionArgs
        self.GRAPH_SURGERY_SEQ = GRAPH_SURGERY_SEQ
        self.MODULE_OPSET = MODULE_OPSET
        self.onnx_ops_wrapper = OnnxOpsWrapper()
        self.logger = logger

        # runner that will catch and store any exception raised in the function that is ran
        self._runner = runner

        # Path to the last valid known onnx model. Used to recover in case of failure
        self._last_valid_model_path: str = fusionArgs.model_path

    def _save_model(self, model: onnx.ModelProto) -> None:
        out_model_path = self.fusionArgs.out_model_path
        self._last_valid_model_path = out_model_path
        save_model(model, out_model_path, self.fusionArgs.external_data)

    def run_fusion(self, model: onnx.ModelProto) -> onnx.ModelProto:
        # constant evaluation and removal
        model = self.graph_const_cleanup(model)

        # QDQ int16 cleanup pass (before NHWC conversion)
        model = self.qdq_int16_cleanup_pass(model)

        # nhwc conversion
        model = self.nhwc_layout_conversion(model)

        # collect nodes for dtype freezing
        if not self.fusionArgs.no_dtype_freeze:
            model = self.find_dtype_frozen_nodes(model)

        # L1 fusion frozen
        if FusionConfigs.get_fusion_configs().keep_border_qdq:
            model = self.run_l1_fusion_frozen(model)
        # L1 fusion
        model = self.graph_surgery(model)

        # Post fusion
        self.post_fusion(model)

        if len(self._runner.errors_occured) != 0:
            self.logger.info(
                f"Some errors occured during fusion. For more information you can look at the summary file {self._runner.summary_file_path}"
            )
            self._runner.dump_error_summary()
        return model

    def post_fusion(self, model: onnx.ModelProto):
        # Post fusion is the last step so there is no need
        # to load the last valid graph from disk to continue fusion
        self._runner.run(self._post_fusion, model)
        if self._runner.has_failed:
            self.logger.warning("Error found when running post fusion")

    def _post_fusion(self, model: onnx.ModelProto):
        self.logger.info("Start post fusion:")
        graph_sort(model, 0)  # order 0: topo order, 1: reserve order

        # call qdq tagging function
        qdq_tagging_obj = QdqTagging(model, self.fusionArgs.qdq_optimization)
        qdq_tagging_obj.tag_qdq_nodes()

        # call graph_partitioning function

        if self.fusionArgs.old_fusion_flow:
            self.logger.info("Running legacy partition graph")
            partition_graph(
                model, self.fusionArgs.out_model_path, self.fusionArgs.fast_pm_enable
            )
        elif self.fusionArgs.assign_pmid_before_partition:
            self.logger.info("Running static pm partition graph")
            is_target_procyon = self.fusionArgs.target == "procyon"
            static_pm_partition_graph(
                model,
                self.fusionArgs.out_model_path,
                self.fusionArgs.fast_pm_enable,
                self.fusionArgs.prebuilt_mladf_mha,
                is_target_procyon,
            )
        # NOTE: PM bin allocation is now done per-subgraph in WAIC_runtime.py by default

        # convert int4 inits to int8
        # if fargs.inits_int4_to_int8:
        #    convert_int4_inits_to_int8(model)

        # clean up quantization annotation
        model.graph.quantization_annotation.clear()
        self._save_model(model)

    def nhwc_layout_conversion(self, model: onnx.ModelProto):
        self._runner.run(self._nhwc_layout_conversion, model)
        if self._runner.has_failed:
            self.logger.warning(
                f"Error found when running NHWC layout conversion: Recovering from last known valid graph: '{self._last_valid_model_path}'"
            )
            return self.load_last_valid_model()

        return model

    def _nhwc_layout_conversion(self, model: onnx.ModelProto):
        """transpose clean up and tagging"""
        from OGOAT.src.L1_fusion.convert_to_nhwc import NHWCLayoutConverter

        self.logger.info(f"Begin NHWC conversion")
        self.fusionArgs.out_model_suffix += "_nhwc"
        self._nhwc_sanity_check(model)

        layout_converter = NHWCLayoutConverter(
            model,
            self.fusionArgs.model_name,
            self.fusionArgs.out_dir_path,
            self.fusionArgs.external_data,
            dbg=self.fusionArgs.debug,
        )

        layout_converter.pre_fusion()
        layout_converter.layout_tagging_and_shielding()
        self._nhwc_sanity_check(model)
        self._save_model(model)

        layout_converter.convert_layout_to_nhwc()
        self._nhwc_sanity_check(model)
        self._save_model(model)

    def _nhwc_sanity_check(self, model: onnx.ModelProto) -> None:
        """
        Perform a sanity check on the model after an NHWC conversion step.
        This meake sure the SafeRunner recovery mechanism does not accidentally
        receive a non-working graph from NHWC conversion as valid graph for
        recovery.
        """
        ModelDict(model, self.onnx_ops_wrapper).sanity_check(self.logger)

    def load_last_valid_model(self):
        return onnx.load_model(
            self._last_valid_model_path,
            load_external_data=self.fusionArgs.external_data,
        )

    def graph_const_cleanup(self, model):
        """
        Run the constant cleanup for the given graph. In case of failure, save
        the error and return the last valid known graph for fusion to continue.
        """
        self._runner.run(self._graph_const_cleanup, model)
        if self._runner.has_failed:
            self.logger.warning(
                f"Constant cleanup failed. Recovering from last known valid graph: {self._last_valid_model_path}"
            )
            return self.load_last_valid_model()
        return model

    def _graph_const_cleanup(self, model):
        self.logger.info("Start constant cleanup:")
        remove_constants(model)
        replace_const(model, self.onnx_ops_wrapper, self.fusionArgs, self.logger)

    def qdq_int16_cleanup_pass(self, model: onnx.ModelProto):
        """
        Run the QDQ int16 cleanup pass. In case of failure, save the error
        and return the last valid known graph for fusion to continue.
        """
        self._runner.run(self._qdq_int16_cleanup_pass, model)
        if self._runner.has_failed:
            self.logger.warning(
                f"QDQ int16 cleanup failed. Recovering from last known valid graph: {self._last_valid_model_path}"
            )
            return self.load_last_valid_model()
        return model

    def _qdq_int16_cleanup_pass(self, model: onnx.ModelProto):
        from OGOAT.src.L1_fusion.qdq_cleanup import cleanup_int16_qdq

        self.logger.info("Start int16 QDQ cleanup pass for mixed precision models:")

        if self.fusionArgs.qdq_int16_cleanup:
            removed_count = cleanup_int16_qdq(model, self.logger)

            if removed_count > 0:
                self.logger.info(
                    f"Removed {removed_count} int16 QDQ pairs - model will be saved to _mod.onnx"
                )
        else:
            self.logger.info("Int16 QDQ cleanup pass disabled (--qdq_int16_cleanup 0)")

        self._save_model(model)

    def _run_l1_fusion_frozen(self, model: onnx.ModelProto):
        self.logger.info("Start L1 fusion frozen:")
        tag_l1_fusion_frozen(model, self.onnx_ops_wrapper, self.logger)

    def _find_dtype_frozen_nodes(self, model: onnx.ModelProto):
        self.logger.info("Start finding dtype frozen nodes:")
        find_dtype_frozen_nodes(model, self.onnx_ops_wrapper, self.logger)

    def run_graph_surgery_sequence(self, fuser: Fuser, model):
        """
        Run the graph surgery on the given model by applying the sequence of pattern
        requested by the user and modify the graph when they match.
        """
        self.logger.info("Start pattern matching:")

        dbg = self.dbg_hook
        fuser.set_model(model)
        for pattern in self.GRAPH_SURGERY_SEQ:
            if dbg.is_hooked(pattern):
                dbg.before_hook(pattern, model)

            matcher = get_py_matcher(pattern)
            fuser.run_one(matcher)

            if dbg.is_hooked(pattern):
                dbg.after_hook(pattern, model)

    def safe_run_graph_surgery_sequence(self, fuser: Fuser, model):
        """
        Run the graph surgery sequence on a given model.
        If a failure is caught in the process, the pattern that failed
        is disabled by the Fuser instance and we can run the graph surgery again.
        from the last valid known graph.
        """
        failure_cnt = 0
        while failure_cnt < len(self.GRAPH_SURGERY_SEQ) * fuser.max_failure_count:
            self._runner.run(self.run_graph_surgery_sequence, fuser, model)
            if not self._runner.has_failed:
                break

            failure_cnt += 1
            model = self.load_last_valid_model()
            self.logger.warning(
                f"Error found when running fusion recovering from last known valid graph: '{self._last_valid_model_path}'"
            )
            self.dbg_hook.delete_debug_files(self.logger)

        # We stopped becaused all patterns were failing, return the
        # last valid known graph
        if failure_cnt >= len(self.GRAPH_SURGERY_SEQ) * fuser.max_failure_count:
            self.logger.warning(
                "Too many failures found in graph surgery stopping it and continuing fusion."
            )
            return self.load_last_valid_model()

        return model

    def graph_surgery(self, model):
        fuser = Fuser(model, self.onnx_ops_wrapper, self.logger, self._runner)
        fused_model = self.safe_run_graph_surgery_sequence(fuser, model)

        # verify and add opset
        for domain in self.MODULE_OPSET:
            version = self.MODULE_OPSET[domain]
            fused_model.opset_import.append(make_opsetid(domain, version))
        opset_dict = {}
        for opset in fused_model.opset_import:
            if opset.domain in opset_dict:
                if opset_dict[opset.domain] != opset.version:
                    self.logger.warning(
                        f"domain={opset.domain} with two different versions: {opset.version} != {opset_dict[opset.domain]}"
                    )
            else:
                opset_dict[opset.domain] = opset.version

        self.fusionArgs.out_model_suffix += "_fused"
        return fused_model

    def run_l1_fusion_frozen(self, model: onnx.ModelProto):
        self._runner.run(self._run_l1_fusion_frozen, model)
        if self._runner.has_failed:
            self.logger.warning(
                f"Error found when running L1 fusion frozen: Recovering from last known valid graph: '{self._last_valid_model_path}'"
            )
            return self.load_last_valid_model()
        return model

    def find_dtype_frozen_nodes(self, model: onnx.ModelProto):
        self._runner.run(self._find_dtype_frozen_nodes, model)
        if self._runner.has_failed:
            self.logger.warning(
                f"Error found when finding dtype frozen nodes: Recovering from last known valid graph: '{self._last_valid_model_path}'"
            )
            return self.load_last_valid_model()
        return model


@dataclasses.dataclass
class FusionDebugHook:
    model_name: str
    out_dir: str
    external_data: bool
    hooked_patterns: list[str]

    debug_files_path: list[str] = dataclasses.field(default_factory=list)

    @staticmethod
    def from_surgery_seq(
        surgery_seq,
    ):
        """
        IF surgery_seq[i] == 'pattern_name*'
        THEN
            surgery_seq[i] := 'pattern_name'
            hooked_patterns += ['pattern_name']

        Returns:
            * a factory of FusionDebugHook with  hooked_patterns
            * modifies hooked entries in SURGERY_SEQ

        """
        hooked_patterns_ = []
        for i, pattern in enumerate(surgery_seq):
            if pattern[-1] == "*":  # hoooked
                pattern = pattern[:-1]
                hooked_patterns_.append(pattern)
                surgery_seq[i] = pattern
        factory = type(
            "",
            (),
            dict(
                create=lambda name_, out_dir_, external_data_: FusionDebugHook(
                    name_, out_dir_, external_data_, hooked_patterns_
                )
            ),
        )
        return factory

    def delete_debug_files(self, logger: Logger) -> None:
        for file_path in self.debug_files_path:
            if not os.path.exists(file_path):
                continue

            logger.info(f"Removing debug file {file_path}")
            os.remove(file_path)

    def after_hook(self, pattern, model):
        self.save_dbg_model(f"{self.model_name}_after_{pattern}", model)
        return

    def before_hook(self, pattern, model):
        self.save_dbg_model(f"{self.model_name}_before_{pattern}", model)
        return

    def is_hooked(self, pattern):
        return pattern in self.hooked_patterns

    def save_dbg_model(self, file_name, model):
        """Save a copy of the model with added debug info."""
        # Uncomment to add debug info to the saved model
        # import copy
        # model = copy.deepcopy(model)

        file_path = os.path.join(self.out_dir, f"{file_name}.onnx")
        self.debug_files_path.append(file_path)
        save_model(model, file_path, external_data=self.external_data)
        return
