# fmt: on
from OGOAT.src.L1_fusion.py_match.checkers import (
    CategoryCheck,
    FusedWithQDQNode,
)
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Element, Matcher, NoMatch
from OGOAT.src.L1_fusion.py_match.basic.binary_op import binary_op
from OGOAT.src.L1_fusion.py_match.basic.lut import lut


class Swish(Matcher, QDQHelper):
    """
    unifies the following classic patterns:
    Swish:  x*Sigmoid(alpha*x) where alpha is a constant value stored as Mul_const_value attribute
    """

    dependencies = [lut, binary_op]

    def _check_tensor_values(self, lut_basic: Element, binary: Element):
        # unified checks for both binary("A") and binary("B")
        for in1, in2 in [("A", "B"), ("B", "A")]:
            if binary(in1).get_non_tensor() == lut_basic:
                self.require_tensor_equal_value(
                    lut_basic("Y_scale"), binary(f"{in1}_scale")
                )
                self.require_tensor_equal_value(
                    lut_basic("Y_zero_point"), binary(f"{in1}_zero_point")
                )
                if binary(in2).get_non_tensor() != lut_basic("X").get_non_tensor():
                    raise NoMatch("binary node and lut node have to share an input")
                if not self.check_tensor_equal_value(
                    lut_basic("X_scale"), binary(f"{in2}_scale")
                ):
                    self.require_tensor_equal_value(
                        lut_basic("X_zero_point"), binary(f"{in2}_zero_point")
                    )

                self.X_scale = binary(f"{in2}_scale")

    def match(self) -> None:
        n = self.n.require_node()
        lut_basic = n.require(CategoryCheck(lut))
        lut_basic.require(FusedWithQDQNode())
        n("Y").require(CategoryCheck(binary_op))
        binary = n("Y").require(FusedWithQDQNode()).require_node()
        lut_basic_type = n.get_attributes().get("pwla_type")
        binary_op_type = n("Y").require_node().get_op_type()

        if not lut_basic_type == "Sigmoid" or not binary_op_type.startswith("Mul_qdq"):
            raise NoMatch(
                f"Swish only support Sigmoid and Mul, other combination are not supported."
            )

        if "Mul_const_value" not in lut_basic.require_node().get_attributes():
            raise NoMatch(
                f"Swish needs to have a Mul_const_value attribute which is the Alpha value."
            )

        self.Alpha = lut_basic.require_node().get_attribute_value("Mul_const_value")

        self._check_tensor_values(lut_basic, binary)

    def modify(self) -> None:
        n = self.n
        inputs = {
            "X": n("X"),
            "X_scale": self.X_scale,
            "X_zero_point": n("X_zero_point"),
            "Y_scale": n("Y.C_scale"),
            "Y_zero_point": n("Y.C_zero_point"),
        }
        outputs = {"Y": n("Y.C")}
        new_type = (
            "PWLA_qdq_"
            + n("X_zero_point").get_dtype()
            + "x"
            + n("Y.C_zero_point").get_dtype()
        )
        copy_attributes = n.get_attributes()
        copy_attributes["orig_type"] = n.get_op_type()
        copy_attributes["pwla_type"] = "Swish"
        copy_attributes["Alpha"] = self.Alpha
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes,
        )
