# fmt: on

from OGOAT.src.L1_fusion.py_match.checkers import (
    opType,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    MatcherOrCategory,
    NoMatch,
)
from OGOAT.src.L1_fusion.py_match.helpers.noop_helper import NoopHelper
import onnx


class AddNoopSuffix(Matcher, NoopHelper):

    def match(self) -> None:
        n = self.n
        if n.check(opType.Transpose) and self.is_noop_transpose(n):
            return
        if n.check(opType.Slice) and self.is_noop_slice(n):
            return
        elif n.check(opType.Unsqueeze | opType.Squeeze) and self.is_noop_unsqueeze(n):
            return
        elif n.check(
            opType.ReduceSum
            | opType.ReduceMax
            | opType.ReduceMean
            | opType.ReduceMin
            | opType.ReduceProd
        ) and self.is_noop_reduction(n("data").get_shape(), n("reduced").get_shape()):
            return
        elif n.check(opType.Reshape) and self.is_noop_reshape(n):
            return
        elif n.check(opType.Flatten):
            return
        else:
            raise NoMatch("Unsupported op type for adding the suffix `_noop`")

    def modify(self) -> None:
        n = self.n
        if n.check(
            opType.ReduceSum
            | opType.ReduceMax
            | opType.ReduceMean
            | opType.ReduceMin
            | opType.ReduceProd
        ):
            new_type = "Reduce_noop"
        else:
            new_type = n.get_op_type() + "_noop"
        attributes = n.get_attributes()
        attributes["num_of_tensor_inputs"] = 1
        inputs = n.get_inputs_dict()
        outputs = n.get_outputs_dict()
        name = n.get_name()
        self.remove_node(n)
        newnode = self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
            new_name=name,
        )
