from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher

class Neg(Matcher):
    def match(self) -> None:
        n = self.n

        n.require(opType.Neg)
        n("X").require(opType.DequantizeLinear)
        n("Y").require(opType.QuantizeLinear)

    def modify(self) -> None:
        n = self.n
        new_type = (
            "Neg_qdq_"
            + n("X.x_zero_point").get_dtype()
            + "x"
            + n("Y.y_zero_point").get_dtype()
        )
        inputs = {
            "X": n("X.x"),
            "X_scale": n("X.x_scale"),
            "X_zero_point": n("X.x_zero_point"),
            "Y_scale": n("Y.y_scale"),
            "Y_zero_point": n("Y.y_zero_point"),
        }
        outputs = {"Y": n("Y.y")}
        attributes = n.get_attributes()
        self.remove_node(n)
        attributes["num_of_tensor_inputs"] = 1
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )