from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    Matcher,
    MatcherError,
    NoMatch,
    Tensor,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.helpers.noop_helper import NoopHelper
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import TransposeHelper
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict, QdqInfo


class Remove_QDQ(Matcher, QDQHelper):
    """
    Remove two DQ + Q nodes with same parameters with nothing in between.
    """

    def match(self) -> None:
        # This pattern does not create a fused node, so it cannot absorb
        # noops (reshape/squeeze/unsqueeze). It would break shape consistency.
        n = self.n = self.n.with_walk_cfg(WalkCfgPlain())
        n.require(opType.DequantizeLinear)
        n("y").require(opType.QuantizeLinear)
        self.require_qdq_equal_scale_zeropoint(dq=n, q=n("y"))

    def modify(self) -> None:
        n = self.n
        self.connect(n("y.y"), n("x"))
        self.remove_node(n)


class Post_Remove_QDQ(Remove_QDQ):
    def match(self):
        return super().match()

    def modify(self):
        return super().modify()


class Remove_Q_Plus_DQ(Matcher, QDQHelper):
    """
    Remove two Q -> DQ nodes with same parameters with nothing in between.
    """

    def match(self) -> None:
        # This pattern does not create a fused node, so it cannot absorb
        # noops (reshape/squeeze/unsqueeze). It would break shape consistency.
        n = self.n = self.n.with_walk_cfg(WalkCfgPlain())
        n.require(opType.QuantizeLinear)

        self.matching_dequantize_nodes = list()
        for output in n("y").get_readers():
            if not output.check(
                opType.DequantizeLinear
            ) or not self.check_qdq_equal_scale_zeropoint(dq=output, q=n):
                continue

            self.matching_dequantize_nodes.append(output)

        if len(self.matching_dequantize_nodes) == 0:
            raise NoMatch("No matching quantize -> dequantize chain found")

    def modify(self) -> None:
        n = self.n

        # Connect the output of all the matching dequantize nodes
        # to the producer of the quantize node (self.n). If all
        # the consumers of the quantize node were matches than
        # it will also be remove during the cleanup phase after
        # the modify method.
        for output in self.matching_dequantize_nodes:
            self.connect(output("y"), n("x"))

        if len(self.n("y").require_tensor().get_readers()) == len(
            self.matching_dequantize_nodes
        ):
            self.remove_node(n)


class RemoveQDQaroundOp(Matcher, TransposeHelper, QDQHelper, NoopHelper):
    """
    Remove DQ and Q nodes around a set of selected operators (mostly dataflow,
    but also some others) if those have equal quantization parameters.
    Only remove it around transpose operators if those are no-ops.
    This pattern is used very early in the fusions. Real transposes need to keep
    their QDQ around those, because many pattern depend on those.

    Covers all dataflow operations except:
      - `DepthToSpace`: The DQ and Q nodes are not removed because there are two
          related `Transpose` nodes, and they (together with QDQ) will be merged
          in the `Dataflow` pattern

      - `GatherElements`: The DQ and Q nodes are not removed because they will
         be necessary for the `RemoveBinary` and `Remove_QDQ` patterns to work
         properly.
    """

    def match(self) -> None:
        self.match_op()
        if self.n.check(opType.Concat):
            _, has_same = self.check_input_output_qdq(self.n)
            if not has_same:
                raise NoMatch("Input and output QDQ do not have same parameters")
        else:
            self.input.require(opType.DequantizeLinear)
            for output in self.n.get_outputs():
                output.require(opType.QuantizeLinear)
                self.require_qdq_equal_scale_zeropoint(dq=self.input, q=output)
            
    def match_op(self) -> None:
        """
        Match the main operator and set self.input and self.output to the input
        and output of the operator node.
        Raise MatcherError if no match.

        In this class, match on the selected set of operators. However, for
        transpose, just match if it is no-op.

        This method is overridden by derived classes.
        """
        n = self.n = self.n.with_walk_cfg(WalkCfgPlain())
        if n.check(opType.Transpose):
            self.input = n("data")
            self.output = n("transposed")
        elif n.check(opType.Unsqueeze):
            n("axes").require_initializer()
            self.input = n("data")
            self.output = n("expanded")
        elif n.check(opType.Squeeze):
            n("axes").require_initializer()
            self.input = n("data")
            self.output = n("squeezed")
        elif n.check(opType.Reshape):
            n("shape").require_initializer()
            self.input = n("data")
            self.output = n("reshaped")
        elif n.check(opType.Gather):
            n("indices").require_tensor()
            self.input = n("data")
            self.output = n("output")
        elif n.check(
            opType.ReduceSum
            | opType.ReduceMax
            | opType.ReduceMean
            | opType.ReduceMin
            | opType.ReduceProd
        ) and self.is_noop_reduction(n("data").get_shape(), n("reduced").get_shape()):
            self.input = n("data")
            self.output = n("reduced")
        elif n.check(opType.Slice):
            n("starts").require_initializer()
            n("ends").require_initializer()
            self.input = n("data")
            self.output = n("output")
        elif n.check(opType.Resize):
            n("roi").require_initializer()
            n("scales").require_initializer()
            n("sizes").require_initializer()
            self.input = n("X")
            self.output = n("Y")
        elif n.check(opType.Flatten):
            self.input = n("input")
            self.output = n("output")
        elif n.check(opType.Pad):
            self.input = n("data")
            self.output = n("output")
        elif n.check(opType.Split):
            self.input = n("input")
        elif n.check(opType.Concat):
            self.input = n("input0")
            self.output = n("concat_result")
        else:
            raise NoMatch("Unsupported op type for removing QDQ around node")

    def modify(self) -> None:
        n = self.n
        type = n.get_op_type()
        attributes = n.get_attributes()
        name = n.get_name()
        domain = n.get_domain()

        inputs: dict[str, Element] = {}
        if n.check(opType.Concat):
            for i, input in enumerate(n.get_inputs()):
                inputs[f"input{i}"] = input("x")
        else:
            inputs = n.get_inputs_dict()
            inputs[list(inputs.keys())[0]] = self.input("x")
        
        
        outputs: dict[str, Element] = {}
        if n.check(opType.Split):
            for i, output in enumerate(n.get_outputs()):
                outputs[f"outputs_{i}"] = output("y")
        else:
            outputs = n.get_outputs_dict()
            outputs[list(outputs.keys())[0]] = self.output("y")

        out_scale_value = (
            n.get_outputs()[0]("y_scale").require_initializer().get_value()
        )
        out_scale_dtype_raw = (
            n.get_outputs()[0]("y_scale").require_initializer().get_dtype_raw()
        )
        out_zp_value = (
            n.get_outputs()[0]("y_zero_point").require_initializer().get_value()
        )
        out_zp_dtype = (
            n.get_outputs()[0]("y_zero_point").require_initializer().get_dtype()
        )
        out_zp_dtype_raw = (
            n.get_outputs()[0]("y_zero_point").require_initializer().get_dtype_raw()
        )
        attributes |= {
            "orig_x_zero_point_dtype": self.input("x_zero_point")
            .require_initializer()
            .get_dtype(),
            "orig_y_zero_point_dtype": out_zp_dtype,
        }
        output_type = self.output.get_dtype()

        if (
            self.n._model_dict.get_quantization_information(self.input.get_name())
            is None
        ):
            self.n._model_dict.add_quantization_information(
                self.n.get_name(),
                QdqInfo(
                    out_scale_value,
                    out_scale_dtype_raw,
                    out_zp_value,
                    out_zp_dtype_raw,
                    output_type,
                ),
            )
        self.remove_node(n)
        newnode = self.add_node(
            type=type,
            domain=domain,
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
            new_name=name,
        )
