# fmt: on
from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.clean.remove import RemoveHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import NoMatch


class RemoveSlice(RemoveHelper):
    """
    Remove the Slice op along with qdq around it if:
        1. Steps input of Slice is either 1 or not present (default is 1)
        2. Input shape of Slice is same as it's output shape
        3. Scales and zero points of dq and q match
    """

    def set_input_and_output(self) -> None:
        self.input = self.n.get_inputs()[0]
        self.output = self.n("output")

        if self.n("starts").check_initializer():
            self.n("starts").require_initializer().flag_used()
        if self.n("ends").check_initializer():
            self.n("ends").require_initializer().flag_used()
        if self.n("axes").check_initializer():
            self.n("axes").require_initializer().flag_used()
        if self.n("steps").check_initializer():
            self.n("steps").require_initializer().flag_used()

    def match_op_specifics(self) -> None:
        self.n.require(opType.Slice)
        steps = self.n("steps")
        if steps.check_initializer() and steps.get_initializer_array()[0] != 1:
            raise NoMatch("Slice steps is present and not 1")
        inputs = self.n.get_inputs()
        outputs = self.n.get_outputs()
        if inputs[0].get_shape() != outputs[0].get_shape():
            raise NoMatch("Slice input and output shape do not match")
