# fmt: on
import numpy as np

from OGOAT.src.L1_fusion.py_match.checkers import opType, DTypes
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BinaryOpHelper
from OGOAT.src.L1_fusion.py_match.helpers.elementwise_helper import (
    has_same_element_size,
)
from OGOAT.src.L1_fusion.py_match.helpers.fusion_configs import FusionConfigs
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Element, Matcher, NoMatch


class BinaryOp(Matcher, QDQHelper):
    """
    unifies the following classic patterns:
    Add, Sub, Mul
    """

    def match(self) -> None:
        n = self.n
        n.require(opType.Add | opType.Sub | opType.Mul | opType.Div)
        self.new_dtype, self.qdq_attributes = self.check_qdq(
            n, DTypes("uint8", "uint16", "int8", "int16")
        )

        extend_qdq = FusionConfigs.get_fusion_configs().extend_qdq
        if not extend_qdq and len(n("B").get_shape()) == 0:
            raise NoMatch("B input should have a non empty shape")

    def modify(self) -> None:
        n = self.n

        in0_shape = n("A").get_shape()
        in1_shape = n("B").get_shape()
        out_shape = n("C").get_shape()

        self.qdq_name = n.get_op_type() + "_qdq_"

        new_type = self.qdq_name + BinaryOp.get_new_type_suffix_with_keyword(
            in0_shape, in1_shape, out_shape, self.new_dtype
        )
        inputs, outputs = self.get_in_out_dict_for_qdq_node(n)

        attributes = n.get_attributes()
        attributes["orig_type"] = n.get_op_type()
        attributes["num_of_tensor_inputs"] = 2
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes | self.qdq_attributes,
        )

    @staticmethod
    def get_new_type_suffix_with_keyword(
        in0_shape: list[int],
        in1_shape: list[int],
        out_shape: list[int],
        new_type_suffix: str,
    ) -> str:
        """
        Returns the  keyword based on the shape of the inputs
        """
        if (
            not isinstance(in0_shape, list)
            or not isinstance(in1_shape, list)
            or not isinstance(out_shape, list)
        ):
            raise TypeError(
                "in0_shape, in1_shape, and out_shape must be lists of integers."
            )

        if in0_shape == in1_shape:
            return "EleWise_" + new_type_suffix
        elif (
            has_same_element_size(in0_shape, in1_shape)
            and BinaryOpHelper.can_broadcast_unidirectionally(in0_shape, in1_shape)
            and BinaryOpHelper.is_valid_unidirectional_broadcast(
                in0_shape, in1_shape, out_shape
            )
        ):
            return "EleWise_" + new_type_suffix
        else:
            return "BroadCast_" + new_type_suffix


binary_op = BinaryOp()
