# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.
from abc import ABC, abstractmethod
from typing import (
    Optional, Dict, Set, List, Any, NamedTuple, Union, ClassVar, override, Self
)
from dataclasses import dataclass, field
from collections import deque
import numpy as np

from L1_utils.safe_runner import SafeRunner
from OGOAT.src.L1_fusion.py_match.helpers.common_type import (
    TensorShape, Perm, NumpyDType, OnnxDType, NodeDict,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    InputTensor,
    OutputTensor,
    Matcher,
    Tensor,
    WalkCfgPlain,
    MatcherError,
    Initializer, 
    NoMatch, 
    Node,
    ShapeMismatchError,
)

from OGOAT.src.L1_fusion.py_match.checkers import (
    opType,
    partialOpType as fusedOpType,
    AttrValue,
)

from OGOAT.src.L1_fusion.py_match.helpers.perm_helper import (
    PermutationHelper as perm_helper,
)
from OGOAT.src.L1_fusion.py_match.helpers.reshape_transpose_helper import (
    ReshapeTransposeHelper,
)
from OGOAT.src.L1_fusion.py_match.adv.rtr_optimize import RTROptimize
from OGOAT.src.L1_fusion.py_match.adv.matmul_transpose import MatMulTranspose

import logging
import sys

from OGOAT.src.utils.context import Logger
from py_match.model_dict import ModelDict

logger = logging.getLogger("L1_fusion").getChild("RTR")
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.INFO)
logger.addHandler(stream_handler)
logger.propagate = False
logger.setLevel(logging.DEBUG)

# TR_OUTPUT = "Y" # -O 3
TR_OUTPUT = "transposed" # -O 2


def get_channel_dim(shape: TensorShape) -> int:
    return shape[-1]

def get_bulk_dim(shape: TensorShape) -> int:
    return int(np.prod(shape[:-1]))

def check_channel_dim(shapeA: TensorShape, shapeB: TensorShape) -> bool:
    channelA = get_channel_dim(shapeA)
    channelB = get_channel_dim(shapeB)
    return channelA == channelB

def check_bulk_dim(shapeA: TensorShape, shapeB: TensorShape) -> bool:
    bulkA = get_bulk_dim(shapeA)
    bulkB = get_bulk_dim(shapeB)
    return bulkA == bulkB

def as_input_tensor(tensor: Tensor, origin: Optional[Node]=None) -> InputTensor:
    if tensor.check_input_tensor():
        return tensor
    elif tensor.check_output_tensor():
        return InputTensor(
            tensor._model_dict,
            tensor._walk_cfg,
            tensor._tensor_name,
            origin_node_name=origin.get_name() if origin else None,
        )
    else:
        raise MatcherError(f"Cannot convert to InputTensor: {tensor}")

def as_output_tensor(tensor: Tensor, origin: Optional[Node]=None) -> OutputTensor:
    if tensor.check_output_tensor():
        return tensor
    elif tensor.check_input_tensor():
        return OutputTensor(
            tensor._model_dict,
            tensor._walk_cfg,
            tensor._tensor_name,
            origin_node_name=origin.get_name() if origin else None,
        )
    else:
        raise MatcherError(f"Cannot convert to OutputTensor: {tensor}")


# TODO: move to reshape_transpose_helper.py
class RTRStateError(MatcherError):
    def __init__(self, msg: str, rtr = None) -> None:
        full_msg = f"RTR-Error @{rtr}: {msg}" if rtr else f"RTR-Error: {msg}"
        super().__init__(full_msg)
        self._rtr = rtr
        self._msg = msg


class RTRMatchingError(NoMatch):
    def __init__(self, msg: str, rtr = None) -> None:
        full_msg = f"No matching {rtr}: {msg}" if rtr else f"No matching RTR: {msg}"
        MatcherError.__init__(self, full_msg)
        self._rtr = rtr
        self._msg = msg


@dataclass(repr=False)
class RTRDimensions:
    _origin: Optional['RTRState'] = field(
        # default=None, init=True, repr=False, compare=False, 
        default=None, init=True, repr=False, compare=False,
    )
    x: int = field(init=True, compare=True, kw_only=True)
    y: int = field(init=True, compare=True, kw_only=True)
    C: int = field(init=True, compare=True, kw_only=True)
    m: int = field(default=1, init=True, compare=True, kw_only=True)
    """ m - bottom channel-splitting factor """
    k: int = field(default=1, init=True, compare=True, kw_only=True)
    """ m - top channel-splitting factor """
    p: int = field(default=1, init=True, compare=True, kw_only=True)
    """ p - padding factor """

    def __str__(self):
        return f"RTRDimensions[x={self.x}, y={self.y}, C={self.C}, m={self.m}, k={self.k}, p={self.p}]"

    def __repr__(self):
        return f"RTRDimensions(x={self.x}, y={self.y}, C={self.C}, m={self.m}, k={self.k}, p={self.p}, origin={str(self._origin)})"

    @property
    def X(self) -> int:
        return self.x * self.x
    
    @property
    def Y(self) -> int:
        return self.y * self.y
    
    @property
    def B(self) -> int:
        return self.X * self.Y

    @property
    def W(self) -> int:
        return self.x * self.y
    
    @property
    def c(self) -> int:
        if self.m:
            return self.C // self.m
        return self.C
    
    @property
    def z(self) -> int:
        return self.p * self.y

    @property
    def Z(self) -> int:
        return self.z * self.z

    @property
    def A(self) -> int:
        return self.X * self.Z

    @property
    def V(self) -> int:
        return self.x * self.z

    @property
    def D(self) -> int:
        if self.k:
            return self.C // self.k
        return self.C

    def __getitem__(self, key: str) -> TensorShape:
        """
        Allow bracket access with a compact RTR mnemonic string.
        Example: rtrDims["BxyC"] -> [self.B, self.x, self.y, self.C]
        """
        if not isinstance(key, str):
            raise TypeError(f"RTRDimensions key must be str, got {type(key).__name__}")
        return self.get_shape(key)    

    @classmethod
    def from_RTR(cls, rtr: 'RTRState', m: int = 1, p: int = 1) -> 'RTRDimensions':
        if cls.is_top_RTR(rtr):
            _, x, y, _, _, C = rtr.reshaped
        elif cls.is_bottom_RTR(rtr):
            _, x, _, y, _, C = rtr.reshaped
        else:
            raise RTRMatchingError(f"{str(rtr)} is neither TOP nor BOTTOM shuffle RTR.", rtr=rtr)

        return cls(rtr, x=x, y=y, C=C, m=m, p=p)

    @property
    def origin(self) -> Optional['RTRState']:
        return self._origin
    
    @classmethod
    def require_shuffle_RTR(cls, rtr: 'RTRState') -> Self:
        if not cls.is_shuffle_RTR(rtr):
            raise RTRMatchingError(f"{str(rtr)} is not a Shuffle RTR.", rtr=rtr)
        return cls

    @classmethod
    def require_top_RTR(cls, rtr: 'RTRState') -> Self:
        if not cls.is_top_RTR(rtr):
            raise RTRMatchingError(f"{str(rtr)} is not a Top RTR.", rtr=rtr)
        return cls

    @classmethod
    def require_bottom_RTR(cls, rtr: 'RTRState') -> Self:
        if not cls.is_bottom_RTR(rtr):
            raise RTRMatchingError(f"{str(rtr)} is not a Bottom RTR.", rtr=rtr)
        return cls

    @classmethod
    def is_shuffle_RTR(cls, rtr: 'RTRState') -> bool:
        # check = rtr.get_Transpose().check(AttrValue("perm", [0, 1, 3, 2, 4, 5]))
        check = rtr.perm == [0, 1, 3, 2, 4, 5]
        return check
    
    @classmethod
    def is_bottom_RTR(cls, rtr: 'RTRState') -> bool:
        if not cls.is_shuffle_RTR(rtr):
            return False
        
        _, x1, x2, y1, y2, C = rtr.reshaped
        if x1 != x2 or y1 != y2:
            return False
        
        x, y = x1, y1
        _, x1, y1, x2, y2, C = rtr.transposed
        if x1 != x2 or y1 != y2 or x != x1 or y != y1:
            return False                
        return True

    @classmethod
    def is_top_RTR(cls, rtr: 'RTRState') -> bool:
        if not cls.is_shuffle_RTR(rtr):
            return False
        
        _, x1, y1, x2, y2, C = rtr.reshaped
        if x1 != x2 or y1 != y2:
            return False
        
        x, y = x1, y1
        _, x1, x2, y1, y2, C = rtr.transposed
        if x1 != x2 or y1 != y2 or x != x1 or y != y1:
            return False                
        return True

    def get_shape(self, shape_pattern: str) -> TensorShape:
        """
        Parse a compact RTR mnemonic string into a TensorShape.
        - Letters map to RTRDimensions attributes/properties (x, y, C, m, X, Y, W, B, c, ...).
        - Digits produce literal sizes (supports multi-digit numbers).
        - Parentheses group tokens: the grouped result is multiplied into a single dimension.
          Example: "Bxy(3mc)" -> [self.B, self.x, self.y, 3 * self.m * self.c]
        """
        if not isinstance(shape_pattern, str) or not shape_pattern.strip():
            return []

        s = shape_pattern.strip()
        out: list[int] = []

        def resolve_token(token: str) -> int:
            if hasattr(self, token):
                val = getattr(self, token)
                if isinstance(val, int):
                    return val
            
            raise RTRStateError(
                f"Unknown RTR-dimension token='{token}' in RTR-shape pattern=['{shape_pattern}'].", 
                rtr=self._origin
            )

        def eval_segment(seg: str) -> int:
            # Evaluate a segment (no parentheses) into either a list of ints or a single grouped int.
            # For product (group content), compute and return a single int.
            i, n = 0, len(seg)
            acc: list[int] = []
            num_buf: list[str] = []

            def flush_number():
                nonlocal num_buf, acc
                if num_buf:
                    acc.append(int("".join(num_buf)))
                    num_buf = []

            while i < n:
                ch = seg[i]
                if ch.isdigit():
                    num_buf.append(ch)
                    i += 1
                    continue
                flush_number()
                if ch.isalpha():
                    token = ch
                    # Optional prime suffix X' Y'
                    if i + 1 < n and seg[i + 1] == "'":
                        i += 1  # consume "'"
                    acc.append(resolve_token(token))
                    i += 1
                    continue
                # Unsupported character
                raise RTRStateError(
                    f"Unexpected character='{ch}' in RTR-shape pattern=['{shape_pattern}'].",
                    rtr=self._origin
                )
            flush_number()

            # If called for group content, return product
            prod = 1
            for v in acc:
                prod *= v
            return prod

        i, n = 0, len(s)
        literal_buf: list[str] = []

        def flush_literal_buf():
            nonlocal literal_buf, out
            if not literal_buf:
                return
            # literal_buf contains a sequence like "Bxy" without parentheses -> expand to ints
            seq = "".join(literal_buf)
            # Expand each letter into a separate dimension
            j, m = 0, len(seq)
            num_acc: list[str] = []
            while j < m:
                ch = seq[j]
                if ch.isdigit():
                    num_acc.append(ch)
                    j += 1
                    continue
                if num_acc:
                    out.append(int("".join(num_acc)))
                    num_acc = []
                if ch.isalpha():
                    token = ch
                    if j + 1 < m and seq[j + 1] == "'":
                        j += 1
                    out.append(resolve_token(token))
                    j += 1
                    continue
                raise RTRStateError(
                    f"Unexpected character='{ch}' in RTR-shape pattern=['{shape_pattern}'].",
                    rtr=self._origin
                )
            
            if num_acc:
                out.append(int("".join(num_acc)))
                num_acc = []
            literal_buf = []

        while i < n:
            ch = s[i]
            if ch == '(':
                # Flush any preceding literal tokens
                flush_literal_buf()
                # Find matching ')'
                i += 1
                start = i
                depth = 1
                while i < n and depth > 0:
                    if s[i] == '(':
                        depth += 1
                    elif s[i] == ')':
                        depth -= 1
                        if depth == 0:
                            break
                    i += 1
                if depth != 0:
                    raise RTRStateError(
                        f"Unmatched '(' in RTR-shape pattern=['{shape_pattern}']",
                        rtr=self._origin
                    )
                
                group_content = s[start:i]
                # Evaluate group as product and append as single dimension
                out.append(eval_segment(group_content))
                i += 1  # consume ')'
                continue
            else:
                literal_buf.append(ch)
                i += 1

        # Flush trailing literal tokens
        flush_literal_buf()
        return out

    def check_shape(self, shape_pattern: str, tensor_shape: TensorShape) -> bool:
        """
        Returns True if get_shape(shape_pattern) equals tensor_shape.
        Supports parentheses grouping as in get_shape (e.g., "Bxy(3mc)").
        """
        if not isinstance(shape_pattern, str):
            raise TypeError("rtr_shape must be a string")
        if not isinstance(tensor_shape, list):
            raise TypeError("tensor_shape must be a TensorShape (list[int])")

        try:
            expected_shape = self.get_shape(shape_pattern)
        except Exception as e:
            # If the pattern cannot be parsed, treat as non-match
            return False

        check_ = expected_shape == tensor_shape
        return check_

    def check_channel_dim(self, tensor_shape: TensorShape) -> bool:
        if not isinstance(tensor_shape, list):
            raise TypeError("tensor_shape must be a TensorShape (list[int])")
        
        if not tensor_shape:
            return False
        
        C_dim = tensor_shape[-1]
        C_dim = get_channel_dim(tensor_shape)
        return C_dim in (self.C, 3*self.C, self.c, self.Y, self.Z, self.D)

    # def get_channel_dims(self, tensor_shape: TensorShape) -> List[int]:
    #     if not isinstance(tensor_shape, list):
    #         raise TypeError("tensor_shape must be a TensorShape (list[int])")
        
    #     if not tensor_shape:
    #         return []
        
    #     C_dim = tensor_shape[-1]
        
        
    #     return channel_dims

    pass


@dataclass(frozen=True)
class RTRState:
    head: Tensor
    tail: Tensor

    shape_in: TensorShape
    """ shape_in == head.get_shape() """
    reshaped: TensorShape
    """ reshaped == head("reshaped").get_shape() == ReshapeIN.out_shape """
    perm: Perm
    """ perm == head("reshaped").attr["perm"] == Transpose.perm """
    shape_out: TensorShape
    """ shape_out == tail.get_shape() == == ReshapeOUT.out_shape"""
    
    _subgraph: 'RTRSubgraph' = field(
        default=None, init=True, repr=False, compare=False,
    )
    _origin: Optional['RTRState'] = field(
        default=None, init=True, repr=False, compare=False,
    )

    _valid_channel_dim: bool = field(
        default=False, init=False, repr=False, compare=False,
    )
    _valid_bulk_dim: bool = field(
        default=False, init=False, repr=False, compare=False,
    )
    _sub_root: bool = field(
        default=False, init=False, repr=False, compare=False,
    )
    _sub_leaf: bool = field(
        default=False, init=False, repr=False, compare=False,
    )

    @property
    def key(self) -> str:
        """ key == head.get_name() """
        return self.head.get_name()
    
    @property
    def inverse_key(self) -> str:
        """ inverse_key == tail.get_name() """
        return self.tail.get_name()

    @property
    def channel_dim(self) -> int:
        channel_dim_ = get_channel_dim(self.reshaped)
        return channel_dim_

    @property
    def bulk_dim(self) -> int:
        bulk_dim_ = get_bulk_dim(self.reshaped)
        return bulk_dim_

    @property
    def dim_size(self) -> int:
        """ dim_size := bulk_dim * channel_dim """
        dim_size_ = int(np.prod(self.reshaped))
        return dim_size_
    
    @property
    def transposed(self) -> TensorShape:
        """
        transposed == head("reshaped.Y").get_shape() == Transpose.out_shape
        """
        transposed_shape_ = perm_helper.permute(self.reshaped, self.perm)
        return transposed_shape_

    @property
    def subgraph(self) -> 'RTRSubgraph':
        rtr_sub_ = self._subgraph
        return rtr_sub_

    @property
    def origin(self) -> Optional['RTRState']:
        return self._origin

    def __hash__(self) -> int:
        return hash((self.key, self.inverse_key))

    def __eq__(self, other) -> bool:
        if not isinstance(other, RTRState):
            return False
        return (
            self.key == other.key
            and self.inverse_key == other.inverse_key
        )

    def __str__(self):
        name = f"RTRState"
        header = f"{self.head}, {self.tail}"
        if self.is_pointer():
            name = f"RTRPointer"
            header = f"{self.head}"

        rtr_str = f"{name}({header})"
        return rtr_str

    def __repr__(self):
        name = f"RTRState"
        header = f"head={self.head}, tail={self.tail}"
        body = (
            f"shape_in={self.shape_in}, "
            f"reshaped={self.reshaped}, "
            f"perm={self.perm}, "
            f"shape_out={self.shape_out}"
        )

        if self.is_pointer():
            name = f"RTRState.Pointer"
            header = f"pointer={self.head}"
            
        rtr_str = f"{name}({header}, {body})"
        return rtr_str

    def _assert_bulk_dim(self) -> None:
        in_bulk = get_bulk_dim(self.shape_in)
        out_bulk = get_bulk_dim(self.shape_out)
        reshaped_bulk = get_bulk_dim(self.reshaped)
        
        if not check_bulk_dim(self.shape_out, self.shape_in):
            raise ShapeMismatchError(
                f"{self} bulk dim mismatch: "
                f"shape_out={self.shape_out}:{out_bulk}, "
                f"shape_in={self.shape_in}:{in_bulk}"
            )

        if not check_bulk_dim(self.shape_out, self.reshaped):
            raise ShapeMismatchError(
                f"{self} bulk dim mismatch: "
                f"shape_out={self.shape_out}:{out_bulk}, "
                f"reshaped={self.reshaped}:{reshaped_bulk}"
            )
        return

    def check_bulk_dim(self, shape: Optional[TensorShape] = None) -> bool:
        if shape and self.check_bulk_dim():
            return check_bulk_dim(self.reshaped, shape)

        # no args == self validation
        if not self._valid_bulk_dim:
            try:
                self._assert_bulk_dim()
                object.__setattr__(self, "_valid_bulk_dim", True)
            except ShapeMismatchError as e:
                logger.debug(e)
                return False

        return True

    def _assert_channel_dim(self) -> None:
        in_channel = get_channel_dim(self.shape_in)
        out_channel = get_channel_dim(self.shape_out)
        reshaped_channel = get_channel_dim(self.reshaped)
        if not check_channel_dim(self.shape_out, self.shape_in):
            raise ShapeMismatchError(
                f"{self} channel dim mismatch: "
                f"shape_out={self.shape_out}:{out_channel}, "
                f"shape_in={self.shape_in}:{in_channel}"
            )

        if not check_channel_dim(self.shape_out, self.reshaped):
            raise ShapeMismatchError(
                f"{self} channel dim mismatch: "
                f"shape_out={self.shape_out}:{out_channel}, "
                f"reshaped={self.reshaped}:{reshaped_channel}"
            )
        return

    def check_channel_dim(self, shape: TensorShape = None) -> bool:
        if shape and self.check_channel_dim():
            return check_channel_dim(self.reshaped, shape)

        # no args == self validation
        if not self._valid_channel_dim:
            try:
                self._assert_channel_dim()
                object.__setattr__(self, "_valid_channel_dim", True)
            except ShapeMismatchError as e:
                logging.debug(e)
                return False
        
        return True
    
    def get_shape(self, tail=False) -> TensorShape:
        """Returns: RTRPointer.head.get_shape()"""
        if tail:
            return self.tail.get_shape()
        return self.head.get_shape()
    
    def is_inverse_RTR(self, state: 'RTRState') -> bool:
        # if self.shape_in != state.shape_out:
        #     return False
        # if self.shape_out != state.shape_in:
        #     return False
        if not perm_helper.are_cancellable_perms(self.perm, state.perm):
            return False
        # if self.reshaped[:2] != state.reshaped[:2]:
            # return False
        # top = self if RTRDimensions.is_top_RTR(self) else state
        # bottom = self if RTRDimensions.is_bottom_RTR(self) else state

        if RTRDimensions.is_top_RTR(self) and  RTRDimensions.is_top_RTR(state):
            return False
        if RTRDimensions.is_bottom_RTR(self) and  RTRDimensions.is_bottom_RTR(state):
            return False

        return True


    @classmethod
    def _Pointer(
        cls, pointer: Tensor, *, subgraph: 'RTRSubgraph', origin: 'RTRState', **kwargs
    ) -> 'RTRState':
        kwargs['_origin'] = origin
        kwargs['_subgraph'] = subgraph
        return RTRState(pointer, pointer, **kwargs)

    def is_pointer(self) -> bool:
        return self.key == self.inverse_key

    def is_alive(self) -> bool:
        return not self.is_pointer()

    def get_pointer(self, pointer: Tensor=None, head=True, **kwargs):
        pointer = pointer if pointer else self.head if head else self.tail
        return RTRState._Pointer(
            pointer,
            subgraph=self.subgraph,
            origin=self.origin if self.is_pointer() else self,
            shape_in=kwargs.get("shape_in", self.shape_in),
            reshaped=kwargs.get("reshaped", self.reshaped),
            perm=kwargs.get("perm", self.perm),
            shape_out=kwargs.get("shape_out", self.shape_out),
        )

    def get_inverse_pointer(self, pointer: Tensor=None, head=False):
        pointer = pointer if pointer else self.head if head else self.tail

        return RTRState._Pointer(
            pointer,
            subgraph=self.subgraph,
            origin=self.origin if self.is_pointer() else self,
            shape_in=self.shape_out,
            reshaped=self.transposed,
            perm=perm_helper.get_inverse_perm(self.perm),
            shape_out=self.shape_in,
        )

    @classmethod
    def SubRoot(
        cls, pointer: Union['RTRState', Tensor],
        origin: Optional['RTRState'] = None,
        **kwargs
    ) -> 'RTRState':
        if isinstance(pointer, RTRState):
            sub_root = pointer.get_pointer(**kwargs)
        elif isinstance(pointer, Tensor) and origin:
            sub_root = RTRState._Pointer(
                pointer,  subgraph=origin.subgraph, origin=origin, **kwargs
            )
        else:
            raise RTRStateError(
                "RTRState.SubRoot() expects a non-null RTRPointer, or a tensor and the origin RTRState",
            )
        
        object.__setattr__(sub_root, "_sub_root", True)
        return sub_root
        
    @classmethod
    def SubLeaf(
        cls, pointer: Union[Tensor, 'RTRState'],
        origin: Optional['RTRState'] = None,
        **kwargs
    ) -> 'RTRState':
        if isinstance(pointer, RTRState):
            sub_leaf = pointer.get_pointer(**kwargs)
        elif isinstance(pointer, Tensor) and origin:
            sub_leaf = RTRState._Pointer(
                pointer, subgraph=origin.subgraph, origin=origin, **kwargs
            )
        else:
            raise RTRStateError(
                "RTRState.SubLeaf() expects a non-null RTRPointer, or a tensor and the origin RTRState",
            )
        
        object.__setattr__(sub_leaf, "_sub_leaf", True)
        return sub_leaf
        
    def is_sub_root(self) -> bool:
        return self._sub_root

    def is_sub_leaf(self) -> bool:
        return self._sub_leaf

    def as_sub_root(self) -> 'RTRState':
        try:
            sub_root = RTRState.SubRoot(self)
        except RTRStateError as err:
            raise RTRStateError(err._msg, rtr=self)
        return sub_root

    def as_sub_leaf(self) -> 'RTRState':
        try:
            sub_leaf = RTRState.SubLeaf(self)
        except RTRStateError as err:
            raise RTRStateError(err._msg, rtr=self)
        return sub_leaf

    def get_Reshape_IN(self) -> Node | NodeDict:
        if self.is_pointer():
            new_reshape_IN = self._make_Reshape_IN_dict()
            return new_reshape_IN

        # if self is a live RTR --> return existing RTR.Reshape_IN node
        return self.tail.get_writer()("data.data").require_node()

    def get_Transpose(self) -> Node | NodeDict:
        if self.is_pointer():
            new_transpose = self._make_Transpose_dict()
            return new_transpose
        
        # if self is a live RTR --> return existing RTR.Transpose node
        return self.tail.get_writer()("data").require_node()

    def get_Reshape_OUT(self) -> Node | NodeDict:
        if self.is_pointer():
            new_reshape_OUT = self._make_Reshape_OUT_dict()
            return new_reshape_OUT
        
        # if self is a live RTR --> return existing RTR.Reshape_OUT node
        return self.tail.get_writer().require_node()

    def get_RTR_node(self) -> NodeDict:
        return dict()

    def _fetch_RTR_nodes_data(self) -> dict:
        origin = self.origin if self.origin else self
        # if self.origin is None:
            # origin = self
            # raise RTRStateError(f"RTRState.origin is None", rtr=self)
        
        input_name = self.head.get_name()
        rtr_name = f"{input_name}_RTR"
        output_name = f"{rtr_name}_out"

        # >>> Get RTR's Nodes data dict >>>
        transpose = origin.get_Transpose()
        reshape_IN = origin.get_Reshape_IN()
        reshape_OUT = origin.get_Reshape_OUT()
        if self.is_inverse_RTR(origin):
            reshape_IN = origin.get_Reshape_OUT()
            reshape_OUT = origin.get_Reshape_IN()

        rtr_attributes = {
            "RTR_subgraph_id": self.subgraph.id,
        }
        if self.is_sub_root():
            rtr_attributes |= {"RTR_subgraph_root": self.subgraph.id}
        if self.is_sub_leaf():
            rtr_attributes |= {"RTR_subgraph_leaf": self.subgraph.id}

        rtr_nodes_data = dict(
            reshape_IN=reshape_IN,
            transpose=transpose,
            reshape_OUT=reshape_OUT,
            reshape_IN_name = f"{rtr_name}_Reshape_IN",
            transpose_name = f"{rtr_name}_Transpose",
            reshape_OUT_name = f"{rtr_name}_Reshape_OUT",
            rtr_node_name = f"{rtr_name}_Pointer",
            rtr_node_optype = f"RTRPointer",
            rtr_node_shape_name = f"{rtr_name}_shape_ini",
            rtr_attributes=rtr_attributes,
        )

        # >>> Get RTR's Tensors data dict >>>
        rtr_input = InputTensor(
            self.head._model_dict,
            self.head._walk_cfg,
            input_name,
            origin_node_name=rtr_nodes_data["reshape_IN_name"],
        )

        rtr_output = OutputTensor(
            self.head._model_dict,
            self.head._walk_cfg,
            output_name,
            origin_node_name=rtr_nodes_data["reshape_OUT_name"],
        )
        rtr_output.set_shape(self.shape_out, origin.output_ndtype)

        reshape_output = Tensor(
            self.head._model_dict,
            self.head._walk_cfg,
            rtr_nodes_data["reshape_IN_name"] + "_output",
            origin_node_name=None,
        )
        reshape_output.set_shape(self.reshaped, origin.reshaped_ndtype)

        transpose_output = Tensor(
            self.head._model_dict,
            self.head._walk_cfg,
            rtr_nodes_data["transpose_name"] + "_output",
            origin_node_name=None,
        )
        transpose_output.set_shape(self.transposed, origin.transposed_ndtype)
        
        rtr_tensors_data = dict(
            input_name=input_name,
            output_name=output_name,
            rtr_input=rtr_input,
            rtr_output=rtr_output,
            reshaped=reshape_output,
            transposed=transpose_output,
        )

        # >>> Get RTR-State data dict >>>
        rtr_state_data = rtr_nodes_data | rtr_tensors_data
        return rtr_state_data
    
    def _make_Reshape_IN_dict(self) -> NodeDict:
        rtr_nodes_data = self._fetch_RTR_nodes_data()
        reshape_IN = rtr_nodes_data['reshape_IN']
        reshape_IN_name = rtr_nodes_data['reshape_IN_name']
        rtr_attributes = reshape_IN.get_attributes() | rtr_nodes_data['rtr_attributes']

        rtr_input: Tensor = rtr_nodes_data['rtr_input']
        reshaped: Tensor = rtr_nodes_data['reshaped']

        new_reshape_IN: NodeDict = dict(
            new_name=reshape_IN_name,
            type=reshape_IN.get_op_type(),
            domain="ai.onnx.contrib",
            inputs=reshape_IN.get_inputs_dict() | {"data": rtr_input},
            outputs={"reshaped": reshaped},
            attributes=rtr_attributes,
        )
        return new_reshape_IN

    def _make_Transpose_dict(self) -> NodeDict:
        rtr_nodes_data = self._fetch_RTR_nodes_data()
        transpose = rtr_nodes_data['transpose']
        transpose_name = rtr_nodes_data['transpose_name']
        rtr_attributes = transpose.get_attributes() | rtr_nodes_data['rtr_attributes']

        reshaped: Tensor = rtr_nodes_data['reshaped']
        transposed: Tensor = rtr_nodes_data['transposed']

        new_transpose: NodeDict = dict(
            new_name=transpose_name,
            type=transpose.get_op_type(),
            domain="ai.onnx.contrib",
            inputs={"data": reshaped},
            outputs={f"{TR_OUTPUT}": transposed},
            attributes=rtr_attributes,
        )
        return new_transpose

    def _make_Reshape_OUT_dict(self) -> NodeDict:
        rtr_nodes_data = self._fetch_RTR_nodes_data()
        reshape_OUT = rtr_nodes_data['reshape_OUT']
        reshape_OUT_name = rtr_nodes_data['reshape_OUT_name']
        rtr_attributes = reshape_OUT.get_attributes() | rtr_nodes_data['rtr_attributes']

        transposed: Tensor = rtr_nodes_data['transposed']
        rtr_output: Tensor = rtr_nodes_data['rtr_output']

        new_reshape_OUT: NodeDict = dict(
            new_name=reshape_OUT_name,
            type=reshape_OUT.get_op_type(),
            domain="ai.onnx.contrib",
            inputs=reshape_OUT.get_inputs_dict() | {"data": transposed},
            outputs={"reshaped": rtr_output},
            attributes=rtr_attributes,
        )
        return new_reshape_OUT

    def _make_RTR_shape_ini(self, rtr_matcher: Matcher, shape: TensorShape = None) -> Initializer:
        shape = shape if shape else self.reshaped
        rtr_state_data = self._fetch_RTR_nodes_data()
        rtr_node_shape_name = rtr_state_data["rtr_node_shape_name"]
        rtr_origin_shape_ini: Initializer = self.origin.get_Reshape_IN()("shape").require_initializer()
        rtr_node_shape_dtype: OnnxDType = rtr_origin_shape_ini.get_dtype_raw()
        rtr_node_shape_ini: Initializer = rtr_matcher.add_initializer(
            rtr_node_shape_name,
            np.array(shape),
            rtr_node_shape_dtype,
        )
        return rtr_node_shape_ini

    def _make_RTR_node_dict(self, rtr_matcher: Matcher) -> NodeDict:
        rtr_state_data = self._fetch_RTR_nodes_data()

        rtr_input: Tensor = rtr_state_data['rtr_input']
        rtr_output: Tensor = rtr_state_data['rtr_output']
        
        rtr_node_name = rtr_state_data["rtr_node_name"]
        rtr_node_optype = rtr_state_data["rtr_node_optype"]
        
        rtr_node_shape_name = rtr_state_data["rtr_node_shape_name"]
        rtr_origin_shape_ini: Initializer = self.origin.get_Reshape_IN()("shape").require_initializer()
        rtr_node_shape_dtype: OnnxDType = rtr_origin_shape_ini.get_dtype_raw()
        rtr_node_shape_ini: Initializer = rtr_matcher.add_initializer(
            rtr_node_shape_name,
            np.array(self.reshaped),
            rtr_node_shape_dtype,
        )

        rtr_attributes = {
            "allowzero": 0,
            "orig_name": self.head.get_name(),
            "num_of_tensor_inputs": 1,
        } | self.get_info("RTR") | rtr_state_data['rtr_attributes']

        rtr_node: NodeDict = dict(
            new_name=rtr_node_name,
            type=rtr_node_optype,
            domain="ai.onnx.contrib",
            inputs={"data": rtr_input, "shape": rtr_node_shape_ini},
            outputs={"Y": rtr_output},
            attributes=rtr_attributes,
        )
        return rtr_node

    def get_info(self, pfx:str=None) -> Dict[str, Any]:
        pfx = f"{pfx}_" if pfx else ""
        return {
            f"{pfx}shape_in": self.shape_in,
            f"{pfx}reshaped": self.reshaped,
            f"{pfx}perm": self.perm,
            f"{pfx}transposed": self.transposed,
            f"{pfx}shape_out": self.shape_out,
            f"{pfx}channel_dim": self.channel_dim,
            f"{pfx}bulk_dim": self.bulk_dim,
        }

    @property
    def input_ndtype(self) -> NumpyDType:
        return self.head.get_dtype()

    @property
    def reshaped_ndtype(self) -> NumpyDType:
        if self.is_pointer():
            return self.input_ndtype
        transpose = self.get_Transpose()
        return transpose("data").require_tensor().get_dtype()

    @property
    def transposed_ndtype(self) -> NumpyDType:
        if self.is_pointer():
            return self.tail.get_dtype()
        transpose = self.get_Transpose()
        return transpose(TR_OUTPUT).require_tensor().get_dtype()

    @property
    def output_ndtype(self) -> NumpyDType:
        return self.tail.get_dtype()

    pass    # end of RTRState


class DeadRTR(NamedTuple):
    ptr: RTRState
    state: Optional[RTRState] = None

    @staticmethod
    def NULL() -> 'DeadRTR':
        return DeadRTR(ptr=None, state=None)

    @property
    def is_null(self) -> bool:
        return self.ptr is None


@dataclass
class RTRSubgraph:
    id: int
    _rtr_matcher: 'RTRSubgraphMatcher' = field(
        default=None, init=True, repr=False, compare=False,
    )

    nodes: Set[Node] = field(default_factory=set, init=False)
    inputs: Set[InputTensor] = field(default_factory=set, init=False)
    outputs: Set[OutputTensor] = field(default_factory=set, init=False)
    nodes_list: List[Node] = field(default_factory=list, init=False)

    @property
    def rtr_matcher(self) -> 'RTRSubgraphMatcher':
        return self._rtr_matcher

    def __str__(self):
        return (
            f"RTRSubgraph("
            f"id={self.id}, "
            f"nodes={len(self.nodes)}, "
            f"inputs={len(self.inputs)}, "
            f"outputs={len(self.outputs)}"
            f")"
        )

    def is_subgraph_input(self, tensor: Tensor) -> bool:
        return tensor in self.inputs

    def is_subgraph_output(self, tensor: Tensor) -> bool:
        return tensor in self.outputs

    def get_live_RTRs(self) -> list[RTRState]:
        rtrs: list[RTRState] = [
            rtr for rtr in self.rtr_matcher.rtr_states.values()
            if self.rtr_matcher.get_RTR_subgraph(rtr).id == self.id
        ]
        return rtrs

    def add_node(self, node: Node):
        if node not in self.nodes:
            self.nodes_list.append(node)
        self.nodes.add(node)


class RTRSubgraphMatcher(Matcher, ABC):

    def __init__(self) -> None:
        super().__init__()
        # matching live-time objects -- reset before each match()
        self.rtr_state: Optional[RTRState] = None
        self.subgraph: Optional[RTRSubgraph] = None
        self.dead_rtrs: Set[DeadRTR] = set()

        # pattern live-time objects -- reset on creation of a new RTRMatcher
        self.rtr_states: Dict[str, RTRState] = {}        # head.name -> RTRState
        self.rtr_states_tails: Dict[str, RTRState] = {}  # tail.name -> RTRState
        self.rtr_subgraphs: List[RTRSubgraph] = []
        self.rtr_subgraph_index: Dict[str, int] = {}     # rtr_head.name -> rtr_subgraph.id

        self._dbg_subgraphs = []
        return

    def _reset(self):
        super()._reset()
        self.rtr_state: Optional[RTRState] = None
        self.subgraph: Optional[RTRSubgraph] = None
        self.dead_rtrs: Set[DeadRTR] = set()
        return

    def _hard_reset(self):
        self._reset()
        self.rtr_states: Dict[str, RTRState] = {}
        self.rtr_states_tails: Dict[str, RTRState] = {}
        self.rtr_subgraphs: List[RTRSubgraph] = []
        self.rtr_subgraph_index: Dict[str, int] = {}
        return

    def _info(self, info_msg) -> None:
        logger.info(info_msg)
        return

    def _debug(self, dbg_msg) -> None:
        if self.subgraph and self.subgraph.id in self._dbg_subgraphs:
            logger.debug(dbg_msg)
        return

    def _is_debug(self) -> bool:
        return bool(self._dbg_subgraphs)
        # return True

    # ========================================================================= #
    # >>>>>              MATCHER :: MATCH()    UTILS                      <<<<< #
    # ========================================================================= #
    def _setup_from_rtr_state(
            self, rtr_state: RTRState, rtr_subgraph: Optional[RTRSubgraph] = None
    ) -> None:
        if not rtr_state:
            raise RTRStateError(
                "Can't setup RTRMatcher from a NULL RTRState object",
                rtr=rtr_state
            )

        if rtr_state.key in self.rtr_states:
            raise RTRMatchingError(f"{rtr_state} already processed", rtr=rtr_state)

        self.rtr_state = rtr_state
        self.rtr_states[rtr_state.key] = rtr_state
        self.rtr_states_tails[rtr_state.inverse_key] = rtr_state

        if not rtr_subgraph:
            self.subgraph = RTRSubgraph(len(self.rtr_subgraphs) + 1, _rtr_matcher=self)
            rtr_subgraph = self.subgraph
        else:
            self.subgraph = rtr_subgraph

        self.rtr_subgraphs.append(self.subgraph)
        self.rtr_subgraph_index[self.rtr_state.key] = self.subgraph.id

        self.rtr_state.get_Reshape_IN().set_attribute("RTR_subgraph_id", self.subgraph.id)
        self.rtr_state.get_Transpose().set_attribute("RTR_subgraph_id", self.subgraph.id)
        self.rtr_state.get_Reshape_OUT().set_attribute("RTR_subgraph_id", self.subgraph.id)
        return

    def match_RTR_subgraph(self, rtr_state: RTRState) -> None:
        head_node = rtr_state.get_Reshape_IN()
        self.subgraph.add_node(head_node)
        queue = deque([head_node])
        while queue:
            curr_node = queue.popleft()
            neighbor_nodes = set()

            for input_tensor in self.get_legal_activations(curr_node):
                writer = input_tensor.get_writer().require_node()
                if self.has_legal_inputs(writer, bulk=True) and self.is_legal_node(writer):
                    neighbor_nodes.add(writer)
                else:
                    self.subgraph.inputs.add(input_tensor)
                    writer.set_attribute(
                        "RTR_subgraph_root", self.subgraph.id
                    )

            for output_tensor in curr_node.get_act_outputs():
                writer = output_tensor.get_writer().require_node()
                if self.is_legal_tensor(output_tensor, bulk=True) and self.is_legal_node(writer):
                    for reader in output_tensor.get_readers():
                        neighbor_nodes.add(reader)
                else:
                    self.subgraph.outputs.add(output_tensor)
                    writer.set_attribute(
                        "RTR_subgraph_leaf", self.subgraph.id
                    )

            for neighbor in neighbor_nodes:
                if neighbor in self.subgraph.nodes:
                    continue
                self.add_node_to_subgraph(neighbor)
                queue.append(neighbor)
        pass

    def get_RTR_subgraph(self, rtr: RTRState) -> RTRSubgraph:
        # if isinstance(rtr_key, RTRState):
        rtr_key = rtr.key

        rtr_sub_id: int = self.rtr_subgraph_index.get(rtr_key, 0)
        if not rtr_sub_id:
            raise RTRMatchingError(
                f"RTRSubgraph(id={rtr_sub_id}) not found for RTRState({rtr_key})", 
                rtr=rtr
            )
        
        rtr_sub: RTRSubgraph = self.rtr_subgraphs[rtr_sub_id - 1]
        return rtr_sub

    @classmethod
    def match_live_rtr_upward(cls, node: Node, rtr_subgraph: RTRSubgraph) -> Optional[RTRState]:
        upward_rtr_exists = (
            node.check(fusedOpType.Reshape)
            and node("data").require_node().check(fusedOpType.Transpose)
            and node("data.data").require_node().check(fusedOpType.Reshape)
        )

        rtr: Optional[RTRState] = None
        if upward_rtr_exists:
            rtr = RTRState(
                head=node("data.data.data").require_tensor(),
                tail=node("reshaped").require_tensor(),
                shape_in=node("data.data.data").require_tensor().get_shape(),
                reshaped=node("data.data.reshaped").require_tensor().get_shape(),
                perm=node("data").require_node().get_attribute_value("perm"),
                shape_out=node("reshaped").require_tensor().get_shape(),
                _subgraph=rtr_subgraph,
            )
            rtr.check_channel_dim()
            rtr.check_bulk_dim()

        return rtr

    @classmethod
    def match_live_rtr_downward(cls, node: Node, rtr_subgraph: RTRSubgraph) -> Optional[RTRState]:
        downward_rtr_exists = (
            node.check(fusedOpType.Reshape)
            and node("reshaped").require_node().check(fusedOpType.Transpose)
            and node(f"reshaped.{TR_OUTPUT}").require_node().check(fusedOpType.Reshape)
        )

        rtr: Optional[RTRState] = None
        if downward_rtr_exists:
            rtr = RTRState(
                head=node("data").require_tensor(),
                tail=node(f"reshaped.{TR_OUTPUT}.reshaped").require_tensor(),
                shape_in=node("data").require_tensor().get_shape(),
                reshaped=node("reshaped").require_tensor().get_shape(),
                perm=node("reshaped").require_node().get_attribute_value("perm"),
                shape_out=node(f"reshaped.{TR_OUTPUT}.reshaped").require_tensor().get_shape(),
                _subgraph=rtr_subgraph,
            )
            rtr.check_channel_dim()
            rtr.check_bulk_dim()

        return rtr

    def match_live_rtr_state(self, node: Node, rtr_subgraph: RTRSubgraph) -> Optional[RTRState]:
        rtr: Optional[RTRState] = self.match_live_rtr_downward(node, rtr_subgraph)
        if not rtr:
            rtr = self.match_live_rtr_upward(node, rtr_subgraph)

        if not rtr:
            return None

        if self.rtr_state and not self.rtr_state.check_channel_dim(rtr.get_shape()):
            return None

        return rtr

    @classmethod
    def require_live_rtr_upward(
        cls, node: Node, rtr_subgraph: RTRSubgraph, ref_rtr: RTRState = None
    ) -> RTRState:
        rtr = cls.match_live_rtr_upward(node, rtr_subgraph)
        if not rtr:
            raise RTRMatchingError(
                f"Failed to match upward RTRState at node {node.get_name()}", rtr=None
            )
 
        if ref_rtr and not ref_rtr.check_channel_dim(rtr.get_shape()):
            raise RTRMatchingError(
                f"Incompatible channel dimensions with reference {ref_rtr}", rtr=rtr
            )
        
        return rtr

    @classmethod
    def require_live_rtr_downward(
        cls, node: Node, rtr_subgraph: RTRSubgraph, ref_rtr: RTRState = None
    ) -> RTRState:
        rtr = cls.match_live_rtr_downward(node, rtr_subgraph)
        if not rtr:
            raise RTRMatchingError(
                f"Failed to match downward RTRState at node {node.get_name()}", rtr=None
            )

        if ref_rtr and not ref_rtr.check_channel_dim(rtr.get_shape()):
            raise RTRMatchingError(
                f"Incompatible channel dimensions with reference {ref_rtr}", rtr=rtr
            )

        return rtr

    @classmethod
    def require_live_rtr_state(
        cls, node: Node, rtr_subgraph: RTRSubgraph, ref_rtr: RTRState = None
    ) -> RTRState:
        rtr: Optional[RTRState] = cls.match_live_rtr_downward(node, rtr_subgraph)
        if not rtr:
            rtr = cls.match_live_rtr_upward(node, rtr_subgraph)

        if not rtr:
            raise RTRMatchingError(
                f"Failed to match a live RTRState at node {node.get_name()}", rtr=None
            )

        if ref_rtr and not ref_rtr.check_channel_dim(rtr.get_shape()):
            raise RTRMatchingError(
                f"Incompatible channel dimensions with reference {ref_rtr}", rtr=rtr
            )
        return rtr

    def is_legal_tensor(self, tensor: Tensor, bulk=False) -> bool:
        """ legal tensor == ends with rtr channel """
        # TODO: check for GraphInput / GraphOutput
        if not tensor.get_shape():  # handles None or empty list
            return False

        # while no live RTR is matched -- all tensors are legal
        if not self.rtr_state: 
            return True

        tensor_shape = tensor.get_shape()
        channel_check = self.rtr_state.check_channel_dim(tensor_shape)

        if bulk:
            bulk_check = self.rtr_state.check_bulk_dim(tensor_shape)
            return channel_check and bulk_check

        # return channel_check and bulk_check -- too strict
        # return channel_check or bulk_check -- fails
        return channel_check

    @classmethod
    def is_illegal_node(cls, node: Node, **kwargs) -> bool:
        """
            illegal node op_types: [Pad, Pool ops, Slice ops, ...]
        """
        # Pad, Pad_qdq
        if node.check(opType.Pad) or node.check(fusedOpType.Pad_qdq):
            return kwargs.get("Pad", True)

        # Slice_qdq, slice_runtime,
        if node.check(fusedOpType.slice) or node.check(fusedOpType.Slice):
            return kwargs.get("Slice", True)

        # Pool, Pool_qdq: MaxPool, AveragePool, GlobalMaxPool, GlobalAveragePool
        if node.check(fusedOpType.Pool):
            return kwargs.get("Pool", True)

        return False

    @classmethod
    def is_legal_node(cls, node: Node, **kwargs) -> bool:
        return not cls.is_illegal_node(node, **kwargs)
    
    def has_legal_inputs(self, node: Node, bulk=False) -> bool:
        """ if ALL act inputs have rtr channel shape, return true """
        for input_tensor in node.get_act_inputs():
            if not self.is_legal_tensor(input_tensor, bulk=bulk):
                return False
        return True

    def has_legal_outputs(self, node: Node, bulk=False) -> bool:
        """ if ANY act output has rtr channel shape, return true"""
        for output_tensor in node.get_act_outputs():
            if not self.is_legal_tensor(output_tensor, bulk=bulk):
                return False
        return True

    def get_legal_activations(self, node: Node, bulk=False) -> set[InputTensor]:
        """ if ANY act input has rtr channel shape, return true """
        upward_tensor = set()
        for input_tensor in node.get_act_inputs():
            if self.is_legal_tensor(input_tensor, bulk=bulk):
                upward_tensor.add(input_tensor)
        return upward_tensor

    def _add_RTR_to_subgraph(self, rtr: RTRState, subgraph: RTRSubgraph) -> None:
        # skip missing and existing RTR-states and RTR-pointers
        if not rtr or rtr.key in self.rtr_states:
            return

        # if rtr_state.key not in self.rtr_states:
        self.rtr_states[rtr.key] = rtr
        self.rtr_states_tails[rtr.inverse_key] = rtr
        self.rtr_subgraph_index[rtr.key] = subgraph.id
        rtr.get_Reshape_IN().set_attribute("RTR_subgraph_id", subgraph.id)
        rtr.get_Transpose().set_attribute("RTR_subgraph_id", subgraph.id)
        rtr.get_Reshape_OUT().set_attribute("RTR_subgraph_id", subgraph.id)
        return

    def add_node_to_subgraph(self, node: Node) -> None:
        rtr_state: Optional[RTRState] = None # node.check(opType.Reshape_noop) and self.match_live_rtr_state(node)
        if node.check(fusedOpType.Reshape):
            rtr_state = self.match_live_rtr_state(node, self.subgraph)

        if rtr_state:
            # TODO if node is in rtr chain already, skip check rtr_chain_exists
            self._add_RTR_to_subgraph(rtr_state, self.subgraph)
            pass
            
        node.set_attribute("RTR_subgraph_id", self.subgraph.id)
        self.subgraph.add_node(node)
        return
    
    def get_rtr_suppliers(self, ptr: RTRState) -> list[RTRState]:
        if self.subgraph.is_subgraph_input(ptr.head):
            return list()

        suppliers: list[InputTensor] = []
        if rtr := self.rtr_states_tails.get(ptr.key):
            suppliers = [rtr.head]
        elif (wr := ptr.head.get_writer().require_node()) and wr in self.subgraph.nodes:
            suppliers = wr.get_act_inputs()
        elif self.subgraph.is_subgraph_input(ptr.head):
            suppliers = list()
        else:
            pass

        suppliers: list[RTRState] = [
            ptr.get_pointer(s) if ptr.head.check_input_tensor() else ptr.get_inverse_pointer(s)
            for s in suppliers
        ]

        return suppliers

    def get_rtr_consumers(self, ptr: RTRState) -> list[RTRState]:
        if self.subgraph.is_subgraph_output(ptr.head):
            return list()

        consumers: list[OutputTensor] = []
        readers = ptr.head.require_tensor().get_readers()
        for rd in readers:
            rd = rd.require_node()
            if rd not in self.subgraph.nodes:
                continue

            act = rd.get_act_inputs()[0]
            rtr = self.rtr_states.get(act.get_name())
            if rtr and rtr.key == ptr.key and rtr.get_Reshape_IN() == rd:
                consumers.append(rtr.tail)
            else:
                consumers.extend(rd.get_act_outputs())
            pass

        consumers: list[RTRState] = [
            ptr.get_pointer(c) if ptr.head.check_output_tensor() else ptr.get_inverse_pointer(c)
            for c in consumers
        ]

        return consumers
    
    
    # ========================================================================= #
    # >>>>>                 MATCHER :: MODIFY()    UTILS                  <<<<< #
    # ========================================================================= #
    def kill_live_rtr(self, ptr: RTRState, dead_rtrs: Set[DeadRTR]) -> tuple[Optional[DeadRTR], list[RTRState]]:
        killed_rtr: Optional[DeadRTR] = DeadRTR.NULL()
        live_rtrs: List[RTRState] = self.subgraph.get_live_RTRs()

        for live in live_rtrs:
            if ptr.is_inverse_RTR(live) and ptr.key in [live.inverse_key]:
                self._debug(f"  a) Cancelling RTR UP: \n  ^^ {ptr}\n  with {live}.")

                killed_rtr = DeadRTR(ptr=ptr, state=live)
                dead_rtrs.add(killed_rtr)
                return killed_rtr, []

            elif ptr.is_inverse_RTR(live) and ptr.key in [live.key]:
                self._debug(f"  b) Cancelling RTR DOWN:\n  vv {ptr} \n  with {live}.")

                killed_rtr = DeadRTR(ptr=ptr, state=live)
                dead_rtrs.add(killed_rtr)

                ptr_consumers = self.get_rtr_consumers(ptr)
                next_ptrs: list[RTRState] = []
                for c in ptr_consumers:
                    if c.key != live.inverse_key:
                        next_ptrs.append(c)

                return killed_rtr, next_ptrs

            else:
                pass

        return None, []

    def remove_rtr_state(self, rtr: RTRState) -> None:
        self.connect(rtr.tail, rtr.head)
        return

    def insert_rtr_pointer(self, ptr: RTRState) -> Node:
        # Produce & Insert RTRPointer node
        rtr_node: NodeDict = ptr._make_RTR_node_dict(rtr_matcher=self)
        rtr_node: Node = self.add_node(**rtr_node)

        # Reconnect ptr's readers to read from the new node
        rtr_input_name = ptr.get_Reshape_IN()["inputs"]["data"].get_name()
        rtr_output_name = ptr.get_Reshape_OUT()["outputs"]["reshaped"].get_name()
        ptr_head_readers = ptr.head.get_readers()
        ptr_head_origin = ptr.head.get_origin()

        for rd in ptr_head_readers:
            rd = rd.require_node()
            check_origin_nowhere = ptr_head_origin.check_nowhere()
            check_origin_reader = ptr_head_origin.get_name() == rd.get_name()
            if check_origin_nowhere or check_origin_reader:
                rd._model_dict.replace_input(
                    rd.get_name(), rtr_input_name, rtr_output_name,
                )

        return rtr_node

    def insert_rtr_state(self, ptr: RTRState) -> tuple[Node, Node, Node]:
        """ Insert RTRState nodes produced out from RTRPointer `ptr`"""
        
        # Produce & Insert RTR-nodes
        rtr_reshape_IN: NodeDict = ptr.get_Reshape_IN()
        rtr_reshape_IN: Node = self.add_node(**rtr_reshape_IN)

        rtr_transpose: NodeDict = ptr.get_Transpose( )
        rtr_transpose: Node = self.add_node(**rtr_transpose)

        rtr_reshape_OUT: NodeDict = ptr.get_Reshape_OUT()
        rtr_reshape_OUT: Node = self.add_node(**rtr_reshape_OUT)

        # Reconnect ptr's readers to read from the new RTR-state
        rtr_input_name = rtr_reshape_IN("data").get_name()
        rtr_output_name = rtr_reshape_OUT("reshaped").get_name()
        ptr_head_readers = ptr.head.get_readers()
        ptr_head_origin = ptr.head.get_origin()

        for rd in ptr_head_readers:
            rd = rd.require_node()
            check_origin_nowhere = ptr_head_origin.check_nowhere()
            check_origin_reader = ptr_head_origin.get_name() == rd.get_name()
            if check_origin_nowhere or check_origin_reader:
                rd._model_dict.replace_input(
                    rd.get_name(), rtr_input_name, rtr_output_name,
                )

        return rtr_reshape_IN, rtr_transpose, rtr_reshape_OUT

    pass # end of RTRSubgraphMatcher


class MHA_RTRCancellation(RTRSubgraphMatcher, ReshapeTransposeHelper):
    dependencies = [RTROptimize()]


    def __init__(self) -> None:
        super().__init__()
        self.subgraph_transposes: List[Node] = list()
        self.rtr_dims = None
        self.top_rtr: Optional[RTRState] = None
        self.channel_split_reshape = None
        self._dbg_subgraphs = [1]
    
    def _reset(self):
        super()._reset()
        self.subgraph_transposes: List[Node] = list()
        self.rtr_dims = None
        self.channel_split_reshape = None
        self.top_rtr = None

    @override
    def _info(self, info_msg) -> None:
        info_msg = f"[MHA_RTR {self.subgraph.id}] {info_msg}"
        super()._info(info_msg) 

    @override
    def _debug(self, dbg_msg) -> None:
        dbg_msg = f"[MHA_RTR {self.subgraph.id}] {dbg_msg}"
        logger.debug(dbg_msg)

    @classmethod
    def has_MMT_support(cls) -> bool:
        mmt_configs = MatMulTranspose.get_configs()
        if mmt_configs.MAX_RANK < 4:
            return False
        if not mmt_configs.ENABLE_INNERMOST_DIM_FUSION:
            return False
        if not mmt_configs.ENABLE_ACTxWGT_FUSION:
            return False
        return True

    @override
    def match(self) -> None:
        if not self.has_MMT_support():
            raise NoMatch(f"MatMulTranspose configs do not support MHA-RTR pattern.")

        n = self.n.with_walk_cfg(WalkCfgPlain())
        n = n.require(opType.Reshape | fusedOpType.Reshape).require_node()

        self.subgraph = RTRSubgraph(len(self.rtr_subgraphs) + 1, _rtr_matcher=self)

        # 1) Match Bottom-RTR at its Reshape-OUT node
        self.rtr_state: RTRState = self.require_live_rtr_upward(n, self.subgraph)
        RTRDimensions.require_shuffle_RTR(self.rtr_state) 
        RTRDimensions.require_bottom_RTR(self.rtr_state) 

        self._debug(f"=========  RTR-MHA-RTR  =========")
        self._debug(f"Matched Bottom-RTR: {self.rtr_state} ")
        
        self._setup_from_rtr_state(self.rtr_state)
        self.rtr_state.get_Reshape_OUT().set_attribute(
            "RTR_subgraph_leaf", self.subgraph.id
        )
        self.subgraph.outputs.add(self.rtr_state.tail)

        # 2) Match MM+Re+MM above bottom RTR to find channel split factor 'm' and padding factor 'p'
        head = self.rtr_state.head                                          # head --> Reshape_IN.data
        head = head.require_node().require(fusedOpType.MatMul)              # head --> MatMul
        self.subgraph.add_node(head.require_node())

        head = head('A').require_shape(self.rtr_state.shape_in)             # head --> MatMul.A
        head = head.require_node().require(fusedOpType.Reshape)             # head --> Reshape: [BC] -> [XYmc]
        self.channel_split_reshape = head.require_node()
        self.subgraph.add_node(head.require_node())
        
        # Getting channel splitting factor 'm'
        channel_split_shape = head('data').get_shape()
        is_valid_diff, in_mask, out_mask = self.reshape_diff(self.rtr_state.shape_in, channel_split_shape)
        if not is_valid_diff or len(in_mask[-1]) != 1:
            raise RTRMatchingError(
                f"Failed to match MHA-RTR channel-split factor (m) from Reshape({head.get_name()}) above Bottom-RTR {self.rtr_state}."
            )
        
        channel_split_factor = channel_split_shape[-2]
        self.rtr_dims = RTRDimensions.from_RTR(self.rtr_state, m=channel_split_factor)
        assert channel_split_shape == self.rtr_dims.get_shape('XYmc')
        self._debug(f"> Matched RTR-dimension 'm={self.rtr_dims.m}' @{head}: {self.rtr_dims}")

        # passing Transpose
        head = head('data').require_node().require(fusedOpType.Transpose)       # head --> Transpose
        head = head.require(AttrValue("perm", [0, 2, 1, 3]))                    # head --> Transpose(0,2,1,3) -> [XYmc]
        self.subgraph.add_node(head.require_node())


        # Getting padding factor 'p'
        head = head('data').require_shape(self.rtr_dims.get_shape("XmYc"))      # head --> [XmYc]
        head = head.require_node().require(fusedOpType.MatMul)                  # head --> MatMul
        head = head.require(AttrValue("actxact", 1))                            # head --> MatMul_actxact: [XmYZ]*[XmZc] -> [X]
        # self.subgraph.nodes.add(head.require_node())

        padded_input_shape_A = head('A').get_shape()
        padding_factor = int((padded_input_shape_A[-1] // self.rtr_dims.Y) ** 0.5)
        self.rtr_dims.p = padding_factor

        self._debug(f"> Matched RTR-dimension 'Z={self.rtr_dims.Z}' @{head}: {self.rtr_dims}")
        self._debug(f"\t\t MHA(XmYZ{self.rtr_dims['XmYZ']} x XmZc{self.rtr_dims['XmZc']}) --> RTR(XmYc{self.rtr_dims['XmYc']}) --> Bottom-RTR(XYmc{self.rtr_dims['XYmc']} ->{self.rtr_dims['xxyyC']}->{self.rtr_dims['xyxyC']}) --> {self.rtr_dims['1WWC']}")
        self._debug(f"\t\t Bottom-RTR[XYmc] = {self.rtr_dims['XYmc']}")
        assert head('A').get_shape() == self.rtr_dims.get_shape('XmYZ'), f"shape mismatch {head('A')}: {head('A').get_shape()} != XmYZ == {self.rtr_dims.get_shape('XmYZ')}"
        assert head('B').get_shape() == self.rtr_dims.get_shape('XmZc'), f"shape mismatch {head('B')}: {head('B').get_shape()} != XmZc == {self.rtr_dims.get_shape('XmZc')}"

        # head --> Reshape: [BC] -> [XYmc]
        # head(data) --> MatMul 
        head_ptr = self.rtr_state.get_pointer(head('Y'))
        try:
            self.match_RTR_subgraph(head_ptr)
        except Exception as e:
            logger.error(f"Error matching RTR-subgraph: {e}")

        self._debug(f"RTRSubgraph={repr(self.subgraph)}")
        self._debug("-" * 100)
        return
    
    @override
    def match_RTR_subgraph(self, ptr: RTRState) -> None:
        head_node = ptr.head.require_node()
        self.subgraph.add_node(head_node)
        queue = deque([head_node])
        while queue:
            node = queue.popleft()
            neighbor_nodes = set()

            if node.get_name() == "/image_encoder/trunk/blocks.3/Reshape_1":
                print(f"{node}")


            if node.check(fusedOpType.Reshape):
                rtr = self.match_live_rtr_upward(node, self.subgraph)
                if rtr and self.rtr_state.is_inverse_RTR(rtr):
                    if not self.top_rtr:
                        self._add_RTR_to_subgraph(rtr, self.subgraph)
                        self.subgraph.inputs.add(rtr.head)
                        rtr.get_Reshape_IN().set_attribute(
                            "RTR_subgraph_root", self.subgraph.id
                        )
                        self.top_rtr = rtr
                        self._debug(f"Matched Top-RTR: {self.top_rtr} ")
                    continue

            node.set_attribute("RTR_subgraph_id", self.subgraph.id)
            self.subgraph.add_node(node)
            
            if node.check(fusedOpType.Transpose):
                self.subgraph_transposes.append(node)

            for input_tensor in self.get_legal_activations(node):
                writer = input_tensor.get_writer().require_node()
                if self.has_legal_inputs(writer) and self.is_legal_node(writer, Pool=False):
                    neighbor_nodes.add(writer)
                elif writer.check(fusedOpType.MatMul) and self.rtr_dims.k == 1:
                    self.rtr_dims.k = self.rtr_dims.C // get_channel_dim(writer("A").get_shape())
                    self._debug(f"> Matched RTR-dimension 'D={self.rtr_dims.D}' @{writer}: {self.rtr_dims}")
                    self._debug(f"> \t '@.A'={writer('A').get_shape()}")
                    neighbor_nodes.add(writer)
                else:
                    self.subgraph.inputs.add(input_tensor)
                    writer.set_attribute(
                        "RTR_subgraph_root", self.subgraph.id
                    )

            for neighbor in neighbor_nodes:
                if neighbor in self.subgraph.nodes:
                    continue
                queue.append(neighbor)
        return

    @override
    def is_legal_tensor(self, tensor: Tensor, bulk=False) -> bool:
        """ legal tensor == ends with rtr channel """
        # TODO: check for GraphInput / GraphOutput
        if not tensor.get_shape():  # handles None or empty list
            return False

        # while no live RTR is matched -- all tensors are legal
        if not self.rtr_state: 
            return True

        if self.top_rtr and (tensor == self.top_rtr.tail or tensor == self.top_rtr.head ):
            return False

        tensor_shape = tensor.get_shape()
        channel_check = self.rtr_dims.check_channel_dim(tensor_shape)

        if bulk:
            bulk_check = self.rtr_state.check_bulk_dim(tensor_shape)
            return channel_check and bulk_check

        return channel_check

    def get_origin_rtr(self) -> str:
        if self.rtr_state:
            return self.rtr_state.get_Reshape_OUT().get_name()
        return None

    @override
    def modify(self) -> None:
        live_rtrs = self.subgraph.get_live_RTRs()
        self._debug(f"### {self.subgraph} has {len(live_rtrs)} live RTRs before cancellation.")
        if len(live_rtrs) != 2:
            return

        # Remove Top-RTR, replace with modified Top-RTR.Reshape_IN
        self._debug(f"Modifying Top-RTR...")
        self._modify_top_RTR()

        # Update nodes in rtr_subgraph
        for node in self.subgraph.nodes_list:
            if not self._get_model_dict().has_node(node.get_name()):
                self._debug(f"\t ... skipping node {node}")
                continue

            self._debug(f"Modifying {node} in RTR-subgraph {self.subgraph.id}...")

            # Set RTR attributes
            node.set_attribute("RTR_origin", self.get_origin_rtr())
            A_bulk = (self.rtr_dims.A, self.rtr_dims.A * self.rtr_dims.m)
            B_bulk = (self.rtr_dims.B, self.rtr_dims.B * self.rtr_dims.m)
            if get_bulk_dim(node.get_outputs()[0].get_shape()) in A_bulk:
                node.setdefault_attribute("RTR_shape", str(self.rtr_dims["xzxzmc"]))
            elif get_bulk_dim(node.get_outputs()[0].get_shape()) in B_bulk:
                node.setdefault_attribute("RTR_shape", str(self.rtr_dims["xyxymc"]))
            else:
                node.setdefault_attribute("RTR_shape", "NULL")

            # Modify:  Reshape + Squeeze 
            if node.check(fusedOpType.Squeeze):
                node("data.data").set_shape(
                    self.rtr_dims["xZxC"], node("data.data").require_tensor().get_dtype()
                )
                self.connect(
                    node("squeezed").require_tensor(), node("data.data").require_tensor()
                )
                pass
            
            # Modify:  Reshape + MaxPool 
            if node.check(fusedOpType.MaxPool):
                self._modify_pool(node)
                pass
            
            # Modify:  pre-MHA Transposes  &  post-MHA Transpose + Reshape
            if node.check(fusedOpType.Transpose):
                is_XZmc = self.rtr_dims.check_shape("XZmc", node("data").get_shape())
                is_XmZc = self.rtr_dims.check_shape("XmZc", node("data").get_shape())
                is_XYmc = self.rtr_dims.check_shape("XYmc", node("data").get_shape())
                is_XmYc = self.rtr_dims.check_shape("XmYc", node("data").get_shape())
                is_AC = self.rtr_dims.check_shape("AC", node("data").get_shape())

                # case 1 @ MatMul.B
                if node.check(AttrValue("perm", [0, 2, 3, 1])) and (is_XZmc or is_AC):
                    node.set_attribute("RTR_perm", (0, 2, 4, 5, 1, 3))
                    node.set_attribute("RTR_transposed", self.rtr_dims["xxmczz"])
                    self._modify_transpose_1(node)

                # case 2 @ MatMul_1.B
                elif node.check(AttrValue("perm", [0, 2, 1, 3])) and (is_XZmc or is_AC):
                    node.set_attribute("RTR_perm", [0, 2, 4, 1, 3, 5])
                    node.set_attribute("RTR_transposed", self.rtr_dims["xxmzzc"])
                    self._modify_transpose_2(node)

                # case 3 @ MatMul.A
                elif node.check(AttrValue("perm", [0, 2, 1, 3])) and is_XYmc:
                    self._debug(f"\t ... modify transpose #3 @ {node}")
                    node.set_attribute("RTR_perm", [0, 1, 4, 2, 3, 5])
                    node.set_attribute("RTR_shape", str(self.rtr_dims["xxyymc"]))
                    node.set_attribute("RTR_transposed", self.rtr_dims["xxmyyc"])
                    
                    re = node("data").require_node()
                    re.set_attribute("RTR_shape", str(self.rtr_dims["xxyymc"]))

                # case 4 @ MatMul_1.out
                elif node.check(AttrValue("perm", [0, 2, 1, 3])) and is_XmYc:
                    node.set_attribute("RTR_perm", [0, 3, 1, 4, 2, 5])
                    node.set_attribute("RTR_transposed", self.rtr_dims["xyxymc"])
                    self._modify_transpose_post_MHA(node)

                else:
                    pass

        # Remove Bottom-RTR, replace with modified Bottom-RTR.Reshape_OUT
        self._debug(f"Modifying Bottom-RTR...")
        self._modify_bottom_RTR()

        # self.hook.debug_hook("MHA-RTR")
        logger.debug((f"=" * 120) + "\n")
        return
    
    def _modify_top_RTR(self) -> None:
        rtr_name = self.top_rtr.get_Reshape_IN().get_name()
        top_transpose = self.top_rtr.get_Transpose()
        top_transpose_name = self.top_rtr.get_Transpose().get_name()
        top_reshape_OUT = self.top_rtr.get_Reshape_OUT()
        top_transpose_attributes = top_transpose.get_attributes()
        top_reshape_OUT_attributes = top_reshape_OUT.get_attributes()
        tr_output_dtype = top_reshape_OUT("reshaped").require_tensor().get_dtype()

        top_Reshape_IN: NodeDict = self.top_rtr.get_pointer()._make_Reshape_IN_dict()
        top_rtr_shape_ini: Initializer = self.top_rtr.get_pointer()._make_RTR_shape_ini(
            # self, shape=self.rtr_dims["BC"]
            self, shape=self.rtr_dims["VxzD"]
        )
        top_Reshape_IN["inputs"] = {
            "data": self.top_rtr.head, 
            "shape": top_rtr_shape_ini
        }
        reshape_1_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{rtr_name}_TOP_RTR_Reshape_IN_output",
            origin_node_name=None,
        )
        reshape_1_output.set_shape(self.rtr_dims["VxzD"], tr_output_dtype)
        top_Reshape_IN['outputs']['reshaped'] = reshape_1_output
        # top_Reshape_IN['outputs']['reshaped'] = self.top_rtr.tail
        top_Reshape_IN['attributes'] |= {
            "RTR_shape": self.rtr_dims["xyxymc"],
            "RTR_origin": self.get_origin_rtr(),
        }

        # remove TOP-RTR
        self.remove_node(self.top_rtr.get_Reshape_IN())
        self.remove_node(self.top_rtr.get_Transpose())
        self.remove_node(self.top_rtr.get_Reshape_OUT())

        # insert optimized chain
        top_Reshape_IN: Node = self.add_node(**top_Reshape_IN)

        transpose_1_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{rtr_name}_TOP_RTR_Transpose_output",
            origin_node_name=None,
        )
        transpose_1_output.set_shape(self.rtr_dims["VzxD"], tr_output_dtype)

        top_Transpose: Node = self.add_node(
            new_name=f"{rtr_name}_TOP_Transpose",
            type="Transpose",
            domain="ai.onnx.contrib",
            inputs= {"data": reshape_1_output},
            outputs={"transposed": transpose_1_output},
            attributes=top_transpose_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xzxzD"]),
                "RTR_transposed": self.rtr_dims["xzzxD"],
                "RTR_perm": [0, 1, 3, 2, 4, 5],
                "perm": [0, 2, 1, 3],
                "orig_name": f"{top_transpose_name}",
            },
        )

        reshape_2_ini = self.add_initializer(
            f"{rtr_name}_TOP_RTR_Reshape_OUT_shape",
            np.array(self.rtr_dims["xZxD"], dtype=np.int32),
        )
        self.top_rtr.tail.set_shape(self.rtr_dims["xZxD"], tr_output_dtype)

        top_Reshape_OUT: Node = self.add_node(
            new_name=f"{rtr_name}_TOP_RTR_Reshape_OUT",
            type="Reshape",
            domain="ai.onnx.contrib",
            inputs= {"data": transpose_1_output, "shape": reshape_2_ini},
            outputs={"reshaped": self.top_rtr.tail},
            attributes=top_reshape_OUT_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xzzxD"]),
            },
        )
        return

    def _modify_bottom_RTR(self) -> None:
        rtr_name = self.rtr_state.get_Reshape_OUT().get_name()
        rtr_input = self.rtr_state.head
        tr_output = self.rtr_state.tail
        rtr_input_dtype = rtr_input.require_tensor().get_dtype()
        tr_output_dtype = tr_output.require_tensor().get_dtype()
        transpose_attributes = self.rtr_state.get_Transpose().get_attributes()

        rtr_input.set_shape(self.rtr_dims["WyxC"], rtr_input_dtype)
        
        # Bottom Reshape_OUT   : [WxyC] --(Re)-> [1WWC]
        bottom_Reshape_OUT: NodeDict = self.rtr_state.get_pointer()._make_Reshape_OUT_dict()
        bottom_Reshape_OUT['attributes'] |= {
            "RTR_shape": self.rtr_dims["xyxymc"],
            "RTR_origin": self.get_origin_rtr(),
        }

        self.remove_node(self.rtr_state.get_Reshape_IN())
        self.remove_node(self.rtr_state.get_Transpose())
        self.remove_node(self.rtr_state.get_Reshape_OUT())

        # Transpose_3   : [WyxC] --(Tr:0213)-> [WxyC]
        transpose_3_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{rtr_name}_RTR_Transpose_3_output",
            origin_node_name=None,
        )
        transpose_3_output.set_shape(self.rtr_dims["WxyC"], tr_output_dtype)

        transpose_3: Node = self.add_node(
            new_name=f"{rtr_name}_RTR_Transpose_3",
            type="Transpose",
            domain="ai.onnx.contrib",
            inputs= {"data": rtr_input},
            outputs={"transposed": transpose_3_output},
            # outputs={"transposed": tr_output},
            attributes=transpose_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": rtr_name,
                "RTR_shape": str(self.rtr_dims["xyyxmc"]),
                "RTR_transposed": self.rtr_dims["xyxymc"],
                "RTR_perm": [0, 1, 3, 2, 4, 5],
                "perm": [0, 2, 1, 3],  
                "orig_name": f"{rtr_name}_RTR_Transpose_3",
            },
        )

        # Bottom_RTR Reshape  : [WxyC] --(Re)-> [1WWC]
        bottom_Reshape_OUT["inputs"] |= {
            "data": transpose_3_output,
        }
        bottom_Reshape_OUT['outputs']['reshaped'] = tr_output
        bottom_Reshape_OUT: Node = self.add_node(**bottom_Reshape_OUT)
        
        return

    def _modify_pool(self, pool: Node) -> None:
        """
        Modify a Pool  node (e.g. MaxPool) to match the RTR dimensions.
            orig: [AC] -> Re -> [XzzC] ->  Pool -> [XyyC]
            mod:  [xZxC] -> Tr -> Re -> [XzzC] ->  Pool -> [XyyC]
        replaced with
            [xZxC] --(Tr:0213)-> [xxZC] --(Re)-> [XzzC] --(Pool)-> [XyyC]
        """
        # TODO: use category check for Pool matcher
        if not pool.check(fusedOpType.Pool):
            return
        
        pool_name = pool.get_name()
        pool_input = pool("X")
        is_XzzC = self.rtr_dims.check_shape("XzzC", pool_input.get_shape())
        if not is_XzzC:
            return
        
        self._debug(f"\t ... modify MaxPool #1: {pool}")
        
        reshape = pool("X").require_node().require(fusedOpType.Reshape)
        reshape_input = reshape("data").require_tensor()
        reshape_input_dtype = reshape_input.get_dtype()
        reshape_attributes = reshape.get_attributes() | {
            "RTR_subgraph_id": self.subgraph.id,
            "RTR_origin": self.get_origin_rtr(),
        }

        # remove existing node
        self.remove_node(reshape)

        # Transpose_2   : [xZxC] --(Tr:0213)-> [xxZC]
        reshape_input.set_shape(self.rtr_dims["xZxC"], reshape_input_dtype)
        transpose_2_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{pool_name}_RTR_Transpose_output",
            origin_node_name=None,
        )
        transpose_2_output.set_shape(self.rtr_dims["xxZC"], reshape_input_dtype)

        transpose_2: Node = self.add_node(
            new_name=f"{pool_name}_RTR_Transpose",
            type="Transpose",
            domain="ai.onnx.contrib",
            inputs= {"data": reshape_input},
            outputs={"transposed": transpose_2_output},
            attributes={
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xzzxmc"]),
                "RTR_transposed": self.rtr_dims["xxzzmc"],
                "RTR_perm": [0, 3, 1, 2, 4, 5],
                "perm": [0, 2, 1, 3],
                "orig_name": f"{pool_name}_RTR_Transpose",
            },
        )
        
        # Reshape_3     : [xxZC] --(Re)-> [XzzC]
        reshape_3_ini = self.add_initializer(
            f"{pool_name}_RTR_Reshape_shape",
            np.array(self.rtr_dims["XzzC"], dtype=np.int32),
        )

        reshape_3: Node = self.add_node(
            new_name=f"{pool_name}_RTR_Reshape",
            type="Reshape",
            domain="ai.onnx.contrib",
            inputs= {"data": transpose_2_output, "shape": reshape_3_ini},
            outputs={"reshaped": pool_input},
            attributes=reshape_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xxzzmc"]),
            },
        )

        pool.set_attribute("RTR_shape", str(self.rtr_dims["xxyymc"]))
        return


    def _modify_transpose_1(self, transpose: Node) -> None:
        """
        Modify a Transpose node to match the RTR dimensions.
            orig: [XZmc] -> Tr(0,2,3,1) -> [XmcZ]
            mod:  [xZxC] -> Tr -> Re -> [XmcZ]
        replaced with
            [xZxC] --(Tr:0231)-> [xxCZ] --(Re)-> [XmcZ]
        """
        if not transpose.check(fusedOpType.Transpose):
            return
        
        if not transpose.check(AttrValue("perm", [0, 2, 3, 1])):
            return

        tr_name = transpose.get_name()
        tr_input = transpose("data")
        is_XZmc = self.rtr_dims.check_shape("XZmc", tr_input.get_shape())
        is_AC = self.rtr_dims.check_shape("AC", tr_input.get_shape())
        if not is_XZmc and not is_AC:
            return

        self._debug(f"\t ... modify Transpose #1: @{transpose}")
        tr_input_dtype = transpose("data").require_tensor().get_dtype()
        tr_output = transpose(TR_OUTPUT)
        tr_output_dtype = tr_output.require_tensor().get_dtype()
        reshape_attributes = self.rtr_state.get_Reshape_IN().get_attributes()
        transpose_attributes = transpose.get_attributes()

        # remove existing node
        self.remove_node(transpose)

        # Transpose_2   : [xZxC] --(Tr:0231)-> [xxCZ]
        transpose_2_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{tr_name}_RTR_Transpose_2_output",
            origin_node_name=None,
        )
        transpose_2_output.set_shape(self.rtr_dims["xxCZ"], tr_output_dtype)

        transpose_2: Node = self.add_node(
            new_name=f"{tr_name}_RTR_Transpose_2",
            type="Transpose",
            domain="ai.onnx.contrib",
            inputs= {"data": tr_input},
            outputs={"transposed": transpose_2_output},
            attributes=transpose_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xzzxmc"]),
                "RTR_transposed": self.rtr_dims["xxmczz"],
                "RTR_perm": [0, 3, 4, 5, 1, 2],
                "perm": [0, 2, 3, 1],
                "orig_name": f"{tr_name}_RTR_Transpose_2",
            },
        )
        
        # Reshape_3     : [xxCZ] --(Re)-> [XmcZ]
        reshape_3_ini = self.add_initializer(
            f"{tr_name}_RTR_Reshape_3_shape",
            np.array(self.rtr_dims["XmcZ"], dtype=np.int32),
        )
        tr_output.set_shape(self.rtr_dims["XmcZ"], tr_output_dtype)

        reshape_3: Node = self.add_node(
            new_name=f"{tr_name}_RTR_Reshape_3",
            type="Reshape",
            domain="ai.onnx.contrib",
            inputs= {"data": transpose_2_output, "shape": reshape_3_ini},
            outputs={"reshaped": tr_output},
            attributes=reshape_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xxmczz"]),
            },
        )

        return

    def _modify_transpose_2(self, transpose: Node) -> None:
        """
        Modify a Transpose node to match the RTR dimensions.
            orig: [XZmc] -> Tr(0,2,1,3) -> [XmZc]
            mod:  [xZxC] -> Tr -> Re -> Tr -> [XmZc]
            mod:  
        replaced with
            [xZxC] --(Tr:0213)-> [xxZC] --(Re)-> [XZmc] --(Tr:0213)-> [XmZc]
        """
        if not transpose.check(fusedOpType.Transpose):
            return
        
        if not transpose.check(AttrValue("perm", [0, 2, 1, 3])):
            return

        tr_name = transpose.get_name()
        tr_input = transpose("data")
        is_XZmc = self.rtr_dims.check_shape("XZmc", tr_input.get_shape())
        is_AC = self.rtr_dims.check_shape("AC", tr_input.get_shape())
        if not is_XZmc and not is_AC:
            return

        self._debug(f"\t ... modify Transpose #2 @ {transpose}")
        tr_output = transpose(TR_OUTPUT)
        tr_output_dtype = tr_output.require_tensor().get_dtype()
        reshape_attributes = self.rtr_state.get_Reshape_IN().get_attributes()
        transpose_attributes = transpose.get_attributes()
        
        # remove existing node
        self.remove_node(transpose)

        # Transpose_2   : [xZxC] --(Tr:0213)-> [xxZC]
        transpose_2_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{tr_name}_RTR_Transpose_2_output",
            origin_node_name=None,
        )
        transpose_2_output.set_shape(self.rtr_dims["xxZC"], tr_output_dtype)

        transpose_2: Node = self.add_node(
            new_name=f"{tr_name}_RTR_Transpose_2",
            type="Transpose",
            domain="ai.onnx.contrib",
            # inputs= {"data": reshape_2_output},
            inputs= {"data": tr_input},
            outputs={"transposed": transpose_2_output},
            attributes=transpose_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xzzxmc"]),
                "RTR_transposed": self.rtr_dims["xxzzmc"],
                "RTR_perm": [0, 3, 1, 2, 4, 5],
                "perm": [0, 2, 1, 3],
                "orig_name": f"{tr_name}_RTR_Transpose_2",
            },
        )
        
        # Reshape_3     : [xxZC] --(Re)-> [XZmc]
        reshape_3_ini = self.add_initializer(
            f"{tr_name}_RTR_Reshape_3_shape",
            np.array(self.rtr_dims["XZmc"], dtype=np.int32),
        )
        reshape_3_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{tr_name}_RTR_Reshape_3_output",
            origin_node_name=None,
        )
        reshape_3_output.set_shape(self.rtr_dims["XZmc"], tr_output_dtype)

        reshape_3: Node = self.add_node(
            new_name=f"{tr_name}_RTR_Reshape_3",
            type="Reshape",
            domain="ai.onnx.contrib",
            inputs= {"data": transpose_2_output, "shape": reshape_3_ini},
            outputs={"reshaped": reshape_3_output},
            attributes=reshape_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xxzzmc"]),
            },
        )

        # Transpose_3   : [XZmc] --(Tr:0213)-> [XmZc]
        tr_output.set_shape(self.rtr_dims["XmZc"], tr_output_dtype)
        transpose_3: Node = self.add_node(
            new_name=f"{tr_name}_RTR_Transpose_3",
            type="Transpose",
            domain="ai.onnx.contrib",
            inputs= {"data": reshape_3_output},
            outputs={"transposed": tr_output},
            attributes=transpose_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xxzzmc"]),
                "RTR_transposed": self.rtr_dims["xxmzzc"],
                "RTR_perm": [0, 1, 4, 2, 3, 5],
                "perm": [0, 2, 1, 3],
                "orig_name": f"{tr_name}_RTR_Transpose_3"
            },
        )
        return

    def _modify_transpose_post_MHA(self, transpose: Node) -> None:
        """
        Modify a Post-MHA Transpose node to match the RTR dimensions.
            [XmYc] -> Tr(0,2,1,3) -> [XYmc]
        replaced with
            [XmYc] --(Tr:0213)-> [XYmc] --(Re)-> [xxYC] --(Tr:0213)-> [xYxC] --(Re)-> [WyxC] --(Tr:0213)-> [WxyC] --(Re)-> [WWmc]
        """
        if not transpose.check(fusedOpType.Transpose):
            return
        
        if not transpose.check(AttrValue("perm", [0, 2, 1, 3])):
            return

        tr_name = transpose.get_name()
        tr_input = transpose("data")
        is_XmYc = self.rtr_dims.check_shape("XmYc", tr_input.get_shape())
        if not is_XmYc:
            return
        if not tr_input.require_node().check(fusedOpType.MatMul):
            return

        self._debug(f"\t ... modify transpose #4 (Post-MHA) @ {transpose}")
        reshape = transpose(TR_OUTPUT).require_node().require(fusedOpType.Reshape)
        tr_output = reshape("reshaped")
        tr_output_dtype = tr_output.require_tensor().get_dtype()
        reshape_attributes = reshape.get_attributes()
        transpose_attributes = transpose.get_attributes()
        
        # remove existing nodes
        self.remove_node(transpose)
        self.remove_node(reshape)

        # Transpose_1   : [XmYc] --(Tr:0213)-> [XYmc]
        transpose_1_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{tr_name}_RTR_Transpose_1_output",
            origin_node_name=None,
        )
        transpose_1_output.set_shape(self.rtr_dims["XYmc"], tr_output_dtype)

        transpose_1: Node = self.add_node(
            new_name=f"{tr_name}_RTR_Transpose_1",
            type="Transpose",
            domain="ai.onnx.contrib",
            inputs= {"data": tr_input},
            outputs={"transposed": transpose_1_output},
            attributes=transpose_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xxmyyc"]),
                "RTR_transposed": self.rtr_dims["xxyymc"],
                "RTR_perm": [0, 1, 3, 4, 2, 5],
                "perm": [0, 2, 1, 3],
                "orig_name": f"{tr_name}_RTR_Transpose_1"
            },
        )

        # Reshape_1     : [XYmc] --(Re)-> [xxYC]
        reshape_1_ini = self.add_initializer(
            f"{tr_name}_RTR_Reshape_1_shape",
            np.array(self.rtr_dims["xxYC"], dtype=np.int32),
        )
        reshape_1_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{tr_name}_RTR_Reshape_1_output",
            origin_node_name=None,
        )
        reshape_1_output.set_shape(self.rtr_dims["xxYC"], tr_output_dtype)

        reshape_1: Node = self.add_node(
            new_name=f"{tr_name}_RTR_Reshape_1",
            # type=self.rtr_state.get_Reshape_IN().get_op_type(),
            type="Reshape",
            domain="ai.onnx.contrib",
            inputs= {"data": transpose_1_output, "shape": reshape_1_ini},
            outputs={"reshaped": reshape_1_output},
            attributes=reshape_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xxyymc"]),
            },
        )

        # ------------------------------------------------------------------------------ #
        # Transpose_2   : [xxYC] --(Tr:0213)-> [xYxC]
        transpose_2_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{tr_name}_RTR_Transpose_2_output",
            origin_node_name=None,
        )
        transpose_2_output.set_shape(self.rtr_dims["xYxC"], tr_output_dtype)

        transpose_2: Node = self.add_node(
            new_name=f"{tr_name}_RTR_Transpose_2",
            type="Transpose",
            domain="ai.onnx.contrib",
            inputs= {"data": reshape_1_output},
            outputs={"transposed": transpose_2_output},
            attributes=transpose_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xxyymc"]),
                "RTR_transposed": self.rtr_dims["xyyxmc"],
                "RTR_perm": [0, 2, 3, 1, 4, 5],
                "perm": [0, 2, 1, 3],
                "orig_name": f"{tr_name}_RTR_Transpose_2"
            },
        )

        # Reshape_2     : [xYxC] --(Re)-> [WyxC]
        reshape_2_ini = self.add_initializer(
            f"{tr_name}_RTR_Reshape_2_shape",
            np.array(self.rtr_dims["WyxC"], dtype=np.int32),
        )
        reshape_2_output = Tensor(
            self.rtr_state.head._model_dict,
            self.rtr_state.head._walk_cfg,
            f"{tr_name}_RTR_Reshape_2_output",
            origin_node_name=None,
        )
        reshape_2_output.set_shape(self.rtr_dims["WyxC"], tr_output_dtype)
        tr_output.set_shape(self.rtr_dims["WyxC"], tr_output_dtype)

        reshape_2: Node = self.add_node(
            new_name=f"{tr_name}_RTR_Reshape_2",
            # type=self.rtr_state.get_Reshape_IN().get_op_type(),
            type="Reshape",
            domain="ai.onnx.contrib",
            inputs= {"data": transpose_2_output, "shape": reshape_2_ini},
            outputs={"reshaped": tr_output},
            attributes=reshape_attributes | {
                "RTR_subgraph_id": self.subgraph.id,
                "RTR_origin": self.get_origin_rtr(),
                "RTR_shape": str(self.rtr_dims["xyyxmc"]),
            },
        )
        return


class RTRStateMatcher(RTRSubgraphMatcher):
    def __init__(self) -> None:
        super().__init__()
        self._dbg_subgraphs = [1]
    
    def _reset(self):
        super()._reset()

    @override
    def _info(self, info_msg) -> None:
        info_msg = f"[RTR_PTR] {info_msg}"
        super()._info(info_msg) 

    @override
    def _debug(self, dbg_msg) -> None:
        dbg_msg = f"[RTR_PTR] {dbg_msg}"
        super()._debug(dbg_msg)

    @override
    def match(self) -> None:
        n = self.n.with_walk_cfg(WalkCfgPlain())
        n = n.require(opType.Reshape | fusedOpType.Reshape).require_node()

        self.subgraph = RTRSubgraph(len(self.rtr_subgraphs) + 1, _rtr_matcher=self)
        rtr_state: RTRState = self.require_live_rtr_state(n, self.subgraph)
        RTRDimensions.require_shuffle_RTR(rtr_state) 
        self._setup_from_rtr_state(rtr_state)

        self._info(f"=========  RTR-States Matcher  =========")
        self._info(f"Matched Node = {n}")
        self._info(f"RTRState={self.rtr_state}")

        logger.info("\n" + ("=" * 100))
        return

    @override
    def modify(self) -> None:
        rtr_node: NodeDict = self.rtr_state.get_pointer()._make_RTR_node_dict(rtr_matcher=self)
        rtr_node['inputs']['data'] = self.rtr_state.head
        rtr_node['outputs']['Y'] = self.rtr_state.tail
        self.remove_node(self.rtr_state.get_Reshape_IN())
        self.remove_node(self.rtr_state.get_Transpose())
        self.remove_node(self.rtr_state.get_Reshape_OUT())
        rtr_node: Node = self.add_node(**rtr_node)
        
    pass



class TransposeCounter(Matcher):

    def __init__(self, label: str) -> None:
        super().__init__()
        self.label = label
        self.count = 0

    def match(self) -> None:
        n = self.n.with_walk_cfg(WalkCfgPlain())
        n = n.require(opType.Transpose | fusedOpType.Transpose).require_node()
        
        if n.check(fusedOpType.noop):
            raise NoMatch(f"Node({n.get_name()}) is 'noop' and will be skipped.")

        if len(n('data').require_tensor().get_shape()) <= 4:
            self.count += 1
        pass

    def modify(self) -> None:
        pass

    def run(self, model_dict: ModelDict, walk_cfg: "WalkCfgBase", runner: SafeRunner, logger: Logger) -> int:
        res = super().run(model_dict, walk_cfg, runner, logger)
        logger.info(f"\n> TransposeCounter[{self.label}].count = {self.count}  | {res}")
        return res



class TransposeCounterBefore(TransposeCounter):
    def __init__(self) -> None:
        super().__init__("BEFORE")
        self.count = 0


class TransposeCounterAfter(TransposeCounter):
    def __init__(self) -> None:
        super().__init__("AFTER")
        self.count = 0

class TransposeCounterFinal(TransposeCounter):
    def __init__(self) -> None:
        super().__init__("FINAL")
        self.count = 0

        