# fmt: on
import onnx

from OGOAT.src.L1_fusion.L1_utils.ops_definition_utils import OnnxOpsWrapper
from OGOAT.src.L1_fusion.L1_utils.safe_runner import SafeRunner
from OGOAT.src.L1_fusion.py_match.fusion_ops import yaml_to_py
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, MatcherOrCategory
from OGOAT.src.L1_fusion.py_match.skip import WalkCfgSkipNoop
from OGOAT.src.utils.context import Logger


class Fuser:

    def __init__(
        self,
        model: onnx.ModelProto,
        onnx_ops: OnnxOpsWrapper,
        logger: Logger,
        runner: SafeRunner,
    ) -> None:
        self._model = model
        self._onnx_ops = onnx_ops
        self._model_dict = ModelDict(self._model, self._onnx_ops)
        self._matchers_done: set[type] = set()
        self._match_cnt = 0
        self._logger = logger
        self._runner = runner

        # Set of matcher class which were disabled due to failures when
        # running them
        self.disabled_matchers: set[type] = set()

        # Number of times a failure was caught per matcher
        self.matcher_failures_count: dict[type, int] = dict()

        # Maximum number of failure allowed per matcher before disabling it
        self.max_failure_count = 5

    def set_model(self, model):
        self._model = model
        self._model_dict = ModelDict(self._model, self._onnx_ops)

    def check_and_handle_error(self, matcher: Matcher) -> None:
        """
        Check if an error was caught by the runner.
        In the case of failure we will count the number of time that the pattern
        failed, if the maximum allowed is exceeded the pattern will be disabled
        and the user will be informed via a warning.
        A SafeRunnerError is then raised to tell the caller that an error was
        caught.
        """
        if not self._runner.has_failed:
            return

        # Count the number of failure
        failures_count = self.matcher_failures_count.get(type(matcher), 0) + 1
        self.matcher_failures_count[type(matcher)] = failures_count

        # If max failure exceeded disable the matcher entirely
        if failures_count > self.max_failure_count:
            self.disabled_matchers.add(type(matcher))
            self._logger.warning(
                f"Maximum number of failures of {self.max_failure_count} exceeded for pattern '{matcher.get_matcher_class_name()}' disabling it entirely."
            )

        # Reset the list of matchers that have ran
        self._matchers_done.clear()

        # Reset logger indentation
        self._logger.reset_indentation()

        # Propgate error on to the caller
        self._runner.raise_error()

    @property
    def match_cnt(self) -> int:
        return self._match_cnt

    def is_matcher_disabled(self, matcher: Matcher) -> bool:
        matcher_type = type(matcher)
        return (
            matcher_type in self._matchers_done
            or matcher_type in self.disabled_matchers
        )

    def run_with_dependencies(self, matcher: Matcher) -> None:
        if self.is_matcher_disabled(matcher):
            return

        for dependency in matcher.dependencies:
            self.run_one(dependency)
        self.run_matcher(matcher)

    def run_matcher(self, matcher: Matcher) -> None:
        matcher_name = matcher.get_matcher_class_name()
        print("{:<40}python, ".format(f"starting: {matcher_name}, "), end="")
        match_cnt = self._runner.run(
            matcher.run, self._model_dict, WalkCfgSkipNoop(), self._runner, self._logger
        )
        self.check_and_handle_error(matcher)

        # the sanity check is rather expensive in terms of execution time (about
        # 0.1s for PSR model), calling it once a pattern matched is acceptable
        if match_cnt > 0:
            self._runner.run(self._model_dict.sanity_check, self._logger)
            self.check_and_handle_error(matcher)

        print(f"found {match_cnt} instances")
        self._matchers_done.add(type(matcher))
        self._match_cnt += match_cnt

    def run_one(self, matcher_or_category: MatcherOrCategory) -> None:
        for matcher in matcher_or_category.get_matchers():
            self.run_with_dependencies(matcher)


def get_py_matcher(pattern_name: str) -> Matcher:
    return yaml_to_py[pattern_name]
