# fmt: on
from OGOAT.src.L1_fusion.py_match.basic.dataflow import Dataflow
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import (
    QDQHelper, InitName
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    Node,
    OutputTensor,
    InputTensor,
)
from OGOAT.src.L1_fusion.py_match.checkers import opType
import numpy as np
import math

class ReshapeToRTR(Matcher, QDQHelper):
    """
    Match and replace Reshape nodes with  Reshape_noop and Transpose_qdq.
    For ex:
    Complex reshape case:
    - 1*5*8*6 -> Reshape_noop -> 1*1*5*8*6 -> Transpose_qdq -> 1*5*8*6*1 -> Reshape_noop -> 1*5*8*3*2*1 -> Transpose_qdq -> 1*1*5*8*3*2 -> Reshape_noop -> 1*5*8*3*2
    - 1*5*8*6 -> Reshape_noop -> 1*1*5*8*6 -> Transpose_qdq -> 1*5*8*6*1 -> Reshape_noop -> 1*5*48*1 -> Transpose_qdq -> 1*1*5*48 -> Reshape_noop -> 1*5*48

    Differs only by last dim case:
    - 1*6 -> Reshape_noop -> 1*1*6 -> Transpose_qdq -> 1*6*1

    Pure Transpose case:
    - 2*8*3*2*1 -> Transpose_qdq -> 1*2*8*3*2

    Tail move with adjacent squeeze/unsqueeze case:
    - 2*8*2*2*1 -> Reshape_noop -> 2*8*4*1 -> Transpose_qdq -> 1*2*8*4
    - 2*8*6*1 -> Reshape_noop -> 2*8*3*2*1 -> Transpose_qdq -> 1*2*8*3*2
    """

    def match(self) -> None:
        n = self.n.require_node()
        n.require(opType.Reshape)
        self.input = n("data").require_tensor()
        self.output = n("reshaped").require_tensor()

    def modify(self) -> None:
        n = self.n.require_node()
        self.node_name = n.get_name()
        input_shape = self.input.get_shape()
        output_shape = self.output.get_shape()
        self.dtype = self.input.get_dtype()
        in_tensor = self.input
        self.suffix_counter = 0

        # Determine transformation strategy and execute
        if self._is_pure_transpose(input_shape, output_shape):
            self._handle_pure_transpose(in_tensor, output_shape)
        elif self._is_tail_move_with_adjacent_squeeze_or_unsqueeze(
            input_shape, output_shape
        ):
            self._handle_tail_move_squeeze_or_unsqueeze(
                in_tensor, input_shape, output_shape
            )
        elif self._differs_only_by_last_dim(input_shape, output_shape):
            self._handle_add_trailing_dim(in_tensor, output_shape)
        else:
            self._handle_complex_reshape(in_tensor, input_shape, output_shape)

        self.remove_node(n)

    def _handle_pure_transpose(
        self, in_tensor: InputTensor, output_shape: list[int]
    ) -> None:
        """Handle reshape that is purely a transpose operation."""
        self._create_transpose_node(in_tensor, output_shape, output_tensor=self.output)

    def _handle_tail_move_squeeze_or_unsqueeze(
        self, in_tensor: InputTensor, input_shape: list[int], output_shape: list[int]
    ) -> None:
        """Handle reshape where last dim moves to front with adjacent squeeze."""
        in_tensor = self._create_reshape_noop_node(
            in_tensor, output_shape[1:] + [input_shape[-1]]
        )("reshaped").require_tensor()
        self._create_transpose_node(in_tensor, output_shape, output_tensor=self.output)

    def _handle_add_trailing_dim(
        self, in_tensor: InputTensor, output_shape: list[int]
    ) -> None:
        """Handle reshape that only adds a trailing dimension of 1."""
        in_tensor = self._create_reshape_noop_node(
            in_tensor, [1] + in_tensor.get_shape()
        )("reshaped").require_tensor()
        self._create_transpose_node(in_tensor, output_shape, output_tensor=self.output)

    def _handle_complex_reshape(
        self, in_tensor: InputTensor, input_shape: list[int], output_shape: list[int]
    ) -> None:
        """Handle complex reshape requiring full transformation pipeline."""
        # Add leading dimension
        in_tensor = self._create_reshape_noop_node(
            in_tensor, [1] + in_tensor.get_shape(), str(self.suffix_counter)
        )("reshaped").require_tensor()
        self.suffix_counter += 1

        # Transpose to move dim=1 to last
        in_tensor = self._create_transpose_node(
            in_tensor, input_shape + [1], str(self.suffix_counter)
        )("transposed").require_tensor()
        self.suffix_counter += 1

        # Reshape to match output shape with extra transposed dim=1 at the end
        in_tensor = self._create_reshape_noop_node(
            in_tensor, output_shape + [1], str(self.suffix_counter)
        )("reshaped").require_tensor()
        self.suffix_counter += 1

        # Transpose to move last dim=1 back to front
        in_tensor = self._create_transpose_node(
            in_tensor, [1] + output_shape, str(self.suffix_counter)
        )("transposed").require_tensor()
        self.suffix_counter += 1

        # Final reshape to remove the extra dim=1 and match output shape
        self._create_reshape_noop_node(
            in_tensor, output_shape, str(self.suffix_counter), self.output
        )

    def _create_output_tensor(self, name: str, shape: list[int]) -> OutputTensor:
        """Helper method to create OutputTensor with consistent pattern."""
        self.n._model_dict.set_shape(name, shape, self.dtype)
        return OutputTensor(self.n._model_dict, self.n._walk_cfg, name, None)

    def _create_initializer(self, name: str, data: np.ndarray) -> InputTensor:
        """Helper method to create initializer with consistent pattern."""
        return self.add_initializer(name, data)

    def _differs_only_by_last_dim(
        self, input_shape: list[int], output_shape: list[int]
    ) -> bool:
        """
        Check if output shape differs from input shape only by having an extra last dimension of 1.
        """
        return (
            len(output_shape) == len(input_shape) + 1
            and output_shape[-1] == 1
            and input_shape == output_shape[:-1]
        )

    def _create_reshape_noop_node(
        self,
        input_tensor: InputTensor,
        output_shape: list[int],
        suffix: str = "",
        output: OutputTensor = None,
    ) -> Node:
        reshape_inputs = {
            "data": input_tensor,
            "shape": self._create_initializer(
                f"{self.node_name}_reshape_shape_initializer_{suffix}",
                np.array(output_shape, dtype=np.int64),
            ),
        }

        reshape_outputs = {
            "reshaped": (
                output
                if output is not None
                else self._create_output_tensor(
                    f"{self.node_name}_reshape_output_tensor_{suffix}", output_shape
                )
            )
        }

        return self.add_node(
            type="Reshape_noop",
            domain="ai.onnx.contrib",
            inputs=reshape_inputs,
            outputs=reshape_outputs,
            attributes={"num_of_tensor_inputs": 1},
            new_name=f"{self.node_name}_reshape_noop_node_{suffix}",
        )

    def _create_transpose_node(
        self,
        input_tensor: InputTensor,
        output_shape: list[int],
        suffix: str = "",
        output_tensor: OutputTensor = None,
    ) -> Node:
        perm = self._calculate_transpose_perm(input_tensor.get_shape(), output_shape)

        transpose_outputs = {
            "transposed": (
                output_tensor
                if output_tensor is not None
                else self._create_output_tensor(
                    f"{self.node_name}_transpose_output_tensor_{suffix}", output_shape
                )
            )
        }

        inputs = {
            "data": input_tensor,
            "data_scale": self._get_initializer_or_dummy(
                self.input("x_scale"), self.n, InitName.SCALE
            ),
            "data_zero_point": self._get_initializer_or_dummy(
                self.input("x_zero_point"), self.n, InitName.SCALE_ZERO_POINT
            ),
            "Y_scale": self._get_initializer_or_dummy(
                self.output("y_scale"), self.n, InitName.OUTPUT_SCALE
            ),
            "Y_zero_point": self._get_initializer_or_dummy(
                self.output("y_zero_point"), self.n, InitName.OUTPUT_ZERO_POINT
            ),
        }
        return self.add_node(
            type=f"Transpose_qdq_{self.dtype}x{self.dtype}",
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=transpose_outputs,
            attributes={
                "perm": perm,
                "disable_dq0": True,
                "disable_q": True,
                "num_of_tensor_inputs": 1,
            },
            new_name=f"{self.node_name}_transpose_qdq_node_{suffix}",
        )

    def _calculate_transpose_perm(
        self, input_shape: list[int], output_shape: list[int]
    ) -> list[int]:
        """
        Calculate permutation for transpose by matching dimensions. We always do transpose between the first and last dim.
        Example: [1, 1, 6] -> [1, 6, 1] yields perm [1, 2, 0] not [0, 2, 1] as first dim 1 moves to last.
        """
        n = len(input_shape)

        # Case 1: first dim moved to last
        if output_shape == input_shape[1:] + input_shape[:1]:
            return list(range(1, n)) + [0]

        # Case 2: last dim moved to first
        if output_shape == input_shape[-1:] + input_shape[:-1]:
            return [n - 1] + list(range(0, n - 1))

        # default identity perm
        return list(range(n))

    def _is_pure_transpose(
        self, input_shape: list[int], output_shape: list[int]
    ) -> bool:
        """
        Return True if output_shape is a permutation of input_shape (pure transpose),
        i.e., same dimensions with the same multiplicities, only reordered.

        Examples:
            _is_pure_transpose([1, 2, 3], [3, 1, 2]) -> True
            _is_pure_transpose([1, 2, 6], [1, 2, 3, 2]) -> False
            _is_pure_transpose((3, 3, 4, 1), (1, 3, 3, 4)) -> True
            _is_pure_transpose([2, 2], [4]) -> False
        """
        # Quick length check: transpose cannot change the number of axes
        if len(input_shape) != len(output_shape):
            return False

        # Multiset check: same dimensions with same counts means it's a permutation
        return sorted(input_shape) == sorted(output_shape)

    def _is_tail_move_with_adjacent_squeeze_or_unsqueeze(
        self, input_shape: list[int], output_shape: list[int]
    ) -> bool:
        """
        Detects reshape of the form where:
        - The last input dim is moved to the front unchanged.
        - The remaining input dims are reshaped (merged/split) to form the remaining output dims.
        - The total product must match (valid reshape).

        Returns True if the transformation matches this pattern.

        Examples:
        [128,128,2,2,1] -> [1,128,128,4]  => True (merge: 2*2=4)
        [128,128,4,1]   -> [1,128,128,2,2] => True (split: 4=2*2)
        [2,2,2,2,1]     -> [1,4,2,2]      => True (merge)
        [2,2,2,2,1]     -> [1,2,2,2,2]    => False (no reshape, same dims)
        [3,4,1]         -> [1,3,4]        => False (pure transpose, no reshape)
        """

        # Tail must be moved unchanged to the front
        if output_shape[0] != input_shape[-1]:
            return False

        in_rest = input_shape[:-1]
        out_rest = output_shape[1:]

        # If no reshape occurs (same dims), return False (should use pure transpose instead)
        if in_rest == out_rest:
            return False

        in_prod = math.prod(in_rest)
        out_prod = math.prod(out_rest)

        return in_prod == out_prod
