# fmt: on

from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    NoMatch,
    Element,
    MatcherError,
    Node,
    InputTensor,
    Tensor,
)
from OGOAT.src.L1_fusion.py_match.clean.remove_qdq import (
    Remove_Q_Plus_DQ,
)

import numpy as np
from onnx.helper import tensor_dtype_to_np_dtype
from typing import Any, Tuple

from OGOAT.src.L1_fusion.py_match.py_match_utils import get_scalar_tensor_value
from OGOAT.src.L1_fusion.py_match.skip import WalkCfgSkipNoop


# Base class for the two remove binary variants
class RemoveBinary(Matcher, QDQHelper):

    def match_qdq(self) -> None:
        """
        Call when both input of the binary operation are dequantized nodes.
        Make sure that one of them is a constant from which we can extract
        the k value that will be used for the folding.
        If that's the case the second dequantize is treated as the activation
        dequantize, regardless if it is constant or not.
        """
        scalar_value = 0
        value_A = None
        value_B = None
        n = self.n
        if n("A.x").check_initializer():
            self.flag_initializers_used(n("A.x"), n("A.x_scale"), n("A.x_zero_point"))
            value_A = n("A.x").require_tensor().get_initializer_array_or_none()
        elif n("B.x").check_initializer():
            self.flag_initializers_used(n("B.x"), n("B.x_scale"), n("B.x_zero_point"))
            value_B = n("B.x").require_tensor().get_initializer_array_or_none()
        elif n("A.x.x").check(opType.Constant):
            self.flag_initializers_used(
                n("A.x.x"), n("A.x.x_scale"), n("A.x.x_zero_point")
            )
            # either A or B can be constant scalar
            value_list_A = list(n("A.x.x").require_node().get_attributes().values())
            if len(value_list_A) > 0:
                value_A = value_list_A[0]
        elif n("B.x.x").check(opType.Constant):
            self.flag_initializers_used(
                n("B.x.x"), n("B.x.x_scale"), n("B.x.x_zero_point")
            )
            value_list_B = list(n("B.x.x").require_node().get_attributes().values())
            if len(value_list_B) > 0:
                value_B = value_list_B[0]
        if value_A is None and value_B is None:
            raise NoMatch(
                f"Remove binary op only if one of the dequantize linear input has const scalar value"
            )

        if value_A is not None and (
            self.n.check(opType.Sub) or self.n.check(opType.Div)
        ):
            raise NoMatch(
                "Qdq initializer: Div and Sub are not supported when weight is the first input"
            )

        # if const scalar in B path, activation path is in A path, check
        # the dequantizelinear and quantizelinear have the same value and save to remove_qdq
        if value_B is not None:
            scalar_value = get_scalar_tensor_value(value_B)
            self.dequant_activation = n("A")
            scalar_node = n("B")
        else:
            scalar_value = get_scalar_tensor_value(value_A)
            self.dequant_activation = n("B")
            scalar_node = n("A")

        scalar_node_zero_pt = (
            scalar_node("x_zero_point").require_tensor().get_initializer_array_or_none()
        )
        scalar_node_scale = (
            scalar_node("x_scale").require_tensor().get_initializer_array_or_none()
        )

        # store dtype for creating new initializers with the same dtype
        self.zero_point_dtype = (
            scalar_node("x_zero_point").require_tensor().get_dtype_raw()
        )
        self.scale_dtype = scalar_node("x_scale").require_tensor().get_dtype_raw()

        if scalar_node_zero_pt is None or scalar_node_scale is None:
            raise NoMatch(f"Remove binary op only if input with valid scalar value")

        scalar_node_zero_pt = get_scalar_tensor_value(scalar_node_zero_pt)
        scalar_node_scale = get_scalar_tensor_value(scalar_node_scale)
        self.k_value = (scalar_value - scalar_node_zero_pt) * scalar_node_scale

    def match_binary_op(self, n: Element) -> None:
        """
        Check that the central node is a Binary operation.
        With an dequantize activation input, a second input that is a constant
        and an quantized output.
        The constant input can be:
          - An Initializer
          - An dequantize node with a constant or initializer input

              constant --------------
                                     |
                                     V
        dequantized_activation -> BinaryOp -> quantized_output

        If matched we create save quantize node, the dequantize activation node,
        and the const value of the input, in the resp. attributes:
          - self.quant_output
          - self.dequant_activation
          - self.k_value
        """

        n.require(opType.Add | opType.Sub | opType.Mul | opType.Div)

        self.quant_output = n("C").require(opType.QuantizeLinear)
        if n("A").check(opType.DequantizeLinear) and n("B").check_initializer():
            initializer_val = n("B").require_tensor().get_initializer_array_or_none()
            self.k_value = get_scalar_tensor_value(initializer_val)
            self.dequant_activation = n("A")
        elif n("B").check(opType.DequantizeLinear) and n("A").check_initializer():
            if self.n.check(opType.Sub) or self.n.check(opType.Div):
                raise NoMatch(
                    "non qdq initializer: Div and Sub are not supported when weight is the first input"
                )

            initializer_val = n("A").require_tensor().get_initializer_array_or_none()
            self.k_value = get_scalar_tensor_value(initializer_val)
            self.dequant_activation = n("B")
        elif n("A").check(opType.DequantizeLinear) and n("B").check(
            opType.DequantizeLinear
        ):
            self.match_qdq()
        else:
            raise NoMatch(f"{self} does not match to critieria to remove binary")

    def find_destination_node(self):
        """
        Called once we have matched the binary operation and extracted the
        dequantize activation node, the quantize output and the k value of the
        constant input.

        At this point we need to determine in which node we will fold the binary
        operation into.

        We are matching the following chain with the optional Quantize and Dequantize.
        By default we select the 'Quantize' node of the 'Op' to be updated if
        the '[Quantize]-> Dequantize' chain can be remove, otherwise we update
        the 'Dequantize' of the 'Op'

        [Quantize]-> Dequantize -> Op -> Quantize -> [Dequantize]

        This method will set the following attributes:
          - self.dst_node
          - self.src_node
          - self.dst_scale
          - self.dst_zero_pt
        """

        self.branch_on_dequant_activation = False

        try:
            remove_qdq = Remove_Q_Plus_DQ()
            remove_qdq.n = self.dequant_activation("x")
            remove_qdq.match()
            self.remove_q_plus_dq_matched = True
        except MatcherError:
            self.remove_q_plus_dq_matched = False

        if len(
            self.dequant_activation.require_tensor().get_readers()
        ) > 1 and self.dequant_activation("x").check(opType.QuantizeLinear):
            self.branch_on_dequant_activation = True
        # If there is branching just above the dequant, the Q/DQ removal at the
        # end of the optimization will also remove Q/DQ on the other branches,
        # which creates issues, so avoid that as well by treating this as branch.
        if len(self.dequant_activation("x").require_tensor().get_readers()) > 1:
            self.branch_on_dequant_activation = True

        # If one of the quantized can be remove will we fold
        # in the dequantize instead, else ...
        if (not self.remove_q_plus_dq_matched) or self.branch_on_dequant_activation:
            self.dst_node = self.dequant_activation.require_node()
            self.src_node = self.quant_output

            self.dst_scale = self.dst_node("x_scale")
            self.dst_zero_pt = self.dst_node("x_zero_point")
            return

        # ... we will fold in the quantize node
        self.dst_node = self.quant_output.require_node()
        self.src_node = self.dequant_activation

        self.dst_scale = self.dst_node("y_scale")
        self.dst_zero_pt = self.dst_node("y_zero_point")

    def create_temp_dequant_node(self) -> None:
        """
        Create a temporary dequant node that will be used to fold the binary
        operation into the dequantize activation node.
        This is needed when the dequantize activation node has more than one
        reader and we cannot fold into the quantize output node."""
        dequant_input = self.dequant_activation("x")
        new_name = dequant_input.get_name() + "_dq"
        new_tensor = Tensor(
            dequant_input._model_dict, dequant_input._walk_cfg, new_name + "_out", None
        )

        # Create a new output tensor
        dequant_output = self.dequant_activation("y")
        shape = dequant_output.get_shape()
        dtype = dequant_output.get_dtype()
        new_tensor.set_shape(shape, dtype)

        self.dst_node = self.add_node(
            type="DequantizeLinear",
            domain="ai.onnx",
            inputs={
                "x": dequant_input,
                "x_scale": self.dequant_activation("x_scale"),
                "x_zero_point": self.dequant_activation("x_zero_point"),
            },
            outputs={"y": new_tensor},
            attributes={
                "orig_name": "temp_dequant",
            },
            new_name=new_name,
        )
        self.dst_scale = self.dst_node("x_scale")
        self.dst_zero_pt = self.dst_node("x_zero_point")
        self.replace_input(self.n, self.dequant_activation, new_tensor)
        self.dequant_activation = None
        for input_tensor in self.n.get_inputs():
            if new_tensor == input_tensor:
                self.dequant_activation = input_tensor
                break
        if self.dequant_activation is None:
            raise NoMatch(
                f"{new_tensor.get_name()} does not replace the dequantize activation input properly"
            )

    def check_needed_values(self) -> None:
        # scale_new = scale/k, when opType is Div or Mul depending on the destination node.
        # Avoid calculating new scale by dividing k_value that is 0
        if self.k_value == 0:
            if self.dst_node.check(opType.DequantizeLinear) and self.n.check(
                opType.Div
            ):
                raise NoMatch(
                    f"when opType is Div, remove binary op only if k_value is not zero, to avoid ZeroDivisionError"
                )
            elif self.dst_node.check(opType.QuantizeLinear) and self.n.check(
                opType.Mul
            ):
                raise NoMatch(
                    f"when opType is Mul, remove binary op only if k_value is not zero, to avoid ZeroDivisionError"
                )

        # k/scale is performed to compute the new zp if opType is Add or a Sub.
        # Avoid calculating new zero_point by dividing scale that is 0
        if self.dst_scale_value == 0 and self.n.check(opType.Add | opType.Sub):
            raise NoMatch(
                f"when opType is Add or Sub, remove binary op only if y_scale is not zero, to avoid ZeroDivisionError"
            )

    def check_new_qdq_parameter_fits(self) -> None:
        scale_dtype = tensor_dtype_to_np_dtype(
            self.dst_scale.require_tensor().get_dtype_raw()
        )
        zp_dtype = tensor_dtype_to_np_dtype(
            self.dst_zero_pt.require_tensor().get_dtype_raw()
        )

        # Add special case for uint4, int4
        ml_dtypes_ranges = {
            "int4": (-8, 7),
            "uint4": (0, 15),
        }

        if str(zp_dtype) in ml_dtypes_ranges:
            zp_min, zp_max = ml_dtypes_ranges[str(zp_dtype)]
        else:
            zp_dtype_info = np.iinfo(zp_dtype)
            zp_min = zp_dtype_info.min
            zp_max = zp_dtype_info.max

        scale_dtype_info = np.finfo(scale_dtype)

        if zp_min > self.new_zp or zp_max < self.new_zp:
            raise NoMatch(f"New zp value {self.new_zp} cannot fit in dtype {zp_dtype}")

        if (
            scale_dtype_info.min > self.new_scale
            or scale_dtype_info.max < self.new_scale
        ):
            raise NoMatch(
                f"New scale value {self.new_scale} cannot fit in dtype {scale_dtype}"
            )

    def match(self) -> None:
        n = self.n = self.n.with_walk_cfg(WalkCfgSkipNoop())

        self.match_binary_op(n)

        self.find_destination_node()

        # Extract the value of the scale and zero point of the destination node
        self.dst_scale_value = get_scalar_tensor_value(
            self.dst_scale.require_tensor().get_initializer_array_or_none()
        )
        self.dst_zero_pt_value = get_scalar_tensor_value(
            self.dst_zero_pt.require_tensor().get_initializer_array_or_none()
        )

        # Check that the value needed to compute the new scale and zp are correct
        self.check_needed_values()

        # Compute the new value of the scale and zp of the destination node
        self.new_scale, self.new_zp = self.compute_new_scale_and_zp_values()

        # Check that the new scale and zp fits in the existing dtype of the destination
        # node
        self.check_new_qdq_parameter_fits()

    def compute_new_scale_and_zp_values(self) -> Tuple[float, int]:
        n = self.n

        new_zp = self.dst_zero_pt_value
        new_scale = self.dst_scale_value
        if n.check(opType.Add):
            # Add + Dequant     --> Dequant  -----> zero_point_new = zp - k/scale
            if self.dst_node.check(opType.DequantizeLinear):
                new_zp -= int(self.k_value / new_scale)
            # Add + Quant     --> Quant      --------> zero_point_new = zp + k/scale
            else:
                new_zp += int(self.k_value / new_scale)
        elif n.check(opType.Sub):
            # Sub + Dequant       --> Dequant     -------------------> zero_point_new = zp + k/scale
            if self.dst_node.check(opType.DequantizeLinear):
                new_zp += int(self.k_value / new_scale)
            # Sub + Quant       --> Quant     -------------------> zero_point_new = zp - k/scale
            else:
                new_zp -= int(self.k_value / new_scale)
        elif n.check(opType.Mul):
            # Mul + Dequant     --> Dequant      -------------------> scale_new = scale*k
            if self.dst_node.check(opType.DequantizeLinear):
                new_scale *= self.k_value
            # Mul + Quant     --> Quant      -------------------> scale_new = scale/k
            else:
                new_scale /= self.k_value
        elif n.check(opType.Div):
            # Div + Dequant       --> Dequant     -------------------> scale_new = scale/k
            if self.dst_node.check(opType.DequantizeLinear):
                new_scale /= self.k_value
            # Div + Quant       --> Quant     -------------------> scale_new = scale * k
            else:
                new_scale *= self.k_value
        else:
            assert False, "opType is not implemented"

        return new_scale, new_zp

    def update_initializer(
        self, node: Node, initializer: InputTensor, new_value: Any, dtype_raw: Any
    ) -> None:
        # If the initializer is only used by one node, simply change its value.
        # Otherwise ...
        if len(initializer.get_readers()) == 1:
            initializer.require_initializer().update_initializer_value(new_value)
            return

        # ... create a new initializer with the new value and replace the old one with it.
        new_name = initializer.get_name() + "_mod"
        new_initializer = self.add_initializer(
            new_name, new_value, initializer.get_dtype_raw()
        )
        self.replace_input(node, initializer, new_initializer)

    def modify(self) -> None:
        n = self.n.require_node()
        if self.branch_on_dequant_activation:
            self.create_temp_dequant_node()

        # Update the new value of the scale and zp of the destination node
        if n.check(opType.Add | opType.Sub):
            self.update_initializer(
                self.dst_node, self.dst_zero_pt, self.new_zp, self.zero_point_dtype
            )
            self.dst_node.set_attribute(f"{n.get_op_type()}_const_value", self.k_value)

        elif n.check(opType.Mul | opType.Div):
            self.update_initializer(
                self.dst_node, self.dst_scale, self.new_scale, self.scale_dtype
            )
            self.dst_node.set_attribute(f"{n.get_op_type()}_const_value", self.k_value)

        # save pointer to the q and dq nodes
        quant_output_node = self.quant_output.require_node()
        dequant_act_node = self.dequant_activation.require_node()

        # Connect the destination node to the src node
        # We always connect the dequantize to the quantize.
        self.connect(self.quant_output, self.dequant_activation)

        # If folding happened in the dequantize node.
        # We possibly went from:
        #   dequantize -> [ bin_op ] -> quantize -> dequantize
        # to:
        #   dequantize -> quantize -> dequantize
        # see if we can possibly remove the quantize -> dequantize chain
        if self.dst_node.check(opType.DequantizeLinear):
            chain_start = quant_output_node
        # If folding happened in the quantize node.
        # We possibly went from:
        #   quantize -> dequantize -> [ bin_op ] -> quantize
        # to:
        #   quantize -> dequantize -> quantize
        # see if we can possibly remove the quantize -> dequantize chain
        else:
            chain_start = dequant_act_node("x")

        # Try to match and remove the qdq chain
        try:
            remove_qdq = Remove_Q_Plus_DQ()
            remove_qdq.n = chain_start
            remove_qdq.match()
            remove_qdq.modify()
        except MatcherError:
            pass
        # Remove the binary node
        self.remove_node(n)
