# fmt: on
from OGOAT.src.L1_fusion.py_match.basic.conv import conv_category
from OGOAT.src.L1_fusion.py_match.checkers import (
    AttrValue,
    CategoryCheck,
    DTypes,
)
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BiasHelper
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import TransposeHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, NoMatch
from OGOAT.src.L1_fusion.py_match.basic.matmul import MatMul


class ConvtoMatmul(Matcher, BiasHelper, TransposeHelper):
    """
    unifies the following classic patterns:
    ConvtoMatMul, ConvtoMatMulBias

    ConvtoMatMul_DS, ConvtoMatMulBias_DS will be fused as ConvtoMatMul or ConvtoMatMulBias as Unsqueeze at input and Squeeze at ouput
    will be optimized away along with squeeze/unsqueeze cleanup task
    """

    dependencies = [conv_category]

    def match(self) -> None:
        n = self.n
        n.require(CategoryCheck(conv_category))
        n.require(AttrValue("kernel_shape", [1, 1]) | AttrValue("kernel_shape", None))
        n.require(AttrValue("strides", [1, 1]) | AttrValue("strides", None))
        group = n.get_attributes().get("group")
        if group is not None and group > 1:
            raise NoMatch(
                "ConvtoMatmul does not support group > 1, found group: {}".format(group)
            )

        self.A = n("A")
        self.B = n("B")
        self.Bias = n("Bias")
        self.A_scale = n("A_scale")
        self.A_zero_point = n("A_zero_point")
        self.B_scale = n("B_scale")
        self.B_zero_point = n("B_zero_point")
        self.Bias_scale = n("Bias_scale")
        self.Bias_zero_point = n("Bias_zero_point")
        self.Y_scale = n("Y_scale")
        self.Y_zero_point = n("Y_zero_point")
        self.Y = n("Y")

        self.has_bias = self.Bias.check_tensor()

        # fused node is connected with QuantizeLinear node at input and dequantize node at output
        self.A_zero_point.require(DTypes("int8", "int16", "uint8", "uint16"))
        self.Y_zero_point.require(DTypes("int8", "int16", "uint8", "uint16"))

        self.new_attributes = MatMul.get_matmul_attributes(has_bias=self.has_bias)
        self.type_keyword = ""

        if self.has_bias:
            self.type_keyword = "bias_"

    def modify(self) -> None:
        n = self.n

        new_type = (
            "MatMul_qdq_"
            + self.type_keyword
            + self.A_zero_point.get_dtype()
            + "x"
            + self.B_zero_point.get_dtype()
            + "x"
            + self.Y_zero_point.get_dtype()
        )

        inputs = {
            "A": self.A,
            "B": self.B,
            "Bias": self.Bias,
            "A_scale": self.A_scale,
            "A_zero_point": self.A_zero_point,
            "B_scale": self.B_scale,
            "B_zero_point": self.B_zero_point,
            "Bias_scale": self.Bias_scale,
            "Bias_zero_point": self.Bias_zero_point,
            "Y_scale": self.Y_scale,
            "Y_zero_point": self.Y_zero_point,
        }
        outputs = {
            "Y": self.Y,
        }
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=self.new_attributes,
        )


conv_to_matmul = ConvtoMatmul()
