# fmt: on
import numpy as np
from typing import Any
from OGOAT.src.L1_fusion.py_match.adv.matmul_transpose import MatMulTranspose
from OGOAT.src.L1_fusion.py_match.adv.silu import Silu
from OGOAT.src.L1_fusion.py_match.adv.swish import Swish
from OGOAT.src.L1_fusion.py_match.basic.categories import linear_category
from OGOAT.src.L1_fusion.py_match.basic.non_linear import (
    Clip,
    Relu,
    get_non_linear_output_tensor,
)
from OGOAT.src.L1_fusion.py_match.checkers import (
    CategoryCheck,
    FusedWithQDQNode,
    opType,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Category,
    Element,
    Matcher,
    NoMatch,
    WalkCfgPlain,
)

from OGOAT.src.L1_fusion.py_match.basic.binary_op import BinaryOp
from OGOAT.src.L1_fusion.py_match.basic.lut import Lut
from OGOAT.src.L1_fusion.py_match.basic.matmul import MatMul

from OGOAT.src.L1_fusion.py_match.basic.reduction import norm_category

from OGOAT.src.L1_fusion.py_match.clean.move_reshape import MoveReshape

from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BinaryOpHelper
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper


class LinearPlusEWBinary(Matcher, BinaryOpHelper):
    dependencies = [MatMul(), MatMulTranspose(), BinaryOp()]

    def match(self) -> None:
        n = self.n
        self.matmul = n.require(
            CategoryCheck(MatMul()) | CategoryCheck(MatMulTranspose())
        ).require_node()
        self.binary = n("Y").require(CategoryCheck(BinaryOp())).require_node()
        self.binary.require(FusedWithQDQNode())

    def modify(self):
        n = self.n

        has_bias = n.has_attribute("bias")
        has_transpose = (
            n.has_attribute("InTransposeA")
            or n.has_attribute("OutTranspose")
            or n.has_attribute("InTransposeB")
        )
        is_activation = n.has_attribute("actxact")
        binary_op_type = n("Y").get_non_tensor().get_op_type().split("_")[0]
        linear_op_type = n.get_op_type().split("_")[0]

        # sets binary_input_name & binary_linear_input_name
        # binary_input_name is the "free" input
        # binary_linear_input_name is the binary input connecting linear node
        binary_linear_input_name = self.get_input_name(
            self.binary, self.get_input(self.binary, self.matmul), False
        )
        binary_input_name = self.get_other_input_name(
            self.binary, self.get_input(self.binary, self.matmul), False
        )

        in1_type = n("A_zero_point").get_dtype()
        in2_type = (
            n("Bias_zero_point").get_dtype()
            if has_bias
            else n("B_zero_point").get_dtype()
        )
        out_type = n("Y.C_zero_point").get_dtype()

        type_prefix = linear_op_type + "_qdq_"
        if has_bias:
            type_prefix += "bias_"
        elif is_activation:
            type_prefix += "actxact_"

        if has_transpose:
            type_prefix += "transpose_"

        new_type = (
            type_prefix + f"{binary_op_type}_qdq_{in1_type}x{in2_type}x{out_type}"
        )

        inputs = {
            f"{linear_op_type}_A_x": n("A"),
            f"{linear_op_type}_B_x": n("B"),
            f"{linear_op_type}_Bias": None,
            f"{binary_op_type}_A_x": self.binary(binary_input_name),
            f"{linear_op_type}_A_scale": n("A_scale"),
            f"{linear_op_type}_A_zero_point": n("A_zero_point"),
            f"{linear_op_type}_B_scale": n("B_scale"),
            f"{linear_op_type}_B_zero_point": n("B_zero_point"),
            f"{linear_op_type}_Bias_scale": None,
            f"{linear_op_type}_Bias_zero_point": None,
            f"{binary_op_type}_A_x_scale": self.binary(f"{binary_input_name}_scale"),
            f"{binary_op_type}_A_zero_point": self.binary(
                f"{binary_input_name}_zero_point"
            ),
            f"{linear_op_type}_Y_scale": n("Y_scale"),
            f"{linear_op_type}_Y_zero_point": n("Y_zero_point"),
            f"{binary_op_type}_B_x_scale": self.binary(
                f"{binary_linear_input_name}_scale"
            ),
            f"{binary_op_type}_B_zero_point": self.binary(
                f"{binary_linear_input_name}_zero_point"
            ),
            f"{binary_op_type}_C_y_scale": n("Y.C_scale"),
            f"{binary_op_type}_C_y_zero_point": n("Y.C_zero_point"),
        }

        if has_bias:
            inputs[f"{linear_op_type}_Bias"] = n("Bias")
            inputs[f"{linear_op_type}_Bias_scale"] = n("Bias_scale")
            inputs[f"{linear_op_type}_Bias_zero_point"] = n("Bias_zero_point")

        outputs = {
            "C": n("Y.C"),
        }

        linear_op_attributes = n.get_attributes()
        ew_binary_op_attributes = n("Y").get_non_tensor().get_attributes()
        ew_binary_op_attributes.pop("num_of_tensor_inputs", None)
        self.remove_node(n)

        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=linear_op_attributes | ew_binary_op_attributes,
        )


class LinearPlusNonLinear(Matcher, QDQHelper):
    """
    ConvWBiasClip
    ConvWBiasRelu
    MatMulBiasRelu
    """

    nonlinear_category = Category([Relu(), Clip()])
    dependencies = [linear_category, nonlinear_category]

    def match(self) -> None:
        n = self.n
        n.require(CategoryCheck(linear_category))
        self.output = n("Y")
        while self.output.check(opType.DequantizeLinear) and self.output("y").check(
            opType.QuantizeLinear
        ):
            self.require_qdq_equal_scale_zeropoint(self.output, self.output("y"))
            self.output = self.output("y.y")

        self.output.require(CategoryCheck(self.nonlinear_category))
        self.output.require(FusedWithQDQNode())
        if self.output.get_non_tensor().get_op_type().startswith("Relu"):
            self.type_name = "relu"
            non_linear_input_scale = self.output("X_scale")
            non_linear_input_zero_point = self.output("X_zero_point")
        elif self.output.get_non_tensor().get_op_type().startswith("Clip"):
            self.type_name = "clip"
            non_linear_input_scale = self.output("input_scale")
            non_linear_input_zero_point = self.output("input_zero_point")
        else:
            raise NoMatch("unsupported non linear operation")
        self.require_tensor_equal_value(n("Y_scale"), non_linear_input_scale)
        self.require_tensor_equal_value(n("Y_zero_point"), non_linear_input_zero_point)

    def modify(self) -> None:
        n = self.n

        dtypes = n("A_zero_point").get_dtype() + "x" + n("B_zero_point").get_dtype()

        orig_type = "Conv"
        if n.get_op_type().startswith("MatMul_qdq"):
            orig_type = "MatMul"

        bias = "bias" if n.get_attribute_value("bias") == 1 else ""

        new_type = (
            orig_type
            + "_qdq_"
            + bias
            + self.type_name
            + "_"
            + dtypes
            + "x"
            + self.output("Y_zero_point").get_dtype()
        )

        inputs = {
            "A": n("A"),
            "B": n("B"),
            "Bias": n("Bias"),
            "A_scale": n("A_scale"),
            "A_zero_point": n("A_zero_point"),
            "B_scale": n("B_scale"),
            "B_zero_point": n("B_zero_point"),
            "Bias_scale": n("Bias_scale"),
            "Bias_zero_point": n("Bias_zero_point"),
            "Y_scale": None,
            "Y_zero_point": None,
        }

        outputs = {"Y": get_non_linear_output_tensor(self.output)}

        if self.type_name == "clip":
            inputs["Y_scale"] = n("output_scale")
            inputs["Y_zero_point"] = n("output_zero_point")
        elif self.type_name == "relu":
            inputs["Y_scale"] = self.output("Y_scale")
            inputs["Y_zero_point"] = self.output("Y_zero_point")

        copy_attributes = n.get_attributes()
        self.remove_node(n)
        attributes = {
            "trans_to_nhwc": 0,
        }
        if bias:
            del copy_attributes["trans_to_nchw"]
            del copy_attributes["nchw_act"]
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes | copy_attributes,
        )


class LinearPlusNorm(Matcher):

    dependencies = [linear_category, norm_category]

    def match(self) -> None:
        n = self.n
        n.require(CategoryCheck(linear_category))
        n("Y").require(CategoryCheck(norm_category))

        self.output_schemas = n("Y").get_non_tensor().get_schema_output_names()
        if len(self.output_schemas) != 1:
            raise NoMatch("only norms with one output are supported")

    def modify(self):
        n = self.n
        new_type = "NEW_TYPE_LinearPlusNorm"

        inputs: dict[str, Element] = {}
        for input_name in n.get_schema_input_names():
            inputs["lin_" + input_name] = n(input_name)
        for input_name in n("Y").get_non_tensor().get_schema_input_names():
            if n("Y." + input_name).get_non_tensor() == n:
                continue
            inputs["norm_" + input_name] = n("Y." + input_name)

        outputs = {
            # TODO handle different number of output from the original nodes
            "Y": n("Y." + self.output_schemas[0])
        }

        attributes: dict[str, Any] = {
            "num_of_tensor_inputs": n.get_attribute_value("num_of_tensor_inputs")
        }

        lin_attrs = n.get_attributes()
        norm_attrs = n("Y").get_non_tensor().get_attributes()

        for attrs in (lin_attrs, norm_attrs):
            attrs.pop("num_of_tensor_inputs", None)

        for k, v in lin_attrs.items():
            attributes["lin_" + k] = v
        for k, v in norm_attrs.items():
            attributes["norm_" + k] = v
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )


class LinearPlusLut(Matcher, QDQHelper):
    """
    This class should match this existings Yaml patterns:
    - ConvWeightsLeakyRelu
    - ConvWeightsBiasLeakyRelu
    - GemmVLeakyRelu
    - GemmWeightsBiasLeakyRelu
    - GemmWBiasGeluPSMU
    - GemmWBiasGelu
    - MatMulBiasGelu
    - ConvSilutoMatMulSilu

    It can potentially allow to match more as long as there is a kernel implementation.
    """

    dependencies = [linear_category, Lut(), Silu(), Swish(), MoveReshape()]

    def match(self) -> None:
        n = self.n.require(CategoryCheck(linear_category)).require_node()

        self.matcher_name = n.get_matcher_name()
        if self.matcher_name == "Gemm":
            assert n.get_op_type().startswith(
                "MatMul"
            ), "Gemm pattern should always give a MatMul"

        # Check for a potential bias in the linear op
        self.has_bias = self.n("Bias").check_tensor()

        # Extract the lut node and its type
        self.lut = (
            n("Y")
            .require(
                CategoryCheck(Lut()) | CategoryCheck(Silu()) | CategoryCheck(Swish())
            )
            .require_node()
        )
        self.lut_type = self.lut.get_attribute_value("pwla_type")

        self.lut_input_name = self.lut.get_schema_input_names()[0]
        self.lut_output_name = self.lut.get_schema_output_names()[0]

        self.require_tensor_equal_value(
            n("Y_scale"), self.lut(f"{self.lut_input_name}_scale")
        )
        self.require_tensor_equal_value(
            n("Y_zero_point"), self.lut(f"{self.lut_input_name}_zero_point")
        )
        if not n("Y").check(CategoryCheck(Silu())) or not n("Y").check(
            CategoryCheck(Swish())
        ):
            n("Y").require(FusedWithQDQNode())

        if self.matcher_name == "Conv" and not self.lut_type == "LeakyRelu":
            # see [JIRA]/AIESW-5484
            raise NoMatch("not supported pattern (tiler; kernel meta data)")

    def modify(self) -> None:
        n = self.n
        if self.matcher_name in ("Gemm", "ConvtoMatmul", "LinearSlice"):
            new_type_prefix = "MatMul_qdq_"
        else:
            new_type_prefix = self.matcher_name + "_qdq_"
        if self.has_bias:
            new_type_prefix += "bias"

        new_type_prefix += "pwla_"

        new_type = (
            new_type_prefix
            + n("A_zero_point").get_dtype()
            + "x"
            + n("B_zero_point").get_dtype()
            + "x"
            + self.lut(f"{self.lut_output_name}_zero_point").get_dtype()
        )
        new_name = n.require_node().get_attribute_value("orig_name") + "_" + new_type

        # Copy the attributes over the fused node if it is not a Gemm converted to matmul.
        attributes = n.get_attributes()
        if attributes.get("num_batches") is None:
            attributes["num_batches"] = 1

        # FIXME: conv basic pattern add a trans_to_nchw attr but that's
        # not needed by this advance pattern. And does not even seem
        # to be used anymore. Remove this when the attribute is not added
        # for Conv pattern anymore.
        attributes.pop("trans_to_nchw", None)
        attributes.pop("nchw_act", None)
        attributes["trans_to_nhwc"] = 0

        inputs = {
            "A": n("A"),
            "B": n("B"),
            "Bias": None,
            "A_scale": n("A_scale"),
            "A_zero_point": n("A_zero_point"),
            "B_scale": n("B_scale"),
            "B_zero_point": n("B_zero_point"),
            "Bias_scale": None,
            "Bias_zero_point": None,
            "Y_scale": n("Y_scale"),
            "Y_zero_point": n("Y_zero_point"),
            "Lut_X_scale": self.lut(f"{self.lut_input_name}_scale"),
            "Lut_X_zero_point": self.lut(f"{self.lut_input_name}_zero_point"),
            "Lut_Y_scale": self.lut(f"{self.lut_output_name}_scale"),
            "Lut_Y_zero_point": self.lut(f"{self.lut_output_name}_zero_point"),
        }
        outputs = {"Y": self.lut(f"{self.lut_output_name}")}

        # Add the Bias if the linear op contains one
        if self.has_bias:
            inputs["Bias"] = n("Bias")
            inputs["Bias_scale"] = n("Bias_scale")
            inputs["Bias_zero_point"] = n("Bias_zero_point")

        if self.lut_type == "LeakyRelu":
            # Add the LeakyReLu_alpha attribute
            attributes["LeakyReLU_alpha"] = (
                n("Y").get_non_tensor().get_attribute_value("alpha")
            )

        if self.lut_type in {"QuickGelu", "Swish"}:
            attrs = n("Y").get_non_tensor().get_attributes()
            attrs.pop("num_of_tensor_inputs", None)
            attributes |= attrs

        attributes["orig_type"] = n.get_op_type()
        attributes["pwla_type"] = (
            "Gelu" if self.lut_type == "QuickGelu" else self.lut_type
        )

        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
            new_name=new_name,
        )
