# fmt: on
from typing import Optional
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    InputTensor,
    NoMatch,
    Node,
    Nowhere,
    OutputTensor,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.checkers import opType


class TransposeWithOptionalQDQ:
    """
    This helper class encapsulates Transposes and Transposes with surrounding QDQ (DequantizeLinear -> Transpose -> QuantizeLinear).
    """

    @staticmethod
    def match(element: Element) -> "TransposeWithOptionalQDQ":
        walker_cfg = element._walk_cfg
        node = element.with_walk_cfg(WalkCfgPlain()).get_non_tensor()
        quantize_node: Optional[Node] = None
        dequantize_node: Optional[Node] = None
        has_qdq = False

        if node.check(opType.Transpose):
            transpose_node = node
        elif node.check(opType.QuantizeLinear):
            quantize_node = node
            transpose_node = quantize_node("x").require_node()
            transpose_node.require(opType.Transpose)
            dequantize_node = transpose_node("data").require_node()
            dequantize_node.require(opType.DequantizeLinear)
            has_qdq = True
        elif node.check(opType.DequantizeLinear):
            dequantize_node = node
            transpose_node = dequantize_node("y").require_node()
            transpose_node.require(opType.Transpose)
            quantize_node = transpose_node("transposed").require_node()
            quantize_node.require(opType.QuantizeLinear)
            has_qdq = True
        else:
            raise NoMatch("not a Transpose with optional QDQ")

        # set to old walker config
        transpose_node = transpose_node.with_walk_cfg(walker_cfg)
        if has_qdq:
            quantize_node = quantize_node.with_walk_cfg(walker_cfg)
            dequantize_node = dequantize_node.with_walk_cfg(walker_cfg)

        return TransposeWithOptionalQDQ(dequantize_node, transpose_node, quantize_node)

    def __init__(
        self,
        dequantize_node: Optional[Node],
        transpose_node: Node,
        quantize_node: Optional[Node],
    ):
        self.dequantize_node = dequantize_node
        self.transpose_node = transpose_node
        self.quantize_node = quantize_node
        if (self.quantize_node is None) != (self.dequantize_node is None):
            raise NoMatch("needed both q/dq around transpose or none")

    def get_first_node(self) -> InputTensor:
        if self.has_qdq():
            return self.dequantize_node
        else:
            return self.transpose_node

    def input(self) -> InputTensor:
        if self.has_qdq():
            return self.dequantize_node("x")
        else:
            return self.transpose_node("data")

    def output(self) -> OutputTensor:
        if self.has_qdq():
            return self.quantize_node("y")
        else:
            return self.transpose_node("transposed")

    def has_qdq(self) -> bool:
        return not (self.quantize_node is None and self.dequantize_node is None)

    def input_scale(self) -> InputTensor:
        """
        If the object only represents a Transpose without surrounding QDQ, the value is taken from the DequantizeLinear after the Transpose.
        If there is no DequantizeLinear, then the function throws NoMatch
        """
        if self.dequantize_node is None:
            dq = self.output().require(opType.DequantizeLinear).require_node()
            return dq("x_scale")
        else:
            return self.dequantize_node("x_scale")

    def input_zero_point(self) -> InputTensor:
        """
        If the object only represents a Transpose without surrounding QDQ, the value is taken from the DequantizeLinear after the Transpose.
        If there is no DequantizeLinear, then the function throws NoMatch
        """
        if self.dequantize_node is None:
            dq = self.output().require(opType.DequantizeLinear).require_node()
            return dq("x_zero_point")
        else:
            return self.dequantize_node("x_zero_point")

    def output_scale(self) -> OutputTensor:
        """
        If the object only represents a Transpose without surrounding QDQ, the value is taken from the QuantizeLinear before the Transpose.
        If there is no QuantizeLinear, then the function throws NoMatch
        """
        if self.quantize_node is None:
            q = self.input().require(opType.QuantizeLinear).require_node()
            return q("y_scale")
        else:
            return self.quantize_node("y_scale")

    def output_zero_point(self) -> OutputTensor:
        """
        If the object only represents a Transpose without surrounding QDQ, the value is taken from the QuantizeLinear before the Transpose.
        If there is no QuantizeLinear, then the function throws NoMatch
        """
        if self.quantize_node is None:
            q = self.input().require(opType.QuantizeLinear).require_node()
            return q("y_zero_point")
        else:
            return self.quantize_node("y_zero_point")

    def __call__(self, path: str) -> Element:
        if path == "x_scale":
            return self.input_scale()
        if path == "x_zero_point":
            return self.input_zero_point()
        if path == "y_scale":
            return self.output_scale()
        if path == "y_zero_point":
            return self.output_zero_point()

        segments = path.split(".")
        head = segments[0]
        if head == "y" or path == "transposed":
            element = self.output()
        elif head == "x" or path == "data":
            element = self.input()
        else:
            return Nowhere(
                self.transpose_node._model_dict,
                self.transpose_node._walk_cfg,
                f"{self.transpose_node.get_name()}.{path}",
            )

        if len(segments) > 1:
            p = ".".join(segments[1:])
            return element(p)
        else:
            return element
