# fmt: on
from dataclasses import dataclass
from enum import Enum
import onnx
from typing import Optional, Union
from OGOAT.src.L1_fusion.kernel_metadata_loader import KernelMetadataLoader
from OGOAT.src.L1_fusion.py_match.checkers import DTypeAny, DTypes, opType
from OGOAT.src.L1_fusion.py_match.helpers.fusion_configs import FusionConfigs
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    Initializer,
    MatcherError,
    Node,
    NoMatch,
    Nowhere,
    Tensor,
    OutputTensor,
)


@dataclass
class CheckQOrDQResult:
    """
    Result of a check for an optional DQ or Q node.
    present -- if DQ or Q node has been found
    orig_tensor -- the original tensor that was passed to the check
    quant_tensor -- the quantized tensor, input of DQ or output of Q, Nowhere if no DQ or Q
    unquant_tensor -- the unquantized tensor, output of DQ or input of Q, Nowhere is no DQ or Q
    scale - scale of DQ or Q node
    zero_point - zero point of DQ or Q node
    """

    present: bool
    orig_tensor: Element
    quant_tensor: Element
    unquant_tensor: Element
    scale: Optional[Initializer]
    zero_point: Optional[Initializer]

    @property
    def quant_or_orig_tensor(self) -> Element:
        return self.quant_tensor if self.present else self.orig_tensor

    @property
    def unquant_or_orig_tensor(self) -> Element:
        return self.unquant_tensor if self.present else self.orig_tensor


@dataclass
class CheckQDQResult:
    """
    Result of a check for a optional DQ or Q nodes at input and output.
    dq: result of DQ node check
    q: result of Q node check
    q_prm_equal: True if DQ and Q present with equal quantization parameters
    """

    dq: CheckQOrDQResult
    q: CheckQOrDQResult
    q_prm_equal: bool


class InitName(str, Enum):
    """Enumeration of initializer names for QDQ nodes."""
    SCALE = "scale"
    SCALE_ZERO_POINT = "scale_zero_point"
    OUTPUT_SCALE = "output_scale"
    OUTPUT_ZERO_POINT = "output_zero_point"

class QDQHelper(KernelMetadataLoader):
    """
    A helper class providing checks on QDQ nodes.
    """

    def check_qdq_equal_scale_zeropoint(
        self, dq: Element, q: Element, factor_dq: float = 1.0, factor_q: float = 1.0
    ) -> bool:
        """
        Check that scale and zero-point of two QDQ nodes match.
        """
        return self.check_tensor_equal_value(
            dq("x_scale").require_initializer(),
            q("y_scale").require_initializer(),
            factor_dq,
            factor_q,
        ) and self.check_tensor_equal_value(
            dq("x_zero_point").require_initializer(),
            q("y_zero_point").require_initializer(),
        )

    def check_dequantize_equal_scale_zp(self, dq1: Element, dq2: Element) -> bool:
        return self.check_tensor_equal_value(
            dq1("x_scale"), dq2("x_scale")
        ) and self.check_tensor_equal_value(dq1("x_zero_point"), dq2("x_zero_point"))

    def require_qdq_equal_scale_zeropoint(
        self, dq: Element, q: Element, factor_dq: float = 1.0, factor_q: float = 1.0
    ) -> None:
        if not self.check_qdq_equal_scale_zeropoint(dq, q, factor_dq, factor_q):
            raise NoMatch(f"QDQ scale/zero-point do not match: dq={dq}; q={q}")

    def check_tensor_equal_value(
        self,
        tensor_a: Tensor,
        tensor_b: Tensor,
        factor_a: float = 1.0,
        factor_b: float = 1.0,
    ) -> bool:
        """
        Check that tensor has same value.
        """
        if tensor_a.check_nowhere() or tensor_b.check_nowhere():
            return False

        tensor_a = tensor_a.require_initializer()
        tensor_b = tensor_b.require_initializer()
        equal = (
            tensor_a.get_value_as_array() * factor_a
            == tensor_b.get_value_as_array() * factor_b
        )
        if equal:
            tensor_a.flag_used()
            tensor_b.flag_used()
        return equal

    def require_tensor_equal_value(
        self,
        tensor_a: Tensor,
        tensor_b: Tensor,
        factor_a: float = 1.0,
        factor_b: float = 1.0,
    ) -> None:
        if not self.check_tensor_equal_value(tensor_a, tensor_b, factor_a, factor_b):
            raise NoMatch("tensors do not match")

    def is_qdq_present(self, input: Element, output: Element) -> bool:
        """
        Check if QDQ is present
        """
        return input.check(opType.DequantizeLinear) and output.check(
            opType.QuantizeLinear
        )

    def check_qdq(
        self,
        node: Element,
        dtypes: Union[DTypes, DTypeAny],
        non_data_indices: list[int] = None,
    ) -> tuple:
        """
        Check if the node is surrounded by QDQ nodes and return the dtypes and attributes based on this.
         - For missing out q or input dq node, use model act dtype as default dtype, otherwise get the dtype from q or dq node.
         - If dq node at input is missing, add disable_dq0=1 as attribute.
         - If q node at output is missing, add disable_q=1 as attribute.
         - non_data_indices contains the indices of inputs which do not require a DQ node since the input is not data
           but more an attribute (e.g., axis input in ReduceSum).
         return -- (op_type_dtype_suffix: str, qdq_attributes: dict[str, int])
        """
        qdq_attributes: dict[str, int] = {}
        new_dtype: list[str] = []
        has_bf16_kernel_support = any(
            ntype in self.get_native_dtype(node.require_node().get_op_type())
            for ntype in ("bf16", "any")
        )
        extend_qdq = FusionConfigs.get_fusion_configs().extend_qdq

        for i, input in enumerate(node.get_inputs()):
            if non_data_indices and i in non_data_indices:
                continue
            try:
                # we don't check or use the dtype of third input which is optional in most cases.
                # For ex: WCR_optimized_CLIP-patch16_model has 3 inputs in LayerNorm, third input has different dtype than the first two and first two are only considered
                if not input.check_initializer():
                    input.require(opType.DequantizeLinear)
                if i < 2:
                    input("x_zero_point").require(dtypes)
                    new_dtype.append(input("x_zero_point").get_dtype())
                    qdq_attributes[f"disable_dq{i}"] = 0
            except NoMatch:
                if has_bf16_kernel_support and extend_qdq:
                    qdq_attributes[f"disable_dq{i}"] = 1
                    if i < 2:
                        new_dtype.append(node.get_input_dtype(i))
                else:
                    raise NoMatch(
                        f"Not a QDQ node nor kernel has bf16 support: {node.get_name()}"
                    )

        for i, output in enumerate(node.get_outputs()):
            try:
                output("y_zero_point").require(dtypes)
                new_dtype.append(output("y_zero_point").require_tensor().get_dtype())
                qdq_attributes["disable_q"] = 0
            except NoMatch:
                if has_bf16_kernel_support and extend_qdq:
                    new_dtype.append(node.get_output_dtype(i))
                    qdq_attributes["disable_q"] = 1
                else:
                    raise NoMatch(
                        f"Not a QDQ node nor kernel has bf16 support: {node.get_name()}"
                    )

        return "x".join(new_dtype), qdq_attributes

    def get_in_out_dict_for_qdq_node(
        self, node: Element
    ) -> tuple[dict[str, Element], dict[str, Element]]:
        """
        Get the input and output dictionary for a node surrounded by QDQ.
        """
        n = node.require_node()

        inputs: dict[str, Element] = {}
        outputs: dict[str, Element] = {}

        # iterate through inputs and check if they are dequantized
        for input, input_name in zip(n.get_inputs(), n.get_schema_input_names()):
            if input.check(opType.DequantizeLinear):
                inputs[input_name] = input("x")
            else:
                inputs[input_name] = input

        # iterate through inputs and add the scale and zero point to the inputs dict
        for input, input_name in zip(n.get_inputs(), n.get_schema_input_names()):
            if input.check(opType.DequantizeLinear):
                inputs[f"{input_name}_scale"] = input("x_scale")
                inputs[f"{input_name}_zero_point"] = input("x_zero_point")

            else:
                inputs[f"{input_name}_scale"] = None
                inputs[f"{input_name}_zero_point"] = None

        # get the output scale and zero point
        for output, output_name in zip(n.get_outputs(), n.get_schema_output_names()):
            if output.check(opType.QuantizeLinear):
                inputs[f"{output_name}_scale"] = output("y_scale")
                inputs[f"{output_name}_zero_point"] = output("y_zero_point")
                outputs[f"{output_name}"] = output("y")

            else:
                inputs[f"{output_name}_scale"] = None
                inputs[f"{output_name}_zero_point"] = None
                outputs[f"{output_name}"] = output

        return inputs, outputs

    def check_q(self, output_tensor: Element) -> CheckQOrDQResult:
        try:
            q = output_tensor.require(opType.QuantizeLinear)
            q_scale = q("y_scale").require_initializer()
            q_zero_point = q("y_zero_point").require_initializer()
            return CheckQOrDQResult(
                present=True,
                orig_tensor=output_tensor,
                quant_tensor=q("y"),
                unquant_tensor=q("x"),
                scale=q_scale,
                zero_point=q_zero_point,
            )
        except MatcherError:
            return CheckQOrDQResult(
                present=False,
                orig_tensor=output_tensor,
                quant_tensor=Nowhere(
                    output_tensor._model_dict,
                    output_tensor._walk_cfg,
                    "no quant tensor",
                ),
                unquant_tensor=Nowhere(
                    output_tensor._model_dict,
                    output_tensor._walk_cfg,
                    "no unquant tensor",
                ),
                scale=None,
                zero_point=None,
            )

    def check_dq(self, input_tensor: Element) -> CheckQOrDQResult:
        try:
            dq = input_tensor.require(opType.DequantizeLinear)
            dq_scale = dq("x_scale").require_initializer()
            dq_zero_point = dq("x_zero_point").require_initializer()
            return CheckQOrDQResult(
                present=True,
                orig_tensor=input_tensor,
                quant_tensor=dq("x"),
                unquant_tensor=dq("y"),
                scale=dq_scale,
                zero_point=dq_zero_point,
            )
        except MatcherError:
            return CheckQOrDQResult(
                present=False,
                orig_tensor=input_tensor,
                quant_tensor=Nowhere(
                    input_tensor._model_dict, input_tensor._walk_cfg, "no quant tensor"
                ),
                unquant_tensor=Nowhere(
                    input_tensor._model_dict,
                    input_tensor._walk_cfg,
                    "no unquant tensor",
                ),
                scale=None,
                zero_point=None,
            )

    def check_dq_and_q(
        self, input_tensor: Element, output_tensor: Element
    ) -> CheckQDQResult:
        """
        Check optional DQ and Q nodes at input and output.
        """
        dq = self.check_dq(input_tensor)
        q = self.check_q(output_tensor)

        q_prm_equal = False
        if dq.present and q.present:
            q_prm_equal = self.check_qdq_equal_scale_zeropoint(
                input_tensor, output_tensor
            )

        return CheckQDQResult(dq=dq, q=q, q_prm_equal=q_prm_equal)

    def check_input_output_qdq(self, node: Node) -> tuple[bool, bool]:
        """_summary
        checks multiple qdq properties

        returns:
            - has_qdq: all inputs and outputs have qdq nodes
            - has_same_values: all input and outputs have the same scale and zero point values

        Args:
            node (Node): node
        """
        has_q = True
        has_dq = True
        has_same = True
        for inp in node.get_act_inputs():
            for out in node.get_outputs():
                has_q = has_q and out.check(opType.QuantizeLinear)
                has_dq = has_dq and inp.check(opType.DequantizeLinear)
                if has_q and has_dq:
                    has_same = has_same and self.check_qdq_equal_scale_zeropoint(inp, out)

        has_qdq = has_q and has_dq
        return has_qdq, has_qdq and has_same


    def _get_initializer_or_dummy(
        self,
        initializer: Optional[Initializer],
        n: Node,
        init_name: InitName,
    ) -> Initializer:
        """Get the QDQ scale or zero point initializer if it exists, otherwise create a dummy initializer."""

        if initializer and initializer.check_initializer():
            return initializer

        qdq_info = n._model_dict.get_quantization_information(n.get_name())

        if init_name in (InitName.SCALE, InitName.OUTPUT_SCALE):
            value = qdq_info.scale if qdq_info else 1.0
            dtype = qdq_info.scale_dtype if qdq_info else onnx.TensorProto.FLOAT

        elif init_name in (InitName.SCALE_ZERO_POINT, InitName.OUTPUT_ZERO_POINT):
            value = qdq_info.zero_point if qdq_info else 0
            dtype = qdq_info.zero_point_dtype if qdq_info else onnx.TensorProto.UINT8

        return self.add_initializer(
            initializer_name=f"dummy_init_{n.get_name()}_{init_name.value}",
            value=value,
            dtype=dtype,
        )
        return has_qdq, has_q, has_dq, has_same

    def go_through_downward_qdq_chain(
        self, output_tensor: OutputTensor
    ) -> OutputTensor:
        """
        From current node, go through "Current_Node -> Q -> DQ ->" downwards qdq chain if present, and return the output of the DQ node.
        If no QDQ chain is present, return the output tensor as is.
        """
        try:
            q_node = output_tensor.require(opType.QuantizeLinear).require_node()
            dq_node = q_node("y").require(opType.DequantizeLinear).require_node()
            self.require_qdq_equal_scale_zeropoint(dq_node, q_node)
            return dq_node("y")
        except MatcherError:
            return output_tensor
        
    def _get_initializer_or_dummy(
        self,
        initializer: Optional[Initializer],
        n: Node,
        init_name: InitName,
    ) -> Initializer:
        """Get the QDQ scale or zero point initializer if it exists, otherwise create a dummy initializer."""

        if initializer and initializer.check_initializer():
            return initializer

        qdq_info = n._model_dict.get_quantization_information(n.get_name())

        if init_name in (InitName.SCALE, InitName.OUTPUT_SCALE):
            value = qdq_info.scale if qdq_info else 1.0
            dtype = qdq_info.scale_dtype if qdq_info else onnx.TensorProto.FLOAT

        elif init_name in (InitName.SCALE_ZERO_POINT, InitName.OUTPUT_ZERO_POINT):
            value = qdq_info.zero_point if qdq_info else 0
            dtype = qdq_info.zero_point_dtype if qdq_info else onnx.TensorProto.UINT8

        return self.add_initializer(
            initializer_name=f"dummy_init_{n.get_name()}_{init_name.value}",
            value=value,
            dtype=dtype,
        )