import math
from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.batch_helper import (
    BatchHelper,
    NodeBatchSignature,
)
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BinaryOpHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    NoMatch,
)
from OGOAT.src.L1_fusion.py_match.py_match_utils import get_value_from_dequantize_linear


class GeluBasic(Matcher, BatchHelper):

    def match(self) -> None:
        # TODO when approximate is tanh, the formular is 0.5*X*(1+tanh(sqrt(2/pi)*(X+0.044715*X^3)))
        # 0.5*X*(1+erf(X/sqrt(2)))
        n = self.n
        div_node = n.require(opType.Div)
        dq_div_act_input = div_node("A").require(opType.DequantizeLinear).require_node()
        div_dq_signature = NodeBatchSignature.from_node(
            dq_div_act_input,
            ignore_attributes=["orig_name"],
            batch_by_out_tensor=False,
        )
        if div_node("B").check_initializer():
            sqrt_2_initializer = div_node("B").require_initializer().get_value()
        else:
            dq_div_const_input = (
                div_node("B").require(opType.DequantizeLinear).require_node()
            )
            sqrt_2_initializer = get_value_from_dequantize_linear(dq_div_const_input)
        if not math.isclose(sqrt_2_initializer, math.sqrt(2), rel_tol=0, abs_tol=1e-6):
            raise NoMatch(f"{self} does not match GeluBasic pattern")

        erf_node = (
            self.go_through_downward_qdq_chain(div_node("C"))
            .require(opType.Erf)
            .require_node()
        )

        add_node = (
            self.go_through_downward_qdq_chain(erf_node("output"))
            .require(opType.Add)
            .require_node()
        )

        # Get the tensor that connects erf to add (either direct or through QDQ chain)
        erf_to_add_tensor = self.go_through_downward_qdq_chain(erf_node("output"))
        other_input_name_add = BinaryOpHelper.get_other_input_name(
            add_node,
            erf_to_add_tensor,
        )
        if add_node(other_input_name_add).check_initializer():
            add_const_value = (
                add_node(other_input_name_add).require_initializer().get_value()
            )
        else:
            dq_add_const = add_node(other_input_name_add).require(
                opType.DequantizeLinear
            )
            add_const_value = get_value_from_dequantize_linear(dq_add_const)

        if not math.isclose(add_const_value, 1, rel_tol=0, abs_tol=1e-2):
            raise NoMatch(f"{self} does not match GeluBasic pattern")

        mul_inner_node = (
            self.go_through_downward_qdq_chain(add_node("C"))
            .require(opType.Mul)
            .require_node()
        )

        # Get the tensor that connects add to mul_inner (either direct or through QDQ chain)
        add_to_mul_inner_tensor = self.go_through_downward_qdq_chain(add_node("C"))
        other_input_name_inner_mul = BinaryOpHelper.get_other_input_name(
            mul_inner_node,
            add_to_mul_inner_tensor,
        )
        dq_input_inner_mul = (
            mul_inner_node(other_input_name_inner_mul)
            .require(opType.DequantizeLinear)
            .require_node()
        )
        mul_dq_signature = NodeBatchSignature.from_node(
            dq_input_inner_mul,
            ignore_attributes=["orig_name"],
            batch_by_out_tensor=False,
        )

        if div_dq_signature != mul_dq_signature:
            raise NoMatch(
                f"{self} does not match GeluBasic pattern as div and mul DQ inputs differ"
            )

        mul_outer_node = (
            self.go_through_downward_qdq_chain(mul_inner_node("C"))
            .require(opType.Mul)
            .require_node()
        )

        # Get the tensor that connects mul_inner to mul_outer (either direct or through QDQ chain)
        mul_inner_to_outer_tensor = self.go_through_downward_qdq_chain(
            mul_inner_node("C")
        )
        other_input_name_outer_mul = BinaryOpHelper.get_other_input_name(
            mul_outer_node,
            mul_inner_to_outer_tensor,
        )
        if mul_outer_node(other_input_name_outer_mul).check_initializer():
            outer_mul_const_value = (
                mul_outer_node(other_input_name_outer_mul)
                .require_initializer()
                .get_value()
            )
        else:
            dq_outer_mul_const = mul_outer_node(other_input_name_outer_mul).require(
                opType.DequantizeLinear
            )
            outer_mul_const_value = get_value_from_dequantize_linear(dq_outer_mul_const)

        if not math.isclose(outer_mul_const_value, 0.5, rel_tol=0, abs_tol=1e-2):
            raise NoMatch(f"{self} does not match GeluBasic pattern")
        self.output = mul_outer_node("C")

    def modify(self) -> None:
        n = self.n
        inputs = {
            "X": n("A"),
        }
        outputs = {"Y": self.output}
        # FIXME tanh is not implemented for now
        attributes = {"approximate ": "none"}
        attributes["num_of_tensor_inputs"] = 1
        self.remove_node(n)
        self.add_node(
            type="Gelu",
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )
