# fmt: on
from typing import Any
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Checker,
    Element,
    Initializer,
    MatcherOrCategory,
    NoMatch,
    Node,
    Tensor,
)


class CheckSyntaxSugar:
    """
    Syntactic sugar to allow e.g. 'selector < opType.Gemm'
    instead of 'selector < OpType("Gemm")'
    """

    def __init__(self, cls: type) -> None:
        self._cls = cls

    def __getattr__(self, attr: str) -> Checker:
        return self._cls(attr)


class AttrValue(Checker):
    """
    Check the value of an attribute.
    """

    def __init__(self, attr_name: str, attr_value: Any) -> None:
        self._attr_name = attr_name
        self._attr_value = attr_value

    def check(self, element: Element) -> bool:
        if not isinstance(element, Node):
            return False
        attr_value = element.get_attribute_value(self._attr_name)
        return attr_value == self._attr_value


class CategoryCheck(Checker):
    """
    Check if a node has been produced by a matcher or a matcher contained in a
    matcher category.
    """

    def __init__(self, category: MatcherOrCategory) -> None:
        self.category = category

    def check(self, element: Element) -> bool:
        node = element.get_non_tensor()
        if not isinstance(node, Node):
            return False
        matcher_name = node.get_matcher_name()
        if matcher_name is None:
            return False
        return self.category.contains_matcher(matcher_name)


class FusedWithQDQNode(Checker):
    """
    Check if a node is fully fused with QDQ nodes, meaning all of its inputs and outputs
    are connected through Dequantize (DQ) and  Quantize(Q) nodes. This check is performed
    based on the node's `disable_q` and `disable_dq{i}` attributes.
    """

    def __init__(self):
        pass

    def check(self, element: Element) -> bool:
        node = element.get_non_tensor()
        if not isinstance(node, Node):
            return False

        for attr, value in node.get_attributes().items():
            if (attr.startswith("disable_dq") or attr == "disable_q") and value == 1:
                return False
        return True


class DType(Checker):
    def __init__(self, dtype: str):
        self._dtype = dtype

    def __repr__(self) -> str:
        return f"DTypes({repr(self._dtype)})"

    def check(self, element: Element) -> bool:
        if not isinstance(element, (Tensor, Initializer)):
            return False
        return element.get_dtype() == self._dtype


dType = CheckSyntaxSugar(DType)


class DTypes(Checker):
    def __init__(self, *dtypes: str):
        self._dtypes = dtypes

    def __repr__(self) -> str:
        return f"DTypes(*{repr(self._dtypes)})"

    def check(self, element: Element) -> bool:
        if not isinstance(element, (Tensor, Initializer)):
            return False
        return element.get_dtype() in self._dtypes


# syntactic sugar not available for checkers with more than one constructor arg


class DTypeAny(Checker):
    """
    Check that there is a tensor and it has (any) data type.
    """

    def __repr__(self) -> str:
        return f"DTypeAny()"

    def check(self, element: Element) -> bool:
        if not isinstance(element, (Tensor, Initializer)):
            return False
        return element.get_dtype() != ""  # any non-empty string is fine


dTypeAny = DTypeAny()


class OpType(Checker):
    def __init__(self, op_type: str):
        self._op_type = op_type

    def __repr__(self) -> str:
        return f"OpType({repr(self._op_type)})"

    def check(self, element: Element) -> bool:
        node = element.get_non_tensor()
        if not isinstance(node, Node):
            return False
        same_type = node.get_op_type() == self._op_type
        if not same_type:
            return False
        is_l1_fusion_frozen = node.get_attribute_value("L1_fusion_frozen")
        if is_l1_fusion_frozen and is_l1_fusion_frozen != 0:
            # If the node is frozen, return False
            return False
        return True


opType = CheckSyntaxSugar(OpType)


class OpTypeIgnoreFrozen(OpType):
    """OpType checker that ignores the L1_fusion_frozen attribute."""

    def check(self, element: Element) -> bool:
        node = element.get_non_tensor()
        if not isinstance(node, Node):
            return False
        # Only check the op type, ignore the L1_fusion_frozen attribute
        return node.get_op_type() == self._op_type


opTypeIgnoreFrozen = CheckSyntaxSugar(OpTypeIgnoreFrozen)


class PartialOpType(Checker):
    def __init__(self, partial_op_type: str):
        self._partial_op_type = partial_op_type

    def __repr__(self) -> str:
        return f"PartialOpType({repr(self._partial_op_type)})"

    def check(self, element: Element) -> bool:
        node = element.get_non_tensor()
        if not isinstance(node, Node):
            return False
        return self._partial_op_type in node.get_op_type()


partialOpType = CheckSyntaxSugar(PartialOpType)


class OpTypes(Checker):
    def __init__(self, *op_types: str):
        self._op_types = op_types

    def __repr__(self) -> str:
        return f"OpTypes(*{repr(self._op_types)})"

    def check(self, element: Element) -> bool:
        node = element.get_non_tensor()
        if not isinstance(node, Node):
            return False
        return node.get_op_type() in self._op_types


class ShapeNonOneDim(Checker):
    """
    Check that there is a tensor and its shape has the correct amount of non-1
    dimensions.
    """

    def __init__(self, non_one_dimension: int) -> None:
        self._non_one_dimension = non_one_dimension

    def check(self, element: Element) -> bool:
        if not isinstance(element, Tensor):
            return False
        shape = element.get_shape()
        return [dim != 1 for dim in shape].count(True) == self._non_one_dimension
