# fmt: on
from typing import Any

from OGOAT.src.L1_fusion.py_match.basic.preprocessing import (
    GemmMutliplier,
    TransposeGemmInitializer,
    TransposeGemmQdqInitializer,
)
from OGOAT.src.L1_fusion.py_match.checkers import DTypes, opType
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, NoMatch
from OGOAT.src.L1_fusion.py_match.basic.matmul import MatMul


class Gemm(Matcher):
    """
    Fuses a Gemm node with surrounding QdQ and optional bias
    Combines GemmV, GemmWBias, GemmWoBias,
    """

    dependencies = [
        GemmMutliplier(),
        TransposeGemmInitializer(),
        TransposeGemmQdqInitializer(),
    ]

    def match(self) -> None:
        n = self.n
        n.require(opType.Gemm)
        n("A").require(opType.DequantizeLinear)
        n("B").require(opType.DequantizeLinear)
        n("Y").require(opType.QuantizeLinear)

        alpha = n.get_attribute_value("alpha")
        beta = n.get_attribute_value("beta")
        if alpha != 1.0 or beta != 1.0:
            raise NoMatch("factors (alpha, beta) should be pushed to DequantizeLinear")

        self.has_bias = n("C").check_node()
        if self.has_bias:
            n("C").require(opType.DequantizeLinear)

        n("A.x_zero_point").require(
            DTypes("int4", "int8", "int16", "uint4", "uint8", "uint16")
        )
        n("Y.y_zero_point").require(
            DTypes("int4", "int8", "int16", "uint4", "uint8", "uint16")
        )

    def modify(self) -> None:
        n = self.n
        new_type = "MatMul_qdq_"
        if self.has_bias:
            new_type += "bias_"
        new_type += (
            n("A.x_zero_point").get_dtype()
            + "x"
            + n("B.x_zero_point").get_dtype()
            + "x"
            + n("Y.y_zero_point").get_dtype()
        )
        new_name = n.get_name() + "_" + new_type

        attributes = MatMul.get_matmul_attributes(has_bias=self.has_bias)
        inputs = {
            "A": n("A.x"),
            "B": n("B.x"),
            "Bias": None,
            "A_scale": n("A.x_scale"),
            "A_zero_point": n("A.x_zero_point"),
            "B_scale": n("B.x_scale"),
            "B_zero_point": n("B.x_zero_point"),
            "Bias_scale": None,
            "Bias_zero_point": None,
            "Y_scale": n("Y.y_scale"),
            "Y_zero_point": n("Y.y_zero_point"),
        }

        if self.has_bias:
            inputs["Bias"] = n("C.x")
            inputs["Bias_scale"] = n("C.x_scale")
            inputs["Bias_zero_point"] = n("C.x_zero_point")

        outputs = {
            "Y": n("Y.y"),
        }

        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
            new_name=new_name,
        )
