# fmt: on
from OGOAT.src.L1_fusion.py_match.checkers import CategoryCheck, FusedWithQDQNode
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, NoMatch
from OGOAT.src.L1_fusion.py_match.basic.binary_op import binary_op
from OGOAT.src.L1_fusion.py_match.basic.lut import lut


class Silu(Matcher, QDQHelper):
    """
    unifies the following classic patterns:
    Silu: Sigmoid with Mul fused together
    """

    dependencies = [lut, binary_op]

    def match(self) -> None:
        n = self.n
        lut_basic = n.require(CategoryCheck(lut))
        lut_basic.require(FusedWithQDQNode())
        n("Y").require(CategoryCheck(binary_op))
        binary = n("Y").require(FusedWithQDQNode()).require_node()
        lut_basic_type = lut_basic.get_attributes().get("pwla_type")
        binary_op_type = n("Y").require_node().get_op_type()

        if binary("A").get_non_tensor() == lut_basic:
            self.require_tensor_equal_value(lut_basic("Y_scale"), binary("A_scale"))
            self.require_tensor_equal_value(
                lut_basic("Y_zero_point"), binary("A_zero_point")
            )
            if binary("B").get_non_tensor() != lut_basic("X").get_non_tensor():
                raise NoMatch(f"binary node and lut node have to share an input")
            self.require_tensor_equal_value(lut_basic("X_scale"), binary("B_scale"))
            self.require_tensor_equal_value(
                lut_basic("X_zero_point"), binary("B_zero_point")
            )

        if binary("B").get_non_tensor() == lut_basic:
            self.require_tensor_equal_value(lut_basic("Y_scale"), binary("B_scale"))
            self.require_tensor_equal_value(
                lut_basic("Y_zero_point"), binary("B_zero_point")
            )
            if binary("A").get_non_tensor() != lut_basic("X").get_non_tensor():
                raise NoMatch(f"binary node and lut node have to share an input")
            self.require_tensor_equal_value(lut_basic("X_scale"), binary("A_scale"))
            self.require_tensor_equal_value(
                lut_basic("X_zero_point"), binary("A_zero_point")
            )

        # FIXME once there are Gelu/leakyRelu/Sigmoid followed by any Binary, the following condition needs to be adapted
        if not lut_basic_type == "Sigmoid" or not binary_op_type.startswith("Mul_qdq"):
            raise NoMatch(
                f"Current LutOp only support Sigmoid and Mul, other combination are not supported in yml"
            )

    def modify(self) -> None:
        n = self.n
        inputs = {
            "X": n("X"),
            "X_scale": n("X_scale"),
            "X_zero_point": n("X_zero_point"),
            "Y_scale": n("Y.C_scale"),
            "Y_zero_point": n("Y.C_zero_point"),
        }
        outputs = {"Y": n("Y.C")}
        new_type = (
            "PWLA_qdq_"
            + n("X_zero_point").get_dtype()
            + "x"
            + n("Y.C_zero_point").get_dtype()
        )
        copy_attributes = n.get_attributes()
        copy_attributes["orig_type"] = n.get_op_type()
        copy_attributes["pwla_type"] = "Silu"
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes,
        )


silu = Silu()
