# fmt: on

from OGOAT.src.L1_fusion.py_match.basic.non_linear import Relu
from OGOAT.src.L1_fusion.py_match.checkers import CategoryCheck
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Category, Matcher


class ReluPWLA(Matcher, QDQHelper):
    """
    Convert remaining Relu_qdq to PWLA_qdq
    """

    dependencies = [Category([Relu()])]

    def match(self) -> None:
        n = self.n
        n.require(CategoryCheck(Relu()))

    def modify(self) -> None:
        n = self.n.require_node()
        new_type = n.get_op_type().replace("Relu", "PWLA")
        attributes = n.get_attributes()
        attributes["orig_type"] = n.get_op_type()
        attributes["pwla_type"] = "Relu"
        attributes["num_of_tensor_inputs"] = 1
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=n.get_inputs_dict(),
            outputs=n.get_outputs_dict(),
            attributes=attributes,
        )
        self.remove_node(n)
