# fmt: on
from abc import ABC, abstractmethod
from typing import TypeVar


from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.noop_helper import NoopHelper
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import TransposeHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    InputTensor,
    MatcherError,
    NoMatch,
    OutputTensor,
    WalkCfgBase,
    WalkCfgPlain,
)


T = TypeVar("T", InputTensor, OutputTensor)


class SkipBase(ABC):
    """
    Skip over a certain pattern while walking the graph downwards or upwards.
    """

    @abstractmethod
    def skip(self, orig: T) -> T:
        """
        Return tensor at the end of all skipped nodes or raise MatcherError if
        nothing can be skipped.
        """
        pass


class SkipDownQDQ(SkipBase, QDQHelper):
    """
    Skip an operation with QDQ around it (walking downwards).
    """

    def __init__(self, skip_op: SkipBase) -> None:
        """
        skip_op: skip object for operation inside QDQ
        """
        SkipBase.__init__(self)
        QDQHelper.__init__(self)
        self._skip_op = skip_op

    def skip(self, orig: OutputTensor) -> OutputTensor:
        orig.require(opType.DequantizeLinear)
        op_out = self._skip_op.skip(orig("y"))
        op_out.require(opType.QuantizeLinear)
        self.require_qdq_equal_scale_zeropoint(orig, op_out)
        return op_out("y")


class SkipDownReduceOp(SkipBase, NoopHelper):
    def skip(self, orig: OutputTensor) -> OutputTensor:
        orig.require(
            opType.ReduceSum
            | opType.ReduceMax
            | opType.ReduceMean
            | opType.ReduceMin
            | opType.ReduceProd
        )
        if not self.is_noop_reduction(
            orig("data").get_shape(), orig("reduced").get_shape()
        ):
            raise NoMatch("not ReduceOp noop")
        return orig("reduced")


class SkipDownReshape(SkipBase):

    def skip(self, orig: OutputTensor) -> OutputTensor:
        orig.require(opType.Reshape)
        orig("shape").require_initializer()  # don't care what the exact shape is
        return orig("reshaped")


class SkipDownTranspose(SkipBase, NoopHelper):

    def skip(self, orig: OutputTensor) -> OutputTensor:
        orig.require(opType.Transpose)
        if not self.is_noop_transpose(orig.get_non_tensor()):
            raise NoMatch("not one dimensional tensor")
        return orig("transposed")


class SkipDownSlice(SkipBase, NoopHelper):

    def skip(self, orig: OutputTensor) -> OutputTensor:
        slice_node = orig.require(opType.Slice)
        if not self.is_noop_slice(slice_node):
            raise NoMatch("input shape does not match output shape")
        return orig("output")


class SkipDownFlatten(SkipBase):

    def skip(self, orig: OutputTensor) -> OutputTensor:
        orig.require(opType.Flatten)
        return orig("output")


class SkipDownSqUsq(SkipBase):

    def skip(self, orig: OutputTensor) -> OutputTensor:
        if orig.check(opType.Squeeze):
            out = orig("squeezed")
        elif orig.check(opType.Unsqueeze):
            out = orig("expanded")
        else:
            raise NoMatch("not a squeeze or unsqueeze")
        orig("axes").require_initializer()  # don't care what the exact axes is
        return out


class SkipUpQDQ(SkipBase, QDQHelper):
    """
    Skip an operation with QDQ around it (walking upwards).
    """

    def __init__(self, skip_op: SkipBase) -> None:
        """
        skip_op: skip object for operation inside QDQ
        """
        SkipBase.__init__(self)
        QDQHelper.__init__(self)
        self._skip_op = skip_op

    def skip(self, orig: InputTensor) -> InputTensor:
        orig.require(opType.QuantizeLinear)
        op_in = self._skip_op.skip(orig("x"))
        op_in.require(opType.DequantizeLinear)
        self.require_qdq_equal_scale_zeropoint(op_in, orig)
        return op_in("x")


class SkipUpQDQChain(SkipBase, QDQHelper):
    """
    Skip DQ Q chain with same value (walking upwards).
    """

    def skip(self, orig: InputTensor) -> InputTensor:
        dq = orig.require(opType.DequantizeLinear)
        q = orig("x").require(opType.QuantizeLinear)
        self.require_qdq_equal_scale_zeropoint(dq, q)
        return q("x")


class SkipUpSlice(SkipBase, NoopHelper):

    def skip(self, orig: InputTensor) -> InputTensor:
        slice_node = orig.require(opType.Slice)
        if not self.is_noop_slice(slice_node):
            raise NoMatch("input shape does not match output shape")
        return orig("data")


class SkipUpReduceOp(SkipBase, NoopHelper):
    def skip(self, orig: InputTensor) -> InputTensor:
        orig.require(
            opType.ReduceSum
            | opType.ReduceMax
            | opType.ReduceMean
            | opType.ReduceMin
            | opType.ReduceProd
        )
        if not self.is_noop_reduction(
            orig("data").get_shape(), orig("reduced").get_shape()
        ):
            raise NoMatch("not ReduceOp noop")
        return orig("data")


class SkipUpReshape(SkipBase):

    def skip(self, orig: InputTensor) -> InputTensor:
        orig.require(opType.Reshape)
        orig("shape").require_initializer()  # don't care what the exact shape is
        return orig("data")


class SkipUpTranspose(SkipBase, NoopHelper):

    def skip(self, orig: InputTensor) -> InputTensor:
        orig.require(opType.Transpose)
        if not self.is_noop_transpose(orig.get_non_tensor()):
            raise NoMatch("not one dimensional tensor")
        return orig("data")


class SkipUpFlatten(SkipBase):

    def skip(self, orig: InputTensor) -> InputTensor:
        orig.require(opType.Flatten)
        return orig("input")


class SkipUpSqUsq(SkipBase):

    def skip(self, orig: InputTensor) -> InputTensor:
        orig.require(opType.Squeeze | opType.Unsqueeze)
        orig("axes").require_initializer()  # don't care what the exact axes is
        return orig("data")


class WalkCfgSkipNoop(WalkCfgBase):
    """
    Walk a graph while skipping no-ops (reshape, squeeze, unsqueeze) automatically.
    """

    def __init__(self) -> None:
        self._walk_plain = WalkCfgPlain()
        self._skip_down_list = [
            SkipDownReshape(),
            SkipDownTranspose(),
            SkipDownQDQ(SkipDownTranspose()),
            SkipDownFlatten(),
            SkipDownQDQ(SkipDownReshape()),
            SkipDownSqUsq(),
            SkipDownQDQ(SkipDownSqUsq()),
            SkipDownReduceOp(),
            SkipDownSlice(),
            SkipDownQDQ(SkipDownSlice()),
        ]

        self._skip_up_list = [
            SkipUpReshape(),
            SkipUpTranspose(),
            SkipUpQDQ(SkipUpTranspose()),
            SkipUpFlatten(),
            SkipUpQDQ(SkipUpReshape()),
            SkipUpSqUsq(),
            SkipUpQDQ(SkipUpSqUsq()),
            SkipUpReduceOp(),
            SkipUpSlice(),
            SkipUpQDQ(SkipUpSlice()),
        ]
        self._skip_up_for_shape_list = [SkipUpQDQChain()]

    def skip_as_much_as_possible(self, tensor: T, skip_list: list[SkipBase]) -> T:
        # turn off auto-skipping (don't recursively auto-skip inthe skip object)
        current = tensor.with_walk_cfg(self._walk_plain)
        skipped = True
        while skipped:
            skipped = False
            for skip in skip_list:
                try:
                    current = skip.skip(current)
                    skipped = True
                except MatcherError:
                    pass
        # turn on auto-skipping again
        return current.with_walk_cfg(self)

    def skip_down(self, out_tensor: OutputTensor) -> OutputTensor:
        return self.skip_as_much_as_possible(out_tensor, self._skip_down_list)

    def skip_up(self, in_tensor: InputTensor) -> InputTensor:
        return self.skip_as_much_as_possible(in_tensor, self._skip_up_list)

    def search_upward_for_shape(self, in_tensor: InputTensor) -> InputTensor:
        return self.skip_as_much_as_possible(in_tensor, self._skip_up_for_shape_list)
