# fmt: on
from OGOAT.src.L1_fusion.py_match.basic.non_linear import Relu
from OGOAT.src.L1_fusion.py_match.checkers import (
    CategoryCheck,
    FusedWithQDQNode,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, NoMatch
from OGOAT.src.L1_fusion.py_match.basic.binary_op import BinaryOp
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper


class EWBinaryPlusRelu(Matcher, QDQHelper):
    """
    EW Binary + Relu
    """

    dependencies = [BinaryOp(), Relu()]

    def match(self) -> None:
        n = self.n
        n.require(CategoryCheck(BinaryOp()))
        self.output = n("C").require(CategoryCheck(Relu()))
        self.output.require(FusedWithQDQNode())

        output_type = self.output.get_non_tensor().get_op_type()
        if not output_type.startswith("Relu"):
            raise NoMatch(f"unsupported op: {output_type}")

        self.type_name = "relu"

    def modify(self) -> None:
        n = self.n
        dtypes = n("A_zero_point").get_dtype() + "x" + n("B_zero_point").get_dtype()
        orig_type = n.get_attribute_value("orig_type")
        new_type = (
            orig_type
            + "_qdq_"
            + self.type_name
            + "_"
            + dtypes
            + "x"
            + self.output("Y_zero_point").get_dtype()
        )

        inputs = {
            "A": n("A"),
            "B": n("B"),
            "A_scale": n("A_scale"),
            "A_zero_point": n("A_zero_point"),
            "B_scale": n("B_scale"),
            "B_zero_point": n("B_zero_point"),
            "Y_scale": self.output("Y_scale"),
            "Y_zero_point": self.output("Y_zero_point"),
        }

        outputs = {"Y": self.output("Y")}

        attributes = n.get_attributes()
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )
