# fmt: on
from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.clean.remove import RemoveHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import NoMatch
from OGOAT.src.L1_fusion.L1_utils.utils import (
    onnxTensor_dtype_to_np_dtype,
)


class RemoveCast(RemoveHelper):
    """
    Remove the Cast op along with qdq if no real cast is needed:
      1. Type of input and cast attribute are the same
      2. Remove QDQ along if Q and DQ quantization parameter are equal
    """

    def set_input_and_output(self) -> None:
        self.input = self.n("input")
        self.output = self.n("output")

    def match_op_specifics(self) -> None:
        n = self.n.require(opType.Cast)
        if n("input").get_dtype() != onnxTensor_dtype_to_np_dtype(
            n.get_attribute_value("to")
        ):
            raise NoMatch("Real cast can not be removed")
