# fmt: on
from collections import defaultdict
from itertools import groupby
from math import prod
import numpy as np
from typing import Any, Iterable
from dataclasses import dataclass, field
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    Initializer,
    InputTensor,
    MatcherError,
    NoMatch,
    Node,
    OutputTensor,
    Tensor,
)
from OGOAT.src.L1_fusion.py_match.skip import WalkCfgSkipNoop
from OGOAT.src.L1_fusion.py_match.checkers import AttrValue, opType
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper


@dataclass(frozen=True)
class AttributeBatchSignature:
    """
    The batch signature of an attribute is an immutable extract from the
    attribute that contains the information that needs to be equal for several
    nodes to be batched together. As this batch signature is immutable, it is
    hashable and can be used to as a key of a map.
    The attribute batch signature contains:
      - attribute name
      - attribute value
    """

    name: str
    value: tuple[float, ...]

    @staticmethod
    def from_attribute(name: str, value: Any) -> "AttributeBatchSignature":
        if isinstance(value, (str, int, float)):
            val = value
        elif isinstance(value, list) and all(
            isinstance(v, (int, float)) for v in value
        ):
            val = tuple(float(v) for v in value)
        else:
            raise MatcherError(
                f"AttributeBatchSignature: Unsupported attribute value type: {type(value)}"
            )
        return AttributeBatchSignature(name=name, value=val)


@dataclass(frozen=True)
class InputOutputBatchSignature:
    """
    The batch signature of an input or output is an immutable extract from the
    input or output that contains the information about this input or output
    that needs to be equal for several nodes to be batched together. As this
    batch signature is immutable, it is hashable and can be used to as a key of
    a map.
    The input or output batch signature contains:
      - input parameter name
      - data type of the input
      - shape of the input
         - FIXME: maybe leave out the first dimension of the shape, because it
                  might be possible/desirable to batch already batched nodes,
                  i.e., num_heads=3 shape 3 x 8 x 16 x 32
                  plus num_heads=7 shape 7 x 8 x 16 x 32
                  to num:_heads=10 shape 10 x 8 x 16 x 32
      - ... additional fields depending on input or output
    """

    prm_name: str
    dtype: str
    shape: tuple[int, ...]


@dataclass(frozen=True)
class InputBatchSignature(InputOutputBatchSignature):
    """
    The input batch signature contains the following additional fields:
      - whether the input is an initializer or not
    """

    is_initializer: bool

    @staticmethod
    def from_input(in_prm_name: str, input_: InputTensor) -> "InputBatchSignature":
        return InputBatchSignature(
            prm_name=in_prm_name,
            dtype=input_.get_dtype(),
            shape=tuple(input_.get_shape()),
            is_initializer=input_.check_initializer(),
        )


@dataclass(frozen=True)
class OutputBatchSignature(InputOutputBatchSignature):
    """
    The output batch signature contains the following  fields:
        - out_prm_name which is schema output name
        - dtype of the out tensor
        - shape of the out tensor

    """

    @staticmethod
    def from_output(
        out_prm_name: str,
        output_: OutputTensor,
    ) -> "OutputBatchSignature":
        return OutputBatchSignature(
            prm_name=out_prm_name,
            dtype=output_.get_dtype(),
            shape=tuple(output_.get_shape()),
        )


@dataclass(frozen=True)
class NodeBatchSignature:
    """
    The batch signature of a node is an immutable extract from the node that
    contains the information that needs to be equal for several nodes to be
    batched together. As this batch signure is immutable, it is hashable and
    can be used to as a key of a map.
    The node batch signature contains:
      - level (max distance to any input)
      - op_type
      - attributes
      - signatures of the inputs
      - signatures of the outputs
      - output tensor name
    """

    op_type: str
    attributes: tuple[AttributeBatchSignature, ...]
    inputs: tuple[InputBatchSignature, ...]
    outputs: tuple[OutputBatchSignature, ...]
    single_output_reader: bool

    @staticmethod
    def from_node(
        node: Node,
        ignore_attributes: Iterable[str],
        batch_by_out_tensor: bool,
    ) -> "NodeBatchSignature":

        single_output_reader = False

        # If 'batch_by_out_tensor' config option is enabled, extract and store the output tensor name using WalkCfgSkipNoop traversal.
        if batch_by_out_tensor:
            outputs = node.with_walk_cfg(WalkCfgSkipNoop()).get_outputs()
            # Currently we only support batching nodes having only one output and  one reader.
            # FIXME This is a limitation, we should support batching nodes with multiple readers once the infrastructure is ready.
            if len(outputs) == 1:
                out_tensor_after_skip = outputs[0].skip()
                readers = out_tensor_after_skip.get_readers()
                # move forward if the out_tensor_after_skip has single readers
                if len(readers) == 1:
                    # check if concat runtime is created by batching
                    if out_tensor_after_skip.check(opType.Concat_runtime) and out_tensor_after_skip.require_node().get_attribute_value("create_by_batch") is None:
                        single_output_reader = True


        return NodeBatchSignature(
            op_type=node.get_op_type(),
            attributes=tuple(
                AttributeBatchSignature.from_attribute(name, value)
                for name, value in node.get_attributes().items()
                if name not in ignore_attributes
            ),
            inputs=tuple(
                InputBatchSignature.from_input(prm, inp)
                for prm, inp in zip(node.get_schema_input_names(), node.get_inputs())
            ),
            outputs=tuple(
                OutputBatchSignature.from_output(prm, out)
                for prm, out in zip(node.get_schema_output_names(), node.get_outputs())
            ),
            single_output_reader=single_output_reader,
        )


@dataclass(frozen=True)
class UpwardChainSignature:
    chain: list["NodeBatchSignature"]
    end_tensor_names: list[str]

    def __hash__(self):
        hash_value = hash(None)
        for node_batch_signature in self.chain:
            hash_value = hash((hash_value, node_batch_signature))
        for end_tensor_name in self.end_tensor_names:
            hash_value = hash((hash_value, end_tensor_name))
        return hash_value


class SignatureChainBuilder:

    OP_LIST_TO_CONTINUE_TRAVERSE = {
        "Reshape",
        "Transpose",
        "Resize",
        "Squeeze",
        "Unsqueeze",
        "DequantizeLinear",
        "QuantizeLinear",
    }

    def get_upward_chain_signature(
        self, input_tensor: InputTensor
    ) -> UpwardChainSignature:
        chain: list["NodeBatchSignature"] = []
        end_tensor_names: list[str] = []
        self._traverse_upwards(input_tensor, chain, end_tensor_names)
        return UpwardChainSignature(
            chain=chain,
            end_tensor_names=end_tensor_names,
        )

    def _traverse_upwards(
        self,
        input_tensor: InputTensor,
        chain: list[NodeBatchSignature],
        end_tensor_names: list[str],
    ) -> None:
        """
        Traverse upwards in the graph using DFS and collect the node signature.
        """
        stack: list[InputTensor] = [input_tensor]
        visited: set[str] = set()

        while stack:
            current_tensor = stack.pop()
            tensor_name = current_tensor.get_name()
            if tensor_name in visited:
                continue
            visited.add(tensor_name)

            current_node = current_tensor.get_writer()

            # Early return if current node is not a Node or not in the allowed op list
            if (
                not current_node.check_node()
                or current_node.get_op_type() not in self.OP_LIST_TO_CONTINUE_TRAVERSE
            ):
                end_tensor_names.append(tensor_name)
                continue

            # Currently we do not support the comparaison of two chains of node with more than one output.
            # This is because it requires analyzing how the nodes are interconnected,
            # including which input of one node is connected to which output of another.
            # Therefore, we raise an error and treat the chain as unique.

            if len(current_node.get_outputs()) > 1:
                raise MatcherError(
                    f"Node '{current_node.get_name()}' has multiple outputs, which is not supported in upward chain traversal."
                )

            sig = NodeBatchSignature.from_node(
                current_node,
                ignore_attributes=["orig_name", "native_dtype"],
                batch_by_out_tensor=False,
            )
            chain.append(sig)

            # Add non-initializer inputs to stack for further traversal
            for input_tensor in current_node.get_inputs():
                if input_tensor.check_initializer():
                    continue
                stack.append(input_tensor)


@dataclass
class BatchDataGenerator(SignatureChainBuilder):
    node_list: list[Node]
    act_input_names: list[str]
    # Stores unique activation inputs for each activation input name across all nodes
    unique_act_inputs_dict: dict[str, dict[str, Tensor]] = field(default_factory=dict)
    # Stores sequence IDs for each activation input across all nodes
    act_inputs_seq_id_for_all_nodes: dict[str, list[int]] = field(default_factory=dict)
    # List of activation input names to be processed
    act_input_names: list[str] = field(default_factory=list)

    def generate_data(self):
        """Generate activation input names for the batchable node."""
        act_input_names = [
            input_name
            for input_name, inp_tensor in self.node_list[0].get_inputs_dict().items()
            if not inp_tensor.check_initializer()
        ]
        self.act_input_names = act_input_names
        self.generate_batched_data_from_act_inputs()

    def generate_batched_data_from_act_inputs(self):
        """For all batchable nodes, generate unique activation inputs dict, sequence IDs for each activation input."""

        # iterate over all activation inputs
        for act_input_name in self.act_input_names:
            # Dictionary to hold unique activation inputs for the current activation input name. For example, if "M" is the activation input name and there are 6 nodes where 3 has unique inputs,
            # it will store the 3 unique input tensors for "M" as dict.
            unique_act_inputs: dict[str, Tensor] = {}
            # Dictionary to map upward chain signatures to unique IDs
            sig_to_id: dict[UpwardChainSignature, int] = {}
            # List to hold sequence IDs for the current activation input
            sig_id_list: list[int] = []
            next_id = 0

            # iterate over all nodes and collect  signatures, unique input and sequence IDs for the current activation input
            for i, node in enumerate(self.node_list):
                input_tensor = node(act_input_name)
                upward_chain_sig = self.get_upward_chain_signature(input_tensor)
                if upward_chain_sig not in sig_to_id:
                    sig_to_id[upward_chain_sig] = next_id
                    next_id += 1
                    unique_act_inputs[f"inputs{i}"] = input_tensor

                # Store the sequence ID for the upward chain signature
                sig_id_list.append(sig_to_id[upward_chain_sig])

            self.unique_act_inputs_dict[act_input_name] = unique_act_inputs
            self.act_inputs_seq_id_for_all_nodes[act_input_name] = sig_id_list

    def get_activation_input_names(self) -> list[str]:
        """Get activation input names from the node list."""
        return self.act_input_names

    def get_unique_act_inputs_dict(self, act_input_name: str) -> dict[str, Tensor]:
        """Retrieve the unique activation input dict for a given activation input name."""
        return self.unique_act_inputs_dict[act_input_name]

    def get_all_unique_act_inputs_dict(self) -> dict[str, dict[str, Tensor]]:
        """Retrieve all unique activation input dicts."""
        return self.unique_act_inputs_dict

    def get_act_inputs_sig_id_list(self, act_input_name: str) -> list[int]:
        """Retrieve the signature ID list for a given activation input name."""
        return self.act_inputs_seq_id_for_all_nodes[act_input_name]


class BatchHelper(QDQHelper):
    """
    Helper methods for batch fusions.
    """

    def set_batch_dimension(self, dimension: int) -> None:
        """
        The batch dimension specifies the dimension of which the batching takes place. The batching starts counting from 1 and from the right (e.g. [3, 2, 1]).
        If the batched inputs dimensions is not large enough dimensions with value 1 is prepended.
        If the batched input dimensions are larger then the batch dimension, the first dimensions are collapsed (multiplied) until the number of dimensions match.
        """
        self.batch_dimension = dimension

    def _insert_reshape(
        self, name: str, data: Tensor, shape: Initializer, reshaped: Tensor
    ) -> Node:
        """
        Insert a reshape node with the given parameters.
        """
        inputs = {"data": data, "shape": shape}
        outputs = {"reshaped": reshaped}
        return self.add_node(
            type="Reshape",
            domain="",
            inputs=inputs,
            outputs=outputs,
            attributes={},
            new_name=name,
            add_matcher_name=False,
        )

    def _insert_reshape_after(
        self, output_tensor: Tensor, reshape_name: str, shape_before: list[int]
    ) -> Node:
        """
        Insert a Reshape before the input tensor.
        The new shape is 1 x old_shape (e.g. [x,y,z] -> [1,x,y,z])
        """
        shape_initializer = self._add_shape_initializer(output_tensor.get_shape())

        input_tensor = Tensor(
            output_tensor._model_dict,
            output_tensor._walk_cfg,
            reshape_name + "_input",
            reshape_name,
        )
        input_tensor.set_shape(shape_before, output_tensor.get_dtype())

        return self._insert_reshape(
            reshape_name, input_tensor, shape_initializer, output_tensor
        )

    def get_splitted_batch_outputs(
        self, batch_nodes: list[Node]
    ) -> dict[str, InputTensor]:
        """
        Add the batch split nodes for the outputs of the new batch node to the
        graph.
        Return the input tensors of the new split nodes as a dictionary with the
        same output parameter names and in the same order as the outputs of the
        original nodes.
        """

        # All the nodes to be batched have the same structure, so it is okay
        # to get the required information just from the first one.
        main_node = batch_nodes[0]
        out_prm_names = main_node.get_schema_output_names()
        main_node_name = main_node.get_name()

        # outputs[out_prm_idx][orig_node_idx] = output tensor
        outputs: list[list[Element]] = [[]] * len(out_prm_names)

        # collect outputs from all original nodes
        for node in batch_nodes:
            for out_prm_idx, outp in enumerate(node.get_outputs()):
                outputs[out_prm_idx].append(outp)

        # create the split nodes
        split_inputs = {}
        for out_prm_idx, out_prm_name in enumerate(out_prm_names):
            node_name = f"{main_node_name}_split_{out_prm_name}"

            # add for each output of the split a Reshape if dimensions do not match for batching
            # computation of batch output depends on these shapes
            for idx, output in enumerate(outputs[out_prm_idx]):
                shape = output.get_shape()
                individual_batching_shape = self.get_individual_shape_for_batching(
                    output
                )
                if shape != individual_batching_shape:
                    reshape = self._insert_reshape_after(
                        output,
                        node_name + "_output" + str(idx),
                        individual_batching_shape,
                    )
                    outputs[out_prm_idx][idx] = reshape("data").require_tensor()

            input_tensor = InputTensor(
                main_node._model_dict,
                main_node._walk_cfg,
                f"{main_node_name}_split_{out_prm_name}_input",
                node_name,
            )
            # the output of the batched node which is the input of the split is computed here
            # the computation depends on the shapes of the individual consumer nodes of the output
            # in order to get the dimension right, the shapes of the consumers need the added 1 dimension for batching if it does not already exists
            self.set_batched_tensor_shape_dtype(input_tensor, outputs[out_prm_idx])
            self.add_node(
                "Split_runtime",
                domain="ai.onnx.contrib",
                inputs={"input": input_tensor},
                outputs={
                    f"output_{orig_node_idx}": outp
                    for orig_node_idx, outp in enumerate(outputs[out_prm_idx])
                },
                attributes={"axis": 0, "num_outputs": len(batch_nodes)},
                new_name=node_name,
                add_matcher_name=False,
            )
            split_inputs[out_prm_name] = input_tensor

        return split_inputs

    def _insert_reshape_before(
        self, input_tensor: Tensor, reshape_name: str, shape_after: list[int]
    ) -> Node:
        """
        Insert a Reshape before the input tensor.
        The new shape is 1 x old_shape (e.g. [x,y,z] -> [1,x,y,z])
        """

        shape_initializer = self._add_shape_initializer(shape_after)

        output_tensor = Tensor(
            input_tensor._model_dict,
            input_tensor._walk_cfg,
            reshape_name + "_output",
            reshape_name,
        )
        output_tensor.set_shape(shape_after, input_tensor.get_dtype())

        return self._insert_reshape(
            reshape_name, input_tensor, shape_initializer, output_tensor
        )

    def create_concat_runtime_node(
        self, n: Node, input_name: str, inputs: dict[str, Tensor]
    ) -> Node:
        """
        Create a Concat runtime node at input to concatenate all the input tensors of a batched node.
        n -- node to be batched
        input_name -- name of the input parameter
        inputs -- dictionary with the input parameter names as keys and the
                  input tensors as values
        """
        new_name = input_name + "_Concat_runtime_" + n.get_attribute_value("orig_name")
        output_tensor = Tensor(
            n._model_dict,
            n._walk_cfg,
            f'{input_name}_{n.get_attribute_value("orig_name")}_Concat_runtime_output',
            None,
        )

        # add for each output of the split a Reshape if dimensions do not match for batching
        for key in inputs.keys():
            input_tensor = inputs[key]
            shape = input_tensor.get_shape()
            # FIXME The individual shape computation is currently limited to non batched matmul
            # the following case is not supported:
            # For e.g 3 matmul with A shape = [3, 2, 6] and B shape = [3, 6, 10] and batch_dimension = 3.
            # As result of this function the new shapes will be A shape [1, 6, 6] and B shape [1, 18, 10] and after batching the shape of the new node will be [3, 6, 6] and [3, 18, 10]
            # which violate the requirement that the K dimension should be the same.

            individual_batching_shape = self.get_individual_shape_for_batching(
                input_tensor
            )
            if shape != individual_batching_shape:
                reshape = self._insert_reshape_before(
                    input_tensor,
                    new_name + key + "Reshape",
                    individual_batching_shape,
                )
                inputs[key] = reshape("reshaped").require_tensor()

        input_tensors = list(inputs.values())
        self.set_batched_tensor_shape_dtype(output_tensor, input_tensors)
        self.n = n
        return self.add_node(
            type="Concat_runtime",
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs={"concat_result": output_tensor},
            attributes={"axis": 0, "create_by_batch": True},
            new_name=new_name,
            add_matcher_name=False,
            model_dict=n._model_dict,
        )

    def check_input_shape(self, node: Node, min_length: int) -> None:
        """Check the shape of the inputs to ensure they are valid for batch fusion."""
        for val in node.get_inputs_dict().values():
            if not val.check_initializer() and len(val.get_shape()) < min_length:
                raise MatcherError(
                    f"The shape of input tensor - {val} is not valid for batch fusion."
                )

    def get_init_vals(
        self,
        node_name: str,
        inputs: dict[str, InputTensor],
        initializer_value_arrays: dict[str, dict[str, Any]],
        input_name: str = None,
    ) -> None:
        """
        Get the initializer values from inputs
        node_name -- name of the node with the batch inputs
        inputs -- dictionary with the input parameter names as keys and the
                  input tensors as values. input parameter names only if
                  a input dict has been passed and not the inputs only
        initializer_value_arrays -- dictionary consisting of all initializers for
                                    all inputs
        input_name -- name of the input, if None, the key of the input dictionary is used
        """
        for key, value in inputs.items():
            if not value.check_initializer():
                continue
            value.require_initializer().flag_used()

            array_val = np.atleast_1d(value.require_initializer().get_value_as_array())

            if array_val.shape[0] != 1:
                array_val = np.reshape(array_val, tuple([1] + list(array_val.shape)))
            dtype = value.get_dtype()
            val_dict = initializer_value_arrays[input_name if input_name else key]

            val_dict["initializer_vals"].append(array_val)

            if val_dict["dtype"] is None:
                val_dict["dtype"] = dtype
                val_dict["name"] = f"{node_name}{value.get_name()}"

    def get_reshaped_tensor(self, input_tensor: Tensor, node_name: str) -> Tensor:
        """
        Prepare an input tensor for batching. If its shape does not match the required
        individual batching shape, insert a reshape node and return the reshaped tensor.
        Otherwise, return the original tensor.
        """
        shape = input_tensor.get_shape()
        individual_batching_shape = self.get_individual_shape_for_batching(input_tensor)

        if shape != individual_batching_shape:
            reshape = self._insert_reshape_before(
                input_tensor,
                "Reshape_" + node_name,
                individual_batching_shape,
            )
            input_tensor = reshape("reshaped").require_tensor()
        return input_tensor

    def _is_ordered(self, seq_list: list[int]) -> bool:
        """
        Returns True if the seq_list consists of repeated, contiguous, ordered sequences else false.
        For example:   [0, 1, 0, 1] is ordered, but [0, 1, 1, 0] is not.
                    or [0, 1, 2, 0, 1, 2] is ordered, but [0, 1, 2, 2, 0, 1] is not.
                    or [0, 0, 1, 1, 2, 2] is marked as ordered, but [0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2] is not.
        see more testcases in IsOrderedTests class of OGOAT/testing/py_match_fast_unit_tests.py  file.
        """

        # number of unique elements in list
        unique_elem_count = len(set(seq_list))
        if unique_elem_count < 1:
            return False
        # number of times each element must be repeated in list
        repeat_count = len(seq_list) // unique_elem_count
        if unique_elem_count * repeat_count != len(seq_list):
            return False  # division does not yield a whole integer
        # first supported case: 0, 1, 2, 0, 1, 2
        # list must equal first unique_elem_count repeated repeat_count times
        if seq_list == seq_list[:unique_elem_count] * repeat_count:
            return True
        # second supported case: 0, 0, 1, 1, 2, 2
        # list sampled with stride repeat_count must be equal
        if seq_list == [
            n for n in seq_list[::repeat_count] for r in range(repeat_count)
        ]:
            return True
        # other cases not supported
        return False

    def get_concated_batch_inputs(
        self,
        node_list: list[Node],
        unique_act_inputs_dict: dict[str, dict[str, Tensor]],
    ) -> dict[str, Tensor]:
        """
        Generate the input tensors for the new batch node.
        node_list -- list of batched nodes
        unique_act_inputs_dict -- dictionary with the activation input names as keys
                                 and a dictionary of unique input tensors as values.
        return -- dictionary with the input parameter names as keys and the
                 input tensors as values
        """

        initializer_value_arrays = defaultdict[str, dict[str, Any]](
            lambda: {"initializer_vals": [], "dtype": None, "name": None}
        )
        batch_inputs: dict[str, Tensor] = {}

        for input_name, unique_act_inputs in unique_act_inputs_dict.items():
            if len(unique_act_inputs) == 1:
                # Only one unique input tensor exists
                _, unique_tensor = next(iter(unique_act_inputs.items()))
                batch_inputs[input_name] = self.get_reshaped_tensor(
                    unique_tensor,
                    node_list[0].get_attribute_value("orig_name") + input_name,
                )
            else:
                # Inputs differ, need to concat
                batch_inputs[input_name] = self.create_concat_runtime_node(
                    node_list[0], input_name, unique_act_inputs
                )("concat_result")

        # fetch the initializer tensors values as np array from the original nodes
        for node in node_list:
            node_name = node.get_attribute_value("orig_name")
            inputs_dict = node.get_inputs_dict()
            self.get_init_vals(node_name, inputs_dict, initializer_value_arrays)

        # update inputs with the initializer tensors
        batch_inputs |= {
            key: self.add_initializer(
                val["name"],
                np.concatenate(val["initializer_vals"]),
                val["dtype"],
            )
            for key, val in initializer_value_arrays.items()
        }
        return batch_inputs

    def get_batched_shape(self, individual_shapes: list[list[int]]) -> list[int]:
        """
        Get the batched shape from the individual shapes before batching.
        individual_shapes -- list of shapes of inputs/outputs of original nodes,
                             shapes may only differ in the first dimension
        return -- shape of input/output of batched node
        """

        first = sum(i_s[0] for i_s in individual_shapes)
        return [first] + individual_shapes[0][1:]

    def cmp_batch_position(self, A: Element, B: Element) -> int:
        """
        Compare two batch node and try to determine their positions in the batching.
        To compare them we find a direct common producer for their first input (arbitrary choice here)
        and compare their position in the input list.
        The node are considered equal if no direct common node can be found or if they are connected
        by the same tensor (fork of the output tensor of the producer).
        This allow us to be able to optimize more chains of slice -> concat produced between two batched layers.
        """

        # Enable skipping the noop operations
        n1 = A.with_walk_cfg(WalkCfgSkipNoop())
        n2 = B.with_walk_cfg(WalkCfgSkipNoop())

        # Get the list of inputs for both nodes and extract the first input tensor, skipping all noop.
        outputs1 = n1.get_outputs()
        outputs2 = n2.get_outputs()
        if not outputs1 or not outputs2:
            # nodes A and B should compare equal, compare by name to make output stable
            return BatchHelper.compare_by_string(A.get_name(), B.get_name())

        output1 = outputs1[0].skip()
        output2 = outputs2[0].skip()

        # See if a node is producing the input tensors and that the two inputs
        # are produced by the same node.
        output_n1 = output1.get_reader()
        output_n2 = output2.get_reader()
        if (
            not output_n1.check_node()
            or not output_n2.check_node()
            or output_n1 != output_n2
        ):
            # nodes A and B should compare equal, compare by name to make output stable
            return BatchHelper.compare_by_string(A.get_name(), B.get_name())

        # find the index of the output connected to the first and second node
        output_idx_n1 = None
        output_idx_n2 = None
        inputs = output_n1.get_inputs()
        for idx in range(len(inputs)):
            if inputs[idx] == output1:
                output_idx_n1 = idx
            if inputs[idx] == output2:
                output_idx_n2 = idx

        # order the node following the output list order
        if output_idx_n2 is None or output_idx_n1 is None:
            raise MatcherError(
                "Could not determine output indices for batch position comparison."
            )
        return output_idx_n1 - output_idx_n2

    @staticmethod
    def compare_by_string(A: str, B: str) -> int:
        if A < B:
            return -1
        elif A > B:
            return 1
        return 0

    def get_batch_nodes(
        self,
        node: Element,
        op_same_level: list[Node],
        # level: int,
        batch_by_out_tensor: bool,
    ) -> list[Node]:
        """
        Get the set of all nodes that can be batched with the passed node.
        If node is not a Node, raise NodeNotFound.
        If only one node is found, i.e., there are no nodes to be batched,
        raise MatcherError.
        batch_by_out_tensor bool flag indicates the additional condition for batching the nodes. If true, nodes are connected to same output tensor can be batched. False otherwise.
        """

        sig = self.get_batch_signature(node, batch_by_out_tensor)
        batch_nodes = [
            so
            for so in op_same_level
            if self.get_batch_signature(so, batch_by_out_tensor) == sig
        ]

        return batch_nodes

    def get_batch_signature(
        self, node: Node, batch_by_out_tensor: bool
    ) -> NodeBatchSignature:
        return NodeBatchSignature.from_node(
            node,
            ignore_attributes=["orig_name", "native_dtype"],
            batch_by_out_tensor=batch_by_out_tensor,
        )

    def set_batched_tensor_shape_dtype(
        self, batched_tensor: Tensor, individual_tensors: list[Tensor]
    ) -> None:
        """
        Set shape and dtype for new batched tensor based on individual tensors.
        batched_tensor -- batched tensor, whose shape and dtype will be set
        individual_tensors -- list of input/output tensors of original nodes,
                              shapes must be batch-able, dtypes must match
        """
        shape = self.get_batched_shape([t.get_shape() for t in individual_tensors])
        dtype = individual_tensors[0].get_dtype()
        batched_tensor.set_shape(shape, dtype)

    def get_individual_shape_for_batching(self, individual_tensor: Tensor) -> list[int]:
        """
        The shapes of the individual tensors needs a leading 1 in the shape.

        Returns a shape which is suitable for batching, by adding 1 in front of the dimensions if needed.
        """
        shape = individual_tensor.get_shape()
        if len(shape) < self.batch_dimension:
            return [1] * (self.batch_dimension - len(shape)) + shape

        while len(shape) >= self.batch_dimension:
            dim = shape.pop(0)
            shape[0] *= dim

        assert len(shape) == self.batch_dimension - 1
        return [1] + shape

    def _add_shape_initializer(self, shape: list[int]) -> Initializer:
        name = "shape_" + "x".join(list(map(str, shape)))
        return self.add_initializer(name, np.array(shape).astype(np.int64))

    @staticmethod
    def extract_matmul_batch_nb(a_shape: list[int], b_shape: list[int]) -> int:
        batched_dims_a = a_shape[:-2]
        batched_dims_b = b_shape[:-2]

        batch_dim_a = prod(batched_dims_a)
        batch_dim_b = prod(batched_dims_b)

        # If B has a batch dimension of 1, we can reshape the batch dimensions
        # of A so it also has a batch dimension of 1:
        # from [batched_dims x M x K] to [1 x (batched_dims x M) x K]
        # which is still a valid shape for the matmul since B has a shape of [1 x K x N]
        if batch_dim_b == 1:
            return 1

        if batch_dim_a == batch_dim_b:
            return batch_dim_a

        raise NoMatch(
            f"Unsupported batching, batch dimensions should be the same: "
            f"A={a_shape} -> {batch_dim_a=}; "
            f"B={b_shape} -> {batch_dim_b=};"
        )
