# fmt: on

from OGOAT.src.L1_fusion.py_match.checkers import DTypes, opType
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher


class Lut(Matcher, QDQHelper):
    """
    unifies the following classic patterns:
    Sigmoid, Gelu, LeakyRelu with qdq
    """

    def match(self) -> None:
        n = self.n
        n.require(
            opType.Gelu
            | opType.Sigmoid
            | opType.LeakyRelu
            | opType.Tanh
            | opType.QuickGelu
            | opType.Elu
        )
        self.new_dtypes, self.qdq_attributes = self.check_qdq(
            n, DTypes("int8", "int16", "uint8", "uint16")
        )

        if n.check(opType.Sigmoid):
            for key, val in n("X").get_non_tensor().get_attributes().items():
                if key.endswith("_const_value"):
                    self.qdq_attributes[key] = val
                    break

    def modify(self) -> None:
        n = self.n.require_node()
        new_type = "PWLA_qdq_" + self.new_dtypes
        inputs, outputs = self.get_in_out_dict_for_qdq_node(n)
        attributes = n.get_attributes()
        attributes["orig_type"] = n.get_op_type()
        # should not be present but may helpful for future implementations
        if attributes.get("pwla_type") is None:
            # op type is fine here since it did not get altered at this point
            attributes["pwla_type"] = n.get_op_type()
        attributes["num_of_tensor_inputs"] = 1
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes | self.qdq_attributes,
        )


lut = Lut()
