from OGOAT.src.L1_fusion.py_match.checkers import opTypeIgnoreFrozen
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    MatcherError,
    Node,
)
from OGOAT.src.L1_fusion.L1_utils.utils import onnxTensor_dtype_to_np_dtype


class Rename_QDQ(Matcher):
    """
    QuantizeLinear and DequantizeLinear with an scalar value needs to be run on the CPU.
    Operator run on NPU needs to be renamed to:
        - Dequant_int8xfloat32, Dequant_int16xfloat32 (for DequantizeLinear)
        - Quant_float32xint8, Quant_float32xint16 (for QuantizeLinear)

    This matcher now also handles frozen nodes (nodes with L1_fusion_frozen attribute set),
    which typically include QuantizeLinear/DequantizeLinear nodes connected to global inputs/outputs.
    The operation name includes actual input and output data types to support multiple quantization formats.
    """

    def match(self):
        n = self.n
        # Use custom checker that ignores L1_fusion_frozen for QDQ ops
        # Check for either QuantizeLinear or DequantizeLinear
        n.require(
            opTypeIgnoreFrozen.QuantizeLinear | opTypeIgnoreFrozen.DequantizeLinear
        )

        shape = n("y").require_tensor().get_shape()
        if shape == []:
            raise MatcherError(
                "Only non-scalar outputs of (De)QuantizeLinear can be run on NPU"
            )

    def modify(self):
        n = self.n.require_node()

        if n.check(opTypeIgnoreFrozen.QuantizeLinear):
            # Quantization: get actual input and output data types
            input_dtype_raw = n("x").require_tensor().get_dtype_raw()
            output_dtype_raw = n("y").require_tensor().get_dtype_raw()

            input_dtype_str = onnxTensor_dtype_to_np_dtype(input_dtype_raw)
            output_dtype_str = onnxTensor_dtype_to_np_dtype(output_dtype_raw)

            new_type = f"Quant_{input_dtype_str}x{output_dtype_str}"

        elif n.check(opTypeIgnoreFrozen.DequantizeLinear):
            # Dequantization: get actual input and output data types
            input_dtype_raw = n("x").require_tensor().get_dtype_raw()
            output_dtype_raw = n("y").require_tensor().get_dtype_raw()

            input_dtype_str = onnxTensor_dtype_to_np_dtype(input_dtype_raw)
            output_dtype_str = onnxTensor_dtype_to_np_dtype(output_dtype_raw)

            new_type = f"Dequant_{input_dtype_str}x{output_dtype_str}"

        else:
            # This shouldn't happen since match() ensures we have one of these types
            node_type = n.get_op_type()
            raise RuntimeError(f"Unexpected node type: {node_type}")

        new_name = n.get_name() + "_" + new_type
        inputs = n.get_inputs_dict()
        outputs = n.get_outputs_dict()
        attributes = n.get_attributes()
        attributes["num_of_tensor_inputs"] = 1
        self.remove_node(n)

        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
            new_name=new_name,
        )
