# fmt: on
import dataclasses
import logging
from typing import Optional, Any, ClassVar
from collections import OrderedDict

from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.common_type import Perm, TensorShape
from OGOAT.src.L1_fusion.py_match.helpers.perm_helper import PermutationHelper
from OGOAT.src.L1_fusion.py_match.helpers.reshape_transpose_helper import (
    ReshapeTransposeHelper,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Tensor,
    InputTensor,
    MatcherError,
    NoMatch,
    Node,
    OutputTensor,
)

dim = ReshapeTransposeHelper.get_tensor_dim
rank = ReshapeTransposeHelper.get_tensor_rank
node_rank = ReshapeTransposeHelper.get_node_rank
permute = PermutationHelper.permute


@dataclasses.dataclass
class ReshapeTransposeState(ReshapeTransposeHelper, PermutationHelper):
    """
    State of matching combination of transposes, reshapes in graph used as a base for `ReshapeTransposeUpState` and `ReshapeTransposeDownState`.

    **Required Fields**:
        - owner: an optional name of a Node (usually MatMul) for which this pattern is used. Default := `inp.name`
        - label: an optional name of the state, usaully name of MatMul's input. Used for logging/debugging.
        - tail -- original input (UP state) / output (DOWN state) tensor, bottom of represented subgraph,
                  for UP: one of the inputs of MatMul_qdq.[A|B]
                  for DOWN: current input
        - head -- top of represented subgraph,
                  for UP: current input
                  for DOWN: the output of Matmul_qdq.Y
        - transp_shape -- shape at input of transpose
        - perm -- permutation of transpose
        - matched -- list of matched nodes
    """

    # REQUIRED
    owner: str
    label: str
    tail: Tensor
    head: Tensor
    transp_shape: list[int]
    perm: list[int]
    out_shape: Optional[OutputTensor] = None
    _matched: OrderedDict[Node, str] = dataclasses.field(default_factory=OrderedDict)

    # INTERNAL
    _in_reshape_: "NodeView" = dataclasses.field(init=False, default=None)
    _out_reshape_: "NodeView" = dataclasses.field(init=False, default=None)
    _state_transpose_: "NodeView" = dataclasses.field(init=False, default=None)
    _acc_transpose_: "NodeView" = dataclasses.field(init=False, default=None)
    _head_node_: Node = None
    _tail_node_: Node = None

    @property
    def name(self):
        return f"{self.owner}_{self.label}"

    @property
    def matched(self) -> list[Node]:
        return list(self._matched)

    def has_nontrivial_transpose(self) -> bool:
        has_trivial_transpose_ = self.is_identity_perm(self.perm)
        return not has_trivial_transpose_

    def has_matched_transpose(self) -> bool:
        for n in self.matched:
            if n.check(opType.Transpose):
                return True
        return False

    def has_exact_rtr_chain(self) -> bool:
        """Checks if the state has an exact 3 nodes and follow exact Reshape-Transpose-Reshape chain."""
        return (
            len(self.matched) == 3
            and self.has_matched_transpose()
            and self.tail.get_reader().check(opType.Reshape)
            and self.tail.get_writer().check(opType.Reshape)
        )

    @classmethod
    def push_transpose_through_reshape(
        cls, tr: "TransposeView", re: "ReshapeView"
    ) -> tuple["ReshapeView", "TransposeView"]:
        logging.debug(
            f"SWAP(Tr-Re):: [{tr.input_shape}] -> Tr({tr.perm}) -> [{tr.output_shape}] ->  Re({re.shape})"
        )
        is_valid_swap, new_trans_shape, new_perm = cls.trans_reshape_swap(
            tr.input_shape, tr.perm, re.shape
        )
        if not is_valid_swap:
            # Should never happen. Otherwise the approach doesn't work:
            # the state looks like `Tr -> Re -> Tr -> Re`
            logging.debug(f"SWAP(Tr-Re)::FAIL")
            raise NoMatch(
                f"Can't swap Transpose({tr.perm}) " f"with Reshape({re.shape})"
            )

        # TrRe-SWAP is SUCCESSFUL!
        re_ = re.viewAs(input_shape=tr.input_shape, output_shape=new_trans_shape)
        tr_ = tr.viewAs(
            input_shape=new_trans_shape,
            output_shape=cls.permute(new_trans_shape, new_perm),
            perm=new_perm,
        )
        logging.debug(
            f"SWAP(Tr-Re)::SUCCESS:: [{re_.input_shape}] -> Re*({re_.shape}) ->  Tr*({tr_.perm}) -> [{tr_.output_shape}]"
        )
        return re_, tr_

    @classmethod
    def push_reshape_through_transpose(
        cls, re: "ReshapeView", tr: "TransposeView"
    ) -> tuple["TransposeView", "ReshapeView"]:
        logging.debug(
            f"SWAP(Re-Tr):: [{re.input_shape}] -> Re({re.shape}) -> Tr({tr.perm}) -> [{tr.output_shape}]"
        )

        if len(re.input_shape) > 4:
            raise MatcherError(
                "Shape with 5 dimensions are found, but only shapes up to 4 dimensions are currently supported."
            )
        is_valid_swap, new_perm = cls.reshape_trans_swap(
            re.input_shape, re.shape, tr.perm
        )
        if not is_valid_swap:
            # Re-SWAP is INVALID:
            logging.debug(f"SWAP(Re-Tr)::FAIL")
            raise NoMatch(
                f"Can't swap Reshape({re.shape}) " f"with Transpose({tr.perm})"
            )

        # ReTr-SWAP is SUCCESSFUL!
        tr_ = tr.viewAs(
            input_shape=re.input_shape,
            output_shape=cls.permute(re.input_shape, new_perm),
            perm=new_perm,
        )
        re_ = re.viewAs(
            input_shape=cls.permute(re.input_shape, new_perm),
            output_shape=tr.output_shape,
        )
        logging.debug(
            f"SWAP(Re-Tr)::SUCCESS:: [{tr_.input_shape}] -> Tr*({tr_.perm}) -> [{tr_.output_shape}] -> Re*({re_.shape}) -> [{re_.output_shape}]"
        )
        return tr_, re_

    def get_last_matched_name(self, default_name=None) -> str:
        if not self.matched:
            return default_name
        name = self.matched[-1].get_name()
        return name

    def get_last_reshape(self) -> Optional[Node]:
        if not self.matched:
            return None

        for n in self.matched[::-1]:
            if n.check(opType.Reshape):
                return n
        return None

    @property
    def ReshapeIN(self) -> "ReshapeView":
        in_reshape_info = self._get_reshape_in_info()
        if not self._in_reshape_:
            _view = type("_ReshapeView", (ReshapeView,), in_reshape_info)
            self._in_reshape_ = _view.of("State.ReshapeIN", self, **in_reshape_info)
        else:
            self._in_reshape_.update(**in_reshape_info)
            self._in_reshape_.matched = in_reshape_info.get("matched")
        return self._in_reshape_

    @property
    def Transpose(self) -> "TransposeView":
        state_transpose_info = self._get_state_transpose_info()
        if not self._state_transpose_:
            _view = type("_TransposeView", (TransposeView,), state_transpose_info)
            self._state_transpose_ = _view.of(
                "State.Transpose", self, **state_transpose_info
            )
        else:
            self._state_transpose_.update(**state_transpose_info)
            self._state_transpose_.matched = state_transpose_info.get("matched")
        return self._state_transpose_

    @property
    def ReshapeOUT(self) -> "ReshapeView":
        out_reshape_info = self._get_reshape_out_info()
        if not self._out_reshape_:
            _view = type("_ReshapeView", (ReshapeView,), out_reshape_info)
            self._out_reshape_ = _view.of("State.ReshapeOUT", self, **out_reshape_info)
        else:
            self._out_reshape_.update(**out_reshape_info)
            self._out_reshape_.matched = out_reshape_info.get("matched")
        return self._out_reshape_

    def _get_state_transpose_info(self) -> dict[str, Any]:
        state_transpose = dict(
            input_shape=self.transp_shape,
            output_shape=permute(self.transp_shape, self.perm),
            perm=self.perm,
            matched=[n for n, pos in self._matched.items() if pos == "State.Transpose"],
        )
        return state_transpose

    def _get_reshape_in_info(self) -> dict[str, Any]:
        hout = self.head.require_tensor()
        in_reshape = dict(
            input_shape=hout.get_shape(),
            output_shape=self.transp_shape,
            matched=[n for n, pos in self._matched.items() if pos == "State.ReshapeIN"],
        )
        return in_reshape

    def _get_reshape_out_info(self) -> dict[str, Any]:
        tail_reshape = dict(
            input_shape=permute(self.transp_shape, self.perm),
            output_shape=self.out_shape if self.out_shape else self.tail.get_shape(),
            matched=[
                n for n, pos in self._matched.items() if pos == "State.ReshapeOUT"
            ],
        )
        return tail_reshape

    def __post_init__(self):
        logging.debug(f"{str(self)}.STATE == {self._log_state_()}] ")
        logging.debug(
            f"{str(self)}.HEAD == {self.head.require_node().get_op_type()} -> [{self.head.require_tensor().get_shape()}] "
        )
        logging.debug(
            f"{str(self)}.ReshapeIN == {self.ReshapeIN}: matched({len(self.ReshapeIN.matched)})={self.ReshapeIN.matched}"
        )
        logging.debug(
            f"{str(self)}.Transpose == {self.Transpose}: matched({len(self.Transpose.matched)})={self.Transpose.matched}"
        )
        logging.debug(
            f"{str(self)}.ReshapeOUT == {self.ReshapeOUT}: matched({len(self.ReshapeOUT.matched)})={self.ReshapeOUT.matched}"
        )
        logging.debug(f"{str(self)}.MATCHED({len(self.matched)}) = {self.matched}")

    def _log_state_stats_(self) -> None:
        score_ = (
            self.ReshapeIN.get_node_rank()
            - abs(self.Transpose.get_dim() - 3)
            - abs(self.ReshapeOUT.get_node_rank())
        )
        logging.debug(f"{str(self)}.STATS")
        logging.debug(f"{str(self)}.STATS.ReshapeIN={self.ReshapeIN.get_stats()}")
        logging.debug(f"{str(self)}.STATS.Transpose={self.Transpose.get_stats()}")
        logging.debug(f"{str(self)}.STATS.ReshapeOUT={self.ReshapeOUT.get_stats()}")
        logging.debug(f"{str(self)}.STATS.score={score_}")
        return

    def _log_state_(self, log=False) -> str:
        msg_ = (
            f"[{self.head.require_tensor().get_shape()}] ->  "
            f"ReIN({self.ReshapeIN.shape}) -> "
            f"Tr({self.Transpose.perm}) -> [{self.Transpose.output_shape}] -> "
            f"ReOUT({self.ReshapeOUT.output_shape}) -> "
            f"[{self.out_shape if self.out_shape else self.tail.require_tensor().get_shape()}]"
        )
        if log:
            logging.debug(f"{str(self)}.STATE:: {msg_}")
        return msg_


@dataclasses.dataclass
class ReshapeTransposeUpState(ReshapeTransposeState, ReshapeTransposeHelper, PermutationHelper):
    """
    State of matching combination of transposes, reshapes upwards in graph.

    State represents merged/overall graph part from `head` at the top to `tail` at the bottom, represented equivalently by:
        -> <head> -> State.ReshapeIN(transp_shape) -> <transp_input> -> State.Transpose(perm) -> State.ReshapeOUT(tail_shape) -> <tail> -> MatMul_qdq
    where <tail> == MatMul_qdq.A or MatMul_qdq.B

    Adding new nodes in RTR state occurs from State.ReshapeIN:
        [<head> -- new node] >>> [State.ReshapeIN -> State.Transpose -> State.ReshapeOUT]

    It will be fused into:
        -> <head> ->  Reshape_NOOP(transp_shape) -> <transp_input> -> MatMul_qdq_Tr(shapeX=tail_shape, permX=perm)
    where 'X' is 'A' or 'B'.
    """

    def __str__(self):
        return f"RTRUpState({self.label})"

    def __repr__(self):
        return (
            f"ReshapeTransposeUpState("
            f"{self.tail.require_tensor().get_name()}, "
            f"{self.owner}, {self.label}"
            f")"
        )

    @staticmethod
    def fromInputTensor(
        inp: InputTensor, state_owner: str = None, state_label: str = None
    ) -> "ReshapeTransposeUpState":
        """
        A factory method for creating `ReshapeTransposeUpState` objects from input tensors (e.g. MatMul_qdq.inputs).

        **Args**:
            - `inp`: an input tensor, the starting point of the state
            - `state_owner`: an optional name of a Node (usually MatMul) for which this pattern is used. Default := `inp.name`
            - `state_label`: an optional name of the state, usaully name of MatMul's input. Used for logging/debugging.

        **Returns**:
            - a new `ReshapeTransposeUpState` object representing the initial state:
                -> <inp> ->  State.ReshapeIN(inp.shape) -> <transp_input> -> State.Transpose(IDENTITY) -> State.ReshapeOUT(inp.shape)  -> <inp> -> MatMul_qdq
        """

        inp_shape = inp.get_shape()
        return ReshapeTransposeUpState(
            owner=state_owner if state_owner else inp.get_name(),
            label=state_label if state_label else "X",
            tail=inp,
            head=inp,
            transp_shape=inp_shape.copy(),
            perm=PermutationHelper.get_identity_perm(inp_shape),
        )

    def get_next_state(self, **kwargs) -> "ReshapeTransposeUpState":
        _matched = self._matched
        if matched_node := kwargs.get("matched"):
            _matched[matched_node] = kwargs.get("matched_position", "HEAD")

        new_state = ReshapeTransposeUpState(
            owner=kwargs.get("owner", self.owner),
            label=kwargs.get("label", self.label),
            tail=kwargs.get("tail", self.tail),
            head=kwargs.get("head", self.head),
            transp_shape=kwargs.get("transp_shape", self.transp_shape),
            perm=kwargs.get("perm", self.perm),
            _matched=_matched,
        )
        new_state._head_node_ = self._head_node_
        new_state._acc_transpose_ = self._acc_transpose_
        return new_state

    def match_upwards(self) -> "ReshapeTransposeUpState":
        logging.debug(f"{str(self)}.START:: {self.name} ")
        cur = self
        while True:
            try:
                next_ = cur.match_upwards_step()
            except MatcherError as e:
                logging.debug(f"{str(self)}.BREAK:: {e}")
                break
            cur = next_
        
        if cur._acc_transpose_ is not None:
            # If there is a pending accumulated transpose that hasn't been merged,
            # it means the match was interrupted (e.g. by shape mismatch or end of chain).
            # In this case, the state is inconsistent (matched nodes include the transpose, but perm doesn't).
            # We should abort the match to prevent incorrect optimization (e.g. removing the transpose).
            logging.debug(f"{str(self)}.ABORT:: Pending transpose {cur._acc_transpose_}")
            return self

        logging.debug(f"{str(self)}.DONE:: {self.name} \n")
        return cur

    def match_upwards_step(self) -> "ReshapeTransposeUpState":
        # next node ==> TRANSPOSE  ------------------------------------------------------------------------------------
        if self.head.check(opType.Transpose):
            # [0] matched with a Transpose node
            self._head_node_ = self.head.require_node()
            head_tr = TransposeView.ofNode(
                self._head_node_, perm=self._head_node_.get_attribute_value("perm")
            )
            logging.debug(
                f"{str(self)}.MATCH:: <<Transpose({head_tr.perm})>> :: {self._head_node_.require_node().get_name()}"
            )

            # collect all consecutive Transpose nodes in `self._acc_transpose_`,
            # change the state only when not Transpose appears
            if self._acc_transpose_:
                # merge the existing `self._acc_transpose_` and the current Transpose
                merged_perm = permute(x=head_tr.perm, perm=self._acc_transpose_.perm)
                self._acc_transpose_ = self._acc_transpose_.viewAs(
                    input_shape=head_tr.input_shape,
                    output_shape=permute(head_tr.input_shape, merged_perm),
                    perm=merged_perm,
                )
            else:
                self._acc_transpose_ = head_tr

            # [1] STEP UP
            return self.get_next_state(
                head=self.head("data"),
                transp_shape=self.transp_shape,
                perm=self.perm,
                matched=self._head_node_,
                matched_position=self.Transpose.view_name,
            )

        # the current node is not Transpose; if some Transpose nodes are collected, update the state
        if self._acc_transpose_:
            # [1] EXPECT: self._acc_transpose.output_shape == self.Transpose.input_shape
            # >>> this transpose must fit on top of transpose in state
            if self._acc_transpose_.output_shape != self.Transpose.input_shape:
                raise MatcherError(
                    f"Shapes are incompatible: "
                    f"out_shape={self._acc_transpose_.output_shape} != "
                    f"StateTranspose.input_shape={self.Transpose.input_shape}"
                )

            # [2] Tr-SWAP with State.ReshapeIN:
            # >>> try to swap this transpose with State.ReshapeIN
            re_, tr_ = self.push_transpose_through_reshape(
                self._acc_transpose_, self.ReshapeIN
            )

            # [3] Tr-MERGE with State.Transpose:
            # >>> now try to combine the new transpose perm with the state perm (State.Transpose)
            #           next state.perm := apply the state.perm to new_perm
            new_state_perm = permute(x=tr_.perm, perm=self.Transpose.perm)
            new_state_tr = self.Transpose.viewAs(
                input_shape=re_.shape,
                output_shape=permute(re_.shape, new_state_perm),
                perm=new_state_perm,
            )

            # [4] Update the current state. Don't do STEP UP: the current node has not yet been processed
            self.transp_shape = new_state_tr.input_shape
            self.perm = new_state_tr.perm
            self._acc_transpose_ = None

        # next node ==> Reshape or Squeeze or Unsqueeze ----------------------------------------------------------------------------
        # >>> try to push Reshape, Squeeze or Unsqueeze node through the state
        if self.head.check(opType.Reshape | opType.Squeeze | opType.Unsqueeze):
            head_node = self.head.require_node()
            head_re = ReshapeView.ofNode(head_node)
            logging.debug(
                f"{str(self)}.MATCH:: <<Reshape({head_re.shape})>> :: {self.head.require_node().get_name()}"
            )

            # [1] Re-MERGE with State.ReshapeIN:
            # >>> merging this Reshape with State.ReshapeIN happens implicitly
            head_re = ReshapeView.ofNode(
                head_node,
                input_shape=head_re.input_shape,
                output_shape=self.ReshapeIN.output_shape,
            )

            # [2] Re-SWAP with State.Transpose:
            # >>> try to swap the merged reshape with State.Transpose
            try:
                tr_, re_ = self.push_reshape_through_transpose(head_re, self.Transpose)

            except MatcherError:
                # Re-SWAP is INVALID:
                # >>> this reshape is merged with State.ReshapeIN
                return self.get_next_state(
                    head=self.head("data"),
                    matched=head_node,
                    matched_position=self.ReshapeIN.view_name,
                )

            # [3] Re-MERGE with State.ReshapeOUT:
            # >>> this reshape is merged with State.ReshapeOUT
            # new_re_out = self.ReshapeOUT.viewAs(
            #     input_shape=re_.input_shape,
            #     output_shape=self.ReshapeOUT.output_shape,
            # )

            # [4] STEP UP
            return self.get_next_state(
                head=self.head("data"),
                transp_shape=tr_.input_shape,
                perm=tr_.perm,
                matched=head_node,
                matched_position=self.ReshapeOUT.view_name,
            )

        raise NoMatch(f"Not Reshape/Transpose: {self.head}")

    def optimize_state(self) -> "ReshapeTransposeUpState":
        self._log_state_stats_()
        # IF ReshapeOUT is identity
        # THEN pushing state.Transpose through it doesn't change anything
        if self.ReshapeOUT.is_identity():
            return self

        try:
            logging.debug(f"{str(self)}.OPT::  StateTr <-> ReOUT  :: ")
            re_, tr_ = self.push_transpose_through_reshape(
                self.Transpose, self.ReshapeOUT
            )

            next_state = self.get_next_state(
                label=f"{self.label}*",
                transp_shape=re_.output_shape,
                perm=tr_.perm,
            )
            logging.debug(
                f"{str(next_state)}.OPT::SUCCESS:: {next_state._log_state_()}"
            )
            next_state._log_state_stats_()

            if next_state.ReshapeOUT.is_identity():
                return next_state

        except MatcherError as e:
            logging.debug(f"{str(self)}.OPT::FAIL:: {e} \n")
        return self


@dataclasses.dataclass
class ReshapeTransposeDownState(ReshapeTransposeState, ReshapeTransposeHelper, PermutationHelper):
    """
    State of matching combination of transposes, reshapes downwards in graph.

    State represents merged/overall graph part from `head` at the top to `tail` at the bottom, represented equivalently by:
        -> MatMul_qdq -> <head> -> State.ReshapeIN(transp_shape) -> <transp_input> -> State.Transpose(perm) -> State.ReshapeOUT(tail_shape) -> <tail>
    where <head> == MatMul_qdq.Y

    Adding new nodes in RTR state occurs from State.ReshapeOUT:
        [State.ReshapeIN -> State.Transpose -> State.ReshapeOUT] <<< [<tail> -- new node]

    It will be fused into:
        -> MatMul_qdq_Tr(shape=head_shape, perm=perm) -> <transp_output> -> Reshape_NOOP(transp_shape) -> <tail>
    """

    def __str__(self):
        return f"RTRDownState({self.label})"

    def __repr__(self):
        return (
            f"ReshapeTransposeDownState("
            f"{self.head.require_tensor().get_name()}, "
            f"{self.owner}, {self.label}"
            f")"
        )

    @staticmethod
    def fromOutputTensor(
        out: OutputTensor, state_owner: str = None, state_label: str = None
    ) -> "ReshapeTransposeDownState":
        """
        A factory method for creating `ReshapeTransposeDownState` objects from output tensor (e.g. MatMul_qdq.output).

        **Args**:
            - `out`: an output tensor, the starting point of the state
            - `state_owner`: an optional name of a Node (usually MatMul) for which this pattern is used. Default := `out.name`
            - `state_label`: an optional name of the state, usaully name of MatMul's output. Used for logging/debugging.

        **Returns**:
            - a new `ReshapeTransposeDownState` object representing the initial state:
                -> MatMul_qdq -> <out> -> State.ReshapeIN(out.shape) -> <transp_input> -> State.Transpose(IDENTITY) -> State.ReshapeOUT(out.shape)
        """

        out_shape = out.get_shape()
        return ReshapeTransposeDownState(
            owner=state_owner if state_owner else out.get_name(),
            label=state_label if state_label else "X",
            tail=out,
            head=out,
            transp_shape=out_shape.copy(),
            perm=PermutationHelper.get_identity_perm(out_shape),
            out_shape=out_shape.copy(),
        )

    def get_next_state(self, **kwargs) -> "ReshapeTransposeDownState":
        _matched = self._matched
        if matched_node := kwargs.get("matched"):
            _matched[matched_node] = kwargs.get("matched_position", "HEAD")

        new_state = ReshapeTransposeDownState(
            owner=kwargs.get("owner", self.owner),
            label=kwargs.get("label", self.label),
            tail=kwargs.get("tail", self.tail),
            head=kwargs.get("head", self.head),
            transp_shape=kwargs.get("transp_shape", self.transp_shape),
            perm=kwargs.get("perm", self.perm),
            _matched=_matched,
            out_shape=kwargs.get("out_shape", self.out_shape),
        )
        new_state._tail_node_ = self._tail_node_
        new_state._acc_transpose_ = self._acc_transpose_
        return new_state

    def match_downwards(self) -> "ReshapeTransposeDownState":
        logging.debug(f"{str(self)}.START:: {self.name} ")
        cur = self
        while True:
            try:
                next_ = cur.match_downwards_step()
            except MatcherError as e:
                logging.debug(f"{str(self)}.BREAK:: {e}")
                break
            cur = next_
        
        if cur._acc_transpose_ is not None:
            # If there is a pending accumulated transpose that hasn't been merged,
            # it means the match was interrupted (e.g. by shape mismatch or end of chain).
            # In this case, the state is inconsistent (matched nodes include the transpose, but perm doesn't).
            # We should abort the match to prevent incorrect optimization (e.g. removing the transpose).
            logging.debug(f"{str(self)}.ABORT:: Pending transpose {cur._acc_transpose_}")
            return self

        logging.debug(f"{str(self)}.DONE:: {self.name} \n")
        return cur

    def match_downwards_step(self) -> "ReshapeTransposeDownState":
        # next node ==> TRANSPOSE  ------------------------------------------------------------------------------------
        if self.tail.check(opType.Transpose):
            # [0] matched with a Transpose node
            self._tail_node_ = self.tail.require_node()
            tail_tr = TransposeView.ofNode(
                self._tail_node_, perm=self._tail_node_.get_attribute_value("perm")
            )
            logging.debug(
                f"{str(self)}.MATCH:: <<Transpose({tail_tr.perm})>> :: {self._tail_node_.require_node().get_name()}"
            )

            # collect all consecutive Transpose nodes in `self._acc_transpose_`,
            # change the state only when not Transpose appears
            if self._acc_transpose_:
                # merge the existing `self._acc_transpose_` and the current Transpose
                merged_perm = permute(x=self._acc_transpose_.perm, perm=tail_tr.perm)
                self._acc_transpose_ = self._acc_transpose_.viewAs(
                    input_shape=self._acc_transpose_.input_shape,
                    output_shape=permute(self._acc_transpose_.input_shape, merged_perm),
                    perm=merged_perm,
                )
            else:
                self._acc_transpose_ = tail_tr

            # [1] STEP DOWN
            return self.get_next_state(
                tail=self.tail("transposed"),
                transp_shape=self.transp_shape,
                out_shape=self.tail.get_shape(),
                perm=self.perm,
                matched=self._tail_node_,
                matched_position=self.Transpose.view_name,
            )

        # the current node is not Transpose; if some Transpose nodes are collected, update the state
        if self._acc_transpose_:
            # [1] EXPECT: self.Transpose.output_shape == self._acc_transpose.input_shape
            # >>> this transpose must fit on bottom of transpose in state, not on top!
            if self.Transpose.output_shape != self._acc_transpose_.input_shape:
                raise MatcherError(
                    f"Shapes are incompatible: "
                    f"StateTranspose.output_shape={self.Transpose.output_shape} != "
                    f"in_shape={self._acc_transpose_.input_shape}"
                )

            # [2] Tr-SWAP with State.ReshapeOUT:
            # >>> try to swap this transpose with State.ReshapeOUT
            tr_, re_ = self.push_reshape_through_transpose(
                self.ReshapeOUT, self._acc_transpose_
            )

            # [3] Tr-MERGE with State.Transpose:
            # >>> now try to combine the new transpose perm with the state perm (State.Transpose)
            #           next state.perm := apply the new_perm to state.perm
            new_state_perm = permute(x=self.Transpose.perm, perm=tr_.perm)
            new_state_tr = self.Transpose.viewAs(
                input_shape=self.Transpose.input_shape,
                output_shape=permute(self.Transpose.output_shape, new_state_perm),
                perm=new_state_perm,
            )

            # [4] Update the current state. Don't do STEP DOWN: the current node has not yet been processed
            self.transp_shape = new_state_tr.input_shape
            self.out_shape = re_.output_shape
            self.perm = new_state_tr.perm
            self._acc_transpose_ = None

        # next node ==> Reshape or Squeeze or Unsqueeze ----------------------------------------------------------------------------
        # >>> try to push Reshape, Squeeze or Unsqueeze node through the state
        if self.tail.check(opType.Reshape | opType.Squeeze | opType.Unsqueeze):
            tail_node = self.tail.require_node()
            tail_re = ReshapeView.ofNode(tail_node)
            logging.debug(
                f"{str(self)}.MATCH:: <<Reshape({tail_re.shape})>> :: {self.tail.require_node().get_name()}"
            )

            # [1] Re-MERGE with State.ReshapeOUT:
            # >>> merging this Reshape with State.ReshapeOUT happens implicitly
            tail_re = ReshapeView.ofNode(
                tail_node,
                input_shape=self.ReshapeOUT.input_shape,
                output_shape=tail_re.output_shape,
            )

            # [2] Re-SWAP with State.Transpose:
            # >>> try to swap State.Transpose with the merged reshape
            try:
                re_, tr_ = self.push_transpose_through_reshape(self.Transpose, tail_re)
            except MatcherError:
                # Re-SWAP is INVALID:
                # >>> this reshape is merged with State.ReshapeOUT
                return self.get_next_state(
                    tail=self.tail(self.get_out_tensor_schema_for_tail()),
                    out_shape=tail_re.output_shape,
                    matched=tail_node,
                    matched_position=self.ReshapeOUT.view_name,
                )

            # [3] Re-MERGE with State.ReshapeIN:
            # [4] STEP DOWN
            return self.get_next_state(
                tail=self.tail(self.get_out_tensor_schema_for_tail()),
                transp_shape=tr_.input_shape,
                out_shape=self.out_shape,
                perm=tr_.perm,
                matched=tail_node,
                matched_position=self.ReshapeIN.view_name,
            )

        raise NoMatch(f"Not Reshape/Transpose: {self.tail}")

    def get_out_tensor_schema_for_tail(self) -> str:
        if self.tail.check(opType.Reshape):
            return "reshaped"
        elif self.tail.check(opType.Squeeze):
            return "squeezed"
        elif self.tail.check(opType.Unqueeze):
            return "expanded"
        elif self.tail.check(opType.Transpose):
            return "transposed"
        else:
            raise MatcherError(
                f"Unsupported op type for getting the output tensor schema: {self.tail}."
            )

    def optimize_state(self) -> "ReshapeTransposeDownState":
        self._log_state_stats_()
        # IF ReshapeIN is identity
        # THEN pushing state.Transpose through it doesn't change anything
        if self.ReshapeIN.is_identity():
            return self

        try:
            logging.debug(f"{str(self)}.OPT::  ReIN <-> StateTr :: ")
            tr_, re_ = self.push_reshape_through_transpose(
                self.ReshapeIN, self.Transpose
            )

            next_state = self.get_next_state(
                label=f"{self.label}*",
                transp_shape=tr_.input_shape,
                perm=tr_.perm,
                out_shape=re_.output_shape
            )
            logging.debug(
                f"{str(next_state)}.OPT::SUCCESS:: {next_state._log_state_()}"
            )
            next_state._log_state_stats_()

            if next_state.ReshapeIN.is_identity():
                return next_state

        except MatcherError as e:
            logging.debug(f"{str(self)}.OPT::FAIL:: {e} \n")
        return self

@dataclasses.dataclass
class NodeView:
    """
    A light-weight view of "real" (from Model) and "imaginary" (from State) nodes.
    Provides view of input/output shapes and node's parameters (like perm for Transpose)

    **Fields**:
        - `view_name`: name/type of viewed obj. If a model node is view then set to that node's opType.
        - `origin`: the viewed object or its name. Mainly used for track keeping and debugging
        - `input_shape`, `output_shape`: the TensorShapes of viewed object
    """

    view_name: str
    origin: str | Node | ReshapeTransposeState
    input_shape: TensorShape
    output_shape: TensorShape
    _fields_: ClassVar[set[str]] = {}

    def __repr__(self):
        return f"View({self.view_name}, {repr(self.origin)})"

    def __str__(self):
        return f"[{self.input_shape}] -> {self.view_name}() -> [{self.output_shape}]"

    @classmethod
    def fields(cls) -> set[str]:
        if not cls._fields_:
            cls._fields_ = {f.name for f in dataclasses.fields(cls)}
        return cls._fields_

    @classmethod
    def of(cls, view_name, origin, *shapes, **info) -> "NodeView":
        info = {f: val for f, val in info.items() if f in cls.fields()}
        return cls(view_name, origin, *shapes, **info)

    @classmethod
    def ofNode(cls, node: Node, *shapes, **info) -> "NodeView":
        """
        Returns a view of a model Node:
            View(view_name=node.opType, origin=node)
        with shapes and params provided by caller/user
        """
        return cls(node.get_op_type(), node, *shapes, **info)

    def viewAs(self, **info) -> "NodeView":
        """
        Returns a new view of the viewed object as if it had the values provided by `info` dict.
        """
        input_shape = info.pop("input_shape", self.input_shape)
        output_shape = info.pop("output_shape", self.output_shape)
        return self.of(self.view_name, self.origin, input_shape, output_shape, **info)

    def update(self, **info) -> None:
        self.input_shape = info.pop("input_shape", self.input_shape)
        self.output_shape = info.pop("output_shape", self.output_shape)

    @dataclasses.dataclass(kw_only=True)
    class Stats:
        node_rank: int
        output_rank: int
        output_dim: int
        identity: bool

    def get_stats(self) -> "NodeView.Stats":
        return NodeView.Stats(
            node_rank=self.get_node_rank(),
            output_rank=self.get_output_rank(),
            output_dim=dim(self.output_shape),
            identity=self.is_identity(),
        )

    def get_node_rank(self) -> int:
        return node_rank(self.input_shape, self.output_shape)

    def get_input_rank(self) -> int:
        return rank(self.input_shape)

    def get_output_rank(self) -> int:
        return rank(self.output_shape)

    def is_identity(self) -> bool:
        return False

    def is_nontrivial(self) -> bool:
        return not self.is_identity()


@dataclasses.dataclass(kw_only=True)
class TransposeView(NodeView):

    perm: Perm

    def __str__(self):
        return f"{self.view_name}({self.perm})"

    def verbose(self) -> str:
        return f"[{self.input_shape}] -> Tr({self.perm}) -> [{self.output_shape}]"

    @classmethod
    def ofNode(cls, node: Node, **info) -> "TransposeView":
        input_shape = info.pop("input_shape", node("data").require_tensor().get_shape())
        output_shape = info.pop(
            "output_shape", node("transposed").require_tensor().get_shape()
        )
        return super().ofNode(node, input_shape, output_shape, perm=info.get("perm"))

    def update(self, **info) -> None:
        super().update(**info)
        self.perm = info.pop("perm", self.output_shape)

    def is_identity(self) -> bool:
        return PermutationHelper.is_identity_perm(self.perm)

    def get_dim(self) -> int:
        return dim(self.perm)


@dataclasses.dataclass
class ReshapeView(NodeView):

    @property
    def shape(self) -> TensorShape:
        return self.output_shape

    def __str__(self):
        return f"{self.view_name}({self.shape})"

    def verbose(self) -> str:
        return f"[{self.input_shape}] -> Re({self.shape}) -> [{self.output_shape}]"

    @classmethod
    def ofNode(cls, node: Node, **info) -> "ReshapeView":
        input_shape = info.pop("input_shape", node("data").require_tensor().get_shape())
        output_shape = info.pop(
            "output_shape", node.get_outputs()[0].require_tensor().get_shape()
        )
        return super().ofNode(node, input_shape, output_shape)

    def is_identity(self) -> bool:
        return ReshapeTransposeHelper.is_identity_reshape(
            self.input_shape, self.output_shape
        )
