from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import TransposeHelper
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher


class Pool(Matcher, TransposeHelper, QDQHelper):

    def match(self) -> None:
        n = self.n
        n.require(opType.AveragePool | opType.GlobalAveragePool | opType.MaxPool)

        input_node = n("X")
        output_node = n("Y")

        input_node.require(opType.DequantizeLinear)
        output_node.require(opType.QuantizeLinear)

        self.input = input_node("x")
        self.output = output_node("y")

        self.input_zero_point = input_node("x_zero_point")
        self.input_scale = input_node("x_scale")
        self.output_zero_point = output_node("y_zero_point")
        self.output_scale = output_node("y_scale")

        if n.check(opType.GlobalAveragePool | opType.MaxPool):
            input_transpose, output_transpose = self.require_nchw_conversion(
                input_node, output_node
            )

            self.input = input_transpose.input()
            self.output = output_transpose.output()

    def modify(self) -> None:
        n = self.n

        inputs = {
            "X": self.input,
            "X_scale": self.input_scale,
            "X_zero_point": self.input_zero_point,
            "Y_scale": self.output_scale,
            "Y_zero_point": self.output_zero_point,
        }
        outputs = {
            "Y": self.output,
        }

        new_type = (
            n.get_op_type()
            + "_qdq_"
            + self.input_zero_point.get_dtype()
            + "x"
            + self.output_zero_point.get_dtype()
        )

        copy_attributes = n.get_attributes()
        copy_attributes["num_of_tensor_inputs"] = 1
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes,
        )
