from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, NoMatch
from OGOAT.src.L1_fusion.py_match.checkers import opType


class ConversionDqQ(Matcher):
    """
    Find the DQ -> Q chains left after fusion and fuse them into
    a 'Conversion_<input_type>x<output_type>' node in order to reduce
    the number of quantize and dequantize left in the graph that we
    cannot remove.
    """

    # All the conversion that are allowed to be fused.
    # The key is the "From" type and the value contains the
    # list of "To" types which we can convert into
    conversion_allowed: dict[str, list[str]] = {
        "uint16": ["uint8"],
        "uint8": ["uint16"],
    }

    def match(self):
        self.dq = self.n.require(opType.DequantizeLinear).require_node()
        self.q = self.dq("y").require(opType.QuantizeLinear).require_node()

        self.dq_type = self.dq("x").require_tensor().get_dtype()
        self.q_type = self.q("y").require_tensor().get_dtype()

        if self.dq_type not in ConversionDqQ.conversion_allowed.keys():
            raise NoMatch(f"Conversion from type '{self.dq_type}' not allowed.")

        if self.q_type not in ConversionDqQ.conversion_allowed[self.dq_type]:
            raise NoMatch(
                f"Conversion from type '{self.dq_type}' to type '{self.q_type}' not allowed."
            )

    def modify(self):
        inputs = {
            "x": self.dq("x"),
            "x_scale": self.dq("x_scale"),
            "x_zero_point": self.dq("x_zero_point"),
            "y_scale": self.q("y_scale"),
            "y_zero_point": self.q("y_zero_point"),
        }
        outputs = {"y": self.q("y")}
        attributes = self.dq.get_attributes()
        attributes["num_of_tensor_inputs"] = 1

        new_op_type = "Conversion_" + self.dq_type + "x" + self.q_type
        self.remove_node(self.dq)
        self.add_node(
            type=new_op_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )
