# fmt: on
from typing import Any

from OGOAT.src.L1_fusion.py_match.basic.gemm import Gemm
from OGOAT.src.L1_fusion.py_match.checkers import CategoryCheck, DTypes, opType
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher


# FIXME move to right basic pattern file
class Unsqueeze(Matcher):
    def match(self) -> None:
        n = self.n
        n.require(opType.Unsqueeze)

        n("data").require(opType.DequantizeLinear)
        n("expanded").require(opType.QuantizeLinear)

        # FIXME current: only match if part of a Gemm pattern
        n("data.x").require(CategoryCheck(Gemm()))

    def modify(self) -> None:
        n = self.n
        new_type = (
            "Unsqueeze_qdq_"
            + n("data.x_zero_point").get_dtype()
            + "x"
            + n("expanded.y_zero_point").get_dtype()
        )
        inputs = {
            "data": n("data.x"),
            "axes": n("axes"),
            "data_scale": n("data.x_scale"),
            "data_zero_point": n("data.x_zero_point"),
            "expanded_scale": n("expanded.y_scale"),
            "expanded_zero_point": n("expanded.y_zero_point"),
        }
        outputs = {"expanded": n("expanded.y")}
        attributes = n.get_attributes()
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )


class GemmPlusUnsqueeze(Matcher):
    """
    NOTE: Special case for Unsqueeze; can be removed is unsqueeze is handled in a generic way (e.g. pass through)
    GGemmWBiasUnsqueeze
    """

    dependencies = [Gemm(), Unsqueeze()]

    def match(self) -> None:
        n = self.n
        n.require(CategoryCheck(Gemm()))
        # FIXME also verion without bias ?!
        n("Bias").require_tensor()
        self.unsqueeze = n("Y").require(CategoryCheck(Unsqueeze())).require_node()
        self.unsqueeze("expanded_zero_point").require(DTypes("uint8", "uint16"))

    def modify(self):
        n = self.n
        new_type = (
            "MatMul_qdq_Unsqueeze_WBias_"
            + n("A_zero_point").get_dtype()
            + "x"
            + n("B_zero_point").get_dtype()
            + "x"
            + self.unsqueeze("expanded_zero_point").get_dtype()
        )
        # FIXME cleanup
        o_name = n.get_attribute_value("orig_name")
        new_name = o_name + "_" + new_type

        attributes: dict[str, Any] = {}

        # FIXME change names
        inputs = {
            # TODO change names
            "0": n("A"),
            "1": n("B"),
            "2": n("Bias"),
            "3": n("A_scale"),
            "4": n("A_zero_point"),
            "5": n("B_scale"),
            "6": n("B_zero_point"),
            "7": n("Bias_scale"),
            "8": n("Bias_zero_point"),
            "9": self.unsqueeze("expanded_scale"),
            "10": self.unsqueeze("expanded_zero_point"),
        }

        outputs = {
            "Y": self.unsqueeze("expanded"),
        }

        attributes["Unsqueeze"] = 1
        attributes["num_of_tensor_inputs"] = 3
        axes: list[int] = None
        if self.unsqueeze("axes").check_initializer():
            axes = self.unsqueeze("axes").get_non_tensor().get_value()
        else:
            axes = self.unsqueeze("axes").get_non_tensor().get_attribute_value("value")
        for i, dim in enumerate(axes):
            attributes[f"axes_{i+1}"] = dim

        self.remove_node(n)
        self.remove_node(self.unsqueeze)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
            new_name=new_name,
        )
