from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BinaryOpHelper
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    NoMatch,
    Node,
)
from OGOAT.src.L1_fusion.py_match.py_match_utils import get_value_from_dequantize_linear


class LayerNormBasic(Matcher, QDQHelper):

    def match(self) -> None:
        n = self.n
        reducemean_node_first = n.require(opType.ReduceMean)
        dq_node = reducemean_node_first("data.x.x").require(opType.DequantizeLinear)
        sub_node_input = reducemean_node_first("reduced").require(opType.Sub)
        sub_node = sub_node_input.require_node()
        other_sub_input_name = BinaryOpHelper.get_other_input_name(
            sub_node, sub_node_input
        )

        # DQ node can be duplicated if its output is consummed by more than 1 nodes.
        # If that's the case we need to match the following pattern, with DQ1 and DQ2 having
        # the same scale and zp and same input:
        #
        # Q ---> DQ1 -> ReduceMean ------
        #    |                           |
        #    |                           v
        #    --> DQ2  ----------------> Sub ---> [...]
        if dq_node != sub_node(other_sub_input_name):
            second_dq = sub_node(other_sub_input_name).require(opType.DequantizeLinear)
            if not self.check_dequantize_equal_scale_zp(
                dq_node, second_dq
            ) or second_dq("x") != dq_node("x"):
                raise NoMatch(
                    f"dq nodes attached to reducemean and sub should be equivalent and have the same input"
                )

        pow_node: Node = None
        div_node: Node = None
        sub_node_readers = sub_node("C").require_tensor().get_readers()
        if len(sub_node_readers) != 2:
            raise NoMatch(f"{self} does not match LayerNormBasic pattern")
        for reader in sub_node_readers:
            if reader.check(opType.Pow):
                pow_node = reader.require(opType.Pow).require_node()
            elif reader.check(opType.Div):
                div_node = reader.require(opType.Div).require_node()
        if pow_node is None or div_node is None:
            raise NoMatch(f"{self} does not match LayerNormBasic pattern")

        if div_node("A").require_node() != sub_node:
            raise NoMatch(f"{self} does not match LayerNormBasic pattern")

        reducemean_node_second = pow_node("Z").require(opType.ReduceMean)

        add_node_input = self.go_through_downward_qdq_chain(
            reducemean_node_second("reduced")
        )
        add_node_input = self.go_through_downward_qdq_chain(add_node_input)
        add_node_input.require(opType.Add)
        # FIXME find correct epsilon value
        add_node = add_node_input.require_node()
        other_add_input_name = BinaryOpHelper.get_other_input_name(
            add_node, add_node_input
        )
        epsilon_dq_node = (
            add_node(other_add_input_name)
            .require(opType.DequantizeLinear)
            .require_node()
        )
        self.epsilon_value = get_value_from_dequantize_linear(epsilon_dq_node)
        sqrt_node_input = self.go_through_downward_qdq_chain(add_node("C"))
        sqrt_node = sqrt_node_input.require(opType.Sqrt).require_node()
        if sqrt_node("Y").require_node() != div_node:
            raise NoMatch(f"{self} does not match LayerNormBasic pattern")

        scale_node_input = self.go_through_downward_qdq_chain(div_node("C"))
        if not scale_node_input.check(opType.Mul):
            self.output = div_node("C")
            return  # No scale, early return
        scale_node_input = scale_node_input.require(opType.Mul)
        scale_node = scale_node_input.require_node()
        scale_input_name = BinaryOpHelper.get_other_input_name(
            scale_node, scale_node_input
        )

        self.input_scale = scale_node(scale_input_name)

        bias_node_input = self.go_through_downward_qdq_chain(scale_node_input("C"))
        if not bias_node_input.check(opType.Add):
            self.output = scale_node("C")
            return  # No bias, early return
        bias_node_input = bias_node_input.require(opType.Add)
        bias_node = bias_node_input.require_node()
        bias_input_name = BinaryOpHelper.get_other_input_name(
            bias_node, bias_node_input
        )

        self.input_bias = bias_node(bias_input_name)
        self.output = bias_node("C")

    def modify(self) -> None:
        n = self.n
        # FIXME find correct inputs
        inputs = {
            "X": n("data"),
            "Scale": self.input_scale if self.input_scale else None,
            "B": self.input_bias if self.input_bias else None,
        }
        # FIXME Outputs has optional item "Mean" and "InvStdDev" in onnx document, for now they are ignored
        outputs = {"Y": self.output}
        axes = n.get_attribute_value("axes")
        attributes = {
            "axis": axes[0] if axes else -1,
            "epsilon": self.epsilon_value,
            "stash_type": 1,
        }
        self.remove_node(n)
        attributes["num_of_tensor_inputs"] = 1
        self.add_node(
            type="LayerNormalization",
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )
