# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.

from typing import Any
import numpy as np
import onnx
from OGOAT.src.L1_fusion.py_match.helpers.common_type import TensorShape
from OGOAT.src.L1_fusion.py_match.helpers.reshape_transpose_state import (
    ReshapeTransposeUpState,
)
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import (
    TransposeHelper,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    InputTensor,
    Matcher,
    NoMatch,
    Tensor,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.checkers import opType


class RTROptimize(Matcher, TransposeHelper):
    """
    This matcher optimizes chains of of Reshape/Transpose/Squeeze/Unsqueeze into a shorter sequence of Reshape+Transpose+Reshape (with each of those only present if needed after optimization). The matcher starts with the main "n" node and look for this chain from all act inputs of "n" node.
    """

    def match(self) -> None:
        n = self.n.with_walk_cfg(WalkCfgPlain())
        if n.check(
            opType.Reshape | opType.Transpose | opType.Unsqueeze | opType.Squeeze
        ):
            raise NoMatch(
                "Searching to RTR chain and do RTR optimization should not based on Reshape, Transpose, Squeeze, Unsqueeze nodes, because these nodes are part of RTR matching chain."
            )
        self.rtr_input_names: dict[str, ReshapeTransposeUpState] = {}

        # for some ops like concat, same input tensor can be used multiple times.
        # we only need to optimize once for each unique input tensor.
        seen_input_names = set()
        for input_name in n.get_schema_input_names():
            input_tensor = n(input_name)
            if not input_tensor.check_input_tensor() or not input_tensor.get_name() or input_tensor.get_name() in seen_input_names:
                continue
            if (
                input_tensor.check_initializer()
                or len(input_tensor.require_tensor().get_shape()) == 0
            ):
                continue
            state_input = self.match_up_rtr_chain(
                input_tensor, self.n.get_name(), input_name
            )
            if (
                state_input
                and state_input.tail != state_input.head
                and len(state_input.matched) > 1
            ):
                optypes = set()
                for node in state_input.matched:
                    optypes.add(node.require_node().get_op_type())

                if len(optypes) == len(state_input.matched):
                    # If all matched nodes have different op types, no need to create RTR nodes again.
                    # RTR optimization is targeted for reducing node with same type back to back.
                    continue
                state_input = state_input.optimize_state()
                if state_input.has_exact_rtr_chain():
                    # do not run RTR optimization if we have exact Reshape-Transpose-Reshape chain.
                    continue
                seen_input_names.add(input_tensor.get_name())
                self.rtr_input_names[input_name] = state_input

        if self.rtr_input_names == {}:
            raise NoMatch("The RTR chain did not match or had only one matched node.")

    def modify(self) -> None:
        n = self.n.with_walk_cfg(WalkCfgPlain())
        self.input_tensor_names: dict[str, InputTensor] = {}
        for input_name, state in self.rtr_input_names.items():
            node_name = self.n.get_name()
            new_node_base_name = f"{node_name}_{input_name}"
            if state.ReshapeIN.input_shape != state.ReshapeIN.output_shape:
                state_in = self.create_reshape(
                    state,
                    input_name,
                    new_node_base_name + "_rtr_reshape_in",
                    state.ReshapeIN.output_shape.copy(),
                )
                self.input_tensor_names[input_name] = state_in
            if state.Transpose.is_nontrivial():
                data_input = self.get_current_input_tensor(input_name, state)
                inputs = {"data": data_input}
                # RTR creates a new node which does not have an orig_name
                # previously orig_name was derived from the "central" node
                # which caused multiple nodes with the same orig_name
                new_transpose_name = new_node_base_name + "_rtr_transpose"
                attributes = {"perm": state.perm, "orig_name": new_transpose_name}
                tranposed_tensor = self._add_input_state_node(
                    new_transpose_name,
                    "Transpose",
                    state.Transpose.output_shape.copy(),
                    data_input.get_dtype(),
                    inputs,
                    "transposed",
                    attributes,
                )
                self.input_tensor_names[input_name] = tranposed_tensor
            if state.ReshapeOUT.input_shape != state.ReshapeOUT.output_shape:
                state_out = self.create_reshape(
                    state,
                    input_name,
                    new_node_base_name + "_rtr_reshape_out",
                    state.ReshapeOUT.output_shape.copy(),
                )
                self.input_tensor_names[input_name] = state_out

            # if all nodes in the RTR state are dummy (`input_name not in self.input_tensor_names`), take their input
            new_input_tensor = (
                self.input_tensor_names[input_name]
                if input_name in self.input_tensor_names
                else state.matched[-1]("data")
            )
            self.replace_input(n, n(input_name), new_input_tensor)

    def get_current_input_tensor(
        self, input_name: str, state: ReshapeTransposeUpState
    ) -> Tensor:
        return (
            self.input_tensor_names[input_name]
            if input_name in self.input_tensor_names
            else state.head.require_tensor()
        )

    def create_reshape(
        self,
        state: ReshapeTransposeUpState,
        input_name: str,
        new_name: str,
        output_shape: TensorShape,
    ):
        attributes = {
            "allowzero": 0,
            "orig_name": state.get_last_matched_name(),
            "num_of_tensor_inputs": 1,
        }
        data_input = self.get_current_input_tensor(input_name, state)
        shape_initializer = self.add_initializer(
            f"{new_name}_shape",
            np.array(output_shape, dtype=np.int64),
            data_input.get_dtype_raw(),
        )
        inputs = {"data": data_input, "shape": shape_initializer}
        state_output = self._add_input_state_node(
            new_name,
            "Reshape",
            output_shape,
            data_input.get_dtype(),
            inputs,
            "reshaped",
            attributes,
        )

        return state_output

    def _add_input_state_node(
        self,
        new_name: str,
        new_type: str,
        output_shape: TensorShape,
        output_dtype: str,
        inputs: dict[str, Tensor],
        output_name: str,
        attributes: dict[str, Any],
    ) -> Tensor:
        state_output = Tensor(
            self.n._model_dict, self.n._walk_cfg, f"{new_name}_output", None
        )
        state_output.set_shape(output_shape, output_dtype)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs={output_name: state_output},
            attributes=attributes,
            new_name=new_name,
            add_matcher_name=False,
        )
        return state_output
