# fmt on
# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    Matcher,
    NoMatch,
    Node,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.checkers import opType


class MergeMultipleOutput(Matcher, QDQHelper):
    """
    Merges multiple outputs of a node into a single output if they have the same shape and dtype
    """

    def check_multiple_readers_equality(self, readers: list[Element]) -> None:
        """
        Check if multiple QuantizeLinear readers have the same scale and zero point and dtype.
        """

        first_reader = readers[0].require_node().require(opType.QuantizeLinear)
        first_reader_y_scale = first_reader("y_scale").require_initializer()
        first_reader_y_zero_point = first_reader("y_zero_point").require_initializer()
        first_reader_dtype = first_reader_y_zero_point.get_dtype()

        # Check all other readers against the first one
        for reader in readers[1:]:
            reader.require_node().require(opType.QuantizeLinear)
            reader_y_scale = reader("y_scale").require_initializer()
            reader_y_zero_point = reader("y_zero_point").require_initializer()
            reader_dtype = reader_y_zero_point.get_dtype()

            if reader_dtype != first_reader_dtype:
                raise NoMatch("Reader dtypes do not match")

            if not self.check_tensor_equal_value(first_reader_y_scale, reader_y_scale):
                raise NoMatch("Reader y_scale initializers do not match")

            if not self.check_tensor_equal_value(
                first_reader_y_zero_point, reader_y_zero_point
            ):
                raise NoMatch("Reader y_zero_point initializers do not match")

    def _merge_multiple_qnode_to_single_one(self, n: Node) -> None:
        """
        consolidate multiple QuantizeLinear readers into a single one by redirecting inputs and removing redundant nodes.
        This method is performed only when multiple readers are detected with same scale, zero-point, and dtype.
        This helps in maintaining a cleaner graph structure and avoids potential conflicts arising from multiple readers.

        Before:
                    ┌─ > Q1( scale, zp) ─> Consumer1
                    │
         Node(out) ─┼─ > Q2( scale, zp) ─> Consumer2
                    │
                    └─ > Q3( scale, zp) ─> Consumer3

        After:

                                      ─> Consumer1
         Node(out) ── >Q1( scale, zp) ─> Consumer2
                                      ─> Consumer3
        """

        for output in n.get_outputs():
            readers = output.require_tensor().get_readers()
            ref_tensor = readers[0]("y").require_tensor()
            for i in range(len(readers)):
                tensor = readers[i]("y").require_tensor()
                # Find the input tensor connected to this reader's 'y' input
                for input in readers[i]("y").require_node().get_inputs():
                    if input.get_name() == tensor.get_name():
                        in_tensor = input
                        break
                self.replace_input(
                    readers[i]("y").require_node(),
                    in_tensor,
                    ref_tensor,
                )

    def match(self) -> None:
        n = self.n.require_node().with_walk_cfg(WalkCfgPlain())
        for output in n.get_outputs():
            readers = output.require_tensor().get_readers()
            if len(readers) <= 1:
                raise NoMatch("output has a single reader.")
            self.check_multiple_readers_equality(readers)

    def modify(self) -> None:
        n = self.n.require_node().with_walk_cfg(WalkCfgPlain())
        self._merge_multiple_qnode_to_single_one(n)
