# fmt: on
from typing import Any

import onnx

from OGOAT.src.L1_fusion.py_match.checkers import DTypeAny, DTypes, opType
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.helpers.noop_helper import NoopHelper
from OGOAT.src.L1_fusion.py_match.helpers.transpose_with_optional_qdq import (
    TransposeWithOptionalQDQ,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Category,
    Element,
    Matcher,
    NoMatch,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import TransposeHelper


class GroupNorm(Matcher, TransposeHelper, QDQHelper):
    def match(self) -> None:
        # this pattern handles no-ops explicitly -> pain walk config
        n = self.n = self.n.with_walk_cfg(WalkCfgPlain())

        # qdq InstanceNormalization
        n.require(opType.InstanceNormalization)

        # the inputs "B" and "scale" from the InstanceNorm node are currently not used and expected to have the following values (see if-check)
        b = n("B.x").require_initializer().get_value()
        scale = n("scale.x").require_initializer().get_value()
        if b != [0] * 32 or scale != [255] * 32:
            raise NoMatch("not supported input")
        n("B.x").require_initializer().flag_used()
        n("B.x_scale").require_initializer().flag_used()
        n("B.x_zero_point").require_initializer().flag_used()
        n("scale.x").require_initializer().flag_used()
        n("scale.x_scale").require_initializer().flag_used()
        n("scale.x_zero_point").require_initializer().flag_used()

        self.inst_norm_input = n("input").require(opType.DequantizeLinear)
        self.inst_norm_output = n("output").require(opType.QuantizeLinear)

        # NHWC tranpose inout
        self.inst_norm_input, self.inst_norm_output = self.require_nchw_conversion(
            self.inst_norm_input, self.inst_norm_output
        )
        if self.inst_norm_input.has_qdq():
            self.require_qdq_equal_scale_zeropoint(
                n("input"), self.inst_norm_input.quantize_node
            )
            if self.inst_norm_input("x").check(opType.QuantizeLinear):
                self.require_qdq_equal_scale_zeropoint(
                    self.inst_norm_input.dequantize_node, self.inst_norm_input("x")
                )
        else:
            if self.inst_norm_input("x").check(opType.QuantizeLinear):
                self.require_qdq_equal_scale_zeropoint(
                    n("input"), self.inst_norm_input("x")
                )

        # Match an optional DQ/Q pair
        if self.inst_norm_input("x").check(
            opType.QuantizeLinear
        ) and self.inst_norm_input("x.x").check(opType.DequantizeLinear):
            self.require_qdq_equal_scale_zeropoint(
                self.inst_norm_input("x.x"), self.inst_norm_input("x")
            )
            self.inst_norm_input = self.inst_norm_input("x.x")

        # Match the reshape in the input
        if self.inst_norm_input("x").check(opType.QuantizeLinear):
            self.reshape_in = self.inst_norm_input("x.x")
            self.reshape_in.require(opType.Reshape)
            self.reshape_in_data_xy = self.reshape_in("data.x")
            self.reshape_in_data_zero_point = self.reshape_in("data.x_zero_point")
            self.reshape_in_data_scale = self.reshape_in("data.x_scale")
            self.reshape_in("data").require(opType.DequantizeLinear)
            self.reshape_in_data_zero_point.require(DTypes("uint8", "uint16"))
        else:
            self.reshape_in = self.inst_norm_input("x")
            self.reshape_in.require(opType.Reshape)
            self.reshape_in_data_xy = self.reshape_in("data")
            self.reshape_in_data_zero_point = n("input.x_zero_point")
            self.reshape_in_data_scale = n("input.x_scale")
            self.reshape_in_data_zero_point.require(DTypes("uint8", "uint16"))

        if self.inst_norm_output.has_qdq():
            self.require_qdq_equal_scale_zeropoint(
                self.inst_norm_output.dequantize_node, n("output")
            )
            if self.inst_norm_output("y").check(opType.DequantizeLinear):
                self.require_qdq_equal_scale_zeropoint(
                    self.inst_norm_output("y"), self.inst_norm_output.quantize_node
                )
            else:
                self.require_qdq_equal_scale_zeropoint(
                    self.inst_norm_output("y.reshaped"),
                    self.inst_norm_output.quantize_node,
                )
        else:
            if self.inst_norm_output("y").check(opType.DequantizeLinear):
                self.require_qdq_equal_scale_zeropoint(
                    self.inst_norm_output("y"), n("output")
                )
            else:
                self.require_qdq_equal_scale_zeropoint(
                    self.inst_norm_output("y.reshaped"), n("output")
                )

        # Match an optional DQ/Q pair
        if self.inst_norm_output("y").check(
            opType.DequantizeLinear
        ) and self.inst_norm_output("y.y").check(opType.QuantizeLinear):
            self.require_qdq_equal_scale_zeropoint(
                self.inst_norm_output("y"), self.inst_norm_output("y.y")
            )
            self.inst_norm_output = self.inst_norm_output("y.y")

        # Match the reshape in the output
        inp_to_reshape_out = self.inst_norm_output("y")
        if inp_to_reshape_out.check(opType.DequantizeLinear):
            # QDQ reshape
            reshape_out = inp_to_reshape_out("y").require(opType.Reshape)
            reshape_out("reshaped").require(opType.QuantizeLinear)
            self.require_qdq_equal_scale_zeropoint(
                inp_to_reshape_out, reshape_out("reshaped")
            )
            outp_of_reshape_out = reshape_out("reshaped.y")
        else:
            # non-QDQ reshape
            reshape_out = inp_to_reshape_out.require(opType.Reshape)
            outp_of_reshape_out = reshape_out("reshaped")
        # const_eval must have converted the shape to an initializer
        reshape_out("shape").require_initializer()

        # Match the mul qdq + transpose
        outp_of_reshape_out.require(opType.DequantizeLinear)
        self.mul_out = outp_of_reshape_out("y").require(opType.Mul)
        self.mul_out("B").require(opType.DequantizeLinear)
        self.mul_out("C").require(opType.QuantizeLinear)

        self.require_qdq_equal_scale_zeropoint(self.mul_out("C.y"), self.mul_out("C"))

        # match the add qdq + transpose
        self.mul_out("C.y").require(opType.DequantizeLinear)
        self.add_out = self.mul_out("C.y.y").require(opType.Add)
        self.add_out("B").require(opType.DequantizeLinear)
        self.add_out("C").require(opType.QuantizeLinear)
        self.add_out("C.y_zero_point").require(DTypes("uint8", "uint16"))

    def modify(self) -> None:
        n = self.n
        new_type = (
            "GroupNormalization_qdq_"
            + self.reshape_in_data_zero_point.get_dtype()
            + "x"
            + self.mul_out("B.x_zero_point").get_dtype()
            + "x"
            + self.add_out("C.y_zero_point").get_dtype()
        )
        inputs = {
            "data": self.reshape_in_data_xy,
            "mul_B": self.mul_out("B.x"),
            "add_B": self.add_out("B.x"),
            "shape": self.reshape_in("shape"),
            "data_scale": self.reshape_in_data_scale,
            "data_zero_point": self.reshape_in_data_zero_point,
            "mul_B_scale": self.mul_out("B.x_scale"),
            "mul_B_zero_point": self.mul_out("B.x_zero_point"),
            "add_B_scale": self.add_out("B.x_scale"),
            "add_B_zero_point": self.add_out("B.x_zero_point"),
            "output_scale": self.add_out("C.y_scale"),
            "output_zero_point": self.add_out("C.y_zero_point"),
        }
        outputs = {
            "output": self.add_out("C.y"),
        }
        copy_attributes = n.get_attributes()
        copy_attributes["num_of_tensor_inputs"] = 3
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes,
        )


class LayerNorm(Matcher, TransposeHelper, QDQHelper):

    def match(self) -> None:
        n = self.n.with_walk_cfg(WalkCfgPlain())
        n.require(opType.LayerNormalization)
        axis = n.get_attribute_value("axis")
        # Default value of axis is -1 which is supported
        # Accept axis = -1 or axis = last dimension index
        if axis is not None and axis != -1 and axis != (len(n("X").get_shape()) - 1):
            raise NoMatch(
                "LayerNormalization_qdq is only supported in NPU when axis is on the last dimension."
            )

        self.has_scale_transpose = False
        self.has_bias_transpose = False
        if n("Scale").check(opType.Transpose):
            self.has_scale_transpose = True
        if n("B").check(opType.Transpose):
            self.has_bias_transpose = True

        self.new_dtype, self.qdq_attributes = self.check_qdq(
            n, DTypes("uint8", "uint16", "int8", "int16")
        )

    def modify(self) -> None:
        n = self.n.with_walk_cfg(WalkCfgPlain())
        new_type = "LayerNormalization_qdq_" + self.new_dtype
        inputs, outputs = self.get_in_out_dict_for_qdq_node(n)
        if self.has_bias_transpose:
            bias_trans = TransposeWithOptionalQDQ.match(n("B"))
            new_initializer_name = n("B").get_name() + "_trans"
            initializer_transposed = self.add_transposed_initializer(
                bias_trans.transpose_node("data.x").require_initializer(),
                new_initializer_name,
            )
            inputs["B"] = initializer_transposed
            inputs["B_scale"] = bias_trans.transpose_node("data.x_scale")
            inputs["B_zero_point"] = bias_trans.transpose_node("data.x_zero_point")

        if self.has_scale_transpose:
            scale_trans = TransposeWithOptionalQDQ.match(n("Scale"))
            new_initializer_name = n("Scale").get_name() + "_trans"
            initializer_transposed = self.add_transposed_initializer(
                scale_trans.transpose_node("data.x").require_initializer(),
                new_initializer_name,
            )
            inputs["Scale"] = initializer_transposed
            inputs["Scale_scale"] = scale_trans.transpose_node("data.x_scale")
            inputs["Scale_zero_point"] = scale_trans.transpose_node("data.x_zero_point")

        copy_attributes = n.get_attributes()
        copy_attributes["num_of_tensor_inputs"] = 3
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes | self.qdq_attributes,
        )


class LpNormIndividual(Matcher, QDQHelper):
    """
    This matcher detects L2 normalization patterns that compute: output = X / sqrt(sum(X^2))

    SUPPORTED MATHEMATICAL PATTERNS ( only 4 combinations at the moment):
    ================================================

    Mathematical Formula: Y = X / sqrt(ReduceSum(X^2, axis))

    IMPORTANT NOTICE:
    =================
    THERE ARE LOTS OF OTHER VARIANTS OF L2 NORMALIZATION PATTERNS, BUT AT THE MOMENT ONLY THE FOUR LISTED PATTERNS ARE SUPPORTED.

    SUPPORTED PATTERN COMBINATIONS:
    ==============================

    Pattern 1: X -> Pow(2) -> ReduceSum -> Pow(0.5) -> Div(X, result)
    Pattern 2: X -> Pow(2) -> ReduceSum -> Sqrt -> Div(X, result)
    Pattern 3: X -> Mul(X,X) -> ReduceSum -> Pow(0.5) -> Div(X, result)
    Pattern 4: X -> Mul(X,X) -> ReduceSum -> Sqrt -> Div(X, result)

    CONSTRAINTS:
    ===========
    - ReduceSum must operate on a single axis only
    - The Div operation must divide the original input X by the normalized result
    - For Mul-based squaring: both inputs (A and B) must reference the same tensor
    - For Pow-based squaring: exponent must be exactly 2
    - For Pow-based sqrt: exponent must be exactly 0.5

    FUSION OUTPUT:
    =============
    All patterns are fused into: LpNormalization(p=2, axis=<axis>)
    """

    def match(self) -> None:
        """Match L2 normalization pattern."""
        n = self.n.require_node()
        n.require(opType.Mul | opType.Pow)

        # Match squaring operation: Mul(X,X) or Pow(X, 2)
        if n.check(opType.Mul):
            # Check if both inputs are the same (X * X)
            if n("A") != n("B"):
                raise NoMatch(
                    "Mul inputs A and B are not the same tensor (not squaring)"
                )
            self.input_tensor = n("A")
            self.pow_in_val = 2
            out_tensor = n("C")
        elif n.check(opType.Pow):
            # Pow: X -> Pow(2)
            self.pow_in_val = n("Y").require_initializer().get_value()
            if self.pow_in_val != 2:
                raise NoMatch("Pow value is not 2 for squaring")
            self.input_tensor = n("X")
            out_tensor = n("Z")
        else:
            raise NoMatch("Node is neither Mul nor Pow")

        # Match ReduceSum operation
        reducesum_node = (
            self.go_through_downward_qdq_chain(out_tensor)
            .require(opType.ReduceSum)
            .require_node()
        )
        self.axes = reducesum_node("axes").require_initializer().get_value()
        if len(self.axes) != 1:
            raise NoMatch("LpNorm only supports one axis")

        # Match square root operation: Sqrt or Pow(0.5)
        self.sqrt_node = self.go_through_downward_qdq_chain(
            reducesum_node("reduced")
        ).require_node()
        # Try Sqrt first (Sqrt schema: X=input, Y=output), then Pow with elif
        if self.sqrt_node.check(opType.Sqrt):
            sqrt_output = self.sqrt_node("Y")
        elif self.sqrt_node.check(opType.Pow):
            # Try Pow with exponent 0.5 (Pow schema: X=base, Y=exponent, Z=output)
            exp_value = self.sqrt_node("Y").require_initializer().get_value()
            if exp_value != 0.5:
                raise NoMatch("Pow value is not 0.5 for square root")
            sqrt_output = self.sqrt_node("Z")
        else:
            raise NoMatch("Square root node is neither Sqrt nor Pow(0.5)")

        self.div_node = (
            self.go_through_downward_qdq_chain(sqrt_output)
            .require(opType.Div)
            .require_node()
        )

        if (
            self.div_node("A").require_tensor().get_name()
            != self.input_tensor.get_name()
        ):
            raise NoMatch("Div input does not match the original input tensor")

    def modify(self) -> None:
        """Replace matched pattern with LpNormalization node."""
        n = self.n
        inputs: dict[str, Element] = {"input": self.input_tensor}
        outputs: dict[str, Element] = {"output": self.div_node("C")}

        attributes = {
            "axis": self.axes[0],
            "p": self.pow_in_val,
            "num_of_tensor_inputs": 1,
        }

        self.remove_node(n)
        self.add_node(
            type="LpNormalization",
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )


class LpNorm(Matcher):

    dependencies = [LpNormIndividual()]

    def match(self) -> None:
        lp_norm = self.n.require(opType.LpNormalization)

        lp_norm("input").require(opType.DequantizeLinear)
        self.has_q = lp_norm("output").check(opType.QuantizeLinear)

    def modify(self) -> None:
        lp_norm = self.n
        inputs: dict[str, Element] = {
            "input": lp_norm("input.x"),
            "input_scale": lp_norm("input.x_scale"),
            "input_zero_point": lp_norm("input.x_zero_point"),
            "output_scale": None,
            "output_zero_point": None,
        }

        output_type = ""
        outputs: dict[str, Element] = None
        attributes = lp_norm.get_attributes()
        if self.has_q:
            inputs["output_scale"] = lp_norm("output.y_scale")
            inputs["output_zero_point"] = lp_norm("output.y_zero_point")
            outputs = {"output": lp_norm("output.y")}
            output_type = "x" + inputs["output_zero_point"].get_dtype()
        else:
            attributes["disable_q"] = 1
            outputs = {"output": lp_norm("output")}

        input_type = inputs["input_zero_point"].get_dtype()
        new_type = "LpNormalization_qdq_" + input_type + output_type
        attributes["num_of_tensor_inputs"] = 1

        self.remove_node(lp_norm)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )


class Softmax(Matcher, QDQHelper):

    def match(self) -> None:
        n = self.n
        n.require(opType.Softmax)

        self.new_dtype, self.qdq_attributes = self.check_qdq(n, DTypeAny())

    def modify(self) -> None:
        n = self.n
        new_type = "Softmax_qdq_" + self.new_dtype
        inputs, outputs = self.get_in_out_dict_for_qdq_node(n)

        copy_attributes = n.get_attributes()
        copy_attributes["num_of_tensor_inputs"] = 1
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes | self.qdq_attributes,
        )


class ReduceQdq(Matcher, QDQHelper, NoopHelper):

    def match(self) -> None:
        n = self.n
        n.require(
            opType.ReduceSum | opType.ReduceMean | opType.ReduceMax | opType.ReduceMin
        )
        if self.is_noop_reduction(n("data").get_shape(), n("reduced").get_shape()):
            raise NoMatch("No op reduction")
        self.new_dtype, self.qdq_attributes = self.check_qdq(n, DTypeAny(), [1])

    def modify(self) -> None:
        n = self.n
        inputs, outputs = self.get_in_out_dict_for_qdq_node(n)
        copy_attributes = n.get_attributes()
        copy_attributes["num_of_tensor_inputs"] = 1
        copy_attributes["reduce_type"] = n.get_op_type()
        self.remove_node(n)
        self.add_node(
            type="Reduce_qdq_" + self.new_dtype,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes | self.qdq_attributes,
        )


# FIXME: should this be norm or reduction (with softmax)
norm_category = Category([GroupNorm(), LayerNorm(), LpNorm()])
