# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.

import numpy as np
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Element


elementwise_operators: list[str] = [
    "Add",
    "Sub",
    "Mul",
    "Div",
    "Abs",
    "Neg",
    "Exp",
    "Log",
    "Pow",
    "Sqrt",
    "Sigmoid",
    "Tanh",
    "Relu",
    "QuickGelu",
    "Gelu",
    "LeakyRelu",
    "Floor",
    "Ceil",
    "Round",
    "Sin",
    "Cos",
    "Tan",
    "Asin",
    "Acos",
    "Atan",
    "And",
    "Equal",
    "Greater",
    "Less",
    "Max",
    "Min",
    "Mean",
    "Or",
    "Sum",
    "Where",
    "Xor",
]


def _is_op_type(l: list[str], e: Element) -> bool:
    "checks if e is node with its op_type in l or the orignal op_type is in l"
    node = e.require_node()
    return node.get_op_type() in l or (
        "orig_type" in node.get_attributes() and node.get_attributes()["orig_type"] in l
    )


def uses_broadcasting(e: Element) -> bool:
    node = e.require_node()
    input1 = node.get_inputs()[0]
    input2 = node.get_inputs()[1]
    return input1.get_shape() != input2.get_shape()


def is_elementwise(e: Element) -> bool:
    if is_binary(e) and uses_broadcasting(e):
        return False

    return _is_op_type(elementwise_operators, e)


def can_multidirectional_broadcasting(shape1: list[int], shape2: list[int]) -> bool:
    for dim1, dim2 in zip(shape1, shape2):
        if dim1 != dim2 and dim1 != 1 and dim2 != 1:
            return False
    return True


unary_operators: list[str] = [
    "Abs",
    "Neg",
    "Exp",
    "Log",
    "Pow",
    "Sqrt",
    "Sigmoid",
    "Tanh",
    "Relu",
    "QuickGelu",
    "Gelu",
    "LeakyRelu",
    "Floor",
    "Ceil",
    "Round",
    "Sin",
    "Cos",
    "Tan",
    "Asin",
    "Acos",
    "Atan",
]
# note: the list contains operator types which can also use broadcasting!
binary_operators: list[str] = [
    "Add",
    "Sub",
    "Mul",
    "Div",
    "And",
    "Equal",
    "Greater",
    "Less",
    "Or",
    "Xor",
]


def is_unary(e: Element) -> bool:
    return _is_op_type(unary_operators, e)


def is_binary(e: Element) -> bool:
    return _is_op_type(binary_operators, e)


def has_same_element_size(
    in0_shape: list[int],
    in1_shape: list[int],
) -> bool:
    """
    Returns whether the two input shapes have the same number of elements.
    """
    if not isinstance(in0_shape, list) or not isinstance(in1_shape, list):
        raise TypeError("in0_shape and in1_shape must be lists of integers.")

    size0 = np.prod(in0_shape) if len(in0_shape) > 0 else 1
    size1 = np.prod(in1_shape) if len(in1_shape) > 0 else 1

    return size0 == size1
