# fmt: on
from typing import Optional
from OGOAT.src.L1_fusion.py_match.helpers.common_type import Perm
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.helpers.reshape_transpose_state import (
    ReshapeTransposeUpState,
    ReshapeTransposeDownState,
)
from OGOAT.src.L1_fusion.py_match.helpers.transpose_with_optional_qdq import (
    TransposeWithOptionalQDQ,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    InputTensor,
    NoMatch,
    OutputTensor,
)


class TransposeHelper(QDQHelper):

    @staticmethod
    def get_unrolled_perm(perm: Perm, prefix: str) -> dict[str, int]:
        """
        Produce an "unrolled" version of a list attribute.
        For example, for input prefix="perm" and value=[0, 3, 1, 2],
        return {"perm_1": 0, "perm_2": 3, "perm_3": 1, "perm_4": 2}.
        """
        if not perm:
            return {}
        return {f"{prefix}_{i+1}": perm[i] if i < len(perm) else -1 for i in range(4)}

    def get_unrolled_perm_attribute(self, node: Element, prefix: str) -> dict[str, int]:
        """
        Extract individual transpose permutation attributes in a structured format.

        Args:
            node: The transpose node to extract permutation from.
            prefix: Identifier for the permutation location ('perm1' for input transpose,
                    'perm2' for output transpose).

        Returns:
            Dictionary containing individual permutation values under keys (e.g., prefix_1,
            prefix_2, etc.), always providing 4 entries (filling with -1 for missing dimensions).
        """

        perm_attr = node.require_node().get_attribute_value("perm")
        result = self.get_unrolled_perm(perm_attr, prefix)
        return result

    def require_nchw_conversion(
        self, input_node: Element, output_node: Element
    ) -> tuple[TransposeWithOptionalQDQ, TransposeWithOptionalQDQ]:
        """
        Require that the input and output specified are transpose nodes performing
        a NCHW conversion, by first converting to NCHW then reverting to NHWC.
        Raise MatcherError if the expected transposes are not present.
        """
        input = TransposeWithOptionalQDQ.match(input_node("x"))
        perm_attr = input.transpose_node.get_attribute_value("perm")
        if perm_attr != [0, 2, 1] and perm_attr != [0, 3, 1, 2]:
            raise NoMatch(
                "permutation attribute does not match an NHWC to NCHW conversion"
            )
        output = TransposeWithOptionalQDQ.match(output_node("y"))
        perm_attr = output.transpose_node.get_attribute_value("perm")
        if perm_attr != [0, 2, 1] and perm_attr != [0, 2, 3, 1]:
            raise NoMatch(
                "permutation attribute does not match an NCHW to NHWC conversion"
            )

        return input, output

    def match_up_rtr_chain(
        self,
        inp: Element,
        state_owner: str = None,
        state_label: str = None,
    ) -> Optional["ReshapeTransposeUpState"]:
        """
        Match upwards in graph / at input a combination of
         - QDQ with equal parameters
         - transposes
         - reshapes
        Match as much of it as can be represented by a single transpose permutation.

        **Args**:
            - `inp` -- Input tensor at which to start matching upwards, e.g. input of Matmul_qdq node
            - `state_owner`: an optional name of a Node (usually MatMul) for which this pattern is used. Default := `inp.name`
            - `state_label` -- an optional name of the state, usually input name. Used for logging/debugging.

        **Returns**:
            - `has_transpose`: bool -- True if transpose is present and overall perm is not equal to identity
            - `state` -- a `ReshapeTransposeUpState` object containing all info about matched Re-Tr
        """

        if not isinstance(inp, InputTensor):
            return False, None

        # initialize the state: [Reshape -> Transpose -> Reshape] that does nothing
        # >>> equivalent to the identity transformation
        # >>> schema: [the orig Re-Tr chain] -> [the initial state: Reshape -> Transpose -> Reshape] -> <in> -> MM
        state = ReshapeTransposeUpState.fromInputTensor(inp, state_owner, state_label)
        state = state.match_upwards()
        return state

    def match_down_rtr_chain(
        self,
        out: Element,
        state_owner: str = None,
        state_label: str = None,
    ) -> Optional["ReshapeTransposeDownState"]:
        """
        Match downwards in graph / at output a combination of
         - QDQ with equal parameters
         - transposes
         - reshapes
        Match as much of it as can be represented by a single transpose permutation.

        **Args**:
            - `out` -- Output tensor at which to start matching downwards, e.g. output of Matmul_qdq node
            - `state_owner`: an optional name of a Node (usually MatMul) for which this pattern is used. Default := `out.name`
            - `state_label` -- an optional name of the state, usually output name. Used for logging/debugging.

        **Returns**:
            - `has_transpose`: bool -- True if transpose is present and overall perm is not equal to identity
            - `state` -- a `ReshapeTransposeDownState` object containing all info about matched Re-Tr
        """

        if not isinstance(out, OutputTensor):
            return False, None

        # initialize the state: [Reshape -> Transpose -> Reshape] that does nothing
        # >>> equivalent to the identity transformation
        # >>> schema: MM -> <out> -> [the initial state: Reshape -> Transpose -> Reshape] -> [the orig Re-Tr chain]
        state = ReshapeTransposeDownState.fromOutputTensor(
            out, state_owner, state_label
        )
        state = state.match_downwards()
        return state
