# fmt: on
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional
from typing_extensions import Self

import numpy as np
import onnx

from OGOAT.src.L1_fusion.kernel_metadata_loader import (
    KernelMetadataLoader,
)

# from OGOAT.src.L1_fusion.py_match.checkers import AttrValue
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict
from OGOAT.src.L1_fusion.py_match.helpers.common_type import (
    NumpyDType,
    OnnxDType,
    TensorShape,
)
from OGOAT.src.utils.context import Logger
from OGOAT.src.L1_fusion.L1_utils.safe_runner import SafeRunner

matcher_name_attribute = "matcher_name"
path_sep = "."


class MatcherError(ValueError):
    pass


class ExpectedNowhereButFound(MatcherError):
    def __init__(self, path: str) -> None:
        super().__init__(f"Expected {path} to be not found, but it was.")


class GraphInputNotFound(MatcherError):
    def __init__(self, path: str) -> None:
        super().__init__(f"GraphInput {path} not found.")


class GraphOutputNotFound(MatcherError):
    def __init__(self, path: str) -> None:
        super().__init__(f"GraphOutput {path} not found.")


class InitializerNotFound(MatcherError):
    def __init__(self, path: str) -> None:
        super().__init__(f"Initializer {path} not found.")


class NodeNotFound(MatcherError):
    def __init__(self, path: str) -> None:
        super().__init__(f"Node {path} not found.")


class TensorNotFound(MatcherError):
    def __init__(self, path: str) -> None:
        super().__init__(f"Tensor {path} not found.")


class ShapeMismatchError(MatcherError):
    def __init__(self, msg: str) -> None:
        super().__init__(f"Shape error: {msg}")


class NoMatch(MatcherError):
    def __init__(self, msg: str) -> None:
        super().__init__(f"No match: {msg}")


class Element(ABC):

    def __init__(self, model_dict: ModelDict, walk_cfg: "WalkCfgBase") -> None:
        self._model_dict = model_dict
        self._walk_cfg = walk_cfg

    def __call__(self, path: str) -> "Element":
        """
        element("a.b.c") is syntactic sugar for
        element.get_connection("a").get_connection("b").get_connection("c")
        """
        current = self
        for name in path.split(path_sep):
            current = current.get_connection(name)
        return current

    def __eq__(self, other) -> bool:
        """
        Two `Element` objects are equal if and only if:
            - both of them are instances of the same class.
            - both are referencing same `ModelDict` object
            - both have the same name
        """
        return (
            # self and other are from same class hierarchy:
            (isinstance(other, type(self)) and isinstance(self, type(other)))
            # both are referencing same ModelDict object:
            and id(self._model_dict) == id(other._model_dict)
            # both have the same name:
            and self.get_name() == other.get_name()
        )

    def __hash__(self):
        """
        Hashed values:
            - name of the class at the top of the class hierarchy this object belongs to (usually it is this object's class)
            - id of referenced ModelDict object
            - name of this object
        """
        hashable = (type(self).__name__, id(self._model_dict), self.get_name())
        return hash(hashable)

    def __ne__(self, rhs) -> bool:
        return not self.__eq__(rhs)

    def __repr__(self) -> str:
        return f"Element()"

    def check(self, checker: "Checker") -> bool:
        """
        Check if the element passes the check implemented by the checker.
        Return the boolean result of the check.
        """
        return checker.check(self)

    def check_one(self, checker: "Checker") -> bool:
        """
        If there are multiple options for the element to check, check if exactly
        one option satisfies the checker.
        Multiple options can only happen for tensors.
        It is important to check that exactly one option satisfies the cheker in
        order to avoid ambiguities in pattern matching.
        """
        # The base class implementation checks the one and only option.
        return self.check(checker)

    def match(self, *matchers: "Matcher") -> "Matcher":
        """
        Try to match each matcher in matchers.
        Stop after the first matcher that matched and return it.
        Raise NoMatch if no matcher matched.
        """
        msgs: list[str] = []
        for matcher in matchers:
            try:
                matcher.do_match(self)
                return matcher
            except MatcherError as m_e:
                msgs.append(str(m_e))
        raise NoMatch("Submatcher(s) did not match: " + "; ".join(msgs))

    def match_opt(self, *matchers: "Matcher") -> Optional["Matcher"]:
        """
        Try to match each matcher in matchers.
        Stop after the first matcher that matched and return it.
        Return None if no matcher matched.
        """
        try:
            return self.match(matchers)
        except MatcherError:
            return None

    def require(self, checker: "Checker") -> "Element":
        """
        Require the element to pass the check implemented by the checker.
        Return self if the check passes.
        Raise NoMatch if the check fails.
        """
        if not checker.check(self):
            raise NoMatch(f"{self} does not match {checker}")
        return self

    def require_one(self, checker: "Checker") -> "Element":
        """
        If there are multiple options for the element to check, require that
        exactly one option satisfies the checker.
        Multiple options can only happen for tensors.
        It is important to check that exactly one option satisfies the cheker in
        order to avoid ambiguities in pattern matching.
        """
        # The base class implementation requires the check for the one and only
        # option.
        return self.require(checker)

    def get_model_activation_dtype(self) -> Optional[str]:
        return self._model_dict._activation_type

    def get_model_activation_dtype_sorted_list(self) -> Optional[str]:
        return self._model_dict._activation_dtype_sorted_list

    def get_connection(self, name: str) -> "Element":
        return Nowhere(
            self._model_dict,
            self._walk_cfg,
            f"{self.get_name()}{path_sep}{name}",
        )

    def get_inputs(self) -> "Inputs":
        return Inputs(
            self._model_dict,
            self._walk_cfg,
            f"{self.get_name()}{path_sep}inputs",
            [],
        )

    def get_inputs_dict(self) -> dict[str, "InputTensor"]:
        """
        Get inputs of the node as dict suitable for Matcher.add_node(inputs=).
        """
        return {
            name: input_
            for name, input_ in zip(self.get_schema_input_names(), self.get_inputs())
        }

    def get_input_dtype(self, idx: int) -> str:
        return self.get_inputs()[idx].require_tensor().get_dtype()

    @abstractmethod
    def get_name(self) -> str:
        pass

    def get_non_tensor(self) -> "Element":
        """
        If at a tensor, walk the graph to the next element that is not a tensor.
        If already at an element that is not a tensor, do nothing.
        """
        return self

    def get_outputs(self) -> "Outputs":
        return Outputs(
            self._model_dict,
            self._walk_cfg,
            f"{self.get_name()}{path_sep}outputs",
            [],
        )

    def get_outputs_dict(self) -> dict[str, "OutputTensor"]:
        """
        Get outputs of the node as dict suitable for Matcher.add_node(output=).
        """
        return {
            name: output
            for name, output in zip(self.get_schema_output_names(), self.get_outputs())
        }

    def get_output_dtype(self, idx: int) -> str:
        return self.get_outputs()[idx].require_tensor().get_dtype()

    def get_schema_input_names(self) -> list[str]:
        return []

    def get_schema_output_names(self) -> list[str]:
        return []

    def check_graph_input(self) -> bool:
        return False

    def check_graph_output(self) -> bool:
        return False

    def check_input_tensor(self) -> bool:
        return False

    def check_output_tensor(self) -> bool:
        return False

    def check_initializer(self) -> bool:
        return False

    def check_node(self) -> bool:
        return False

    def check_nowhere(self) -> bool:
        return False

    def check_tensor(self) -> bool:
        return False

    def require_graph_input(self) -> "GraphInput":
        raise GraphInputNotFound(f"{self.get_name()}{path_sep}graph_input")

    def require_graph_output(self) -> "GraphOutput":
        raise GraphOutputNotFound(f"{self.get_name()}{path_sep}graph_output")

    def require_initializer(self) -> "Initializer":
        raise InitializerNotFound(f"{self.get_name()}{path_sep}initializer")

    def require_node(self) -> "Node":
        raise NodeNotFound(f"{self.get_name()}{path_sep}node")

    def require_nowhere(self) -> "Nowhere":
        raise ExpectedNowhereButFound(self.get_name())

    def require_tensor(self) -> "Tensor":
        raise TensorNotFound(f"{self.get_name()}{path_sep}tensor")

    @abstractmethod
    def shallow_clone(self) -> Self:
        """
        Return a shallow clone of the element (i.e. a cloned python object, but
        pointing to the same node/tensor/initializer in the ONNX graph).
        """
        pass

    def skip(self) -> Self:
        """
        Skip nodes according to current walk config for input and ouput tensors
        (overriden in InputTensor and OutputTensor classes).
        Do not skip anything for all other elements.
        """
        return self

    def with_walk_cfg(self, walk_cfg: "WalkCfgBase") -> Self:
        """
        Return shallow clone of element, but with different "graph walk config".
        """
        clone = self.shallow_clone()
        clone._walk_cfg = walk_cfg
        return clone


class GraphInput(Element):
    def __init__(
        self,
        model_dict: ModelDict,
        walk_cfg: "WalkCfgBase",
        graph_input_name: str,
    ) -> None:
        super().__init__(model_dict, walk_cfg)
        self._graph_input_name = graph_input_name

    def __repr__(self) -> str:
        return f"GraphInput({repr(self._graph_input_name)})"

    def get_dtype(self) -> str:
        try:
            return self._model_dict.get_data_type(self._graph_input_name)
        except KeyError:
            raise GraphInputNotFound(f"{self._graph_input_name}{path_sep}dtype")

    def get_name(self) -> str:
        return self._graph_input_name

    def get_shape(self) -> list[int]:
        try:
            return [
                int(dim) for dim in self._model_dict.get_shape(self._graph_input_name)
            ]
        except KeyError:
            raise GraphInputNotFound(f"{self._graph_input_name}{path_sep}shape")

    def check_graph_input(self) -> bool:
        return True

    def require_graph_input(self) -> "GraphInput":
        return self

    def shallow_clone(self) -> Self:
        return GraphInput(self._model_dict, self._walk_cfg, self._graph_input_name)


class GraphOutput(Element):
    def __init__(
        self,
        model_dict: ModelDict,
        walk_cfg: "WalkCfgBase",
        graph_output_name: str,
    ) -> None:
        super().__init__(model_dict, walk_cfg)
        self._graph_output_name = graph_output_name

    def __repr__(self) -> str:
        return f"GraphOutput({repr(self._graph_output_name)})"

    def get_dtype(self) -> str:
        try:
            return self._model_dict.get_data_type(self._graph_output_name)
        except KeyError:
            raise GraphOutputNotFound(f"{self._graph_output_name}{path_sep}dtype")

    def get_name(self) -> str:
        return self._graph_output_name

    def get_shape(self) -> list[int]:
        try:
            return [
                int(dim) for dim in self._model_dict.get_shape(self._graph_output_name)
            ]
        except KeyError:
            raise GraphOutputNotFound(f"{self._graph_output_name}{path_sep}shape")

    def check_graph_output(self) -> bool:
        return True

    def require_graph_output(self) -> "GraphOutput":
        return self

    def shallow_clone(self) -> Self:
        return GraphOutput(self._model_dict, self._walk_cfg, self._graph_output_name)


class Initializer(Element):
    def __init__(
        self, model_dict: ModelDict, walk_cfg: "WalkCfgBase", init_name: str
    ) -> None:
        super().__init__(model_dict, walk_cfg)
        self._init_name = init_name

    def __repr__(self) -> str:
        return f"Initializer({repr(self._init_name)})"

    def get_dtype(self) -> NumpyDType:
        try:
            return self._model_dict.get_data_type(self._init_name)
        except KeyError:
            raise InitializerNotFound(f"{self._init_name}{path_sep}dtype")

    def get_dtype_raw(self) -> OnnxDType:
        try:
            return self._model_dict.get_data_type_raw(self._init_name)
        except KeyError:
            raise TensorNotFound(f"{self._init_name}{path_sep}dtype")

    def get_name(self) -> str:
        return self._init_name

    def get_shape(self) -> list[int]:
        try:
            return [int(dim) for dim in self._model_dict.get_shape(self._init_name)]
        except KeyError:
            raise InitializerNotFound(f"{self._init_name}{path_sep}shape")

    def get_value(self) -> Any:
        return self._model_dict.get_initializer_value(self.get_name())

    def get_value_as_array(self) -> np.ndarray:
        return onnx.numpy_helper.to_array(
            self._model_dict.get_initializer(self.get_name())
        )

    def multiply(self, factor: float) -> None:
        self._model_dict.initializer_multiplication(self.get_name(), factor)

    def check_initializer(self) -> bool:
        return True

    def require_initializer(self) -> "Initializer":
        return self

    def shallow_clone(self) -> Self:
        return Initializer(self._model_dict, self._walk_cfg, self._init_name)

    def update_initializer_value(self, new_value: any) -> None:
        # update initializer value
        self._model_dict.update_initializer_value(self.get_name(), new_value)

    def flag_used(self) -> None:
        """
        Flag the initializer used once for the lost initializer check
        """
        self._model_dict._lost_ini_helper.inc(self.get_name())


class Nowhere(Element):
    def __init__(
        self, model_dict: ModelDict, walk_cfg: "WalkCfgBase", path: str
    ) -> None:
        super().__init__(model_dict, walk_cfg)
        self._path = path

    def __repr__(self) -> str:
        return f"NotFound({self._path})"

    def get_inputs(self) -> "Inputs":
        return Inputs(self._model_dict, self._walk_cfg, self._path, [])

    def get_name(self) -> str:
        return self._path

    def get_outputs(self) -> "Outputs":
        return Outputs(self._model_dict, self._walk_cfg, self._path, [])

    def check_nowhere(self) -> bool:
        return True

    def require_nowhere(self) -> "Nowhere":
        return self

    def shallow_clone(self) -> Self:
        return Nowhere(self._model_dict, self._walk_cfg, self._path)


class Node(Element):
    def __init__(
        self, model_dict: ModelDict, walk_cfg: "WalkCfgBase", node_name: str
    ) -> None:
        super().__init__(model_dict, walk_cfg)
        self._node_name = node_name

    def __eq__(self, other) -> bool:
        return (
            # self and other are from same class hierarchy:
            isinstance(other, Node)
            # both are referencing same ModelDict object:
            and id(self._model_dict) == id(other._model_dict)
            # both have the same name:
            and self.get_name() == other.get_name()
        )

    def __hash__(self):
        hashable = (Node.__name__, id(self._model_dict), self.get_name())
        return hash(hashable)

    def __repr__(self) -> str:
        return f"Node({repr(self._node_name)})"

    def change_op_type(self, new_op_type: str) -> None:
        """
        Change the op_type of this node.
        """
        self._model_dict.change_op_type(self._node_name, new_op_type)

    def set_attribute(self, attr_name: str, attr_value: Any) -> None:
        self._model_dict.set_attribute(self.get_name(), attr_name, attr_value)

    def remove_attribute(self, attr_name: str) -> None:
        self._model_dict.remove_attribute(self.get_name(), attr_name)

    def setdefault_attribute(self, attr_name: str, attr_value: Any) -> Any:
        """Analogous to dict.setdefault:
            sets attribute `attr_name` to `attr_value` ONLY IF there is no attribute `attr_name`
        :return value of `attr_name` attribute
        """
        if not self.has_attribute(attr_name):
            self.set_attribute(attr_name, attr_value)
        return self.get_attribute_value(attr_name)

    def get_attribute_value(self, attribute_name: str) -> Any:
        return self._model_dict.get_attribute_value(self._node_name, attribute_name)

    def get_attributes(self) -> dict[str, Any]:
        return self._model_dict.get_attributes(self._node_name)

    def get_raw_attributes(self) -> dict[str, Any]:
        return self._model_dict.get_raw_attributes(self._node_name)

    def get_raw_attribute(self, attribute_name: str) -> Any:
        return self._model_dict.get_raw_attribute(self._node_name, attribute_name)

    def get_connection(self, name: str) -> Element:
        op_type = self.get_op_type()
        try:
            input_idx = self._model_dict.get_onnx_ops().get_input_prm_idx_by_name(
                op_type, name
            )
        except KeyError:
            # unknown op type in node, probably a node already fused -> connection not found -> nowhere
            return Nowhere(
                self._model_dict,
                self._walk_cfg,
                f"{self._node_name}{path_sep}{op_type}(unknown){path_sep}.{name}",
            )
        if input_idx is not None:
            return self.get_inputs()[input_idx]
        try:
            output_idx = self._model_dict.get_onnx_ops().get_output_prm_idx_by_name(
                op_type, name
            )
        except KeyError:
            # unknown op type in node, probably a node already fused -> connection not found -> nowhere
            return Nowhere(
                self._model_dict,
                self._walk_cfg,
                f"{self._node_name}{path_sep}{op_type}(unknown){path_sep}.{name}",
            )
        if output_idx is not None:
            return self.get_outputs()[output_idx]
        return Nowhere(
            self._model_dict,
            self._walk_cfg,
            f"{self._node_name}{path_sep}{name}(op_type={op_type}, unknown_input_output={name})",
        )

    def get_domain(self) -> str:
        return self._model_dict.get_domain(self._node_name)

    def get_inputs(self) -> "Inputs":
        try:
            input_names = self._model_dict.get_input_names(self._node_name)
        except KeyError:
            input_names = []
        return Inputs(self._model_dict, self._walk_cfg, self._node_name, input_names)

    def get_act_inputs(self) -> list["InputTensor"]:
        act_inputs = []
        try:
            for input_tensor in self.get_inputs():
                if not input_tensor.check_input_tensor() or not input_tensor.get_name():
                    continue
                if (
                    input_tensor.check_initializer()
                    or len(input_tensor.require_tensor().get_shape()) == 0
                ):
                    continue
                act_inputs.append(input_tensor)
        except KeyError:
            pass
        return act_inputs

    def get_act_outputs(self) -> list["OutputTensor"]:
        act_outputs = []
        for output_tensor in self.get_outputs():
            act_outputs.append(output_tensor)
        return act_outputs

    def get_matcher_name(self) -> Optional[str]:
        """
        Return name if matcher that has produced this node.
        If the node does not have the attribute indicating the matcher name,
        return None.
        """
        try:
            matcher_name = self.get_attribute_value(matcher_name_attribute)
        except KeyError:
            return None
        return matcher_name

    def get_name(self) -> str:
        return self._node_name

    def get_op_type(self) -> str:
        return self._model_dict.get_op_type(self._node_name)

    def get_outputs(self) -> "Outputs":
        try:
            output_names = self._model_dict.get_output_names(self._node_name)
        except KeyError:
            output_names = []
        return Outputs(self._model_dict, self._walk_cfg, self._node_name, output_names)

    def get_node_list_by_optype(self) -> list[Self]:
        """
        Get list of nodes with the same op_type as this node.
        """

        node_names = self._model_dict.get_node_names(self.get_op_type())
        matched_nodes = [
            Node(self._model_dict, WalkCfgPlain(), node_name)
            for node_name in node_names
        ]
        return matched_nodes

    def is_consumer(self, start_node, target_node) -> bool:
        """
        Returns True if there is a downward path from start_node to target_node.
        """
        stack = [start_node]
        visited = set()

        while stack:
            current_node = stack.pop()
            if current_node.get_name() in visited:
                continue
            visited.add(current_node.get_name())

            if current_node == target_node:
                return True

            if not isinstance(current_node, GraphOutput):
                for output in current_node.get_outputs():
                    readers = output.get_readers() or []
                    for reader in readers:
                        stack.append(reader.get_non_tensor())

        return False

    def check_node(self) -> bool:
        return True

    def require_node(self) -> "Node":
        return self

    def get_schema_input_names(self) -> list[str]:
        return self._model_dict.get_onnx_ops().get_input_names(self.get_op_type())

    def get_schema_output_names(self) -> list[str]:
        return self._model_dict.get_onnx_ops().get_output_names(self.get_op_type())

    def has_attribute(self, attribute_name: str) -> bool:
        return self.get_attribute_value(attribute_name) is not None

    def shallow_clone(self) -> Self:
        return Node(self._model_dict, self._walk_cfg, self._node_name)


class Tensor(Element):
    def __init__(
        self,
        model_dict: ModelDict,
        walk_cfg: "WalkCfgBase",
        tensor_name: str,
        origin_node_name: Optional[str] = None,
    ) -> None:
        super().__init__(model_dict, walk_cfg)
        self._tensor_name = tensor_name
        self._origin_node_name = origin_node_name

    def __eq__(self, other) -> bool:
        return (
            # self and other are from same class hierarchy:
            isinstance(other, Tensor)
            # both are referencing same ModelDict object:
            and id(self._model_dict) == id(other._model_dict)
            # both have the same name:
            and self.get_name() == other.get_name()
        )

    def __hash__(self):
        hashable = (Tensor.__name__, id(self._model_dict), self.get_name())
        return hash(hashable)

    def __repr__(self) -> str:
        return f"{type(self).__name__}({repr(self._tensor_name)})"

    def check_one(self, checker: "Checker") -> bool:
        """
        If there are multiple options for the element to check, check if exactly
        one option satisfies the checker.
        Multiple options can only happen for tensors.
        It is important to check that exactly one option satisfies the cheker in
        order to avoid ambiguities in pattern matching.
        """
        cnt = 0
        for next_elem in self.get_next_list():
            if next_elem.check(checker):
                cnt += 1
        return cnt == 1

    def get_connection(self, name: str) -> Element:
        return self.get_non_tensor().get_connection(name)

    def get_dtype(self) -> NumpyDType:
        try:
            return self._model_dict.get_data_type(self._tensor_name)
        except KeyError:
            raise TensorNotFound(f"{self._tensor_name}{path_sep}dtype")

    def get_dtype_raw(self) -> OnnxDType:
        try:
            return self._model_dict.get_data_type_raw(self._tensor_name)
        except KeyError:
            raise TensorNotFound(f"{self._tensor_name}{path_sep}dtype")

    def get_inputs(self) -> "Inputs":
        # If this tensor can be auto-converted to a node, forward to node.
        return self.get_non_tensor().get_inputs()

    def get_name(self) -> str:
        return self._tensor_name

    def get_next(self) -> Element:
        """
        Get next element if it is defined uniquely.
        """
        next_list = self.get_next_list()
        if len(next_list) == 1:
            return next_list[0]
        return Nowhere(
            self._model_dict,
            self._walk_cfg,
            f"{self._tensor_name}{path_sep}next",
        )

    def get_next_list(self) -> list[Element]:
        """
        Get list of all options for next element. Next means the element at the
        "other" side of the tensor when walking a graph.
        """
        # It is not clear in base class if to return reader or writer.
        return Nowhere(
            self._model_dict,
            self._walk_cfg,
            f"{self._tensor_name}{path_sep}next_list",
        )

    def get_non_tensor(self) -> Element:
        return self.get_next()

    def get_origin(self) -> Element:
        if self._origin_node_name is None:
            return Nowhere(
                self._model_dict,
                self._walk_cfg,
                f"{self._tensor_name}{path_sep}origin",
            )
        return Node(self._model_dict, self._walk_cfg, self._origin_node_name)

    def get_outputs(self) -> "Outputs":
        # If this tensor can be auto-converted to a node, forward to node.
        return self.get_non_tensor().get_outputs()

    def get_reader(self) -> Element:
        readers = self.get_readers()
        if len(readers) != 1:
            return Nowhere(
                self._model_dict,
                self._walk_cfg,
                f"{self._tensor_name}{path_sep}unique_reader",
            )
        return readers[0]

    def get_readers(self) -> list[Element]:
        readers = []
        if self._tensor_name in self._model_dict._output_names:
            readers.append(
                GraphOutput(self._model_dict, self._walk_cfg, self._tensor_name)
            )
        try:
            node_names = self._model_dict.get_reader_names(self._tensor_name)
        except KeyError:
            return (
                readers
                if len(readers) > 0
                else Nowhere(
                    self._model_dict,
                    self._walk_cfg,
                    f"{self._tensor_name}{path_sep}readers",
                )
            )
        for node_name in node_names:
            readers.append(Node(self._model_dict, self._walk_cfg, node_name))
        return readers

    def get_schema_input_names(self) -> list[str]:
        return self.get_non_tensor().get_schema_input_names()

    def get_schema_output_names(self) -> list[str]:
        return self.get_non_tensor().get_schema_output_names()

    def get_shape(self) -> TensorShape:
        try:
            return [int(dim) for dim in self._model_dict.get_shape(self._tensor_name)]
        except ValueError:
            return self._model_dict.get_shape(self._tensor_name)
        except KeyError:
            raise TensorNotFound(f"{self._tensor_name}{path_sep}shape")

    def set_shape(self, shape: TensorShape, dtype: NumpyDType) -> None:
        self._model_dict.set_shape(self._tensor_name, shape, dtype)

    def get_writer(self) -> Element:
        if self._tensor_name in self._model_dict._ini_dict:
            return Initializer(self._model_dict, self._walk_cfg, self._tensor_name)
        elif self._model_dict.is_graph_input(self._tensor_name):
            return GraphInput(self._model_dict, self._walk_cfg, self._tensor_name)
        try:
            node_name = self._model_dict.get_writer_name(self._tensor_name)
        except ValueError:
            return Nowhere(
                self._model_dict,
                self._walk_cfg,
                f"{self._tensor_name}{path_sep}writer",
            )
        return Node(self._model_dict, self._walk_cfg, node_name)

    def get_act_suppliers(self) -> list["InputTensor"]:
        """
        Tensor suppliers of a tensor X -- are input tensors of X's  writer-nodes
        """
        suppliers_ = [i for i in self.get_writer().require_node().get_act_inputs()]
        return suppliers_

    def get_act_consumers(self) -> list["OutputTensor"]:
        """
        Tensor consumers of a tensor X -- are output tensors of X's  reader-nodes
        """
        consumers_ = [
            o for r in self.get_readers() for o in r.require_node().get_act_outputs()
        ]
        return consumers_

    # FIXME deprecated: Use require_initializer().get_value_as_array()
    def get_initializer_array(self) -> np.ndarray:
        tensor_proto = self._model_dict._ini_dict[self._tensor_name]
        return onnx.numpy_helper.to_array(tensor_proto)

    def get_initializer_array_or_none(self) -> Optional[np.ndarray]:
        try:
            return self.get_initializer_array()
        except KeyError:
            return None

    def require_one(self, checker: "Checker") -> "Tensor":
        """
        If there are multiple options for the element to check, require that
        exactly one option satisfies the checker.
        Multiple options can only happen for tensors.
        It is important to check that exactly one option satisfies the cheker in
        order to avoid ambiguities in pattern matching.
        """
        found: Optional[Element] = None
        msgs: list[str] = []
        for next_elem in self.get_next_list():
            try:
                next_elem.require(checker)
            except MatcherError as m_e:
                msgs.append(str(m_e))
                continue
            if found is not None:
                raise NoMatch("multiple options matched")
            found = next_elem
        if found is None:
            raise NoMatch("No option matched: " + "; ".join(msgs))
        return found

    def check_graph_input(self) -> bool:
        return self.get_non_tensor().check_graph_input()

    def check_graph_output(self) -> bool:
        return self.get_non_tensor().check_graph_output()

    def check_initializer(self) -> bool:
        return self.get_non_tensor().check_initializer()

    def check_node(self) -> bool:
        # If this tensor can be auto-converted to a node, return True.
        return self.get_non_tensor().check_node()

    def check_tensor(self) -> bool:
        return True

    def check_shape(self, shape: TensorShape) -> bool:
        return self.get_shape() == shape

    def require_graph_input(self) -> "GraphInput":
        return self.get_non_tensor().require_graph_input()

    def require_graph_output(self) -> "GraphOutput":
        return self.get_non_tensor().require_graph_output()

    def require_initializer(self) -> "Initializer":
        return self.get_non_tensor().require_initializer()

    def require_node(self) -> "Node":
        # If this tensor can be auto-converted to a node, return this node.
        return self.get_non_tensor().require_node()

    def require_tensor(self) -> "Tensor":
        return self

    def require_shape(self, shape: TensorShape) -> Self:
        if not self.check_shape(shape):
            raise NoMatch(
                f"Tensor({self._tensor_name}).shape={self.get_shape()} does not match required shape {shape}"
            )
        return self

    def shallow_clone(self) -> Self:
        return Tensor(
            self._model_dict,
            self._walk_cfg,
            self._tensor_name,
            self._origin_node_name,
        )


class InputTensor(Tensor):
    def get_next_list(self) -> list[Element]:
        return [self.skip().get_writer()]

    def check_input_tensor(self) -> bool:
        return True

    def shallow_clone(self) -> Self:
        return InputTensor(
            self._model_dict,
            self._walk_cfg,
            self._tensor_name,
            self._origin_node_name,
        )

    def skip(self) -> Self:
        """
        Skip upwards as much as possible according to current walk config.
        """
        return self._walk_cfg.skip_up(self)

    def search_for_shape(self) -> Self:
        """
        Skip upwards as much as possible to retreve the tensor shape, if tensor shape is unknown.
        """
        return self._walk_cfg.search_upward_for_shape(self)


class OutputTensor(Tensor):
    def get_next_list(self) -> list[Element]:
        return self.skip().get_readers()

    def check_output_tensor(self) -> bool:
        return True

    def shallow_clone(self) -> Self:
        return OutputTensor(
            self._model_dict,
            self._walk_cfg,
            self._tensor_name,
            self._origin_node_name,
        )

    def skip(self) -> Self:
        """
        Skip downward as much as possible according to current walk config.
        """
        return self._walk_cfg.skip_down(self)

    def search_for_shape(self) -> Self:
        """
        Skip downwards as much as possible to retreve the tensor shape, if tensor shape is unknown.
        """
        return self._walk_cfg.search_downward_for_shape(self)


class InputsOrOutputs(ABC):
    """
    Base class for inputs or outputs of node (or element in general).
    """

    def __init__(
        self,
        model_dict: ModelDict,
        walk_cfg: "WalkCfgBase",
        origin_node_name: str,
        connection_names: list[str],
    ) -> None:
        self._model_dict = model_dict
        self._walk_cfg = walk_cfg
        self._origin_node_name = origin_node_name
        self._connection_names = connection_names

    def __iter__(self):
        # Use counting-based iteration internally.
        # Default __iter__ implementation would execute till IndexError occurs,
        # which does not happen for this class.
        return (self[i] for i in range(len(self)))

    @abstractmethod
    def __getitem__(self, idx: int) -> Element:
        pass

    def __len__(self) -> int:
        return len(self._connection_names)


class Inputs(InputsOrOutputs):
    def __getitem__(self, idx: int) -> Element:
        if idx < 0 or idx >= len(self._connection_names):
            return Nowhere(
                self._model_dict,
                self._walk_cfg,
                f"{self._origin_node_name}{path_sep}inputs[{idx}]",
            )
        return InputTensor(
            self._model_dict,
            self._walk_cfg,
            self._connection_names[idx],
            self._origin_node_name,
        )


class Outputs(InputsOrOutputs):
    def __getitem__(self, idx: int) -> Element:
        if idx < 0 or idx >= len(self._connection_names):
            return Nowhere(
                self._model_dict,
                self._walk_cfg,
                f"{self._origin_node_name}{path_sep}outputs[{idx}]",
            )
        return OutputTensor(
            self._model_dict,
            self._walk_cfg,
            self._connection_names[idx],
            self._origin_node_name,
        )


class WalkCfgBase(ABC):
    """
    Graph walk configuration.
    Decides which nodes to skip automatically during "walk" of ONNX graph.
    """

    @abstractmethod
    def skip_down(self, out_tensor: OutputTensor) -> OutputTensor:
        """
        Skip nodes automatically while walking the graph downwards, i.e.,
        towards the outputs.
        Return the output tensor behind the skipped nodes.
        Return the passed output tensor if nothing got skipped.
        """

    @abstractmethod
    def skip_up(self, in_tensor: InputTensor) -> InputTensor:
        """
        Skip nodes automatically while walking the graph upwards, i.e., towards
        the inputs.
        Return the input tensor behind the skipped elements.
        Return the passed input tensor if nothing got skipped.
        """


class WalkCfgPlain(WalkCfgBase):
    """
    Plain graph walk. Nothing is skipped automatically.
    """

    def skip_down(self, out_tensor: OutputTensor) -> OutputTensor:
        return out_tensor

    def skip_up(self, in_tensor: InputTensor) -> InputTensor:
        return in_tensor


class Checker(ABC):
    def __and__(self, rhs: "Checker") -> bool:
        return CheckerAnd(self, rhs)

    def __not__(self) -> bool:
        return CheckerNot(self)

    def __or__(self, rhs: "Checker") -> bool:
        return CheckerOr(self, rhs)

    @abstractmethod
    def check(self, element: Element) -> bool:
        pass


class CheckerAnd(Checker):
    def __init__(self, a: Checker, b: Checker):
        self._a = a
        self._b = b

    def check(self, element: Element) -> bool:
        return self._a.check(element) and self._b.check(element)


class CheckerNot(Checker):
    def __init__(self, a: Checker):
        self._a = a

    def check(self, element: Element) -> bool:
        return not self._a.check(element)


class CheckerOr(Checker):
    def __init__(self, a: Checker, b: Checker):
        self._a = a
        self._b = b

    def check(self, element: Element) -> bool:
        return self._a.check(element) or self._b.check(element)


class MatcherOrCategory:
    """
    Common base class for Matcher and matcher Category.
    Use this class in all funtions that run single matchers or a category of
    matchers.
    """

    @abstractmethod
    def contains_matcher(self, matcher_name: str) -> bool:
        pass

    @abstractmethod
    def get_matchers(self) -> list["Matcher"]:
        pass


class Matcher(MatcherOrCategory, ABC):
    """
    A fusion pattern, i.e., matching a subgraph and fusing it to a single node
    or any other similar modification of the matched subgraph.

    dependencies -- list of categories that need to run before this matcher
    """

    dependencies: list["MatcherOrCategory"] = []
    node_blacklist: set[str] = set()

    @dataclass
    class NodeInfo:
        name: str
        type: str
        domain: str

    def __init__(self) -> None:
        """
        Initialize matcher.
        """
        self.n: Optional[Element] = None
        self.hook: "FusionDebugHook" = None
        self._reset()

    def _get_model_dict(self) -> ModelDict:
        if self.n is None:
            raise MatcherError("cannot access _model_dict before matching")
        return self.n._model_dict

    def _reset(self) -> None:
        self.n = None

    def run_single_node(self, node: Node, logger: Logger) -> bool:
        """
        Return True if fusion for the node was successful, and False otherwise
        """
        matcher_name = self.get_matcher_class_name()
        try:
            self.do_match(node)
            logger.debug(
                f"Successfully matching pattern {matcher_name} for node {node._node_name}"
            )
        except MatcherError as e:
            logger.debug(
                f"Failing matching pattern {matcher_name} for node {node._node_name}: {e}"
            )
            return False

        self.do_modify(logger)
        return True

    def check_and_handle_error(
        self, runner: SafeRunner, node: Node, logger: Logger
    ) -> None:
        """
        Check if an error was caught by the runner, if that' the case we will
        add that node to the blacklist, emit a warning to inform the user and
        raise a SafeRunnerError for the caller.
        """
        if not runner.has_failed:
            return

        node_name = node._node_name
        self.node_blacklist.add(node_name)

        logger.warning(
            f"Error when matching pattern {self.get_matcher_class_name()} for node '{node_name}'. It will be skipped if fusion is restarted."
        )
        logger.reset_indentation()

        runner.raise_error()

    def run(
        self,
        model_dict: ModelDict,
        walk_cfg: "WalkCfgBase",
        runner: SafeRunner,
        logger: Logger,
    ) -> int:
        match_cnt = 0
        matcher_name = self.get_matcher_class_name()
        logger.debug(f"Start matching node for pattern {matcher_name}")
        logger.increase_indentation()
        for node_name in model_dict.get_node_names():
            if node_name in self.node_blacklist or not model_dict.has_node(node_name):
                logger.debug(
                    f"Skipping matching for node '{node_name}' as it is blacklisted due to failure."
                )
                continue

            node = Node(model_dict, walk_cfg, node_name)
            matched = runner.run(self.run_single_node, node, logger)
            self.check_and_handle_error(runner, node, logger)

            if matched:
                match_cnt += 1

        logger.decrease_indentation()
        return match_cnt

    def do_match(self, n: Element) -> None:
        """
        Set new root node.
        Reset internal state.
        Match the node and its surroundings to the pattern.
        n -- Root element (node) to match to pattern
        The function returns if there is a match.
        The function raises MatcherError or a derived exception if there is no
        match.
        """
        self._reset()
        self.n = n
        self._get_model_dict()._lost_ini_helper.reset()
        self.match()

    def do_match_first(self, *elems: Element) -> Element:
        """
        Call do_match on each of the passed nodes.
        Stop on the first match and return the matching node.
        Raise MatcherError if no node matched.
        """
        msgs: list[str] = []
        for elem in elems:
            try:
                self.do_match(elem)
                return elem
            except MatcherError as m_e:
                msgs.append(str(m_e))
        raise NoMatch("Element(s) did not match: " + "; ".join(msgs))

    def do_match_first_opt(self, *elems: Element) -> Optional[Element]:
        """
        Call do_match on each of the passed nodes.
        Stop on the first match and return the matching node.
        Return None is no node matched.
        """
        try:
            return self.do_match_first(elems)
        except MemoryError:
            return None

    def collect_required_new_node_attributes(self):
        """
        Collect the required attributes needed to perform the modifications.
        If the attributes are present in the central node, use them. Otherwise,
        get their new values.
        """
        self.required_attr = dict()

        attributes = self.n.get_attributes()
        self.required_attr["orig_name"] = attributes.get("orig_name", self.n.get_name())

    def do_modify(self, logger: Logger) -> None:
        """
        Apply the modification of the matched pattern to the graph.
        Clean up unconnected graph parts left behind by modification.
        Perform a sanity check on the graph.
        """
        self.collect_required_new_node_attributes()
        self.modify()
        self.n._model_dict.remove_unconnected()
        if logger.context.debug:
            self._get_model_dict()._lost_ini_helper.check(
                self.get_matcher_class_name(), logger
            )
        self.n._model_dict.remove_unused_ini_nodes()
        # the sanity check is rather expensive in terms of execution time (about
        # 0.1s for PSR model), enable only for debugging when something goes
        # wrong
        # self.n._model_dict.sanity_check()

    @abstractmethod
    def match(self) -> None:
        """
        Match the node and its surroundings to the pattern.
        The function returns if there is a match.
        The function raises MatcherError or a derived exception if there is no
        match.
        """

    @abstractmethod
    def modify(self) -> None:
        """
        Apply the modification of the matched pattern to the graph.
        This may be a fusion, a replacement, or a removal.
        """

    def add_node(
        self,
        type: str,
        domain: str,
        inputs: dict[str, Optional[Tensor]],
        outputs: dict[str, Optional[Tensor]],
        attributes: Optional[dict[str, Any]] = None,
        new_name: Optional[str] = None,
        add_matcher_name: bool = True,
        model_dict: Optional[ModelDict] = None,
        required_attr: Optional[dict[str, str]] = None,
    ) -> Node:
        if model_dict is None:
            model_dict = self._get_model_dict()
        if required_attr is not None:
            self.required_attr = required_attr
        # Filter all inputs and outputs which are None and use the name of the
        # tensor as value.
        inputs = {
            k: v.get_name()
            for k, v in inputs.items()
            if v is not None and not v.check_nowhere()
        }

        outputs = {
            k: v.get_name()
            for k, v in outputs.items()
            if v is not None and not v.check_nowhere()
        }

        # For the nodes of the original graph the attributes are added in
        # update_node_attributes before calling modify.
        # For fused node they are carried over the central node of the pattern.
        if attributes is None:
            attributes: dict[str, Any] = {}
        assert (
            "orig_name" in self.required_attr
        ), "original name is needed for the naming schema of the new nodes"  # sanity check

        # If the required attributes were not specified,
        # add them to the attributes of the new node

        if "orig_name" not in attributes:
            attributes["orig_name"] = self.required_attr["orig_name"]

        attributes["native_dtype"] = self.get_native_dtype(type)

        # Extract the original name of the "central" node of the pattern.
        # WARNING: This hack is needed in order to keep the node naming equivalence
        # between the old yaml flow and the new python one.
        name = attributes["orig_name"] + "_" + type

        # In certain cases we need to force the name to be a specific one and not following the
        # naming scheme of: new= <original_name> + "_" + <type>
        # TODO: remove this hack when not needed anymore
        if new_name is not None:
            name = new_name

        model_dict.append_node(name, type, inputs, outputs, domain)

        for attr_name, attr_value in attributes.items():
            model_dict.append_attribute(name, attr_name, attr_value)

        if add_matcher_name:
            model_dict.append_attribute(
                name, matcher_name_attribute, self.get_matcher_class_name()
            )

        return Node(model_dict, WalkCfgPlain(), name)

    def flag_initializers_used(self, *initializers: Tensor):
        """
        Indicates the usage of an initializer for the lost initializer check
        """
        for initializer in initializers:
            initializer.require_initializer()
            initializer._model_dict._lost_ini_helper.inc(initializer.get_name())

    def add_input(self, node: Element, tensor: Tensor) -> None:
        node = node.get_non_tensor()
        assert isinstance(node, Node), "expected node"  # sanity check
        node._model_dict.append_input(node.get_name(), tensor.get_name())

    def add_output(self, node: Element, tensor: Tensor) -> None:
        node = node.get_non_tensor()
        assert isinstance(node, Node), "expected node"  # sanity check
        node._model_dict.append_output(node.get_name(), tensor.get_name())

    def add_initializer(
        self,
        initializer_name: str,
        value: np.ndarray | float | int,
        # `dtype`` is not required if the value is `np.ndarray`
        dtype: Optional["onnx.TensorProto.DataType"] | OnnxDType = None,
    ) -> Initializer:
        md = self._get_model_dict()
        if isinstance(value, np.ndarray):
            md.add_initializer(onnx.numpy_helper.from_array(value, initializer_name))
        elif isinstance(value, float) or isinstance(value, int):
            assert (
                dtype is not None
            ), "provide a data type for the new initializer"  # sanity check
            tensor = onnx.helper.make_tensor(initializer_name, dtype, [], [value])
            md.add_initializer(tensor)
        return Initializer(md, WalkCfgPlain(), initializer_name)

    def add_transposed_initializer(
        self,
        initializer: Initializer,
        transposed_initializer_name: str,
        permutation: Optional[list[int]] = None,
    ) -> Initializer:
        initializer._model_dict.add_transposed_initializer(
            initializer.get_name(), transposed_initializer_name, permutation
        )
        return Initializer(
            initializer._model_dict,
            initializer._walk_cfg,
            transposed_initializer_name,
        )

    def get_initializers_for_split(
        self, initializer: Initializer, split_factor: int
    ) -> Optional[list[Initializer]]:
        initializer_name = initializer.get_name()
        added_initializer_names = initializer._model_dict.add_split_initializer(
            initializer_name, split_factor
        )

        # if the initializer is scalar or cannot be split into equal parts with the `split_factor` along the `axis``,
        # new initializers will not be created
        if added_initializer_names == "scalar":
            return [initializer] * split_factor
        if added_initializer_names == "impossible_split":
            raise MatcherError(
                f"It's mpossible to split the initializer into {split_factor} equal parts."
            )

        return [
            Initializer(
                initializer._model_dict,
                initializer._walk_cfg,
                name,
            )
            for name in added_initializer_names
        ]

    def connect(self, trgt: Element, src: Element) -> None:
        """
        Connect the target to the src.
        Whatever is connected to the target will be disconnected.
        trgt is supposed to be an OutputTensor, src is supposed to be an
        InputTensor. (The type annotations don't reflect this, because this is
        called with Elements).
        For all readers of the target tensor: set the tensor name of the input
        of the target reader to the name of the src tensor.
        """
        trgt_tensor = trgt.require_tensor()
        assert (
            trgt_tensor.check_graph_output() or trgt_tensor.check_output_tensor()
        )  # sanity check

        src_tensor = src.require_tensor()
        assert (
            src_tensor.check_graph_input() or src_tensor.check_input_tensor()
        )  # sanity check

        # If the target tensor is an output of the graph,
        # change the node which produce the src_tensor so
        # it now produce this graph output
        if trgt_tensor.check_graph_output():
            assert (
                not src_tensor.check_graph_input()
            ), "Cannot connect a graph input to a graph output"  # sanity check

            src_node = src_tensor.get_writer().require_node()
            new_input_name = trgt_tensor.get_name()
            src_node._model_dict.replace_output(
                src_node.get_name(), src_tensor.get_name(), new_input_name
            )

            # Connect all readers of src_tensor to the new output which replaces it
            for src_rd in src_tensor.get_readers():
                assert (
                    src_rd.check_node()
                ), f"Matcher.connect(): src_rd {src_rd.get_name()} type {type(src_rd)} not implemented"
                src_reader = src_rd.require_node()
                src_reader._model_dict.replace_input(
                    src_reader.get_name(), src_tensor.get_name(), new_input_name
                )
        # if target tensor is not a graph output we can connect all of its
        # readers to src tensor
        else:
            new_input_name = src_tensor.get_name()
            for trgt_rd in trgt_tensor.get_readers():
                assert (
                    trgt_rd.check_node()
                ), f"Matcher.connect(): trgt_rd {trgt_rd.get_name()} type {type(trgt_rd)} not implemented"
                trgt_reader = trgt_rd.require_node()
                trgt_reader._model_dict.replace_input(
                    trgt_reader.get_name(),
                    trgt_tensor.get_name(),
                    new_input_name,
                )

    def contains_matcher(self, matcher_name: str) -> bool:
        return self.get_matcher_class_name() == matcher_name

    def get_matchers(self) -> list["Matcher"]:
        return [self]

    @classmethod
    def get_matcher_class_name(cls) -> str:
        return cls.__name__

    def get_native_dtype(self, op_type: str) -> Optional[str]:
        kml = KernelMetadataLoader()
        return kml.get_native_dtype(op_type)

    def remove_input(self, input_tensor: Element) -> None:
        assert isinstance(
            input_tensor, InputTensor
        ), "expected input tensor"  # sanity check
        node = input_tensor.get_origin()
        node._model_dict.remove_input(node.get_name(), input_tensor.get_name())

    def remove_inputs(self, *input_tensors: InputTensor) -> None:
        for input_tensor in input_tensors:
            self.remove_input(input_tensor)

    def remove_output(self, output_tensor: Element) -> None:
        assert isinstance(
            output_tensor, OutputTensor
        ), "expected output tensor"  # sanity check
        node = output_tensor.get_origin()
        node._model_dict.remove_output(node.get_name(), output_tensor.get_name())

    def remove_outputs(self, *output_tensors: OutputTensor) -> None:
        for output_tensor in output_tensors:
            self.remove_output(output_tensor)

    def remove_node(self, node: Element) -> None:
        node = node.get_non_tensor()
        assert isinstance(node, Node), "expected node"  # sanity check
        node._model_dict.remove_node(node.get_name())

    def remove_nodes(self, *nodes: Element) -> None:
        for node in nodes:
            self.remove_node(node)

    def replace_input(
        self,
        node: Element,
        input_tensor_old: InputTensor,
        input_tensor_new: Tensor,
    ) -> None:
        node = node.get_non_tensor()
        assert isinstance(node, Node), "expected node"  # sanity check
        node._model_dict.replace_input(
            node.get_name(),
            input_tensor_old.get_name(),
            input_tensor_new.get_name(),
        )

    def replace_output(
        self,
        node: Element,
        output_tensor_old: InputTensor,
        output_tensor_new: Tensor,
    ) -> None:
        node = node.get_non_tensor()
        assert isinstance(node, Node), "expected node"  # sanity check
        node._model_dict.replace_output(
            node.get_name(),
            output_tensor_old.get_name(),
            output_tensor_new.get_name(),
        )


class Category(MatcherOrCategory):
    """
    A category of matchers or sub-categories.
    """

    def __init__(self, matchers_or_categories: list["MatcherOrCategory"]) -> None:
        self.matchers_or_categories = matchers_or_categories

    def contains_matcher(self, matcher_name: str) -> bool:
        for dependency in self.matchers_or_categories:
            if dependency.contains_matcher(matcher_name):
                return True
        return False

    def get_matchers(self) -> list[MatcherOrCategory]:
        matchers: list[Matcher] = []
        for matcher_or_category in self.matchers_or_categories:
            matchers += matcher_or_category.get_matchers()
        return matchers
