# fmt: on
from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.clean.remove import RemoveHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import NoMatch


class RemoveConcat(RemoveHelper):
    """
    Remove the Concat op along with qdq around it if:
      1. Number of Inputs of Concat is 1
      2. Scales and zero points of dq and q match
    """

    def set_input_and_output(self) -> None:
        self.input = self.n.get_inputs()[0]
        self.output = self.n("concat_result")

    def match_op_specifics(self) -> None:
        self.n.require(opType.Concat)
        inputs = self.n.get_inputs()
        if len(inputs) != 1:
            raise NoMatch("Concat has more than 1 input")
