# fmt: on

from functools import reduce
from operator import mul
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    OutputTensor,
    InputTensor,
)
from OGOAT.src.L1_fusion.py_match.checkers import opType, DTypeAny
import numpy as np


class SqueezeUnsqueezeToRTR(Matcher, QDQHelper):
    """
    Transform squeeze/unsqueeze which is not a no op and axes is the innermost dimension to:
    reshape_noop(ND -> 2D) -> transpose_qdq -> reshape_noop
    """

    def match(self) -> None:
        n = self.n.require_node()
        n.require(opType.Unsqueeze | opType.Squeeze)

        self.new_dtype, self.qdq_attributes = self.check_qdq(n, DTypeAny(), [1])
        self.input = (
            n("data.x") if n("data").check(opType.DequantizeLinear) else n("data")
        )
        self.output = n.get_outputs()[0]
        self.output = (
            self.output("y")
            if self.output.check(opType.QuantizeLinear)
            else self.output
        )

    def modify(self) -> None:

        n = self.n.require_node()

        # Get input and output shapes
        input_tensor = self.input.require_tensor()
        output_tensor = self.output.require_tensor()
        input_shape = input_tensor.get_shape()
        output_shape = output_tensor.get_shape()

        # Calculate intermediate shapes and permutation
        intermediate_shapes, perm = self._calculate_transformation_params(input_shape)

        # Create intermediate tensor names
        base_name = n.get_name()
        reshape_one_output_tensor_name = f"{base_name}_reshape_1_output"
        transpose_output_tensor_name = f"{base_name}_transpose_output"

        n._model_dict.set_shape(
            reshape_one_output_tensor_name,
            intermediate_shapes[0],
            input_tensor.get_dtype(),
        )
        n._model_dict.set_shape(
            transpose_output_tensor_name,
            intermediate_shapes[1],
            input_tensor.get_dtype(),
        )

        # Create first reshape_noop node (identity reshape)
        reshape1_inputs = {
            "data": input_tensor,
            "shape": self.add_initializer(
                f"{base_name}_reshape1_shape",
                np.array(intermediate_shapes[0], dtype=np.int64),
            ),
        }
        reshape1_outputs = {
            "output": OutputTensor(
                n._model_dict, self.n._walk_cfg, reshape_one_output_tensor_name, None
            )
        }

        reshape1_node = self.add_node(
            type="reshape_noop",
            domain="ai.onnx.contrib",
            inputs=reshape1_inputs,
            outputs=reshape1_outputs,
            attributes={},
            new_name=f"{base_name}_reshape1",
        )

        # Create transpose_qdq node
        transpose_inputs = {
            "data": reshape1_node("output"),
            "data_scale": (
                self.input("x_scale")
                if self.input("x_scale").check_initializer()
                else None
            ),
            "data_zero_point": (
                self.input("x_zero_point")
                if self.input("x_zero_point").check_initializer()
                else None
            ),
            "output_scale": (
                self.output("y_scale")
                if self.output("y_scale").check_initializer()
                else None
            ),
            "output_zero_point": (
                self.output("y_zero_point")
                if self.output("y_zero_point").check_initializer()
                else None
            ),
        }

        transpose_outputs = {
            "transposed": OutputTensor(
                n._model_dict, self.n._walk_cfg, transpose_output_tensor_name, None
            )
        }

        perm_attribute = {"perm": perm}
        transpose_node = self.add_node(
            type="Transpose_qdq_" + self.new_dtype,
            domain="ai.onnx.contrib",
            inputs=transpose_inputs,
            outputs=transpose_outputs,
            attributes=perm_attribute | self.qdq_attributes,
            new_name=f"{base_name}_transpose",
        )

        reshape2_inputs = {
            "data": transpose_node("transposed"),
            "shape": self.add_initializer(
                f"{base_name}_reshape2_shape",
                np.array(output_shape, dtype=np.int64),
            ),
        }

        reshape2_outputs = {"output": output_tensor}

        self.add_node(
            type="reshape_noop",
            domain="ai.onnx.contrib",
            inputs=reshape2_inputs,
            outputs=reshape2_outputs,
            attributes={},
            new_name=f"{base_name}_reshape2",
        )
        self.remove_node(n)

    def _calculate_transformation_params(self, input_shape):

        # convert to ND to 2D
        # for example, [1, 20, 30] -> [1, (1*20*30)]
        total_prod = reduce(mul, input_shape, 1)
        intermediate1_shape = [1, total_prod]

        # Always transpose the 2D matrix, so perm is always [1, 0]
        perm = [1, 0]
        # After transpose
        intermediate2_shape = [total_prod, 1]

        return [intermediate1_shape, intermediate2_shape], perm
