# fmt: on
from typing import Optional

from OGOAT.src.L1_fusion.py_match.checkers import OpTypes, opType
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    InputTensor,
    Matcher,
    Node,
)


class BiasHelper:
    """
    A helper class to return the input and output tensor of add bias node.
    """

    def get_bias_elem_in_and_out_tensor(
        self, elem: Element
    ) -> tuple[bool, Optional[Element], Element]:

        if elem("y").check(opType.DequantizeLinear) and elem("y.y").check(opType.Add):
            out_tensor = elem("y.y.C")
            for input_tensor in elem("y.y").get_inputs():
                if (
                    input_tensor.check(opType.DequantizeLinear)
                    and input_tensor("x").check_initializer()
                ):
                    return True, input_tensor, out_tensor

        return False, None, elem


class BinaryOpHelper:
    @staticmethod
    def get_input(node: Node, input_node: Node) -> Optional[InputTensor]:
        assert node.get_matcher_name() == "BinaryOp" or node.check(
            opType.Add | opType.Sub | opType.Mul | opType.Div
        )  # sanity check
        if node("A").get_non_tensor() == input_node:
            return node("A")
        if node("B").get_non_tensor() == input_node:
            return node("B")

    @staticmethod
    def get_other_input_name(node: Node, input: InputTensor, skip: bool = True) -> str:
        assert (
            node.get_matcher_name() == "BinaryOp"
            or node.get_matcher_name() == "MatMul"
            or node.check(opType.Add | opType.Sub | opType.Mul | opType.Div)
        )  # sanity check
        # Skip everything that is auto-skipped when going from tensor to node,
        # otherwise, the equality check below might not work due to a skipped
        # node in between.
        if skip:
            input_ = input.skip()
        else:
            input_ = input
        if node("A") == input_:
            return "B"
        if node("B") == input_:
            return "A"
        raise ValueError(
            f"Input {input.get_name()} not found for node {node.get_name()}"
            f" with inputs {node('A')} and {node('B')}"
        )

    @staticmethod
    def get_input_name(node: Node, input: InputTensor, skip: bool = True) -> str:
        assert (
            node.get_matcher_name() == "BinaryOp"
            or node.get_matcher_name() == "MatMul"
            or node.check(opType.Add | opType.Sub | opType.Mul | opType.Div)
        )  # sanity check
        # Skip everything that is auto-skipped when going from tensor to node,
        # otherwise, the equality check below might not work due to a skipped
        # node in between.

        if skip:
            input_ = input.skip()
        else:
            input_ = input
        if node("A") == input_:
            return "A"
        if node("B") == input_:
            return "B"
        raise ValueError(
            f"Input {input.get_name()} not found for node {node.get_name()}"
        )

    @staticmethod
    def is_valid_multidirectional_broadcast(
        shape1: list[int], shape2: list[int]
    ) -> bool:
        """
        compute the shape of a multidirectional broadcast
        """
        shape1_rank = len(shape1)
        shape2_rank = len(shape2)
        max_rank = max(shape1_rank, shape2_rank)

        # Add missing dimensions for the two shape to have the same rank
        if shape1_rank != max_rank:
            shape1 = [1 * max_rank - shape1_rank] + shape1
        if shape2_rank != max_rank:
            shape2 = [1 * max_rank - shape2_rank] + shape2

        for i in range(max_rank):
            if shape1[i] == shape2[i]:
                continue
            if shape1[i] == 1 or shape2[i] == 1:
                continue

            return False
        return True

    @staticmethod
    def is_valid_unidirectional_broadcast(
        input_a_shape: list[int], 
        input_b_shape: list[int], 
        output_shape: list[int]
    ) -> bool:
        """
        Check if unidirectional broadcasting is valid for ONNX Add operation.
        
        Unidirectional broadcasting means one tensor can be broadcast to match
        the other's shape by prepending dimensions of size 1.
        
        Args:
            input_a_shape: Shape of input A tensor (e.g., [448, 1280])
            input_b_shape: Shape of input B tensor (e.g., [1, 448, 1280])
            output_shape: Expected output shape (e.g., [1, 448, 1280])
            
        Returns:
            bool: True if the broadcasting is valid, False otherwise
            
        Example:
            A: [448, 1280] -> broadcast to [1, 448, 1280]
            B: [1, 448, 1280] (already correct shape)
            Output: [1, 448, 1280]
        """
        if not input_a_shape or not input_b_shape or not output_shape:
            return False
            
        # Get ranks
        rank_a = len(input_a_shape)
        rank_b = len(input_b_shape)
        rank_output = len(output_shape)
        
        # Output rank should be the maximum of input ranks
        expected_output_rank = max(rank_a, rank_b)
        if rank_output != expected_output_rank:
            return False
            
        # Pad shorter shapes with leading 1s to match output rank
        padded_a = [1] * (rank_output - rank_a) + input_a_shape
        padded_b = [1] * (rank_output - rank_b) + input_b_shape
        
        # Check if broadcasting produces the expected output shape
        computed_output = []
        for i in range(rank_output):
            dim_a = padded_a[i]
            dim_b = padded_b[i]
            
            # Broadcasting rules: dimensions must be equal, or one must be 1
            if dim_a == dim_b:
                computed_output.append(dim_a)
            elif dim_a == 1:
                computed_output.append(dim_b)
            elif dim_b == 1:
                computed_output.append(dim_a)
            else:
                # Incompatible dimensions
                return False
        
        # Verify computed output matches expected output
        return computed_output == output_shape

    @staticmethod
    def can_broadcast_unidirectionally(
        shape1: list[int], shape2: list[int]
    ) -> bool:
        """
        Check if one shape can be broadcast to another by only prepending 1s.
        
        This is more restrictive than multidirectional broadcasting - it only
        allows adding leading dimensions of size 1, not expanding existing dims.
        
        Args:
            shape1: First tensor shape
            shape2: Second tensor shape
            
        Returns:
            bool: True if unidirectional broadcasting is possible
        """
        if not shape1 or not shape2:
            return True
            
        # Determine which is smaller/larger
        if len(shape1) > len(shape2):
            smaller_shape, larger_shape = shape2, shape1
        else:
            smaller_shape, larger_shape = shape1, shape2
            
        # Quick check: if ranks are equal, shapes must be identical
        if len(smaller_shape) == len(larger_shape):
            return smaller_shape == larger_shape
            
        # Check if smaller shape matches the trailing dimensions of larger shape
        rank_diff = len(larger_shape) - len(smaller_shape)
        return smaller_shape == larger_shape[rank_diff:]

    @staticmethod
    def get_output_shape(shape1: list[int], shape2: list[int]) -> list[int]:
        """
        compute the shape of a multidirectional broadcast
        """
        if not shape1:
            return shape2
        if not shape2:
            return shape1
        if shape1[-1] == shape2[-1]:
            return BinaryOpHelper.get_output_shape(shape1[:-1], shape2[:-1]) + [
                shape1[-1]
            ]
        assert shape1[-1] == 1 or shape2[-1] == 1, "mismatching dimensions"
        return BinaryOpHelper.get_output_shape(shape1[:-1], shape2[:-1]) + [
            max(shape1[-1], shape2[-1])
        ]


class SubmatcherOptDqIni(Matcher):
    """
    Match an initializer with an optional DQ node, i.e., match an initializer
    or a DQ node that has an initializer as its input.
    Raise MatcherError if none of those is found.

    after matching:
    - self.has_dq: if the DQ node is present
    - self.ini: initializer
    """

    def match(self):
        n = self.n
        self.has_dq = n.check(opType.DequantizeLinear)
        if not self.has_dq:
            self.ini = n.require_initializer()
            return
        self.ini = n("x").require_initializer()

    def modify(self):
        pass


class SubmatcherBinLinBias(Matcher):
    """
    Match a binary linear operator that acts as bias, i.e., has one of its
    inputs connected to an initializer (via an optional DQ node).

    FIXME:
    This class is experimental WIP and currently unused.
    There is an idea to use it for ConvtoMatMul...
    """

    def match(self):
        n = self.n
        n.require(OpTypes("Add", "Sub"))
        inputs = n.get_inputs()
        sm_opt_dq_ini = SubmatcherOptDqIni()
        self.bias_input = sm_opt_dq_ini.do_match_first(*inputs)
        self.other_input = [inp for inp in inputs if inp != self.bias_input][0]

    def modify(self):
        pass
