# fmt: on

from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    NoMatch,
    Node,
)
import numpy as np


def get_value_from_dequantize_linear(node: Node) -> any:
    # TODO add QuantizeLinear support
    if (
        not node.check(opType.DequantizeLinear)
        or not node("x").require_tensor().check_initializer()
    ):
        raise NoMatch(f"must be a DequantizeLinear node with initializer input")
    node_x = get_scalar_tensor_value(node("x").require_tensor().get_initializer_array())

    node_scale = get_scalar_tensor_value(
        node("x_scale").require_tensor().get_initializer_array()
    )
    node_zp = get_scalar_tensor_value(
        node("x_zero_point").require_tensor().get_initializer_array()
    )
    return (node_x - node_zp) * node_scale


def get_scalar_tensor_value(tensor_value: any):
    if isinstance(tensor_value, np.ndarray):
        if tensor_value.ndim == 0:
            return tensor_value.item()
        tensor_value = tensor_value.squeeze().tolist()
    if np.isscalar(tensor_value):
        return tensor_value
    elif isinstance(tensor_value, list) and len(tensor_value) >= 1:
        flattened_value = flatten_list(tensor_value)
        if len(flattened_value) == 1:
            return flattened_value[0]
        else:
            raise NoMatch(f"must be a list of same value")
    else:
        raise NoMatch(f"must be a scalar value or a list of same value")

def get_zp_dtype_from_attribute(n: Element) -> dict[str, str]:
    if (
        "orig_x_zero_point_dtype" in n.get_attributes()
        and "orig_y_zero_point_dtype" in n.get_attributes()
    ):
        return {
            "x_zero_point_type": n.get_attributes()["orig_x_zero_point_dtype"],
            "y_zero_point_type": n.get_attributes()["orig_y_zero_point_dtype"],
        }

    else:
        return {
            "x_zero_point_type": n._model_dict._activation_type,
            "y_zero_point_type": n._model_dict._activation_type,
        }

def flatten_list(tensor_list: list) -> list:
    # Return True if the flattened list, all value are same
    flattened = set()

    def flatten(current: list):
        for item in current:
            if isinstance(item, list):
                flatten(item)
            else:
                flattened.add(item)

    flatten(tensor_list)
    return list(flattened)
