import os
import sys
import argparse
import copy
import json
import logging
import traceback
from enum import Enum, auto, IntEnum
from typing import TypeAlias, Literal, NamedTuple, Iterable, Optional, ClassVar, Deque, Any
from collections import Counter, defaultdict, deque
from math import prod
from dataclasses import dataclass, field

import numpy as np
import onnx
from onnx import (
    numpy_helper,
    TensorProto,
    NodeProto,
    ValueInfoProto,
)
from onnx.helper import (
    make_node,
    make_tensor_value_info,
    tensor_dtype_to_np_dtype,
    make_attribute,
)

from OGOAT.src.L1_fusion.L1_utils.utils import (
    model_dict,
    get_attrs,
    set_attribute,
    get_attribute,
    TensorInfo,
    check_binary_shapes,
    ShapeMismatchError,
    onnxTensorProto_to_array,
    onnxTensorProto_from_array,
    save_model,
    collect_unused_ini_nodes
)

from OGOAT.src.L1_fusion.py_match.helpers.common_type import (
    TensorShape,
    OnnxDType,
    Perm,
)
from topo_sort import graph_sort
from OGOAT.src.L1_fusion.py_match.helpers.perm_helper import PermutationHelper
from OGOAT.src.L1_fusion.py_match.helpers.reshape_transpose_helper import ReshapeTransposeHelper

get_dim = ReshapeTransposeHelper.get_tensor_dim
get_rank = ReshapeTransposeHelper.get_tensor_rank
reshape_diff = ReshapeTransposeHelper.reshape_diff
permute = PermutationHelper.permute

_WRN = "\033[33m!!\033[0m"
_ERR = "\033[31m!!\033[0m"

logger = logging.getLogger("L1_fusion").getChild("NHWC")
logger.setLevel(logging.DEBUG)
# stream_handler = logging.StreamHandler(sys.stdout)
# stream_handler.setLevel(logging.DEBUG)
# logger.addHandler(stream_handler)
# logger.propagate = False
# logger.setLevel(logging.DEBUG)

DEFAULT_INT_DTYPE = TensorProto.UINT16
DEFAULT_FLOAT_DTYPE = TensorProto.FLOAT

def is_floating_dtype(tensor_dtype):
    _dtype = tensor_dtype_to_np_dtype(tensor_dtype)
    return np.issubdtype(_dtype, np.floating)

def is_integer_dtype(tensor_dtype):
    _dtype = tensor_dtype_to_np_dtype(tensor_dtype)
    return np.issubdtype(_dtype, np.integer)


NCHW_OPERATORS = [
    'Conv',
    'ConvTranspose',
    'InstanceNormalization',
    'GlobalAveragePool',
    'AveragePool',
    'MaxPool',
    'DepthToSpace',
    'SpaceToDepth',
]
NHWC_OPERATORS = [
    'Softmax',
    'ReduceSum',
    'ReduceSumSquare',
    'ReduceProd',
    'ReduceMean',
    'ReduceMin',
    'ReduceMax',
    'ReduceL1',
    'ReduceL2',
]

"""
Layout Agnostic ops: 
    - don't have any special params that need a special handler 
    - converted by DefaultHandler, i.e. only layout of I/O tensors is channged  
"""
LAYOUT_AGNOSTIC_OPS = {
    'QuantizeLinear',
    'DequantizeLinear',
    'Shape',
    'Sigmoid',
    'Gemm',
    'LeakyRelu',
    'Relu',
}

"""
Layout Aware ops:
    - have input params (initializers or attributes) that need to be converted 
    - a special handler is required for each such op
"""
LAYOUT_AWARE_OPS = {
    'Reshape',
    'Unsqueeze',
    'Squeeze',
    'MatMul',
    'Mul',
    'Div',
    'Sub',
    'Add',
    'Concat',
    'Resize',
    'Slice',
    'LayerNormalization',
    'LpNormalization',
    'Flatten',
    'Softmax',
    'ReduceSum',
    'ReduceSumSquare',
    'ReduceProd',
    'ReduceMean',
    'ReduceMin',
    'ReduceMax',
    'ReduceL1',
    'ReduceL2',
    'Split',
    'Pad',
}

REDUCE_OPS = {
    'ReduceSum',
    'ReduceSumSquare',
    'ReduceProd',
    'ReduceMean',
    'ReduceMin',
    'ReduceMax',
    'ReduceL1',
    'ReduceL2',
}

# dim -> perm
NHWC_PERMUTATIONS = {
    1:[0], 2: [0, 1], 3: [0, 2, 1], 4: [0, 2, 3, 1],
    -4: [0, 2, 1, 3], # == [0, 2, 3, 1]^T
}
NCHW_PERMUTATIONS = {
    1:[0], 2: [0, 1], 3: [0, 2, 1], 4: [0, 3, 1, 2],
    -4: [0, 1, 3, 2], # == [0, 3, 1, 2]^T
}

def is_NCHW_perm(perm_: Perm, *dims: int) -> bool:
    if dims:
        return perm_ in [NCHW_PERMUTATIONS[d] for d in dims]
    return perm_ in NCHW_PERMUTATIONS.values()

def is_NHWC_perm(perm_: Perm, *dims: int) -> bool:
    if dims:
        return perm_ in [NHWC_PERMUTATIONS[d] for d in dims]
    return perm_ in NHWC_PERMUTATIONS.values()

def is_layout_perm(perm_: Perm, *dims) -> bool:
    is_nchw = is_NCHW_perm(perm_, *dims)
    is_nhwc = is_NHWC_perm(perm_, *dims)
    return is_nchw or is_nhwc



ModAction: TypeAlias = Literal[
    'create_node', 'create_vinfo', 'create_ini', 'create_tag',
    'remove_constant', 'remove_node', 'remove_vinfo', 'remove_ini',
]

class ModActionError(ValueError):
    pass

class LayoutMismatchError(ValueError):
    pass

class LayoutHandlerError(ValueError):
    pass


class Layout(IntEnum):
    NONE = -1
    UNKNOWN = 0
    ANY = 1
    NCHW = 2
    NHWC = 3
    NHCW = 4
    CONFLICT = 6

    def __str__(self):
        return self.name

    def __repr__(self):
        return f"<{self.__class__.__name__}.{self.name}>"

    def transposed(self) -> 'Layout':
        if self is Layout.NCHW or self is Layout.NHWC:
            return Layout(6 // self)
        return self

    def apply_transpose(self, perm_):
        if is_layout_perm(perm_):
            return self.transposed()
        else:
            return self

    @classmethod
    def unify(cls, *_layout_tags) -> 'Layout':
        _layout_tags = set(_layout_tags)
        # UNK = 0; ANY/NONE = +1/-1; NCHW/NHWC = 2/3; CONFLICT = 6
        layout_unified = abs(prod(_layout_tags))
        layout_unified = min(layout_unified, Layout.CONFLICT)
        layout_unified = Layout(layout_unified)
        return layout_unified

    @classmethod
    def get_perm_layout(self, perm_: Perm, *dims: int) -> 'Layout':
        is_nchw = is_NCHW_perm(perm_, *dims)
        is_nhwc = is_NHWC_perm(perm_, *dims)
        if is_nchw and is_nhwc:  # 1D, 2D, 3D
            return Layout.ANY
        if is_nchw:  # 4D
            return Layout.NCHW
        if is_nhwc:  # 4D
            return Layout.NHWC
        return Layout.UNKNOWN

@dataclass
class TaggedShape:
    shape: TensorShape
    tag: Layout
    tag_source: str | None = None
    shape_denotation: list[str] = field(default_factory=list)

    def __hash__(self):
        _shape = tuple(self.shape)
        _self = (self.tag, _shape)
        return hash(_self)

    def __str__(self):
        denote_ = self.get_denotation_str()
        return f"TaggedShape(shape={self.shape}, tag={self.tag}, denotation={denote_}, source={self.tag_source})"

    @classmethod
    def NULL(cls):
        # TODO: maybe should return the same singleton object
        return TaggedShape([], Layout.NONE)

    def get_denotation_str(self):
        if self.shape_denotation:
            denote_ = "|".join(self.shape_denotation)
        else:
            denote_ = "NONE"
        return denote_

    def update(
            self, shape: TensorShape=None, tag: Layout=None, source: str | None = None, denote: list[str] | None = None
    ) -> 'TaggedShape':
        shape = shape if shape else self.shape
        tag = tag if tag else self.tag
        source = source if source else self.tag_source
        denote = denote if denote is not None else self.shape_denotation
        return TaggedShape(shape, tag, source, denote)

    class UnexpectedTaggedShape(ValueError):
        pass

    def require(
            self, shape: TensorShape=None, tag: Layout=None, source: str | None = None, denote: list[str] | None = None
    ) -> 'TaggedShape':

        shape = shape if shape else self.shape
        tag = tag if tag else self.tag
        source = source if source else self.tag_source
        denote = denote if denote is not None else self.shape_denotation
        check_ = (
            shape == self.shape,
            tag == self.tag,
            source == self.tag_source,
            denote == self.shape_denotation
        )
        if not all(check_):
            raise TaggedShape.UnexpectedTaggedShape(
                f"{self}.require({shape=}, {tag=}, denotation={denote}, {source=}) == {check_}"
            )
        return self

    def apply_transpose(self, perm_: Perm) -> 'TaggedShape':
        shape_ = permute(self.shape, perm_)
        tag_ = self.tag.apply_transpose(perm_)
        denote_ = self.shape_denotation
        if self.shape_denotation:
            logger.debug(
                f"APPLY :: Transpose({perm_}) to {self}"
            )
            denote_ = permute(self.shape_denotation, perm_)

            if denote_[0] != "N":
                logger.debug(
                    f"      xx PERM BATCH :: {self.shape_denotation} >> {denote_} :: {self.shape_denotation[0]} >> {denote_[0]}"
                )
                return self.update(shape_, Layout.NONE, denote=[])

            tag_ = self._get_layout_tag(denote_, self.tag)
            if self.tag is not tag_:
                logger.debug(
                    f"      :: SWITCHING LAYOUT :: {self.tag} >> {tag_}"
                )

        return self.update(shape_, tag_, denote=denote_)

    def get_channel_dim(self):
        if self.shape_denotation and "C" in self.shape_denotation:
            return self.shape_denotation.index("C")
        return None

    def get_batch_dim(self):
        if self.shape_denotation and "N" in self.shape_denotation:
            return self.shape_denotation.index("N")
        return None

    @staticmethod
    def _get_layout_denotation(tag: Layout, dim: int) -> list[str]:
        if tag is Layout.NCHW:
            if dim == 4:
                return list("NCHW")
            elif dim == 3:
                return ["N", "C", "HW"]
            else:
                return []
        if tag is Layout.NHWC:
            if dim == 4:
                return list("NHWC")
            elif dim == 3:
                return ["N", "HW", "C"]
            else:
                return []
        return []

    @staticmethod
    def _get_layout_tag(denote: list[str], default_tag = Layout.UNKNOWN) -> Layout:
        if denote[-1] == "C":
            return Layout.NHWC
        if denote[1] == "C":
            return Layout.NCHW
        return default_tag


    def apply_reshape(self, reshaped: TensorShape) -> 'TaggedShape':
        if not self.shape_denotation:
            return self.update(reshaped)

        def get_dims_mask(dim_, mask_) -> list[int]:
            for m in mask_:
                if dim_ in m:
                    return m
            return []

        is_valid_diff, in_mask, out_mask = reshape_diff(self.shape, reshaped)
        logger.debug(
            f"APPLY :: Reshape({reshaped}) to {self}"
        )
        if not is_valid_diff:
            dim_N_diff = self.shape[0] != reshaped[0]
            logger.debug(
                f"      :: INVALID DIFF :: BATCH {'DIFF' if dim_N_diff else 'SAME'}"
            )
            denote_ = self._get_layout_denotation(self.tag, len(reshaped))
            return self.update(reshaped, denote=denote_)

        logger.debug(
            f"      :: VALID DIFF :: {in_mask=}; {out_mask=};"
        )

        if not out_mask[0] or (not in_mask[0] and reshaped[0] == 1):
            logger.debug(
                f"      xx DROPPING BATCH :: {self.shape} >> {reshaped} :: {in_mask=}; {out_mask=};"
            )
            return self.update(reshaped, Layout.NONE, denote=[])

        dimC = self.get_channel_dim()
        dimC_out_mask = get_dims_mask(dimC, out_mask)
        dimC_in_mask = get_dims_mask(dimC, in_mask)

        if len(dimC_in_mask) >= 2 and len(dimC_out_mask) >= 2:
            logger.debug(
                f"      :: N2M DIFF :: {self.shape} >> {reshaped} :: dimC={dimC} >> in_mask[dimC]={dimC_in_mask} >> out_mask[dimC]={dimC_out_mask};"
            )
            denote_ = self._get_layout_denotation(self.tag, len(reshaped))
            return self.update(reshaped, denote=denote_)

        if len(dimC_out_mask) >= 2:
            logger.debug(
                f"      xx SPLITTING CHANNEL :: {self.shape} >> {reshaped} :: dimC={dimC} >> out_mask[dimC]={dimC_out_mask};"
            )
            return self.update(reshaped, Layout.NONE, denote=[])

        if len(dimC_in_mask) >= 2:
            logger.debug(
                f"      xx MERGING CHANNEL :: {self.shape} >> {reshaped} :: dimC={dimC} >> in_mask[dimC]={dimC_in_mask};"
            )
            return self.update(reshaped, Layout.NONE, denote=[])

        denote_ = self._get_layout_denotation(self.tag, len(reshaped))
        return self.update(reshaped, denote=denote_)


class LayoutTagger:
    def __init__(self, model, model_name, out_dir, external_data: bool, verbose:bool=True):
        self.model = model
        self.model_name = model_name
        self.out_dir = out_dir
        self.external_data = external_data
        self._verbose = verbose
        self.md = model_dict.create(self.model)
        self.nodes_layout: dict[str, TaggedShape] = dict()
        self.tensors_layout: dict[str, TaggedShape] = dict()
        pass

    def get_perm(self, node: str|NodeProto) -> Perm:
        node = self.md.get_node(node)
        if not node or node.op_type != "Transpose":
            return []
        n_perm = get_attribute(node, "perm")
        return n_perm

    def get_node_layout(self, node: str|NodeProto) -> TaggedShape:
        node = self.md.get_node(node)

        in_acts = self.md.get_node_activations(node)
        out_acts = self.md.get_node_outputs(node)
        if not out_acts:
            logger.warning(
                f"\033[33m"
                f"[NODE LAYOUT] :: Node({node.name}) has no outputs :: {out_acts} ::"
                f"\033[0m"
            )
            return TaggedShape([], Layout.NONE)
        
        out_shape = self.md.get_shape(list(out_acts)[0])     # for backward compatibility
        if node.op_type == "Transpose":
            in_acts.clear()

        node_acts = in_acts | out_acts
        node_acts_layout = [self.tensors_layout[a_name].tag for a_name in node_acts]
        layout_summary = Layout.unify(*node_acts_layout)
        return TaggedShape(out_shape, layout_summary)

    def get_nodes_layout(self) -> dict[str, TaggedShape]:
        return {
            n_name: self.get_node_layout(n)
            for n_name, n in self.md.nodes.items()
        }

    def reset_layout(self) -> None:
        def initialize_tensor_layout(tensor_: ValueInfoProto) -> TaggedShape:
            if self.md.is_initializer(tensor_.name):
                tag = Layout.ANY
            else:
                tag = Layout.UNKNOWN if 2 < self.md.get_dim(tensor_) < 5 else Layout.NONE
            return TaggedShape(self.md.get_shape(tensor_), tag, "init")


        self.nodes_layout: dict[str, TaggedShape] = dict()
        self.tensors_layout: dict[str, TaggedShape] = {
            t.name: initialize_tensor_layout(t)
            for t in self.md.vinfo.values()
        }
        return

    def _initialize_layout_from_NCHW_ops(self, ) -> set[str]:
        md = self.md
        tagged_acts: set[str] = set()

        ### NCHW ops ###
        nchw_nodes: list[NodeProto] = md.get_nodes(op_types=NCHW_OPERATORS)
        for node in nchw_nodes:
            in_acts = md.get_node_activations(node)
            out_acts = md.get_node_outputs(node)
            node_acts = in_acts | out_acts
            n_layouts = {
                a_name: TaggedShape(
                    md.get_shape(act), Layout.NCHW, node.name,
                    list("NCHW") if md.get_dim(act) == 4 else ["N", "C", "HW"]
                )
                for a_name, act in node_acts.items()
                if self.tensors_layout[a_name].tag is Layout.UNKNOWN
            }
            tagged_acts |= n_layouts.keys()
            self.tensors_layout |= n_layouts

        return tagged_acts

    def _initialize_layout_from_Reduce_ops(self, ) -> set[str]:
        md = self.md
        tagged_acts: set[str] = set()
        nchw_nodes: list[NodeProto] = md.get_nodes(op_types=REDUCE_OPS)
        for node in nchw_nodes:
            if node.op_type.startswith("ReduceMean"):
                in_acts = md.get_node_activations(node)
                out_acts = md.get_node_outputs(node)
                node_acts = in_acts | out_acts
                axes = None
                if "axes" in node.attribute:
                    axes = node.attribute['axes']
                elif len(node.input) > 1:
                    axes_input = node.input[1]
                    if axes_input in self.md.ini:
                        axes = numpy_helper.to_array(self.md.ini[axes_input]).tolist()                 
                    
                n_layouts = dict()
                if axes == [1]:
                    n_layouts = {
                        a_name: TaggedShape(
                            md.get_shape(act), Layout.NCHW, node.name,
                            list("NCHW") if md.get_dim(act) == 4 else ["N", "C", "HW"]
                        )
                        for a_name, act in node_acts.items()
                        if self.tensors_layout[a_name].tag is Layout.UNKNOWN
                    }
                elif axes  == [-1] or axes == [len(md.get_shape(node.input[0]))-1]:
                    n_layouts = {
                        a_name: TaggedShape(
                            md.get_shape(act), Layout.NHWC, node.name,
                            list("NHWC") if md.get_dim(act) == 4 else ["N", "HW", "C"]
                        )
                        for a_name, act in node_acts.items()
                        if self.tensors_layout[a_name].tag is Layout.UNKNOWN
                    }
                tagged_acts |= n_layouts.keys()
                self.tensors_layout |= n_layouts
        return tagged_acts

                    
    def tag_nodes_v2(self, dbg=False) -> None:
        # tag tensors, not nodes
        # 0D, 1D, 2D -- no tags
        # --------------------------
        # * N, B      : BATCH
        # * F, H, W   : FEATURES
        # * C         : CHANNELS
        # * G         : CHANNEL GROUP == BATCH for CHANNELS
        # --------------------------
        # NCHW-activations layouts:
        #   3D(2) -- 1CF == NCF,        : InstanceNorm
        #   3D(3) -- CHW
        #         -- CCF
        #   4D(3) -- 1CHW == NCHW       : Conv, AveragePool, MaxPool
        #         -- 1GCF
        #         -- G1CF
        #         -- GC1F
        #   4D(4) -- NCHW
        #         -- GCCF
        #   5D(4) --
        #   5D(5) -- GCCHW
        #         -- GCHCW
        # --------------------------
        # NHWC-activations layouts:
        #   5D(4) -- 1FGCC
        #   5D(3) -- 1F1CC

        def count_tags(layout_dict: dict[str, TaggedShape]) -> Counter[Layout]:
            _counter = Counter([
                l.tag for l in layout_dict.values()
            ])
            return _counter

        md = self.md
        logger.info(f"Start Layout Tagging V2")
        logger.debug(f"  |- ini-nodes / nodes = {len(md._initializer_nodes)} / {len(md.nodes)}")
        logger.debug(f"  |- const-tensors = {len(md.ini)}")
        logger.debug(f"  |- act-tensors + ini-tensors / total = {len(md.get_activations())} + {len(md.get_dynamic_initializers())} / {len(md.vinfo)}")

        # -------------------------------------------------------------------------------------------
        # INITIALIZE LAYOUT TAGS
        # -------------------------------------------------------------------------------------------
        self.reset_layout()
        tensor_layout_counts = count_tags(self.tensors_layout)
        unknown_total = tensor_layout_counts[Layout.UNKNOWN]
        logger.debug(f"# Initial tensor layout:")
        logger.debug(f"  |- {dict(tensor_layout_counts)}")

        # tag inputs/outputs of NCHW ops
        tagged_acts: set[str] = self._initialize_layout_from_NCHW_ops()
        tagged_acts |= self._initialize_layout_from_Reduce_ops()
        tensor_layout_counts = count_tags(self.tensors_layout)
        logger.debug(f"# Infer tensors layout from NCHW nodes:")
        logger.debug(f"  |- tensors tagged + left unknown / total unknown = {len(tagged_acts)} + {tensor_layout_counts[Layout.UNKNOWN]} / {unknown_total}")
        logger.debug(f"  |- {dict(tensor_layout_counts)}")

        ### layout initialization is done ###
        self.nodes_layout = self.get_nodes_layout()
        if dbg: self.save_tagged_model(f"{self.model_name}_initialized_layout.onnx", self.nodes_layout, self.tensors_layout)
        if dbg: self.save_tensor_layout_dict(f"{self.model_name}_initialized_layout.json", self.tensors_layout)


        # -------------------------------------------------------------------------------------------
        # PROPAGATE LAYOUT TAGS
        # -------------------------------------------------------------------------------------------
        bfs_step = 0
        tagged_acts_count = 0
        logger.debug(f"# Propagate layout tags with BFS")
        while tagged_acts:
            tagged_acts_count += len(tagged_acts)
            logger.debug(f"  * BFS[{bfs_step}]: tagged_acts / total unknown = {tagged_acts_count} / {unknown_total}")
            next_tagged_acts = set()

            #####   BSF  STEP   #####
            while tagged_acts:
                act = tagged_acts.pop()
                a_layout = self.tensors_layout[act]
                up_layout = a_layout.update(source=act)
                down_layout = a_layout.update(source=act)

                a_suppliers = md.get_tensor_suppliers(act, activations_only=True)
                a_suppliers = [s for s in a_suppliers if self.tensors_layout[s].tag is Layout.UNKNOWN]

                a_consumers = md.get_tensor_consumers(act)
                a_consumers = [s for s in a_consumers if self.tensors_layout[s].tag is Layout.UNKNOWN]

                #####   STEP  UP   >>>   to SUPPLIERS   #####
                supp_layouts: dict[str, TaggedShape] = dict()
                if a_suppliers and md.is_output_of(act, "Transpose"):
                    n_transpose = md.get_writer(act)
                    n_down_perm = self.get_perm(n_transpose)
                    n_up_perm = PermutationHelper.get_inverse_perm(n_down_perm)
                    supplier_, *_ = a_suppliers

                    up_layout = up_layout.apply_transpose(n_up_perm)
                    up_layout = up_layout.require(md.get_shape(supplier_))
                    logger.debug(
                        f"\033[33m"
                        f"LAYOUT UP :: Supplier<{up_layout.tag}>({supplier_}) :: {up_layout} ::\n"
                        f"          :: Tensor<{a_layout.tag}>({act}).shape={md.get_shape(act)} "
                        f"is produced by Transpose({n_transpose.name}).perm={n_down_perm}"
                        f"\033[0m"
                    )

                    supp_layouts[supplier_] = up_layout
                    pass
                elif a_suppliers and md.is_output_of(act, "Reshape", "Squeeze", "Unsqueeze", "Flatten"):
                    n_reshape = md.get_writer(act)
                    supplier_, *_ = a_suppliers
                    n_input_shape = md.get_shape(supplier_)

                    up_layout = up_layout.apply_reshape(n_input_shape)
                    up_layout = up_layout.require(md.get_shape(supplier_))
                    logger.debug(
                        f"\033[33m"
                        f"LAYOUT UP :: Supplier<{up_layout.tag}>({supplier_}) :: {up_layout} ::\n"
                        f"          :: Tensor<{a_layout.tag}>({act}).shape={md.get_shape(act)} "
                        f"is produced by Reshape({n_reshape.name}).up_shape={n_input_shape}"
                        f"\033[0m"
                    )

                    if up_layout.tag is Layout.NONE:
                        continue

                    supp_layouts[supplier_] = up_layout
                    pass
                else:
                    supp_layouts = {
                        s: up_layout.update(md.get_shape(s))
                        for s in a_suppliers
                    }
                    pass

                # propagate layout up
                next_tagged_acts |= supp_layouts.keys()
                self.tensors_layout |= supp_layouts


                #####   STEP  DOWN   >>>   to CONSUMERS   #####
                cons_layouts = dict()
                for consumer_ in a_consumers:
                    if md.is_output_of(consumer_, "Transpose"):
                        n_transpose = md.get_writer(consumer_)
                        n_down_perm = self.get_perm(n_transpose)

                        cons_layout: TaggedShape = down_layout.apply_transpose(n_down_perm)
                        cons_layout = cons_layout.require(md.get_shape(consumer_))
                        logger.debug(
                            f"\033[33m"
                            f"LAYOUT DOWN :: Consumer<{cons_layout.tag}>({consumer_}) :: {cons_layout} ::\n"
                            f"            :: Tensor<{a_layout.tag}>({act}).shape={md.get_shape(act)} "
                            f"is consumed by Transpose({n_transpose.name}).perm={n_down_perm}"
                            f"\033[0m"
                        )
                        cons_layouts[consumer_] = cons_layout
                        pass
                    elif md.is_output_of(consumer_, "Reshape", "Squeeze", "Unsqueeze", "Flatten"):
                        n_reshape = md.get_writer(consumer_)
                        n_output_shape = md.get_shape(consumer_)

                        cons_layout = down_layout.apply_reshape(n_output_shape)
                        cons_layout = cons_layout.require(md.get_shape(consumer_))
                        logger.debug(
                            f"\033[33m"
                            f"LAYOUT DOWN :: Consumer<{cons_layout.tag}>({consumer_}) :: {cons_layout} ::\n"
                            f"            :: Tensor<{a_layout.tag}>({act}).shape={md.get_shape(act)} "
                            f"is consumed by Reshape({n_reshape.name}).down_shape={n_output_shape}"
                            f"\033[0m"
                        )

                        if cons_layout.tag is Layout.NONE:
                            continue

                        cons_layouts[consumer_] = cons_layout
                        pass
                    else:
                        cons_layouts[consumer_] = (
                            down_layout.update(md.get_shape(consumer_))
                        )
                    pass

                # propagate layout down
                next_tagged_acts |= cons_layouts.keys()
                self.tensors_layout |= cons_layouts

                if self._verbose:
                    logger.debug(
                        f"  |- BFS[{bfs_step}][{len(tagged_acts)}]: Tensor<{a_layout}>({act}) propagated "
                        f"UP to {len(a_suppliers)} suppliers and "
                        f"DOWN to {len(a_consumers)} consumers."
                    )
                pass

            logger.debug(f"  * BFS[{bfs_step}]: next_tagged_acts = {len(next_tagged_acts)}")
            tagged_acts = next_tagged_acts
            bfs_step += 1
            pass

        ### LAYOUT TAGGING is DONE ###
        self.nodes_layout = self.get_nodes_layout()
        tensor_layout_counts = count_tags(self.tensors_layout)
        node_layout_counts = count_tags(self.nodes_layout)
        logger.debug("# LAYOUT TAGGING is DONE ")
        logger.debug(f"  |- tensors tagged + left unknown / total unknown"
                     f"  =  {tagged_acts_count} + {tensor_layout_counts[Layout.UNKNOWN]} / {unknown_total}")
        logger.debug(f"  |- Tensors layout: {dict(tensor_layout_counts)}")
        logger.debug(f"  |- Nodes layout: {dict(node_layout_counts)}")
        return

    def save_tagged_model(
            self, file_name: str, nodes_layout: dict[str, TaggedShape] | None = None, tensors_layout: dict[str, TaggedShape] | None = None,
    ) -> None:
        nodes_layout = nodes_layout if nodes_layout else self.get_nodes_layout()
        tensors_layout = tensors_layout if tensors_layout else self.tensors_layout

        # add memory layout tag
        _model = copy.deepcopy(self.model)
        for node in _model.graph.node:
            node.attribute.append(make_attribute(
                "node_layout", nodes_layout.get(node.name, TaggedShape.NULL()).tag.name
            ))
            for a_name, a_idx in self.md.get_node_activations_index(node).items():
                node.attribute.append(make_attribute(
                    f"input_{a_idx}_layout", str(tensors_layout.get(a_name, TaggedShape.NULL()).get_denotation_str())
                ))
            for a_name, a_idx in self.md.get_node_outputs_index(node).items():
                node.attribute.append(make_attribute(
                    f"output_{a_idx}_layout", str(tensors_layout.get(a_name, TaggedShape.NULL()).get_denotation_str())
                ))

        model_path = os.path.join(self.out_dir, file_name)
        save_model(_model, model_path, external_data=self.external_data)
        logger.debug(f"> Tagged model saved: {model_path}")
        return

    def save_tensor_layout_dict(
            self, file_name: str, layout_dict: dict[str, TaggedShape] | None = None, NCHW_inputs: list | None = None,
    ):
        layout_dict = layout_dict if layout_dict else self.tensors_layout
        NCHW_inputs = NCHW_inputs if NCHW_inputs else []
        mem_layout_dict = {}
        for _name, tag_shape in layout_dict.items():
            if self.md.is_node(_name):
                _type = self.md.nodes[_name].op_type
            elif self.md.is_initializer(_name):
                _type = "Initializer"
            elif self.md.is_activation(_name):
                _type = "Activation"
            else:
                _type = None

            mem_layout_dict[_name] = dict(
                layout=tag_shape.tag.name,
                shape=tag_shape.shape,
                source=tag_shape.tag_source,
                type=_type,
                denotation=tag_shape.shape_denotation,
                is_nchw_input=_name in NCHW_inputs
            )

        with open(os.path.join(self.out_dir, file_name), 'w') as f:
            json.dump(mem_layout_dict, f, indent=2)
        return
    pass


class OpTag(NamedTuple):
    shape: TensorShape
    tag: str                    # one of {NONE, UNKNOWN, ANY, NCHW, NHWC, CONFLICT}

    @classmethod
    def NOT_FOUND(cls) -> 'OpTag':
        return OpTag([], "Not Found")

    def with_tag(self, tag_: str | Layout) -> 'OpTag':
        if isinstance(tag_, Layout):
            tag_ = tag_.name
        return OpTag(self.shape, tag_)


class ModChain(NamedTuple):
    start: str
    chain: list[str]
    ends: list[str]
    conv_end: bool
    transpose_end: bool
    id: int


@dataclass
class NCHWSubgraph:
    id: int
    roots: list[str] = field(default_factory=list)
    body: list[str] = field(default_factory=list)
    leaves: list[str] = field(default_factory=list)
    conv_end: bool = False
    transpose_end: bool = False

    _md: model_dict = None
    _tag_ops: dict[str, OpTag] = field(default_factory=dict, init=False)
    _null_subgraph: ClassVar['NCHWSubgraph'] = field(default=None, init=False, )

    @classmethod
    def NULL(cls, *, reset=False) -> 'NCHWSubgraph':
        if not cls._null_subgraph or reset:
            cls._null_subgraph = NCHWSubgraph(0)
        return NCHWSubgraph._null_subgraph

    def __len__(self):
        return len(self.roots) + len(self.body) + len(self.leaves)

    def __eq__(self, obj):
        if not obj or not isinstance(obj, NCHWSubgraph):
            return False
        return self.id == obj.id

    def __repr__(self):
        return (
            f"NCHWSubgraph("
            f"id={self.id}, "
            f"roots={len(self.roots)}, body={len(self.body)}, leaves={len(self.leaves)}"
            f")"
        )

    def __str__(self):
        return f"NCHWSubgraph(id={self.id}, len={len(self)})"

    def with_roots(self, nodes_: Iterable[str] | str) -> 'NCHWSubgraph':
        self.roots.clear()
        if isinstance(nodes_, str):
            self.roots.append(nodes_)
        else:
            self.roots.extend(nodes_)
        return self

    def with_context(self, md: model_dict, tag_ops: dict[str, OpTag]) -> 'NCHWSubgraph':
        self._md = md
        self._tag_ops = tag_ops
        return self

    def contains_node(self, node: str) -> bool:
        if not node:
            return False
        _contains = node in self.roots or node in self.leaves
        _contains |= node in self.body
        return _contains

    def get_sub_nodes(self) -> list[str]:
        return self.roots + self.body + self.leaves

    def absorb(self, sg_minor: 'NCHWSubgraph') -> Iterable[str]:
        logger.debug(f"[ABSORB SG] :: main={str(self)} ; minor={repr(sg_minor)}")
        if not sg_minor:
            return []
        # TODO: if uniqueness is not critical then it can be sacrificed for speed
        absorbed_roots = set(sg_minor.roots) - set(self.roots)
        absorbed_body = set(sg_minor.body) - set(self.body)
        absorbed_leaves = set(sg_minor.leaves) - set(self.leaves)
        self.roots.extend(absorbed_roots)
        self.body.extend(absorbed_body)
        self.leaves.extend(absorbed_leaves)
        # return sg_minor.get_sub_nodes()
        return list(absorbed_roots) + list(absorbed_body) + list(absorbed_leaves)

    def is_valid_subgraph(self) -> Optional[bool]:
        """
        Checks if this is a VALID NCHW-subgraph:
            1) sub_roots  are : Transpose<NCHW>,
            2) sub_nodes  are : Node<NCHW> and not NCHW-ops
            3) sub_leaves are : Transpose<NHWC>
        """
        if not self._md or not self._tag_ops:
            return None

        # empty subgraph is NOT VALID
        if len(self) == 0:
            return False

        # check SUB_ROOTS
        if not self.has_valid_subroots():
            return False

        # check SUB_NODES
        for sub_node in self.body:
            if self._md.get_op_type(sub_node) in NCHW_OPERATORS:
                return False
            if self._tag_ops.get(sub_node, OpTag.NOT_FOUND()).tag != "NCHW":
                return False

        # check SUB_LEAVES
        if not self.has_valid_subleaves():
            return False

        # subgraph is VALID
        return True

    def has_valid_subroots(self) -> Optional[bool]:
        """
        Checks if this is a VALID NCHW-subgraph:
            *) sub_roots  are : Transpose<NCHW>,
        """
        if not self._md or not self._tag_ops:
            return None

        # empty SUB_ROOTS are NOT VALID
        if not self.roots:
            return False

        # check SUB_ROOTS
        for sub_root in self.roots:
            if not self.is_valid_subroot(sub_root):
                return False

        return True

    def is_valid_subroot(self, sub_root: str):
        if self._md.get_op_type(sub_root) != "Transpose":
            return False
        if self._tag_ops.get(sub_root, OpTag.NOT_FOUND()).tag != "NCHW":
            return False
        return True

    def has_valid_subleaves(self) -> Optional[bool]:
        """
        Checks if this is a VALID NCHW-subgraph:
            3) sub_leaves are : Transpose<NHWC>
        """
        if not self._md or not self._tag_ops:
            return None

        # empty SUB_LEAVES are NOT VALID
        if not self.leaves:
            return False

        # check SUB_LEAVES
        for sub_leaf in self.leaves:
            if not self.is_valid_subleaf(sub_leaf):
                return False

        return True

    def is_valid_subleaf(self, sub_leaf: str):
        if self._md.get_op_type(sub_leaf) != "Transpose":
            return False
        if self._tag_ops.get(sub_leaf, OpTag.NOT_FOUND()).tag != "NHWC":
            return False
        return True

    def clear(self):
        self.roots.clear()
        self.body.clear()
        self.leaves.clear()
        return

    def asdict(self) -> dict[str: int|list[str]]:
        sg_dict = dict(
            id=self.id,
            roots=self.roots,
            body=self.body,
            leaves=self.leaves,
        )
        return sg_dict


class NHWCLayoutConverter:

    def __init__(self, model: onnx.ModelProto, model_name, out_dir, external_data: bool, dbg=False):
        self.model: onnx.ModelProto = model
        self.model_name = model_name
        self.out_dir = out_dir
        self.external_data = external_data

        self.md: model_dict = model_dict.create(self.model)
        self._init_handlers_to_NHWC(self.md)
        self.check_model_vinfo()

        self.tag_ops: dict[str, OpTag] = dict()         # {node_name -> node.output[0].shape, layout_tag}
        self.frozen_layout: dict[str, str] = dict()     # {node_name -> frozen layout tag}
        self.transpose_ops = list()
        self.new_transposes = {'input': {'NCHW': [], 'NHWC': []}, 'output': {'NCHW': [], 'NHWC': []}}
        self.NCHW_inputs = []

        self.nchw_subgraphs: dict[int: NCHWSubgraph] = dict()           # {sub_id -> subgraph}
        self.nodes_to_subgraphs: dict[str, int] = defaultdict(int)      # {node_name -> subgraph.id}
        self.tensor_map: dict[str: dict] = dict()                       # {new_tensor.name -> orig_tensor.name}

        self._dbg = dbg
        self.check_model(dbg)
        return

    def _infer_layout_and_insert_transposes(self, dbg=False) -> list[str]:
        # LAYOUT TAGGING V2
        tagger = LayoutTagger(
            self.model, self.model_name, self.out_dir,
            external_data=self.external_data, verbose=False,
        )
        tagger.tag_nodes_v2(dbg=dbg)
        if dbg: tagger.save_tensor_layout_dict(f"{self.model_name}_1_tagged_layout.json")
        if dbg: tagger.save_tagged_model(f"{self.model_name}_1_tagged_layout.onnx")

        self.md = tagger.md
        self.transpose_ops = list(self.md.get_nodes_dict("Transpose"))

        # INITIALIZE TENSOR MAP
        for tensor_name, tag_shape in tagger.tensors_layout.items():
            if self.md.is_initializer(tensor_name):
                continue

            self.tensor_map[tensor_name] = dict(
                orig_tensor=tensor_name,
                orig_shape=tag_shape.shape,
                orig_layout=tag_shape.tag.name,
                final_shape=tag_shape.shape,
                final_layout=tag_shape.tag.name,
            )

        # INITIALIZE NODE LAYOUT TAGS
        nodes_layout = tagger.get_nodes_layout()
        self.tag_ops = {
            node_name: OpTag(node_layout.shape, node_layout.tag.name)
            for node_name, node_layout in nodes_layout.items()
        }
        self.check_tag_status()
        if dbg: self.save_memory_layout_dict(f"{self.model_name}_1_tagged_layout_nodes.json")
        if dbg: self.save_dbg_model(f"{self.model_name}_1_tagged_layout_nodes.onnx")

        # T-SHIELDING of NCHW ops
        self.shield_NCHW_ops()
        self.md.update_dict(self.model)
        logger.debug(f"... layout shielding of NCHW ops is DONE")
        if dbg: self.save_memory_layout_dict(f"{self.model_name}_2A_after_shielding_NCHW_ops.json")
        if dbg: self.save_dbg_model(f"{self.model_name}_2A_after_shielding_NCHW_ops.onnx")

        # T-SHIELDING of NHWC ops
        self.shield_NHWC_ops()
        self.md.update_dict(self.model)
        logger.debug(f"... layout shielding of NHWC ops is DONE")
        if dbg: self.save_dbg_model(f"{self.model_name}_2B_after_shielding_NHWC_ops.onnx")
        if dbg: self.save_memory_layout_dict(f"{self.model_name}_2B_after_shielding_NHWC_ops.json")

        # T-SHIELDING of GRAPH Inputs/Outputs
        NCHW_inputs = self.shield_NCHW_graph_inputs()
        self.shield_NCHW_graph_outputs()
        self.md.update_dict(self.model)
        logger.debug(f"... layout shielding of Graph Inputs/Outputs is DONE")
        if dbg: self.save_dbg_model(f"{self.model_name}_3_after_shielding_IO.onnx")
        if dbg: self.save_memory_layout_dict(f"{self.model_name}_3_after_shielding_IO.json", NCHW_inputs)

        # T-SHIELDING of 3D MatMuls
        self.shield_3D_MatMuls()
        self.md.update_dict(self.model)
        logger.debug(f"... layout shielding of 3D MatMul(A=ini, B=act) ops is DONE")
        if dbg: self.save_dbg_model(f"{self.model_name}_4_after_shielding_3D_MatMuls.onnx")
        if dbg: self.save_memory_layout_dict(f"{self.model_name}_4_after_shielding_3D_MatMuls.json", NCHW_inputs)

        self.check_tag_status()

        # Save tensor map dictionary in a json file
        with open(os.path.join(self.out_dir, f"{self.model_name}_tensor_map_before.json"), "w") as f:
            json.dump(self.tensor_map, f, indent=2)

        return NCHW_inputs

    def pre_fusion(self, dbg=None) -> None:
        dbg = self._dbg if dbg is None else dbg

        # PROCYON FIX / PRE-FUSION PASS
        logger.info(f"\033[34mStart PRE-FUSION\033[0m")
        if dbg: self.save_dbg_model(f"{self.model_name}_00_pre_fusion.onnx")

        qdq_fixed = False
        for act_op_name, act_op in self.md.get_nodes_dict("Relu", "Clip", "LeakyRelu").items():
            act_sc, act_zp = None, None  # qdq_scale_0, qdq_zero_point_0
            act_float_dtype, act_int_dtype = DEFAULT_FLOAT_DTYPE, DEFAULT_INT_DTYPE

            # [A.1] CHECK INPUT:
            act_op_input_name = act_op.input[0]
            act_op_supplier = self.md.get_writer(act_op_input_name)
            act_op_input_shape = self.md.get_shape(act_op_input_name)
            act_float_dtype = self.md.get_onnx_dtype(act_op_input_name)

            if not act_op_supplier:
                logger.warning(
                    f"\033[34m[PRE-FUSION]: \033[0m"
                    f"an activation op-node {act_op.op_type}({act_op.name}) "
                    f"has no node-suppliers writing its input[0]={act_op_input_name}."
                )
                continue

            if act_op_supplier.op_type == "DequantizeLinear":
                continue

            # [A.2] CHECK OUTPUT:
            act_op_output_name = act_op.output[0]
            act_op_consumer = self.md.get_readers(act_op_output_name)
            if not act_op_consumer:
                logger.warning(
                    f"\033[34m[PRE-FUSION]: \033[0m"
                    f"an activation op-node {act_op.op_type}({act_op.name}) "
                    f"has no node-consumers reading from its output[0]={act_op_output_name}."
                )
                continue

            if len(act_op_consumer) > 1:
                logger.warning(
                    f"\033[34m[PRE-FUSION]: \033[0m"
                    f"an activation op-node {act_op.op_type}({act_op.name}) "
                    f"has {len(act_op_consumer)} node-consumers reading from its output[0]={act_op_output_name}."
                )
                continue

            logger.debug(
                f"\033[34m[PRE-FUSION]: \033[0m"
                f"an activation op-node {act_op.op_type}({act_op.name}) will be QDQ-FIXED with: \n"
                f"has {len(act_op_consumer)} node-consumers reading from its output[0]={act_op_output_name}."
            )

            # [1] GET QDQ PARAMS:
            act_op_consumer, *_ = act_op_consumer
            if act_op_consumer.op_type == "QuantizeLinear":
                act_sc = self.md.get_tensor(act_op_consumer.input[1])
                act_zp = self.md.get_tensor(act_op_consumer.input[2])
                # act_float_dtype = self.md.get_onnx_dtype(act_sc)
                act_int_dtype = self.md.get_onnx_dtype(act_zp)
            else:
                act_int_dtype = DEFAULT_INT_DTYPE
                act_sc = numpy_helper.from_array(
                    np.array(1).astype(tensor_dtype_to_np_dtype(act_float_dtype)), f"{act_op_name}_dummy_qdq_scale"
                )
                act_zp = numpy_helper.from_array(
                    np.array(0).astype(tensor_dtype_to_np_dtype(act_int_dtype)), f"{act_op_name}_dummy_qdq_zero_point"
                )
                act_sc = self.model.graph.initializer.append(act_sc)
                act_zp = self.model.graph.initializer.append(act_zp)

            logger.debug(
                f"\033[34m[PRE-FUSION]: \033[0m"
                f"an activation op-node {act_op.op_type}({act_op.name}) will be QDQ-FIXED with: \n"
            )

            # [2] SETUP:
            qdq_fixed = True
            q_in_name = f"{act_op_name}_fixed_in_q"
            dq_in_name = f"{act_op_name}_fixed_in_dq"

            # [3] INSERT FIXED QDQ CHAIN:
            q_in, q_in_out = self._append_node_to_model(*self._make_q_node(
                q_in_name, act_sc, act_zp,
                # act_op_input_name, connect_qdq_name, act_op_input_shape
                input=act_op_input_name,
                output_shape=act_op_input_shape,
                output_dtype=act_int_dtype,
            ))
            dq_in, dq_in_out = self._append_node_to_model(*self._make_dq_node(
                dq_in_name, act_sc, act_zp,
                input=q_in_out.name,
                output_shape=act_op_input_shape,
                output_dtype=act_float_dtype,
            ))

            act_op_ = copy.deepcopy(act_op)
            act_op_.input[0] = dq_in_out.name
            self.model.graph.node.remove(act_op)
            self.model.graph.node.append(act_op_)

            pass

        if qdq_fixed:
            self.graph_postprocessing_cleanup(check=False, dbg=dbg)
        logger.info(f"\033[34mPRE-FUSION is COMPLETE\033[0m")
        return

    def layout_tagging_and_shielding(self, dbg=None):
        dbg = self._dbg if dbg is None else dbg

        # SHIELDING & LAYOUT TAGGING
        try:
            self.NCHW_inputs = self._infer_layout_and_insert_transposes(dbg=dbg)
        except Exception as e:
            logger.error(f"Layout conversion has FAILED during shielding and layout tagging with {e}")
            logger.error(traceback.format_exc())
        else:
            logger.info(f"\033[32mLayout shielding and tagging is COMPLETE\033[0m")
            self.graph_postprocessing_cleanup(check=False, dbg=dbg)
        finally:
            self.save_memory_layout_dict(f"{self.model_name}_memory_layout.json", self.NCHW_inputs)
        return

    def _convert_to_nhwc(self, NCHW_inputs: list[str], dbg=False):
        # COLLECT NCHW-SUBGRAPHS
        self.collect_NCHW_subgraphs()
        logger.debug(f"... subgraph collection is DONE: {len(self.nchw_subgraphs)} NCHW-subgraphs were COLLECTED")
        if dbg: self.save_subgraph_dict(f"{self.model_name}_5A_collected_subgraphs.json", self.nchw_subgraphs)
        if dbg: self.save_dbg_model(f"{self.model_name}_5A_collected_subgraphs.onnx", self.nodes_to_subgraphs)

        # PATCH NCHW-SUBGRAPHS
        patched_subgraphs = self.patch_NCHW_subgraphs()
        self.md.update_dict(self.model)
        self._init_handlers_to_NHWC(self.md)
        logger.debug(f"... subgraph patching is DONE: {len(patched_subgraphs)} NCHW-subgraphs were PATCHED")
        if dbg: self.save_subgraph_dict(f"{self.model_name}_5B_patched_subgraphs.json", self.nchw_subgraphs)
        if dbg: self.save_dbg_model(f"{self.model_name}_5B_patched_subgraphs.onnx", self.nodes_to_subgraphs)

        # TRANSFORM LAYOUT of NCHW-SUBGRAPHS
        mod_subgraph_dict = self.transform_subgraphs(self.nchw_subgraphs)
        self.md.update_dict(self.model)
        logger.debug(f"... layout transformation is DONE: {len(mod_subgraph_dict)} NCHW-subgraphs were TRANSFORMED")
        if dbg: self.save_subgraph_dict(f"{self.model_name}_6_transformed_subgraphs.json", mod_subgraph_dict)
        if dbg: self.save_dbg_model(f"{self.model_name}_6_after_transformation.onnx")
        if dbg: self.save_memory_layout_dict(f"{self.model_name}_6_after_transformation.json", NCHW_inputs)

        # Save tensor map dictionary in a json file
        with open(os.path.join(self.out_dir, f"{self.model_name}_tensor_map.json"), "w") as f:
            json.dump(self.tensor_map, f, indent=2)

        self.check_tag_status()
        check_binary_shapes(self.md)
        return

    def convert_layout_to_nhwc(self, dbg=None):
        dbg = self._dbg if dbg is None else dbg

        # CONVERT MODEL TO NHWC LAYOUT
        try:
            self._convert_to_nhwc(self.NCHW_inputs, dbg=dbg)
        except Exception as e:
            logger.error(f"Layout conversion has FAILED with {e}")
            logger.error(traceback.format_exc())
        else:
            logger.info(f"\033[32mLayout conversion is COMPLETE\033[0m")
            self.graph_postprocessing_cleanup(dbg=dbg)
        finally:
            self.save_memory_layout_dict(f"{self.model_name}_nhwc_memory_layout.json", self.NCHW_inputs)
        return

    def graph_postprocessing_cleanup(self, check=True, dbg=False):
        graph_sort(self.model, 0)
        self.md.update_dict(self.model)

        if check:
            self.check_graph_status()
        if dbg:
            self.save_dbg_model(f"{self.model_name}_debug.onnx")
        return

    def _init_handlers_to_NHWC(self, md=None):
        md = self.md if not md else md
        self.handlers_to_NHWC = {
            'Reshape': lambda _name: self.Reshape_to_NHWC(md, _name),
            'Squeeze': lambda _name: self.Squeeze_to_NHWC(md, _name),
            'Unsqueeze': lambda _name: self.Unsqueeze_to_NHWC(md, _name),
            'MatMul': lambda _name: self.MatMul_to_NHWC(md, _name),
            'Mul': lambda _name: self.Mul_to_NHWC(md, _name),
            'Div': lambda _name: self.Div_to_NHWC(md, _name),
            'Add': lambda _name: self.Add_to_NHWC(md, _name),
            'Sub': lambda _name: self.Sub_to_NHWC(md, _name),
            'Concat': lambda _name: self.Concat_to_NHWC(md, _name),
            'Resize': lambda _name: self.Resize_to_NHWC(md, _name),
            'Slice': lambda _name: self.Slice_to_NHWC(md, _name),
            'LayerNormalization': lambda _name: self.LayerNormalization_to_NHWC(md, _name),
            'LpNormalization': lambda _name: self.LpNormalization_to_NHWC(md, _name),
            'Flatten': lambda _name: self.Flatten_to_NHWC(md, _name),
            'Softmax': lambda _name: self.Softmax_to_NHWC(md, _name),
            'ReduceSum': lambda _name: self.Reduce_to_NHWC(md, _name),
            'ReduceMean': lambda _name: self.Reduce_to_NHWC(md, _name),
            'ReduceL2': lambda _name: self.Reduce_to_NHWC(md, _name),
            'ReduceL1': lambda _name: self.Reduce_to_NHWC(md, _name),            
            'ReduceSumSquare': lambda _name: self.Reduce_to_NHWC(md, _name),
            'ReduceProd': lambda _name: self.Reduce_to_NHWC(md, _name),
            'ReduceMin': lambda _name: self.Reduce_to_NHWC(md, _name),
            'ReduceMax': lambda _name: self.Reduce_to_NHWC(md, _name),
            'Split': lambda _name: self.Split_to_NHWC(md, _name),
            'Pad': lambda _name: self.Pad_to_NHWC(md, _name),
        }
        return

    def save_dbg_model(self, file_name, subgraphs_idx=None):
        # add memory layout tag
        _model = copy.deepcopy(self.model)
        for node in _model.graph.node:
            node.attribute.append(make_attribute(
                "mem_layout", self.tag_ops.get(node.name, OpTag.NOT_FOUND()).tag
            ))
            node.attribute.append(make_attribute(
                "frozen_layout", node.name in self.frozen_layout
            ))
            if subgraphs_idx:
                node.attribute.append(make_attribute(
                    "subgraph", str(subgraphs_idx.get(node.name, "NONE"))
                ))

        model_path = os.path.join(self.out_dir, file_name)
        save_model(_model, model_path, external_data=self.external_data)
        logger.debug(f"> Debug model saved: {model_path}")
        pass

    def save_subgraph_dict(self, file_name, nchw_subgraphs: dict[int, NCHWSubgraph]) -> None:
        subgraph_dicts = {
            _id: subgraph.asdict()
            for _id, subgraph in nchw_subgraphs.items()
        }
        with open(os.path.join(self.out_dir, file_name), 'w') as f:
            json.dump(subgraph_dicts, f, indent=2)
        return

    def save_memory_layout_dict(self, file_name, NCHW_inputs:list=None, nodes_layout:dict=None):
        NCHW_inputs = NCHW_inputs if NCHW_inputs else []
        nodes_layout = nodes_layout if nodes_layout else self.tag_ops
        mem_layout_dict = {}
        for node_name, (_shape, _layout) in nodes_layout.items():
            mem_layout_dict[node_name] = dict(
                op_type=self.md.nodes[node_name].op_type if node_name in self.md.nodes else "MISSING",
                layout=_layout,
                out_shape=_shape,
                is_nchw_input=node_name in NCHW_inputs
            )

        with open(os.path.join(self.out_dir, file_name), 'w') as f:
            json.dump(mem_layout_dict, f, indent=2)
        return

    def load_memory_layout_dict(self, mem_layout_path):
        mem_layout_dict = dict()
        with open(mem_layout_path, 'r') as f:
            mem_layout_dict = json.load(f)

        self.tag_ops = dict()
        NCHW_inputs = []
        for node_name, info in mem_layout_dict.items():
            self.tag_ops[node_name] = OpTag(info['out_shape'], info['layout'])
            if info['nchw_input']:
                NCHW_inputs.append(node_name)
        return NCHW_inputs

    def check_model(self, dbg):
        self.check_model_vinfo()
        if dbg:
            self.check_tensor_shapes()
            self.check_tensor_ranks()
            self.check_transpose_perms()
            self.check_nchw_ops()
            self.save_dbg_model(f"{self.model_name}_0_before_tagging.onnx")
        return

    def check_model_vinfo(self):
        for t_name, tensor in self.md.vinfo.items():
            writers = self.md.tensor_writers.get(t_name, [])
            if len(writers) > 1:
                logger.info(f"Tensor {t_name} has multiple writers: {writers}")

    def _get_shapes_by_dim(self, tensors: Iterable[ValueInfoProto]) -> dict[str, list[dict]]:
        md = self.md
        shapes_ = [
            tuple(md.get_shape(tensor))
            for tensor in tensors
        ]

        shape_counter = Counter(shapes_)
        shapes_by_dim = defaultdict(list)
        for shape_, count_ in shape_counter.items():
            shapes_by_dim[f"{get_dim(shape_)}D"].append(
                {"shape": shape_, "count": count_}
            )
        shapes_by_dim = dict(sorted(shapes_by_dim.items()))
        return shapes_by_dim

    def check_tensor_shapes(self) -> None:
        md = self.md
        shapes_by_dim = self._get_shapes_by_dim(md.vinfo.values())
        ginputs_by_dim = self._get_shapes_by_dim(md.inputs.values())
        goutputs_by_dim = self._get_shapes_by_dim(md.outputs.values())
        with open(os.path.join(self.out_dir, "check_tensor_shapes.json"), 'w') as f:
            json.dump(
                {
                    "graph_inputs": ginputs_by_dim,
                    "graph_outputs": goutputs_by_dim,
                    "tensors": shapes_by_dim,
                },
                f,
                indent=2
            )
        return

    def _get_shapes_by_dim_rank(self, tensors: Iterable[ValueInfoProto]) -> dict[str, list[dict]]:
        md = self.md
        shapes_ = [
            tuple(md.get_shape(tensor))
            for tensor in tensors
        ]

        shape_counter = Counter(shapes_)
        shapes_by_dim_rank = defaultdict(list)
        for shape_, count_ in shape_counter.items():
            shapes_by_dim_rank[f"{get_dim(shape_)}D({get_rank(shape_)})"].append(
                {"shape": shape_, "count": count_}
            )
        shapes_by_dim_rank = dict(sorted(shapes_by_dim_rank.items()))
        return shapes_by_dim_rank

    def check_tensor_ranks(self) -> None:
        md = self.md
        shapes_by_dim_rank = self._get_shapes_by_dim_rank(md.vinfo.values())
        ginputs_by_dim_rank = self._get_shapes_by_dim_rank(md.inputs.values())
        goutputs_by_dim_rank = self._get_shapes_by_dim_rank(md.outputs.values())
        with open(os.path.join(self.out_dir, "check_tensor_ranks.json"), 'w') as f:
            json.dump(
                {
                    "graph_inputs": ginputs_by_dim_rank,
                    "graph_outputs": goutputs_by_dim_rank,
                    "tensors": shapes_by_dim_rank,
                },
                f,
                indent=2
            )
        return

    def check_transpose_perms(self) -> None:
        md = self.md
        transposes_ = {
            n_name: node
            for n_name, node in md.nodes.items()
            if node.op_type == "Transpose"
        }

        perms_ = defaultdict(list)
        for n_name, node in transposes_.items():
            perm_ = tuple(get_attribute(node, "perm"))
            perms_[perm_].append(n_name)

        perms_by_dim = defaultdict(list)
        for perm_, nodes_ in perms_.items():
            perms_by_dim[f"{len(perm_)}D"].append({
                "perm": perm_, "count": len(nodes_), "nodes": nodes_
            })
        perms_by_dim = dict(sorted(perms_by_dim.items()))

        with open(os.path.join(self.out_dir, "check_transpose_perms.json"), 'w') as f:
            json.dump(perms_by_dim, f, indent=2)
        return

    def check_nchw_ops(self) -> None:
        md = self.md
        nchw_ops = md.get_nodes(op_types=NCHW_OPERATORS)
        logger.debug(f"Found {len(nchw_ops)} NCHW ops")
        nchw_ops_by_optype = defaultdict(list)
        for node in nchw_ops:
            nchw_ops_by_optype[node.op_type].append(node)
            pass

        for op_type, nodes in nchw_ops_by_optype.items():
            in_act_dims = [
                f"{md.get_dim(act)}D({md.get_rank(act)})"
                for n in nodes
                for act in md.get_node_activations(n)
            ]
            in_act_dims = dict(Counter(in_act_dims))

            out_act_dims = [
                f"{md.get_dim(act)}D({md.get_rank(act)})"
                for n in nodes
                for act in md.get_node_outputs(n)
            ]
            out_act_dims = dict(Counter(out_act_dims))
            logger.debug(f"\t#{op_type} = {len(nodes)} :: {in_act_dims = } : {out_act_dims = }")

        return

    def get_shape(self, tensor_: str | ValueInfoProto | TensorProto | TensorInfo) -> TensorShape:
        # TODO: replace self.get_shape() with self.md.get_shape()
        return self.md.get_shape(tensor_)

    def get_NHWC_shape(self, tensor_: str | ValueInfoProto | TensorProto | TensorInfo | TensorShape) -> TensorShape:
        shape_ = self.md.get_shape(tensor_) if not isinstance(tensor_, (list | tuple)) else tensor_
        shape_, perm_ = self.convert_shape_to_NHWC(shape_)
        return shape_

    def get_onnx_dtype(self, tensor_: str | ValueInfoProto | TensorProto | TensorInfo) -> OnnxDType:
        # TODO: replace self.get_onnx_dtype() with self.md.get_onnx_dtype()
        return self.md.get_onnx_dtype(tensor_)

    def get_out_shape(self, node: str | NodeProto, out_idx=0, *, with_dtype=False) -> TensorShape | tuple[TensorShape, OnnxDType]:
        node = self.md.get_node(node)
        if not node:
            return list()

        out_tensor = node.output[out_idx]
        out_tensor = self.md.vinfo[out_tensor]
        out_shape = self.md.get_shape(out_tensor)
        if with_dtype:
            out_dtype = self.md.get_onnx_dtype(out_tensor)
            return out_shape, out_dtype

        return out_shape

    def get_tag(self, node: str | Optional[NodeProto]) -> str:
        node = self.md.get_node(node, node if node and isinstance(node, NodeProto) else None)
        if not node or node.name not in self.tag_ops:
            return OpTag.NOT_FOUND().tag
        return self.tag_ops[node.name].tag

    def transpose_vinfo_nhwc(self, vinfo: ValueInfoProto):
        shape = self.get_shape(vinfo)
        perm = [0, 2, 1] if len(shape) == 3 else [0, 2, 3, 1]
        dtype = self.get_onnx_dtype(vinfo)
        nhwc_shape = [shape[x] for x in perm]
        vinfo = onnx.helper.make_tensor_value_info(vinfo.name, dtype, nhwc_shape)
        return vinfo

    def transpose_vinfo_nchw(self, vinfo: ValueInfoProto):
        shape = self.get_shape(vinfo)
        perm = [0, 2, 1] if len(shape) == 3 else [0, 3, 1, 2]
        dtype = self.get_onnx_dtype(vinfo)
        nchw_shape = [shape[x] for x in perm]
        vinfo = onnx.helper.make_tensor_value_info(vinfo.name, dtype, nchw_shape)
        return vinfo

    def convert_tag_to_NHWC(self, tag: OpTag) -> OpTag:
        shape_ = self.get_NHWC_shape(tag.shape)
        return OpTag(shape_, 'NHWC')

    def convert_shape_to_NHWC(self, shape: TensorShape) -> tuple[TensorShape, Perm]:
        perm = NHWC_PERMUTATIONS.get(len(shape), [])
        if not perm:
            return [], []
            #raise Exception(f"unsupported shape len {name} '{dim}'")
        return [shape[x] for x in perm], perm

    def check_tag_status(self):
        logger.debug(f"\tTotal number of graph nodes: {len(self.md.nodes)}")
        logger.debug(f"\tTotal number of tagged nodes: {len(self.tag_ops)}")
        logger.debug(f"\tTotal number of transpose nodes: {len(self.transpose_ops)}")

        layout_tags = [tag for _sh, tag in self.tag_ops.values()]
        tag_counter = Counter(layout_tags)
        for layout, count in tag_counter.items():
            if layout == "Don't care":
                logger.warning(f"\t{layout} tag count {count}")
            else:
                logger.debug(f"\t{layout} tag count {count}")
        logger.debug(f"\tTotal tag count {tag_counter.total()}\n")

    def check_graph_status(self):
        md = self.md
        logger.debug(f"Graph status - NCHW nodes:")
        g_status = defaultdict(list)
        for node_name, (shape, tag) in self.tag_ops.items():
            if tag == 'NCHW' and node_name in md.nodes:
                op_type = md.nodes[node_name].op_type
                if op_type == 'DequantizeLinear':
                    if md.nodes[node_name].input[0] in md.ini:
                        op_type = op_type + '-Initializer'
                    else:
                        op_type = op_type + '-Datapath'

                g_status[op_type].append(node_name)

        for op_type, nodes in sorted(g_status.items()):
            count = len(nodes)
            g_status[op_type] = {"count": count, "nodes": nodes}
            logger.debug(f"\t{op_type} {count}")

        with open(os.path.join(self.out_dir, "check_graph_status.json"),'w') as f:
            json.dump(g_status, f, indent=2)

    def _shield_NCHW_op_input_simple(
        self, op_name: str, in_shield_output: str,
    ) -> ValueInfoProto:
        """
        Shielding INPUT of a NCHW node "C"

        ===========================================================================================
                                (DQ0)   --[X]->   (C)   --[Y]->   (Q1)
        -------------------------------------------------------------------------------------------

                (DQ0) --[X]->                                                 -> (C)  --[Y]->  (Q1)
        ===>
                (DQ0) --[X]-> (Q0*) -> (dq>T0>q) -> (dq>T1>q) -> (DQ0*) --[X*]-> (C)
                      <NCHW>               ||  <NHWC>   ||              <NCHW>

        ===========================================================================================

        - T0: toNHWC,
        - T1: toNCHW
        - layout[X] == layout[C] == NCHW
        - QDQ<IN>: [sc, zp] := DQ0.[sc, zp]

        -------------------------------------------------------------------------------------------
        """
        has_qdq, sc0, zp0 = False, None, None  # qdq_scale_0, qdq_zero_point_0
        float_dtype, int_dtype = DEFAULT_FLOAT_DTYPE, DEFAULT_INT_DTYPE
        nodeC = self.md.get_node(op_name)

        # [1] IN-SHIELD I/O and SETUP
        inC: ValueInfoProto = self.md.get_node_activations(nodeC, first=True)      # in_shield_input
        inC_ = TensorInfo(                                                          # in_shield_output
            in_shield_output, self.md.get_shape(inC), self.md.get_onnx_dtype(inC)
        )

        # in_shield_input = inC_.name
        in_shield_input = inC.name
        transpose_nhwc_0 = f"{op_name}_in_transpose_nhwc_0"
        transpose_nchw_1 = f"{op_name}_in_transpose_nchw_1"
        nchw_shape = self.md.get_shape(inC)
        nhwc_shape = self.get_NHWC_shape(nchw_shape)
        nchw_perm = NCHW_PERMUTATIONS[get_dim(nchw_shape)]
        nhwc_perm = NHWC_PERMUTATIONS[get_dim(nchw_shape)]

        # [2] check for DQ0 at C's input
        if dq0 := self.md.get_node_suppliers(nodeC, "DequantizeLinear", first=True):
            has_qdq, sc0, zp0 = True, dq0.input[1], dq0.input[2]
            float_dtype = self.get_onnx_dtype(sc0)
            int_dtype = self.get_onnx_dtype(zp0)
        else:
            logger.debug(
                f"NCHW SHIELDING: "
                f"hasn't found a DQ-node at the inputs of a NCHW op-node {nodeC.op_type}({nodeC.name})."
                f"Transpose nodes without Q/DQ around will be inserted."
            )
            pass

        # [3.0] insert new Q0*:      (DQ0)  --[osh.in]->   (Q0*)   --[q0*.out]->
        if has_qdq:
            _q0, _q0_out = self._append_node_to_model(*self._make_q_node(
                f"{op_name}_in_shield_head_q0", sc0, zp0,
                input=in_shield_input,
                output_shape=nchw_shape,
                output_dtype=int_dtype,
            ))
            self.tag_ops[_q0.name] = OpTag(self.md.get_shape(_q0_out), 'NCHW')
            in_shield_input = _q0_out.name

        # [3.1] insert T0-shield:           --[q0*.out]->   (dq->T0->q)   --[ish0]->
        shield_out_0, shield_0 = self._append_dq_transpose_q(
            transpose_nhwc_0, nhwc_perm, sc0, zp0,
            input=in_shield_input,
            # make_dq=False if has_qdq else True, # append_dq=False,
            input_shape=nchw_shape, output_shape=nhwc_shape,
            input_layout='NCHW', output_layout='NHWC',
            float_dtype=float_dtype, int_dtype=int_dtype ,
        )

        # [3.2] insert T1-shield:             --[ish0]->    (dq->T1->q)   --[ish1]->
        shield_out_1, shield_1 = self._append_dq_transpose_q(
            transpose_nchw_1, nchw_perm, sc0, zp0,
            input=shield_out_0.name,
            output=None if has_qdq else in_shield_output,
            input_shape=nhwc_shape, output_shape=nchw_shape,
            input_layout='NHWC', output_layout='NCHW',
            float_dtype=float_dtype, int_dtype=int_dtype,
        )
        inC_ = shield_out_1

        # [3.3] insert in-shield's tail:          --[ish1]->     (DQ0*)    --[ish.out]-> (C)
        if has_qdq:
            _dq0, _dq0_out = self._append_node_to_model(*self._make_dq_node(
                f"{op_name}_in_shield_tail_dq0", sc0, zp0,
                input=shield_out_1.name,
                output=in_shield_output,
                output_shape=nchw_shape,
                output_dtype=float_dtype,
            ))
            self.tag_ops[_dq0.name] = OpTag(nchw_shape, 'NCHW')
            inC_ = _dq0_out

        if has_qdq:
            in_node1 = self.md.get_node_suppliers(dq0, "QuantizeLinear", first=True)
            if in_node1:
                self.tensor_map[shield_out_1.name] = dict(
                    orig_tensor=dq0.input[0],
                    orig_shape=self.md.get_shape(dq0.input[0]),
                    orig_layout=self.get_tag(dq0),
                    final_shape=self.md.get_shape(shield_out_1),
                    final_layout=self.get_tag(shield_1['q'][0])
                )

        self.transpose_ops.append(transpose_nhwc_0)
        self.transpose_ops.append(transpose_nchw_1)
        return inC_

    def _shield_NCHW_op_output_simple(
            self, op_name: str, in_shield_output: str, out_shield_input: str,
    ) -> None:
        """
        Shielding OUTPUT of a NCHW node "C"

        1) the default case:
        ===========================================================================================
                                (DQ0)   --[X]->   (C)   --[Y]->   (Q1)
        -------------------------------------------------------------------------------------------

         (DQ0)  --[X]->  (C)                                                       --[Y]->  (Q1)

        ==>              (C)  --[Y*]->  (Q1*) -> (dq>T0>q) -> (dq>T1>q) -> (DQ1*)  --[Y]->  (Q1)
                               <NCHW>                ||  <NHWC>   ||               <NCHW>

        ===========================================================================================


        2) a non-qdq op-node R followed by Q-node is allowed at C's output:
        ===========================================================================================
                            (DQ0)   --[X]->   (C) -> (R)   --[Y]->   (Q1)
        -------------------------------------------------------------------------------------------

        (DQ0)  --[X]->  (C) -> (R)                                                     --[Y]->  (Q1)

        ==>             (C) -> (R) --[Y*]-> (Q1*) -> (dq>T0>q) -> (dq>T1>q) -> (DQ1*)  --[Y]->  (Q1)
                                    <NCHW>               ||  <NHWC>   ||            <NCHW>

        ===========================================================================================

        - T0: toNHWC
        - T1: toNCHW
        - layout[C] == layout[R] == layout[Y] == NCHW
        - QDQ<OUT>: [sc, zp] := Q1.[sc, zp]

        -------------------------------------------------------------------------------------------
        """
        has_qdq, sc1, zp1 = False, None, None  # qdq_scale, qdq_zero_point
        float_dtype, int_dtype = DEFAULT_FLOAT_DTYPE, DEFAULT_INT_DTYPE
        nodeC = self.md.get_node(op_name)

        # TODO: move modification of C-node into separate metod
        # [A.1] MODIFY C-NODE: connect to the output of  IN-shield
        nodeC_ = copy.deepcopy(nodeC)
        nodeC_.input[0] = in_shield_output

        # [A.2] check for non-Q node (Relu) at C's input:
        #     - to find what 'out_shield_output' is
        #     - where to attach 'out_shield_input' -- at C's output or at non-Q's output
        #     - update the C-node and non-Q nodes accordingly
        if nonQ := self.md.get_node_consumers(nodeC, "QuantizeLinear", exclude=True):
            if len(nonQ) > 1:
                logger.warning(
                    f"NCHW SHIELDING: "
                    f"a NCHW op-node {nodeC.op_type}({nodeC.name}) "
                    f"has {len(nonQ)} non-QuantizeLinear node-consumers."
                )
                # TODO: raise exception?

            _, nonQ = nonQ.popitem()
            logger.debug(
                f"NCHW SHIELDING: "
                f"a NCHW op-node {nodeC.op_type}({nodeC.name}) "
                f"has a single non-QuantizeLinear node-consumer {nonQ.op_type}({nonQ.name})."
            )

            # [A.3] MODIFY C-NODE: save the updated C-node and move pointer to the matched nonQ-node
            self.model.graph.node.remove(nodeC)
            self.model.graph.node.append(nodeC_)
            nodeC = nonQ
            nodeC_ = copy.deepcopy(nodeC)
            pass

        # OUT SHIELD I/O
        outC = self.md.get_node_outputs(nodeC, first=True)            # out_shield_output
        outC_ = TensorInfo(                                                 # out_shield_input
            out_shield_input, self.md.get_shape(outC), self.md.get_onnx_dtype(outC)
        )
        self.model.graph.value_info.append(outC_.make_activation_tensor())

        # [A.4] MODIFY C-NODE: make the input of OUT-shield and connect it the output of C-node
        nodeC_.output[0] = out_shield_input
        self.model.graph.node.remove(nodeC)
        self.model.graph.node.append(nodeC_)
        self.md.nodes[nodeC_.name] = nodeC_
        self.tag_ops[nodeC_.name] = OpTag(self.md.get_shape(outC), 'NCHW')
        self.frozen_layout[nodeC_.name] = 'NCHW'

        # [1] setup OUT-shield
        # out_shield_input = outC_.name
        out_shield_output = outC.name

        transpose_nhwc_0 = f"{op_name}_out_transpose_nhwc_0"
        transpose_nchw_1 = f"{op_name}_out_transpose_nchw_1"
        nchw_shape = self.md.get_shape(outC)
        nhwc_shape = self.get_NHWC_shape(nchw_shape)
        nchw_perm = NCHW_PERMUTATIONS[get_dim(nchw_shape)]
        nhwc_perm = NHWC_PERMUTATIONS[get_dim(nchw_shape)]

        # [2] check for Q1-node at C's outputs
        if q1 := self.md.get_node_consumers(nodeC, "QuantizeLinear", first=True):
            has_qdq, sc1, zp1 = True, q1.input[1], q1.input[2]
            float_dtype = self.get_onnx_dtype(sc1)
            int_dtype = self.get_onnx_dtype(zp1)
        else:
            logger.debug(
                f"NCHW SHIELDING: "
                f"hasn't found a Q-node consumer at the outputs of a NCHW op-node {nodeC.op_type}({nodeC.name})."
                f"Transpose nodes without Q/DQ around will be inserted."
            )

        # [3.0] insert copy of Q1:      (C) --[osh.in]->    (Q1*)   --[q1*.out]->
        _q1, _q1_out = None, None
        if has_qdq:
            _q1, _q1_out = self._append_node_to_model(*self._make_q_node(
                f"{op_name}_out_shield_head_q1", sc1, zp1,
                input=out_shield_input,
                output_shape=nchw_shape,
                output_dtype=int_dtype,
            ))
            self.tag_ops[_q1.name] = OpTag(self.md.get_shape(_q1_out), 'NCHW')
            out_shield_input = _q1_out.name

        # [3.1] insert T0-shield:           --[q1*.out]->   (dq->T0->q)  --[osh0]->
        shield_out_0, shield_0 = self._append_dq_transpose_q(
            transpose_nhwc_0, nhwc_perm, sc1, zp1,
            input=out_shield_input,
            input_shape=nchw_shape, output_shape=nhwc_shape,
            input_layout='NCHW', output_layout='NHWC',
            float_dtype=float_dtype, int_dtype=int_dtype,
        )

        # [3.2] insert T1-shield:               --[osh0]->    (dq->T1->q) -> (DQ1*)   --[osh.out]->  (Q1)
        shield_out_1, shield_1 = self._append_dq_transpose_q(
            transpose_nchw_1, nchw_perm, sc1, zp1,
            input=shield_out_0.name,
            output=None if has_qdq else out_shield_output,
            input_shape=nhwc_shape, output_shape=nchw_shape,
            input_layout='NHWC', output_layout='NCHW',
            float_dtype=float_dtype, int_dtype=int_dtype,
        )

        # insert out-shield's tail  (DQ1*)
        if has_qdq:
            _dq1, _dq1_out = self._append_node_to_model(*self._make_dq_node(
                f"{transpose_nchw_1}_out_shield_tail_dq1", sc1, zp1,
                input=shield_out_1.name,
                output=out_shield_output,
                output_shape=nchw_shape,
                output_dtype=float_dtype,
            ))
            self.tag_ops[_dq1.name] = OpTag(self.md.get_shape(_dq1_out), 'NCHW')

        if has_qdq:         # out_shield_input == _q1_out.name
            has_dq, has_tr, has_q = False, False, False
            out_nodes1 = self.md.get_node_consumers(q1)
            if len(out_nodes1) == 1 and next(iter(out_nodes1.values())).op_type == "DequantizeLinear":
                has_dq = True
                out_nodes2 = self.md.get_node_consumers(next(iter(out_nodes1.values())))
                if len(out_nodes2) == 1 and next(iter(out_nodes2.values())).op_type == "Transpose":
                    has_tr = True
                    out_nodes3 = self.md.get_node_consumers(next(iter(out_nodes2.values())))
                    if len(out_nodes3) == 1 and next(iter(out_nodes3.values())).op_type == "QuantizeLinear":
                        has_q = True
                        final_q = next(iter(out_nodes3.values()))
                        final_q_out = final_q.output[0]
                        pass
                    pass
                pass

            if has_dq and has_tr and has_q:
                self.tensor_map[shield_out_0.name] = dict(
                    orig_tensor=final_q_out,
                    orig_shape=self.md.get_shape(final_q_out),
                    orig_layout=self.get_tag(final_q),
                    final_shape=self.md.get_shape(shield_out_0),
                    final_layout=self.get_tag(shield_0['q'][0]),
                )
            else:
                self.tensor_map[shield_out_0.name] = dict(
                    orig_tensor=q1.output[0],
                    orig_shape=self.md.get_shape(q1.output[0]),
                    orig_layout=self.get_tag(q1),
                    final_shape=self.md.get_shape(shield_out_0),
                    final_layout=self.get_tag(shield_0['q'][0]),
                )

            self.tensor_map[out_shield_input] = dict(
                orig_tensor=q1.output[0],
                orig_shape=self.md.get_shape(q1.output[0]),
                orig_layout=self.get_tag(q1),
                final_shape=self.md.get_shape(_q1_out),
                final_layout=self.get_tag(_q1),
            )


        self.transpose_ops.append(transpose_nhwc_0)
        self.transpose_ops.append(transpose_nchw_1)
        return

    def shield_NCHW_ops(self):
        """
        Double shielding of NCHW ops:
            inserts a pair of extra transpose nodes at the inputs and outputs of NCHW op-nodes,
            separating them from the rest of the Graph and protecting their layout from modifying.
        """
        nchw_nodes = self.md.get_nodes_dict(*NCHW_OPERATORS)
        logger.info(
            f"Start layout shielding of NCHW ops: "
            f"{len(nchw_nodes)} NCHW op-node{'s' if len(nchw_nodes) != 1 else ''} will be shielded"
        )

        # SHIELD INPUTS OF NCHW OP-NODES
        for op_name, nodeC in nchw_nodes.items():
            try:
                # SHIELD INPUTS  ----------------------------------------------------------------------
                in_shield_output = self._shield_NCHW_op_input_simple(
                    nodeC.name,  in_shield_output=f"{nodeC.name}_shielded_input"
                )

                # MODIFY CENTRAL NODE  ----------------------------------------------------------------
                # TODO: move modification of C-node into separate metod

                # SHIELD OUTPUTS  ---------------------------------------------------------------------
                self._shield_NCHW_op_output_simple(
                    nodeC.name, in_shield_output.name, out_shield_input=f"{nodeC.name}_shielded_output"
                )

            except Exception as e:
                logger.error(f"Error occurred during shielding inputs of NCHW Node({op_name}): {e}")
                raise e
            pass

        logger.debug(f"  * Tags after shielding NCHW-ins:  {len(self.tag_ops)}")
        logger.debug(f"  * Tags after shielding NCHW-OUTs: {len(self.tag_ops)}")
        return

    @classmethod
    def _make_q_node(
            cls, name,
            scale: TensorProto|str, zero_point: TensorProto|str,
            input=None, output=None,
            output_shape=None,
            output_dtype=DEFAULT_INT_DTYPE,
            make_q_out=True,
            **kwargs
    ) -> tuple[NodeProto, ValueInfoProto]:
        input = input if input else f"{name}_in"
        output = output if output else f"{name}_out"
        scale = scale.name if isinstance(scale, TensorProto) else scale
        zero_point = zero_point.name if isinstance(zero_point, TensorProto) else zero_point
        q = make_node(
            'QuantizeLinear',
            name=name,
            inputs=[input, scale, zero_point],
            outputs=[output],
            domain=kwargs.get('domain', 'com.microsoft'),
        )

        q_out = None
        if make_q_out and output_shape:
            q_out = make_tensor_value_info(
                output, output_dtype, output_shape
            )
        return q, q_out

    @classmethod
    def _make_dq_node(
            cls, name,
            scale: TensorProto|str, zero_point: TensorProto|str,
            input=None, output=None,
            output_shape=None,
            output_dtype=DEFAULT_FLOAT_DTYPE,
            make_dq_out=True,
            **kwargs
    ) -> tuple[NodeProto, ValueInfoProto]:
        input = input if input else f"{name}_in"
        output = output if output else f"{name}_out"
        scale = scale.name if isinstance(scale, TensorProto) else scale
        zero_point = zero_point.name if isinstance(zero_point, TensorProto) else zero_point
        dq = make_node(
            'DequantizeLinear',
            name=name,
            inputs=[input, scale, zero_point],
            outputs=[output],
            domain=kwargs.get('domain', 'com.microsoft'),
        )

        dq_out = None
        if make_dq_out and output_shape:
            dq_out = make_tensor_value_info(
                output, output_dtype, output_shape
            )
        return dq, dq_out

    @classmethod
    def _make_transpose_node(
            cls, name, perm,
            input=None, output=None,
            output_shape=None,
            output_dtype=DEFAULT_FLOAT_DTYPE,
            make_tr_out=True,
            **kwargs
    ) -> tuple[NodeProto, ValueInfoProto]:
        input = input if input else f"{name}_in"
        output = output if output else f"{name}_out"
        tr = make_node(
            'Transpose',
            name=name,
            inputs=[input],
            outputs=[output],
            perm=perm,
            domain=kwargs.get('domain'),
        )

        tr_out = None
        if make_tr_out and output_shape:
            tr_out = make_tensor_value_info(
                output, output_dtype, output_shape
            )
        return tr, tr_out

    @classmethod
    def _make_dq_transpose_q(
            cls, tr_name:str,
            perm: list[int],
            scale: TensorProto|str, zero_point: TensorProto|str,
            input_shape: list[int],    # shape before Tr
            output_shape: list[int],   # shape after  Tr
            input:str=None, output:str=None,
            make_dq=True, make_q=True,
            float_dtype=DEFAULT_FLOAT_DTYPE,
            int_dtype=DEFAULT_INT_DTYPE,
            **kwargs,
    ) -> dict[str, tuple[NodeProto, ValueInfoProto]]:
        has_qdq = False if scale is None or zero_point is None else True
        if not input_shape or not output_shape:
            raise ValueError("Input and output shapes are required ")

        dq, dq_out = None, None
        if has_qdq and make_dq:
            dq, dq_out = cls._make_dq_node(
                f"{tr_name}_in_dq", scale, zero_point,
                input=input,
                output_shape=input_shape,
                output_dtype=float_dtype,
                **kwargs,
            )
            input = dq_out.name

        tr, tr_out = cls._make_transpose_node(
            tr_name, perm,
            input=input,
            output= None if has_qdq and make_q else output,
            output_shape=output_shape,
            output_dtype=float_dtype,
            **kwargs,
        )

        q, q_out = None, None
        if has_qdq and make_q:
            q, q_out = cls._make_q_node(
                f"{tr_name}_out_q", scale, zero_point,
                input=tr_out.name,
                output=output,
                output_shape=output_shape,
                output_dtype=int_dtype,
                **kwargs,
            )

        return  dict(dq=(dq, dq_out), tr=(tr, tr_out), q=(q, q_out))

    @classmethod
    def _make_reshape_node(
            cls, name,
            input=None, output=None,
            output_shape=None,
            output_dtype=DEFAULT_INT_DTYPE,
            make_re_out=True,
            **kwargs
    ) -> tuple[NodeProto, ValueInfoProto, TensorProto]:
        input = input if input else f"{name}_in"
        output = output if output else f"{name}_out"
        output_shape = output_shape if output_shape else []

        re_shape_ini = onnx.numpy_helper.from_array(
            np.array(output_shape).astype(np.int64),
            f"{name}_shape"
        )

        re = make_node(
            "Reshape",
            name=name,
            inputs=[input, re_shape_ini.name],
            outputs=[output],
            domain=kwargs.get('domain', 'com.microsoft'),
            allowzero=kwargs.get('allowzero', 0),
        )

        re_out = None
        if make_re_out and output_shape:
            re_out = make_tensor_value_info(
                output, output_dtype, output_shape
            )
        return re, re_out, re_shape_ini

    def _append_node_to_model(
            self, n: NodeProto | None = None, n_out: ValueInfoProto | None = None, *n_ini: TensorProto
    ) -> tuple[NodeProto, ValueInfoProto, list[TensorProto, ...]] | tuple[NodeProto, ValueInfoProto]:
        if n:
            self.model.graph.node.append(n)
            n = self.model.graph.node[-1]
        if n_out:
            self.model.graph.value_info.append(n_out)
            n_out = self.model.graph.value_info[-1]
        if n_ini:
            for _ini in n_ini:
                self.model.graph.initializer.append(_ini)
        return n, n_out, *n_ini

    def _append_dq_transpose_q(
            self, tr_name,
            perm: list[int],
            scale: TensorProto | str, zero_point: TensorProto | str,
            input_shape: list[int],    # shape before Tr
            output_shape: list[int],   # shape after  Tr
            input_layout: str, output_layout: str,
            append_q=True, append_dq=True,
            **kwargs,
    ) -> tuple[ValueInfoProto, dict[str, tuple[NodeProto, ValueInfoProto]]]:
        has_qdq = False if scale is None or zero_point is None else True

        # make DQ -> T -> Q
        dq_tr_q = self._make_dq_transpose_q(
            tr_name, perm, scale, zero_point,
            input_shape=input_shape, output_shape=output_shape,
            **kwargs
        )

        if has_qdq and append_dq:
            dq, dq_out = self._append_node_to_model(*dq_tr_q['dq'])
            if dq:
                self.tag_ops[dq.name] = OpTag(input_shape, input_layout)

        tr, tr_out = self._append_node_to_model(*dq_tr_q['tr'])
        self.tag_ops[tr.name] = OpTag(output_shape, output_layout)
        group_output = tr_out

        if has_qdq and append_q:
            q, q_out = self._append_node_to_model(*dq_tr_q['q'])
            group_output = q_out
            if q:
                self.tag_ops[q.name] = OpTag(output_shape, output_layout)

        return group_output, dq_tr_q

    def _copy_and_replace_input(
            self, node: NodeProto | str, old_input: str, new_input: str,
            remove=True, append=True,
    ) -> NodeProto:
        node = self.md.get_node(node)
        if remove:
            self.model.graph.node.remove(node)
        node_ = copy.deepcopy(node)
        idx_ = list(node.input).index(old_input)
        node_.input[idx_] = new_input

        if append:
            self.model.graph.node.append(node_)
            node_ = self.model.graph.node[-1]
            self.md.nodes[node_.name] = node_
        return node_

    def _copy_and_replace_output(
            self, node: NodeProto, old_output: str, new_output: str,
            remove=True, append=True,
    ) -> NodeProto:
        node = self.md.get_node(node)
        if remove:
            self.model.graph.node.remove(node)
        node_ = copy.deepcopy(node)
        idx_ = list(node.output).index(old_output)
        node_.output[idx_] = new_output

        if append:
            self.model.graph.node.append(node_)
            node_ =  self.model.graph.node[-1]
            self.md.nodes[node_.name] = node_
        return node_

    def shield_NCHW_graph_inputs(self) -> list[str]:
        """
        Double shielding of NCHW Graph Inputs:
            inserts a pair of extra transpose nodes (w/ or w/o qdq nodes)
            between Graph Inputs tagged as NCHW  and the rest of the Graph.
        """
        # check each input for NCHW layout
        NCHW_inputs = []
        md = self.md

        def is_nchw_root(n):
            return self.md.is_root(n) and self.tag_ops[n.name].tag == "NCHW"

        nchw_roots = {
            gin: md.tensor_readers[gin]
            for gin in md.inputs
            if any(map(is_nchw_root, self.md.get_nodes(md.tensor_readers[gin])))
        }
        logger.info(
            f"Start layout shielding of Graph Inputs: "
            f"{len(nchw_roots)} graph input{'s' if len(nchw_roots) != 1 else ''} will be shielded"
        )

        # IF node is a NCHW root ==> double shield the NCHW layout of incident GraphInput
        for graph_input, (node, *siblings) in nchw_roots.items():
            input_tensor = md.get_tensor(graph_input)
            node = md.get_node(node)
            # logger.info(f">>> shielding: GraphInput({graph_input}) -> Node({node.name}) and {len(siblings)} others")

            if node.op_type == 'Transpose':
                n_perm = get_attribute(node, "perm")
                if n_perm not in NCHW_PERMUTATIONS.values():
                    continue

                re_name = f"{input_tensor.name}_reshape_dummy"
                re, re_out, re_shape_ini = self._make_reshape_node(
                    re_name,
                    input=input_tensor.name, output=f"{re_name}_out",
                    output_shape=self.get_shape(input_tensor),
                    output_dtype=self.get_onnx_dtype(input_tensor),
                    domain="",
                    allowzero=0,
                )
                self._append_node_to_model(re, re_out, re_shape_ini)
                self.tag_ops[re.name] = OpTag(self.get_shape(re_out), "NHWC")
                self._copy_and_replace_input(node, input_tensor.name, new_input=re_out.name)

                NCHW_inputs.append(node.name)
                continue

            if node.op_type not in ("QuantizeLinear", "DequantizeLinear"):
                logger.warning(
                    f"GraphInput({graph_input}) expects an adjacent Q/DQ node, "
                    f"but found {node.op_type}({node.name}). "
                    f"Transpose nodes without Q/DQ around will be inserted."
                )
                has_qdq, qdq_scale, qdq_zero_point = False, None, None
            else:
                # Q/DQ.scale, Q/DQ.zero_point == node.input[1], node.input[2],
                has_qdq, qdq_scale, qdq_zero_point = True, node.input[1], node.input[2]

            # node is a Q/DQ-type node ==>  node.input[0] == Q/DQ.x,
            input_name_x = input_tensor.name
            input_dtype = self.get_onnx_dtype(input_tensor)
            output_dtype = self.get_onnx_dtype(node.output[0])
            floattype = input_dtype if is_floating_dtype(input_dtype) else output_dtype
            uinttype = input_dtype if is_integer_dtype(input_dtype) else output_dtype

            nchw_shape = self.get_shape(input_tensor)
            nhwc_shape = self.get_NHWC_shape(input_tensor)
            if len(nchw_shape) == 1:
                continue
            to_NCHW_perm = NCHW_PERMUTATIONS[len(nchw_shape)]
            to_NHWC_perm = NHWC_PERMUTATIONS[len(nchw_shape)]

            # MAKE SHIELDING CHAIN for GraphInput:
            #           [GIn] -> (_Q*) -> (dq->T0->q) -> (dq->T1->Q*) -> (node)
            # T0: to_nhwc, T1: to_nchw
            transpose_0 = f"{input_name_x}_transpose_nhwc_0"
            transpose_1 = f"{input_name_x}_transpose_nchw_1"

            # add _Q*
            if has_qdq and is_floating_dtype(input_dtype):
                _q, _q_out = self._append_node_to_model(*self._make_q_node(
                    f'{transpose_0}_in_q', qdq_scale, qdq_zero_point,
                    input=input_name_x,
                    output_shape=nchw_shape,
                    output_dtype=uinttype,
                ))
                self.tag_ops[_q.name] = OpTag(nchw_shape, 'NCHW')
                input_name_x = _q_out.name
                if len(self.md.tensor_readers[graph_input]) == 1:
                    # self.tensor_map[_q_out.name] = node.output[0]
                    self.tensor_map[_q_out.name] = dict(
                        orig_tensor=node.output[0],
                        orig_shape=self.md.get_shape(node.output[0]),
                        orig_layout=self.get_tag(node),
                        final_shape=self.md.get_shape(_q_out),
                        final_layout=self.get_tag(_q)
                    )
                pass

            # add NHWC-shield: DQ -> T0 -> Q
            shield_out_0, shield_0 = self._append_dq_transpose_q(
                transpose_0, to_NHWC_perm, qdq_scale, qdq_zero_point,
                input=input_name_x,
                input_shape=nchw_shape, output_shape=nhwc_shape,
                input_layout='NCHW', output_layout='NHWC',
                float_dtype=floattype, int_dtype=uinttype,
            )

            # add NCHW-shield: DQ -> T1 -> Q*
            shield_out_1, shield_1 = self._append_dq_transpose_q(
                transpose_1, to_NCHW_perm, qdq_scale, qdq_zero_point,
                input=shield_out_0.name,
                append_q=is_integer_dtype(input_dtype),
                input_shape=nhwc_shape, output_shape=nchw_shape,
                input_layout='NHWC', output_layout='NCHW',
                float_dtype=floattype, int_dtype=uinttype,
            )

            if has_qdq and is_floating_dtype(input_dtype) and len(self.md.tensor_readers[graph_input]) == 1:
                self.tensor_map[shield_out_0.name] = dict(
                    orig_tensor=node.output[0],
                    orig_shape=self.md.get_shape(node.output[0]),
                    orig_layout=self.get_tag(node),
                    final_shape=self.md.get_shape(shield_out_0),
                    final_layout=self.get_tag(shield_0['q'][0])
                )
            elif has_qdq and is_integer_dtype(input_dtype) and len(self.md.tensor_readers[graph_input]) == 1:
                out_nodes = self.md.get_node_consumers(node)
                if len(out_nodes) == 1 and next(iter(out_nodes.values())).op_type == "QuantizeLinear":
                    q_node = next(iter(out_nodes.values()))
                    self.tensor_map[shield_out_0.name] = dict(
                        orig_tensor=q_node.output[0],
                        orig_shape=self.md.get_shape(q_node.output[0]),
                        orig_layout=self.get_tag(q_node),
                        final_shape=self.md.get_shape(shield_out_0),
                        final_layout=self.get_tag(shield_0['q'][0])
                        )

            # put root node at the bottom of shielding chain:
            self._copy_and_replace_input(node, input_tensor.name, shield_out_1.name)

            # reconnect node's siblings
            for sb in siblings:
                self._copy_and_replace_input(sb, input_tensor.name, shield_out_1.name)

            NCHW_inputs.append(transpose_1)
        return NCHW_inputs

    def shield_NCHW_graph_outputs(self) -> None:
        """
        Double shielding of NCHW Graph Outputs:
            inserts a pair of extra transpose nodes (w/ or w/o qdq nodes)
            between Graph Outputs tagged as NCHW  and the rest of the Graph.
        """
        md = self.md
        def is_nchw_leaf(n):
            return self.md.is_leaf(n) and self.get_tag(n) == "NCHW"

        nchw_leaves = {
            gout: md.tensor_writers[gout]
            for gout in self.md.outputs
            if any(map(is_nchw_leaf, self.md.get_nodes(md.tensor_writers[gout])))
        }
        logger.info(
            f"Start layout shielding of Graph Outputs: "
            f"{len(nchw_leaves)} graph output{'s' if len(nchw_leaves) != 1 else ''} will be shielded"
        )

        # IF node is a NCHW leaf ==> double shield the NCHW layout of incident GraphOutput
        for graph_output, (node_name, *siblings) in nchw_leaves.items():
            assert not siblings, f"Graph Output '{graph_output}' has multiple writers"

            node = self.md.get_node(node_name)
            output_tensor = self.md.vinfo[graph_output]
            if node.op_type not in (LAYOUT_AWARE_OPS | LAYOUT_AGNOSTIC_OPS):
                logger.debug(
                    f"GraphOutput({graph_output}) is written by LeafNode::{node.op_type}<{self.get_tag(node)}>({node.name})"
                    f"that does not support layout conversion."
                )
                continue
            elif node.op_type not in ("QuantizeLinear", "DequantizeLinear"):
                logger.warning(
                    f"GraphOutput({graph_output}) expects an adjacent Q/DQ node, "
                    f"but found LeafNode::{node.op_type}<{self.get_tag(node)}>({node.name}). "
                    f"Transpose nodes without Q/DQ around will be inserted."
                )
                has_qdq, qdq_scale, qdq_zero_point = False, None, None
            else:
                # Q/DQ.scale, Q/DQ.zero_point == node.input[1], node.input[2],
                has_qdq, qdq_scale, qdq_zero_point = True, node.input[1], node.input[2]

            # node is a Q/DQ-type node ==>  node.output[0] == Q/DQ.y,
            output_name_y = output_tensor.name
            output_dtype = self.get_onnx_dtype(output_tensor)
            input_dtype = self.get_onnx_dtype(node.input[0])
            floattype = output_dtype if is_floating_dtype(output_dtype) else input_dtype
            uinttype = output_dtype if is_integer_dtype(output_dtype) else input_dtype

            nchw_shape = self.get_shape(output_tensor)
            nhwc_shape = self.get_NHWC_shape(output_tensor)
            if get_dim(nchw_shape) == 1:
                continue

            to_NHWC_perm = NHWC_PERMUTATIONS[get_dim(nchw_shape)]
            to_NCHW_perm = NCHW_PERMUTATIONS[get_dim(nchw_shape)]

            # MAKE SHIELDING CHAIN for GraphOutput:
            #           (node) -> (_Q*) -> (dq->T0->q) -> (dq->T1->q) -> (_DQ*) -> [GOut]
            # T0: to_nhwc, T1: to_nchw
            transpose_0 = f"{output_name_y}_transpose_nhwc_0"
            transpose_1 = f"{output_name_y}_transpose_nchw_1"

            # put the leaf node on top of the shielding chain:
            leaf_new_out = f'{transpose_0}_in_dq_in' if is_floating_dtype(output_dtype)  else f'{transpose_0}_in_q_in'
            self.model.graph.value_info.append(make_tensor_value_info(
                leaf_new_out, output_dtype, nchw_shape
            ))
            node_ = self._copy_and_replace_output(node, output_tensor.name, leaf_new_out)
            self.md.nodes[node_name] = node_

            # add _Q*
            if has_qdq and is_floating_dtype(output_dtype):
                _q, _q_out = self._append_node_to_model(*self._make_q_node(
                    f'{transpose_1}_in_q', qdq_scale, qdq_zero_point,
                    input=leaf_new_out,
                    output_shape=nchw_shape,
                    output_dtype=uinttype,
                ))
                self.tag_ops[_q.name] = OpTag(nchw_shape, 'NCHW')
                leaf_new_out = _q_out.name
                in_node = self.md.get_node_suppliers(node_, "QuantizeLinear", first=True)
                if in_node:
                    self.tensor_map[_q_out.name] = dict(
                        orig_tensor=in_node.output[0],
                        orig_shape=self.md.get_shape(in_node.output[0]),
                        orig_layout=self.get_tag(in_node),
                        final_shape=self.md.get_shape(_q_out),
                        final_layout=self.get_tag(_q)
                    )
                pass

            # add NHWC-shield: DQ -> T0 -> Q
            shield_out_0, shield_0 = self._append_dq_transpose_q(
                transpose_0, to_NHWC_perm, qdq_scale, qdq_zero_point,
                input=leaf_new_out,
                input_shape=nchw_shape, output_shape=nhwc_shape,
                input_layout='NCHW', output_layout='NHWC',
                float_dtype=floattype, int_dtype=uinttype,
            )

            # add NCHW-shield: DQ -> T1 -> Q
            shield_out_1, shield_1 = self._append_dq_transpose_q(
                transpose_1, to_NCHW_perm, qdq_scale, qdq_zero_point,
                input=shield_out_0.name,
                output=output_name_y if has_qdq and is_integer_dtype(output_dtype) or not has_qdq else None,
                input_shape=nhwc_shape, output_shape=nchw_shape,
                input_layout='NHWC', output_layout='NCHW',
                float_dtype=floattype, int_dtype=uinttype,
            )

            # add _DQ*
            if has_qdq and is_floating_dtype(output_dtype):
                _dq, *_ = self._append_node_to_model(*self._make_dq_node(
                    f'{transpose_1}_out_dq', qdq_scale, qdq_zero_point,
                    input=shield_out_1.name,
                    output=output_name_y,
                    # output_shape=None,  # output tensor already exists
                    make_dq_out=False,    # output tensor already exists
                ))
                self.tag_ops[_dq.name] = OpTag(nchw_shape, 'NCHW')

            # reconnect other nodes if the leaf node has other edges
            if md.output_nodes.get(node.name, []):
                logging.debug(f"..... layout shielding of Graph Outputs: reconnect out-nodes of a leaf-node Node({node.name})")
                for out_node_name in md.output_nodes[node.name]:
                    dup_node = copy.deepcopy(md.nodes[out_node_name])
                    for idx, edge in enumerate(dup_node.input):
                        if edge == output_name_y:
                            # dup_node.input[idx] = f'{output_name_y}_transpose_in_q_in'
                            dup_node.input[idx] = node_.output[0]
                    self.model.graph.node.remove(md.nodes[out_node_name])
                    self.model.graph.node.append(dup_node)

            if has_qdq and is_integer_dtype(output_dtype): #leaf is Q
                self.tensor_map[leaf_new_out] = dict(
                    orig_tensor=output_tensor.name,
                    orig_shape=self.md.get_shape(output_tensor.name),
                    orig_layout=self.get_tag(node),
                    final_shape=self.md.get_shape(output_tensor.name),
                    final_layout=self.get_tag(node)
                )
            elif has_qdq and is_floating_dtype(output_dtype): #leaf is DQ
                in_node = self.md.get_node_suppliers(node_, "QuantizeLinear", first=True)
                if in_node:
                    self.tensor_map[shield_out_1.name] = dict(
                        orig_tensor=in_node.output[0],
                        orig_shape=self.md.get_shape(in_node.output[0]),
                        orig_layout=self.get_tag(in_node),
                        final_shape=self.md.get_shape(shield_out_1),
                        final_layout=self.get_tag(shield_1['q'][0])
                    )

        return

    def _shield_NHWC_op_input(
            self, node: NodeProto,
            *,
            in_shield_output: TensorInfo =None,      # == inS^T
            in_transpose_0:str=None,
    ) -> tuple[ValueInfoProto, dict]:
        """
        Shielding INPUT of a NHWC op "S"

        -------------------------------------------------------------------------------------------
                                            [X]  -> (DQ0->S->Q1)   -> [Y]
        ===>
                [X] -> DQ0 -> Q0* -> (dq->T0->q) -> (DQ0*->S'->Q1)
                     <NCHW>              ||            <NHWC>

        -------------------------------------------------------------------------------------------

        - T0: toNHWC
        - layout[S'] == NHWC

        -------------------------------------------------------------------------------------------
        """
        float_dtype, int_dtype = DEFAULT_FLOAT_DTYPE, DEFAULT_INT_DTYPE
        has_qdq, sc0, zp0 = False, None, None       # qdq_scale_0, qdq_zero_point_0

        dq0: Optional[NodeProto] = None
        inS_ = in_shield_output

        # IN[0]: match DQ at input or not
        inS = self.md.get_node_activations(node, first=True)
        if dq0 := self.md.get_node_suppliers(node, "DequantizeLinear", first=True):
            has_qdq, sc0, zp0 = True, dq0.input[1], dq0.input[2]
            float_dtype = self.get_onnx_dtype(sc0)
            int_dtype = self.get_onnx_dtype(zp0)

        # IN[1]: single shielding of node's inputs
        in_NCHW_shape = self.get_shape(inS)
        out_NHWC_shape = self.get_shape(inS_)
        to_NHWC_perm = NHWC_PERMUTATIONS[len(in_NCHW_shape)]

        if has_qdq:
            _q0, _q0_out = self._append_node_to_model(*self._make_q_node(
                f"{in_transpose_0}_in_q", sc0, zp0,
                input=inS.name,
                output_shape=in_NCHW_shape,
                output_dtype=int_dtype,
            ))
            self.tag_ops[_q0.name] = OpTag(in_NCHW_shape, 'NCHW')
            inS = _q0_out

        # add NHWC-shield (DQ0->T0->Q0)
        shield_out_0, shield_0 = self._append_dq_transpose_q(
            in_transpose_0, to_NHWC_perm, sc0, zp0,
            input=inS.name,
            output=None if has_qdq else inS_.name,
            input_shape=in_NCHW_shape,
            output_shape=out_NHWC_shape,
            input_layout='NCHW', output_layout='NHWC',
            float_dtype=float_dtype, int_dtype=int_dtype,
        )

        # convert DQ0 to NHWC
        if has_qdq:
            _dq0, _dq0_out =  self._append_node_to_model(*self._make_dq_node(
                f"{in_transpose_0}_out_dq", sc0, zp0,
                input=shield_out_0.name,
                output=inS_.name,
                output_shape=out_NHWC_shape,
                output_dtype=float_dtype,
            ))
            self.tag_ops[_dq0.name] = OpTag(out_NHWC_shape, 'NHWC')
            inS = _dq0_out

        return inS, dict(has_qdq=has_qdq, sc=sc0, zp=zp0)

    def _shield_NHWC_op_output(
            self, node: NodeProto,
            *,
            out_shield_input: TensorInfo =None,   # == outS
            out_transpose_1:str=None,
            qdq: dict
    ) -> None:
        """
        Shielding OUTPUT of a NHWC op "S"

        -------------------------------------------------------------------------------------------

                    [X] ->  (DQ0->S->Q1)  -> [Y]
        ===>
                          (DQ0*->S'->Q1*) -> (dq->T1->q) -> DQ1* -> Q1 -> [Y]
                                <NHWC>            ||         <NCHW>

        -------------------------------------------------------------------------------------------

        - T1: toNCHW,
        - layout[S'] == NHWC

        -------------------------------------------------------------------------------------------
        """
        float_dtype, int_dtype = DEFAULT_FLOAT_DTYPE, DEFAULT_INT_DTYPE
        has_qdq, sc1, zp1 = False, None, None       # qdq_scale_1, qdq_zero_point_1
        if qdq:
            has_qdq, sc1, zp1 = qdq['has_qdq'], qdq['sc'], qdq['zp']

        q1: Optional[NodeProto] = None
        outS_ = out_shield_input

        # OUT[0]: match Q1 at output
        outS = self.md.get_node_outputs(node, first=True)
        if q1 := self.md.get_node_consumers(node, "QuantizeLinear", first=True):
            has_qdq, sc1, zp1 = True, q1.input[1], q1.input[2]

        if has_qdq:
            float_dtype = self.get_onnx_dtype(sc1)
            int_dtype = self.get_onnx_dtype(zp1)

        # OUT[1]: single shielding of node's output
        in_NHWC_shape = self.md.get_shape(out_shield_input)
        out_NCHW_shape = self.md.get_shape(outS)
        out_NCHW_perm = NCHW_PERMUTATIONS[len(out_NCHW_shape)]

        # convert Q1 to NHWC
        if has_qdq:
            _q1, _q1_out = self._append_node_to_model(*self._make_q_node(
                f"{out_transpose_1}_in_q", sc1, zp1,
                input=outS_.name,
                output_shape=in_NHWC_shape,
                output_dtype=int_dtype,
            ))
            self.tag_ops[_q1.name] = OpTag(in_NHWC_shape, 'NHWC')
            outS_ = _q1_out

        # add NCHW-shield (DQ1->T1->Q1)
        shield_out_1, shield_1 = self._append_dq_transpose_q(
            out_transpose_1, out_NCHW_perm, sc1, zp1,
            input=outS_.name,
            output=None if has_qdq else outS.name,
            input_shape=in_NHWC_shape,
            output_shape=out_NCHW_shape,
            input_layout='NHWC', output_layout='NCHW',
            float_dtype=float_dtype, int_dtype=int_dtype,
        )

        if has_qdq:
            _dq1, _dq1_out = self._append_node_to_model(*self._make_dq_node(
                f"{out_transpose_1}_out_dq", sc1, zp1,
                input=shield_out_1.name,
                output=outS.name,
                output_shape=out_NCHW_shape,
                output_dtype=float_dtype,
            ))
            self.tag_ops[_dq1.name] = OpTag(out_NCHW_shape, 'NCHW')

        return

    def shield_NHWC_ops(self) -> None:
        """
        Single shielding of NHWC nodes:
            when a NHWC-op is tagged as NCHW,
            inserts a pair of extra transpose nodes around that op,
            one on each side (w/ or w/o qdq nodes).

        -------------------------------------------------------------------------------------------

        IF node "S" is a NHWC op-node tagged as NCHW ==> single shielding on both sides
                                  [X]  -> (DQ0->S->Q1)  ->  [Y]
        ==>
            [X] -> DQ0 ->  (_->T0->q) -> (DQ0*->S'->Q1*) -> (dq->T1->q) -> DQ1* -> Q1 -> [Y]

        -------------------------------------------------------------------------------------------

        - T0: toNHWC
        - T1: toNCHW
        - layout[X] == layout[S] == layout[Y] == NCHW
        - layout[S'] == NHWC
        - QDQ<IN>:  [sc, zp] := DQ0.[sc, zp]
        - QDQ<OUT>: [sc, zp] := Q1.[sc, zp]
        -------------------------------------------------------------------------------------------
        """
        nhwc_nodes = {
            name: n
            for name, n in self.md.nodes.items()
            if n.op_type in NHWC_OPERATORS and self.tag_ops[name][1] == "NCHW"
        }
        logger.info(
            f"Start layout shielding of NHWC ops: "
            f"{len(nhwc_nodes)} NHWC op-node{'s' if len(nhwc_nodes) != 1 else ''} will be shielded"
        )

        for node_name, node in nhwc_nodes.items():
            # NCHW node S  --------------------------------------------------------------------------------------------
            inS = self.md.get_node_activations(node, first=True)
            outS = self.md.get_node_outputs(node, first=True)

            # NHWC node S'  -------------------------------------------------------------------------------------------
            inS_ = TensorInfo(
                f"{node_name}_nhwc_input", self.get_NHWC_shape(inS), self.get_onnx_dtype(inS),
            )
            outS_ = TensorInfo(
                f"{node_name}_nhwc_output", self.get_NHWC_shape(outS), self.get_onnx_dtype(outS),
            )

            node_ = copy.deepcopy(node)
            node_.input[0] = inS_.name
            node_.output[0] = outS_.name

            rank = len(self.md.get_shape(node.input[0]))
            if (attrs := get_attrs(node)) and "axis" in attrs:
                axis_: int = self._axis_attribute_to_NHWC(node)
                set_attribute(node_, "axis", axis_)
            elif attrs and "axes" in attrs:
                axes_: list[int] = self._axes_attribute_to_NHWC(node)
                set_attribute(node_, "axes", axes_)
                pass
            elif len(node.input) == 2:
                # Reduce.input == [0:data:act, 1:axes:ini]
                axes_: onnx.TensorProto = self._axes_initializer_to_NHWC(node, axes_input_idx=1)
                if axes_.name not in self.md.ini:
                    self.model.graph.initializer.append(axes_)
                    self.md.ini[axes_.name] = self.model.graph.initializer[-1]
                node_.input[1] = axes_.name
                pass
            else:
                pass

            # INPUT ---------------------------------------------------------------------------------------------------
            _inS, qdq = self._shield_NHWC_op_input(
                node,
                in_shield_output=inS_,
                in_transpose_0=f"{node.name}_in_transpose_nhwc_0",
            )

            # CENTRAL NODE  -------------------------------------------------------------------------------------------
            # # modify node S -- place it at the bottom of IN-shields and on top of OUT-shields
            self.model.graph.node.remove(node)
            self.model.graph.node.append(node_)
            self.tag_ops[node_.name] = OpTag(outS_.shape, 'NHWC')
            self.frozen_layout[node_.name] = 'NHWC'

            # OUTPUT --------------------------------------------------------------------------------------------------
            _outS = self.model.graph.value_info.append(outS_.make_activation_tensor())
            self._shield_NHWC_op_output(
                node,
                out_shield_input= outS_,
                out_transpose_1= f"{node.name}_out_transpose_nchw_1",
                qdq=dict()
            )
        return

    def shield_3D_MatMuls(self):
        matmul_nodes = {
            name: n
            for name, n in self.md.nodes.items()
            if n.op_type in ["MatMul"] and self.md.is_initializer(n.input[0])
               and (len(self.get_out_shape(n)) == 3 or get_rank(self.get_out_shape(n)) == 3)
        }
        logger.info(
            f"Start layout shielding of MatMul(A=ini, B=act) ops: "
            f"{len(matmul_nodes)} MatMul(A=ini, B=act) op-node{'s' if len(matmul_nodes) != 1 else ''} will be shielded"
        )

        switch_layout = {
            "UNKNOWN": "UNKNOWN",
            "NCHW": "NHWC",
            "NHWC": "NCHW",
            "ANY": "ANY",
        }

        _logstr = f"\033[35mShield[MatMul]\033[0m"
        for node_name, node in matmul_nodes.items():
            layout = self.get_tag(node)
            input_A = self.md.get_tensor(node.input[0])
            input_B = self.md.get_tensor(node.input[1])
            output_Y = self.md.get_tensor(node.output[0])
            in_shape_A, in_dtype_A = self.md.get_shape_dtype(input_A)
            in_shape_B, in_dtype_B = self.md.get_shape_dtype(input_B)
            out_shape, out_dtype = self.md.get_shape_dtype(output_Y)
            shielding_layout = switch_layout[layout]
            # shielding_perm = NHWC_PERMUTATIONS.get(get_dim(out_shape), [])  # should be [0, 2, 1]
            shielding_perm = []
            if len(out_shape) == 3:
                shielding_perm = [0, 2, 1]
            elif len(out_shape) == 4:
                shielding_perm = [0, 1, 3, 2]


            assert self.md.is_initializer(input_A.name), f"MatMul<{layout}>({node.name}).A is not ini"
            assert self.md.is_activation(input_B.name), f"MatMul<{layout}>({node.name}).B is not act"
            assert len(out_shape), f"MatMul<{layout}>({node.name}){out_shape} is not 3D"
            logger.debug(
                f"{_logstr} >> MatMul<{layout}>({node.name}){out_shape} :: A<ini>({node.input[0]}){in_shape_A} :: B<act>({node.input[1]}){in_shape_B}  {out_shape}"
            )

            in_node_A = self.md.get_writer(input_A.name)
            in_node_B = self.md.get_writer(input_B.name)
            out_node_Y, *out_node_siblings = self.md.get_readers(output_Y.name)
            in_node_A_ini = self.md.get_tensor(in_node_A.input[0])
            logger.debug(
                f"{_logstr} :: MatMul<{layout}>({node.name})  <- A{in_shape_A}  <- {in_node_A.op_type}<{self.get_tag(in_node_A)}>({in_node_A.name}) <- Ini({in_node_A_ini.name}){self.md.get_shape(in_node_A_ini)}"
            )
            logger.debug(
                f"{_logstr} :: MatMul<{layout}>({node.name})  <- B{in_shape_B}  <- {in_node_B.op_type}<{self.get_tag(in_node_B)}>({in_node_B.name}) <- Act"
            )
            logger.debug(
                f"{_logstr} :: MatMul<{layout}>({node.name})  -> Y{out_shape}  -> {out_node_Y.op_type}<{self.get_tag(out_node_Y)}>({out_node_Y.name}) + {len(out_node_siblings)}"
            )


            # TRANSPOSE INITIALIZER -------------------------------------------
            # pad 2D initializer up to activation tensor
            if (pad_shape := len(in_shape_B) - len(in_shape_A)) > 0:
                in_shape_A = [1] * pad_shape + list(in_shape_A)

            new_node_A_ini = copy.deepcopy(in_node_A_ini)
            for i in range(0, len(in_shape_A)):
                if i < len(new_node_A_ini.dims):
                    new_node_A_ini.dims[i] = in_shape_A[i]
                else:
                    new_node_A_ini.dims.append(in_shape_A[i])

            new_node_A_ini, ini_A_ndtype = onnxTensorProto_to_array(new_node_A_ini, transpose=False)
            new_node_A_ini = np.transpose(new_node_A_ini, axes=shielding_perm)
            new_node_A_ini = onnxTensorProto_from_array(
                new_node_A_ini, f"{in_node_A_ini.name}_T", og_dtype=ini_A_ndtype,
            )

            # self.model.graph.initializer.remove(in_node_A_ini)
            new_node_A_ini = self.model.graph.initializer.append(new_node_A_ini)
            logger.debug(f"{_logstr} .. new_ini_A = {new_node_A_ini.name} :: {self.md.get_shape(new_node_A_ini)}")

            self.model.graph.value_info.remove(input_A)
            input_A = self.model.graph.value_info.append(make_tensor_value_info(
                input_A.name, in_dtype_A, self.md.get_shape(new_node_A_ini)
            ))
            logger.debug(f"{_logstr} .. new_input_A = {input_A.name} :: {self.md.get_shape(input_A)}")

            new_node_A = copy.deepcopy(in_node_A)
            new_node_A.input[0] = new_node_A_ini.name
            new_node_A.output[0] = input_A.name

            self.model.graph.node.remove(in_node_A)
            new_node_A = self.model.graph.node.append(new_node_A)

            # SHIELD ACT AND OUT ----------------------------------------------
            in_shield_shape_B = permute(in_shape_B, shielding_perm)
            out_shield_shape = permute(out_shape, shielding_perm)

            transpose_A = f"{node.name}_in_transpose_ini_A"
            transpose_B = f"{node.name}_in_transpose_act_B"
            transpose_Y = f"{node.name}_out_transpose_act_Y"

            has_qdq_B, sc_B, zp_B = False, None, None
            if in_node_B.op_type == "DequantizeLinear":
                has_qdq_B = True
                sc_B = self.md.get_tensor(in_node_B.input[1])
                zp_B = self.md.get_tensor(in_node_B.input[2])
                in_dtype_B_int = self.md.get_onnx_dtype(zp_B)
            else:
                in_dtype_B_int = DEFAULT_INT_DTYPE
                sc_B = numpy_helper.from_array(
                    np.array(1).astype(tensor_dtype_to_np_dtype(in_dtype_B)),
                    f"{in_node_B.name}_dummy_in_qdq_scale"
                )
                zp_B = numpy_helper.from_array(
                    np.array(0).astype(tensor_dtype_to_np_dtype(in_dtype_B_int)),
                    f"{in_node_B.name}_dummy_in_qdq_zero_point"
                )
                sc_B = self.model.graph.initializer.append(sc_B)
                zp_B = self.model.graph.initializer.append(zp_B)

            has_qdq_Y, sc_Y, zp_Y = False, None, None
            if out_node_Y.op_type == "QuantizeLinear":
                has_qdq_Y = True
                sc_Y = self.md.get_tensor(out_node_Y.input[1])
                zp_Y = self.md.get_tensor(out_node_Y.input[2])
                out_dtype_int = self.md.get_onnx_dtype(zp_Y)
            else:
                out_dtype_int = DEFAULT_INT_DTYPE
                sc_Y = numpy_helper.from_array(
                    np.array(1).astype(tensor_dtype_to_np_dtype(out_dtype)),
                    f"{in_node_B.name}_dummy_out_qdq_scale"
                )
                zp_Y = numpy_helper.from_array(
                    np.array(0).astype(tensor_dtype_to_np_dtype(out_dtype_int)),
                    f"{in_node_B.name}_dummy_out_qdq_zero_point"
                )
                sc_Y = self.model.graph.initializer.append(sc_Y)
                zp_Y = self.model.graph.initializer.append(zp_Y)


            ## IN-SHIELD B ::  <-[MM'.A]-- DQ_B <- (q<T[MM.B]<dq) <- Q_B*  <-[MM.B]-- ACT
            if is_floating_dtype(in_dtype_B):
                _qB, _qB_out = self._append_node_to_model(*self._make_q_node(
                    f"{transpose_B}_head_q", sc_B, zp_B,
                    input=input_B.name,
                    output_shape=in_shape_B,
                    output_dtype=in_dtype_B_int,
                ))
                self.tag_ops[_qB.name] = OpTag(in_shape_B, layout)
                input_B = _qB_out

            shield_B_out, shield_B = self._append_dq_transpose_q(
                transpose_B, shielding_perm, sc_B, zp_B,
                input=input_B.name,
                input_shape=in_shape_B, output_shape=in_shield_shape_B,
                input_layout=layout, output_layout=shielding_layout,
                float_dtype=in_dtype_B, int_dtype=in_dtype_B_int,
            )
            _dqB, _dqB_out = self._append_node_to_model(*self._make_dq_node(
                f"{transpose_B}_tail_dq", sc_B, zp_B,
                input=shield_B_out.name,
                output_shape=in_shield_shape_B,
                output_dtype=in_dtype_B,
            ))
            self.tag_ops[_dqB.name] = OpTag(in_shield_shape_B, shielding_layout)

            ## TRANSPOSE MM ::
            mm_Y_out = self.model.graph.value_info.append(make_tensor_value_info(
                f"{transpose_Y}_head", out_dtype, out_shield_shape
            ))
            mm = copy.deepcopy(node)
            mm.input[0] = _dqB_out.name     # MM'.A := DQ_B <- q<T[MM.B]<dq <- Q_B <- <ACT>
            mm.input[1] = input_A.name      # MM'.B := T[MM.A] <INI>
            mm.output[0] = mm_Y_out.name    # MM'.Y := T[MM.Y] <OUT>
            self.model.graph.node.remove(node)
            mm = self.model.graph.node.append(mm)
            self.tag_ops[mm.name] = OpTag(out_shield_shape, shielding_layout)
            self.frozen_layout[mm.name] = shielding_layout

            ## OUT-SHIELD Y :: MM' --[Y']-> Q_Y -> (dq>T[MM.Y]>q) -> DQ_Y --[Y]-> OUT
            _qY, _qY_out = self._append_node_to_model(*self._make_q_node(
                f"{transpose_Y}_head_q", sc_Y, zp_Y,
                input=mm_Y_out.name,
                output_shape=out_shield_shape,
                output_dtype=out_dtype_int,
            ))
            self.tag_ops[_qY.name] = OpTag(out_shield_shape, shielding_layout)

            shield_Y_out, shield_Y = self._append_dq_transpose_q(
                transpose_Y, shielding_perm, sc_Y, zp_Y,
                input=_qY_out.name,
                input_shape=out_shield_shape, output_shape=out_shape,
                input_layout=shielding_layout, output_layout=layout,
                float_dtype=out_dtype, int_dtype=out_dtype_int,
            )
            _dqY, _dqY_out = self._append_node_to_model(*self._make_dq_node(
                f"{transpose_Y}_tail_dq", sc_Y, zp_Y,
                input=shield_Y_out.name,
                output=output_Y.name,
                make_dq_out=False,
            ))

            self.tag_ops[_dqY.name] = OpTag(out_shape, layout)
            logger.debug(
                f"{_logstr} << {out_shield_shape} <- MatMul*<{shielding_layout}>({mm.name})  <- A({mm.input[0]}){self.md.get_shape(_dqB_out)}  <- Tr<{shielding_layout}>({transpose_A}){shielding_perm}"
            )
            logger.debug(
                f"{_logstr} << {out_shield_shape} <- MatMul*<{shielding_layout}>({mm.name})  <- B({mm.input[1]}){self.md.get_shape(input_A)}  <- Ini({new_node_A_ini.name}){self.md.get_shape(new_node_A_ini)}"
            )

            pass

        return

    def get_NCHW_subgraph(self, node: str | Optional[NodeProto]) -> Optional[NCHWSubgraph]:
        node = self.md.get_node(node)
        if not node:
            return None
        sub_id = self.nodes_to_subgraphs[node.name]     # IF node not in any subgraph THEN sub_id == 0
        if not sub_id:
            return NCHWSubgraph.NULL()
        sub_gr = self.nchw_subgraphs.get(sub_id)
        return sub_gr

    def collect_NCHW_subgraphs(self) -> None:
        self.nchw_subgraphs: dict[int: NCHWSubgraph] = dict()           # {sub_id -> subgraph}
        self.nodes_to_subgraphs: dict[str, int] = defaultdict(int)      # {node_name -> sub_id}
        subgraphs: dict[int: NCHWSubgraph] = dict()                     # {sub_id -> subgraph}
        nodes_idx: dict[str, int] = dict()                              # {node_name -> node_idx}

        # -------------------------------------------------------------------------------------------
        # FIND <NCHW>-TRANSPOSES
        # -------------------------------------------------------------------------------------------
        for idx, node in enumerate(self.model.graph.node):
            nodes_idx[node.name] = idx
            if self.get_tag(node) not in ("NCHW", "ANY"):
                continue
            if node.op_type != "Transpose":
                continue

            sub_id = len(subgraphs) + 1
            sub_gr = NCHWSubgraph(sub_id).with_roots(node.name).with_context(self.md, self.tag_ops)
            subgraphs[sub_id] = sub_gr
            self.nodes_to_subgraphs[node.name] = sub_id

        logger.info(
            f"Start collecting NCHW-subgraphs: "
            f"{len(subgraphs)} NCHW-tagged Transposes were found"
        )

        # -------------------------------------------------------------------------------------------
        # PARTITION MODEL into NCHW-SUBGRAPHS :: UTILS
        # -------------------------------------------------------------------------------------------
        def _is_illegal_subnode(sn: str) -> bool:
            """
            sub_node is ILLEGAL if it:
                1) is Graph Root/Leaf; 2) is a frozen op, eg. shielded;
                3) not Node<NCHW|ANY>; 4) not Transpose<NCHW|NHWC>;
            """
            if self.md.get_op_type(sn) == "Transpose" and self.get_tag(sn) not in ["NCHW", "NHWC"]:
                return True
            if (self.get_tag(sn) not in ["NCHW", "ANY"]) or (sn in self.frozen_layout):
                return True
            if self.md.is_root(sn) or self.md.is_leaf(sn):
                return True
            return False

        def _discard_subgraph(sg: NCHWSubgraph):
            logger.debug(f"[DISCARD SG] :: {sg}")
            for sn in sg.get_sub_nodes():
                self.nodes_to_subgraphs[sn] = NCHWSubgraph.NULL().id
            sg.clear()
            return

        def _merge_subgraphs(sg_main: NCHWSubgraph, sg_minor: NCHWSubgraph) -> Iterable[str]:
            absorbed_nodes = sg_main.absorb(sg_minor)
            for n in absorbed_nodes:
                self.nodes_to_subgraphs[n] = sg_main.id

            subgraphs.pop(sg_minor.id, NCHWSubgraph.NULL())
            sg_minor.clear()
            return absorbed_nodes

        def _is_NCHW_Transpose(n: str) -> bool:
            return self.md.get_op_type(n)  == "Transpose" and self.get_tag(n) == "NCHW"
        def _is_NHWC_Transpose(n: str) -> bool:
            return self.md.get_op_type(n)  == "Transpose" and self.get_tag(n) == "NHWC"
        def _is_ANY_Initializer(n: str) -> bool:
            return self.md.is_initializer_node(n) or self.get_tag(n) == "ANY"

        # -------------------------------------------------------------------------------------------
        # PARTITION MODEL into NCHW-SUBGRAPHS :: BFS PROPAGATION
        # -------------------------------------------------------------------------------------------
        while subgraphs:
            sub_id, sub_gr = subgraphs.popitem()
            sub_nodes = deque(sub_gr.get_sub_nodes())
            discard_sub_gr = False
            # Construct NCHW-subgraph by propagating from its Tr<NCHW> sub_root
            while sub_nodes:
                sub_node: str = sub_nodes.popleft()
                # Discard subgraph if any sub_node is:
                #       Graph Root/Leaf, frozen op, not Node<NCHW|ANY>, or not Tr<NCHW|NHWC>
                if _is_illegal_subnode(sub_node):
                    logger.debug(
                        f"[ILLEGAL SG] :: {str(sub_gr)} "
                        f"has illegal SubNode<{self.get_tag(sub_node)}>({sub_node}) and will be discarded"
                    )
                    discard_sub_gr = True

                # Get sub_node's unvisited Consumers & Suppliers
                sub_cons: list[str] = list(self.md.get_node_consumers(sub_node))
                sub_cons = [n for n in sub_cons if self.nodes_to_subgraphs[n] != sub_id]
                sub_supp: list[str] = list(self.md.get_node_suppliers(sub_node))
                sub_supp = [n for n in sub_supp if self.nodes_to_subgraphs[n] != sub_id]

                # BFS DOWN: Forward/Downward to Consumers
                for cons in sub_cons:
                    if sgid := self.nodes_to_subgraphs[cons]:   # encountered other NCHW-subgraph ==> absorb
                        cons_sg = subgraphs.get(sgid, NCHWSubgraph.NULL())
                        merged_nodes = _merge_subgraphs(sub_gr, cons_sg)
                        sub_nodes.extend(merged_nodes)
                    # elif _is_NHWC_Transpose(cons) or _is_ANY_Initializer(cons):
                    elif _is_NHWC_Transpose(cons):
                        self.nodes_to_subgraphs[cons] = sub_id
                        sub_gr.leaves.append(cons)
                    elif _is_ANY_Initializer(cons):
                        sub_gr.leaves.append(sub_node)
                        break
                    elif self.get_tag(cons) == "NCHW":
                        self.nodes_to_subgraphs[cons] = sub_id
                        sub_gr.body.append(cons)
                        sub_nodes.append(cons)
                    pass

                # BFS UP: Backward/Upward to Suppliers
                for supp in sub_supp:
                    if sgid := self.nodes_to_subgraphs[supp]:   # encountered other NCHW-subgraph ==> absorb
                        supp_sg = subgraphs.get(sgid, NCHWSubgraph.NULL())
                        merged_nodes = _merge_subgraphs(sub_gr, supp_sg)
                        sub_nodes.extend(merged_nodes)
                    elif _is_NCHW_Transpose(supp) or self.md.is_initializer_node(supp):
                        self.nodes_to_subgraphs[supp] = sub_id
                        sub_gr.roots.append(supp)
                    elif _is_ANY_Initializer(supp):
                        sub_gr.roots.append(sub_node)
                        break
                    elif self.get_tag(supp) == "NCHW":
                        self.nodes_to_subgraphs[supp] = sub_id
                        sub_gr.body.append(supp)
                        sub_nodes.append(supp)
                    pass

                pass    # END :: NCHW-subgraph is constructed

            if discard_sub_gr or not sub_gr.leaves:
                _discard_subgraph(sub_gr)
            else:
                self.nchw_subgraphs[sub_id] = sub_gr

            pass    # END :: no more subgraphs to extend

        valid_subgraphs = [
            sg for sg in self.nchw_subgraphs.values()
            if sg.with_context(self.md, self.tag_ops).is_valid_subgraph()
        ]
        logger.debug(
            f"[COLLECT SG] << Number of NCHW-subgraphs collected "
            f": TOTAL={len(self.nchw_subgraphs)} "
            f": VALID={len(valid_subgraphs)} "
            f": TO PATCH={len(self.nchw_subgraphs) - len(valid_subgraphs)}"
        )
        return

    def patch_NCHW_subgraphs(self) -> dict[int: NCHWSubgraph]:
        unpatched_subgraphs = [
            sg for sg in self.nchw_subgraphs.values()
            if not sg.with_context(self.md, self.tag_ops).is_valid_subgraph()
        ]
        unpatched_subroots = [
            sg for sg in self.nchw_subgraphs.values()
            if not sg.with_context(self.md, self.tag_ops).has_valid_subroots()
        ]
        unpatched_subleaves = [
            sg for sg in self.nchw_subgraphs.values()
            if not sg.with_context(self.md, self.tag_ops).has_valid_subleaves()
        ]

        logger.info(
            f"Start patching NCHW-subgraphs: "
            f"{len(unpatched_subgraphs)} out of {len(self.nchw_subgraphs)} will be patched, "
            f"{len(unpatched_subroots)} subroots and {len(unpatched_subleaves)} subleaves"
        )

        # -------------------------------------------------------------------------------------------
        # PATCH SUB-ROOTS
        # -------------------------------------------------------------------------------------------
        for sg in unpatched_subroots:
            new_sub_roots = list()
            for sr_idx, sub_root in enumerate(sg.roots):
                sub_root = self.md.get_node(sub_root)
                if sg.is_valid_subroot(sub_root.name):
                    new_sub_roots.append(sub_root.name)
                    continue

                if self.md.is_initializer_node(sub_root):
                    # 1) Find which consumers of this root are in subgraph
                    # 2) if a consumer is binary op -- pad subroot input
                    # 3) transpose subroot's input with T0
                    # 4) insert T1 between subroot and consumer
                    sr_ini = self.md.get_tensor(sub_root.input[0])
                    sr_consumers = self.md.get_node_consumers(sub_root)
                    keep_original = False
                    for cons in sr_consumers.values():
                        if self.nodes_to_subgraphs.get(cons.name, 0) != sg.id:
                            keep_original = True
                            continue
                        cons_ini_idx = None
                        for idx, cons_in in enumerate(cons.input):
                            if (cons_supp := self.md.get_writer(cons_in)) and cons_supp.name == sub_root.name:
                                cons_ini_idx = idx
                                break
                        if cons_ini_idx is None:
                            logger.error(f"SubRoot({sub_root.name})'s Consumer({cons.name}) is not supplied by SubRoot")
                            continue

                        # Create new ini-tensor
                        sr_ini_shape = self.md.get_shape(sr_ini)
                        sr_ini_data, sr_ini_dtype = onnxTensorProto_to_array(sr_ini, transpose=False)
                        cons_ini = self.md.get_tensor(cons.input[cons_ini_idx])
                        cons_ini_shape = self.md.get_shape(cons_ini)
                        cons_ini_dtype = self.md.get_onnx_dtype(cons_ini)
                        logger.debug(f"SG({sg.id}) :: <SubRoot[{sr_idx}]>{sub_root.op_type}({sub_root.name}).ini_shape={sr_ini_shape}  -->  <Consumer>{cons.op_type}({cons.name}).input[{cons_ini_idx}].shape={cons_ini_shape}")
                        assert cons_ini_shape == sr_ini_shape, f"SG({sg.id}) :: <SubRoot[{sr_idx}]>{sub_root.op_type}({sub_root.name}).ini_shape={sr_ini_shape}  !=  <Consumer>{cons.op_type}({cons.name}).input[{cons_ini_idx}].shape={cons_ini_shape}"

                        if cons.op_type in {"Pad"}:
                            keep_original = True
                            continue
                        if cons.op_type in {"Add", "Sub", "Mul", "Div", "Pow"}:
                            cons_act = cons.input[1 - cons_ini_idx]
                            cons_act_shape = self.md.get_shape(cons_act)
                            # pad scalar and vector bias up to activation tensor
                            if (pad_shape := len(cons_act_shape) - len(cons_ini_shape)) > 0:
                                cons_ini_shape = [1] * pad_shape + list(cons_ini_shape)
                                sr_ini_shape = [1] * pad_shape + list(sr_ini_shape)
                                sr_ini_data = sr_ini_data.reshape(sr_ini_shape)
                            pass

                        # transpose ini-tensor
                        cons_nhwc_ini_shape = self.get_NHWC_shape(cons_ini_shape)
                        nhwc_perm = NHWC_PERMUTATIONS[len(cons_ini_shape)]
                        nchw_perm = NCHW_PERMUTATIONS[len(cons_ini_shape)]

                        sr_nhwc_ini_data = np.transpose(sr_ini_data, axes=nhwc_perm)
                        # Create new initializer with transposed data and shape
                        new_sr_ini = onnxTensorProto_from_array(
                            sr_nhwc_ini_data,
                            f"{cons.name}_ini_nhwc",  # Give it a new name to avoid conflicts
                            og_dtype=sr_ini_dtype,
                        )
                        self.model.graph.initializer.append(new_sr_ini)

                        # Create a copy of the DQ node with updated input reference
                        new_sub_root = copy.deepcopy(sub_root)
                        new_sub_root.name = f"{sub_root.name}__{cons.name}__ini_nhwc"
                        new_sub_root.input[0] = new_sr_ini.name  # Point to the new initializer
                        new_sub_root.output[0] = f"{new_sub_root.name}_out"

                        # Prepare new tensor value info (for DQ output)
                        new_sr_output = onnx.helper.make_tensor_value_info(
                            new_sub_root.output[0], cons_ini_dtype, cons_nhwc_ini_shape
                        )
                        self.model.graph.node.append(new_sub_root)
                        self.model.graph.value_info.append(new_sr_output)
                        self.tag_ops[new_sub_root.name] = OpTag(cons_nhwc_ini_shape, "NHWC")

                        # Create T1=Transpose<NCHW> and set is as sub_root
                        transpose_1_name = f"{sub_root.name}__{cons.name}__transpose_nchw"
                        transpose_1 = make_node(
                            'Transpose',
                            name=transpose_1_name,
                            inputs=[new_sr_output.name],
                            outputs=[f"{transpose_1_name}_out"],
                            perm=nchw_perm,
                        )
                        transpose_1_out = onnx.helper.make_tensor_value_info(
                            transpose_1.output[0], cons_ini_dtype, cons_ini_shape
                        )
                        self.model.graph.node.append(transpose_1)
                        self.tag_ops[transpose_1.name] = OpTag(cons_ini_shape, "NCHW")

                        new_sub_roots.append(transpose_1.name)
                        cons.input[cons_ini_idx] = transpose_1_out.name

                        pass

                    if not keep_original:
                        self.model.graph.node.remove(sub_root)

                elif self.md.get_op_type(sub_root) == "Unsqueeze":
                    # skip Unsq -- since patching them with Tr_noop anyway
                    continue
                else:
                    # SUB_ROOT == Reshape-like
                    patch_input = f"{sub_root.name}_patch_input"
                    patch_output = self.md.get_node_outputs(sub_root, first=True)
                    # patch_input = self.md.get_node_outputs(sub_root, first=True)
                    # patch_output = f"{sub_root.name}_patch_output"
                    transpose_nhwc_0 = f"{sub_root.name}_patch_transpose_nhwc_0"
                    transpose_nchw_1 = f"{sub_root.name}_patch_transpose_nchw_1"

                    sub_root_out_dtype = self.get_onnx_dtype(sub_root.output[0])
                    nchw_shape = self.md.get_shape(patch_output)
                    nhwc_shape = self.get_NHWC_shape(nchw_shape)
                    nchw_perm = NCHW_PERMUTATIONS[get_dim(nchw_shape)]
                    nhwc_perm = NHWC_PERMUTATIONS[get_dim(nchw_shape)]
                    float_dtype, int_dtype = DEFAULT_FLOAT_DTYPE, DEFAULT_INT_DTYPE
                    has_qdq, sc0, zp0 = False, None, None

                    logger.debug(
                        f"SG({sg.id}) :: <SubRoot[{sr_idx}]>{sub_root.op_type}({sub_root.name}).patch_input={patch_input}"
                    )
                    if q0 := self.md.get_node_consumers(sub_root, "QuantizeLinear", first=True):
                        has_qdq, sc0, zp0 = True, q0.input[1], q0.input[2]
                        float_dtype = self.get_onnx_dtype(sc0)
                        int_dtype = self.get_onnx_dtype(zp0)
                    else:
                        logger.debug(
                            f"PATCHING {sg}: "
                            f"hasn't found a Q-node at the outputs of a sub-root {sub_root.op_type}({sub_root.name})."
                            f"Transpose nodes without Q/DQ around will be inserted."
                        )
                        pass

                    # MODIFY SUB_ROOT
                    self.model.graph.value_info.append(make_tensor_value_info(
                        patch_input, sub_root_out_dtype, nchw_shape,
                    ))
                    new_sub_root = copy.deepcopy(sub_root)
                    new_sub_root.output[0] = patch_input
                    self.model.graph.node.remove(sub_root)
                    self.model.graph.node.append(new_sub_root)
                    self.md.nodes[sub_root.name] = new_sub_root

                    # insert Q1*:  (Re) --[X]-> (Q0*) --[q0*.out]->
                    if has_qdq:
                        _q1, _q1_out = self._append_node_to_model(*self._make_q_node(
                            f"{sub_root.name}_patch_head_q0", sc0, zp0,
                            input=patch_input,
                            output_shape=nchw_shape,
                            output_dtype=int_dtype,
                        ))
                        self.tag_ops[_q1.name] = OpTag(self.md.get_shape(_q1_out), 'NCHW')
                        patch_input = _q1_out.name

                    # [3.1] insert T0-shield:           --[q0*.out]->   (dq->T0->q)   --[ish0]->
                    shield_out_0, shield_0 = self._append_dq_transpose_q(
                        transpose_nhwc_0, nhwc_perm, sc0, zp0,
                        input=patch_input,
                        input_shape=nchw_shape, output_shape=nhwc_shape,
                        input_layout='NCHW', output_layout='NHWC',
                        float_dtype=float_dtype, int_dtype=int_dtype,
                    )

                    # [3.2] insert T1-shield:             --[ish0]->    (dq->T1->q)   --[ish1]->
                    shield_out_1, shield_1 = self._append_dq_transpose_q(
                        transpose_nchw_1, nchw_perm, sc0, zp0,
                        input=shield_out_0.name,
                        output=None if has_qdq else patch_output.name,
                        input_shape=nhwc_shape, output_shape=nchw_shape,
                        input_layout='NHWC', output_layout='NCHW',
                        float_dtype=float_dtype, int_dtype=int_dtype,
                    )

                    self.nodes_to_subgraphs[sub_root.name] = 0
                    if sub_root.name in sg.body:
                        sg.body.remove(sub_root.name)
                    new_sub_roots.append(shield_1['tr'][0].name)
                    sg.body.append(shield_1['q'][0].name)
                    self.nodes_to_subgraphs[shield_1['q'][0].name] = sg.id

                    # [3.3] insert in-shield's tail:          --[ish1]->     (DQ0*)    --[ish.out]-> (sub_leaf)
                    if has_qdq:
                        _dq1, _dq1_out = self._append_node_to_model(*self._make_dq_node(
                            f"{sub_root.name}_patch_tail_dq0", sc0, zp0,
                            input=shield_out_1.name,
                            output=patch_output.name,
                            output_shape=nchw_shape,
                            output_dtype=float_dtype,
                        ))
                        self.tag_ops[_dq1.name] = OpTag(nchw_shape, 'NCHW')
                        sg.body.append(_dq1.name)
                        self.nodes_to_subgraphs[_dq1.name] = sg.id
                    pass
                pass

            sg.roots = new_sub_roots
            self.nodes_to_subgraphs |= {
                sr: sg.id for sr in sg.roots
            }

        # -------------------------------------------------------------------------------------------
        # PATCH SUB-LEAVES
        # -------------------------------------------------------------------------------------------
        for sg in unpatched_subleaves:
            new_sub_leaves = list()
            for sl_idx, sub_leaf in enumerate(sg.leaves):
                sub_leaf = self.md.get_node(sub_leaf)

                # SUB_LEAF == Transpose
                if sg.is_valid_subleaf(sub_leaf.name):
                    new_sub_leaves.append(sub_leaf.name)
                    continue

                if self.md.get_op_type(sub_leaf) == "Squeeze":
                    # skip Sq -- since patching them with Tr_noop anyway
                    continue

                # SUB_LEAF == Reshape-like
                patch_input = self.md.get_node_activations(sub_leaf, first=True)
                patch_output = f"{sub_leaf.name}_patch_output"
                transpose_nhwc_0 = f"{sub_leaf.name}_patch_transpose_nhwc_0"
                transpose_nchw_1 = f"{sub_leaf.name}_patch_transpose_nchw_1"

                sub_leaf_out = self.md.get_tensor(sub_leaf.output[0])

                nchw_shape = self.md.get_shape(patch_input)
                nhwc_shape = self.get_NHWC_shape(nchw_shape)
                nchw_perm = NCHW_PERMUTATIONS[get_dim(nchw_shape)]
                nhwc_perm = NHWC_PERMUTATIONS[get_dim(nchw_shape)]
                float_dtype, int_dtype = DEFAULT_FLOAT_DTYPE, DEFAULT_INT_DTYPE
                has_qdq, sc0, zp0 = False, None, None

                logger.debug(
                    f"SG({sg.id}) :: <SubLeaf[{sl_idx}]>{sub_leaf.op_type}({sub_leaf.name}).patch_input={patch_input.name}"
                )
                if dq0 := self.md.get_node_suppliers(sub_leaf, "DequantizeLinear", first=True):
                    has_qdq, sc0, zp0 = True, dq0.input[1], dq0.input[2]
                    float_dtype = self.get_onnx_dtype(sc0)
                    int_dtype = self.get_onnx_dtype(zp0)
                else:
                    logger.debug(
                        f"PATCHING {sg}: "
                        f"hasn't found a DQ-node at the inputs of a sub-leaf {sub_leaf.op_type}({sub_leaf.name})."
                        f"Transpose nodes without Q/DQ around will be inserted."
                    )
                    pass

                # insert Q0*:  (DQ0) --[X]-> (Q0*) --[q0*.out]->
                if has_qdq:
                    _q0, _q0_out = self._append_node_to_model(*self._make_q_node(
                        f"{sub_leaf.name}_patch_head_q0", sc0, zp0,
                        input=patch_input.name,
                        output_shape=nchw_shape,
                        output_dtype=int_dtype,
                    ))
                    self.tag_ops[_q0.name] = OpTag(self.md.get_shape(_q0_out), 'NCHW')
                    sg.body.append(_q0.name)
                    self.nodes_to_subgraphs[_q0.name] = sg.id
                    patch_input = _q0_out

                # [3.1] insert T0-shield:           --[q0*.out]->   (dq->T0->q)   --[ish0]->
                shield_out_0, shield_0 = self._append_dq_transpose_q(
                    transpose_nhwc_0, nhwc_perm, sc0, zp0,
                    input=patch_input.name,
                    input_shape=nchw_shape, output_shape=nhwc_shape,
                    input_layout='NCHW', output_layout='NHWC',
                    float_dtype=float_dtype, int_dtype=int_dtype,
                )
                sg.body.append(shield_0['dq'][0].name)
                self.nodes_to_subgraphs[shield_0['dq'][0].name] = sg.id
                new_sub_leaves.append(shield_0['tr'][0].name)

                # [3.2] insert T1-shield:             --[ish0]->    (dq->T1->q)   --[ish1]->
                shield_out_1, shield_1 = self._append_dq_transpose_q(
                    transpose_nchw_1, nchw_perm, sc0, zp0,
                    input=shield_out_0.name,
                    output=None if has_qdq else patch_output,
                    input_shape=nhwc_shape, output_shape=nchw_shape,
                    input_layout='NHWC', output_layout='NCHW',
                    float_dtype=float_dtype, int_dtype=int_dtype,
                )

                # [3.3] insert in-shield's tail:          --[ish1]->     (DQ0*)    --[ish.out]-> (sub_leaf)
                if has_qdq:
                    _dq0, _dq0_out = self._append_node_to_model(*self._make_dq_node(
                        f"{sub_leaf.name}_patch_tail_dq0", sc0, zp0,
                        input=shield_out_1.name,
                        output=patch_output,
                        output_shape=nchw_shape,
                        output_dtype=float_dtype,
                    ))
                    self.tag_ops[_dq0.name] = OpTag(nchw_shape, 'NCHW')
                    patch_output = _dq0_out.name
                else:
                    patch_output = shield_out_1.name

                # MODIFY SUB_LEAF
                new_sub_leaf = copy.deepcopy(sub_leaf)
                new_sub_leaf.input[0] = patch_output
                self.model.graph.node.remove(sub_leaf)
                self.model.graph.node.append(new_sub_leaf)
                self.md.nodes[sub_leaf.name] = new_sub_leaf

                self.nodes_to_subgraphs[sub_leaf.name] = 0
                if sub_leaf.name in sg.body:
                    sg.body.remove(sub_leaf.name)

                pass

            sg.leaves = new_sub_leaves
            self.nodes_to_subgraphs |= {
                sl: sg.id for sl in sg.leaves
            }



        return unpatched_subroots

    def _apply_graph_modifications(self, modification_dict, tag_ops_update):
        """
        Applies modifications to the ONNX graph based on the modification_dict.
        This function handles the removal, addition, and update of constants,
        nodes, value_info, and initializers in the ONNX graph, and updates
        the tag_ops_update dictionary.

        Args:
            modification_dict (dict): A dictionary containing various modifications to apply,
                                    categorized by type (e.g., 'remove_node', 'create_node').
            tag_ops_update (dict): A dictionary accumulating tag updates for operators.
                                This dictionary is modified in-place and then returned.

        Returns:
            dict: The updated tag_ops_update dictionary after applying all modifications.
        """

        # -------------------------------------------------------------------------------------------
        # MODIFY: Constant Nodes
        # -------------------------------------------------------------------------------------------
        logger.debug(f'Remove Constants {len(modification_dict["remove_constant"])}')
        for const_node_name in modification_dict['remove_constant']:
            # Check if node exists in current model state before removing
            if const_node_name in self.md.nodes and self.md.nodes[const_node_name] in self.model.graph.node:
                self.model.graph.node.remove(self.md.nodes[const_node_name])
                self.tag_ops.pop(const_node_name, None)

        # -------------------------------------------------------------------------------------------
        # MODIFY: Nodes
        # -------------------------------------------------------------------------------------------
        logger.debug(f"Remove/Add Nodes {len(modification_dict['remove_node'])}/{len(modification_dict['create_node'])}")
        # create_node_names = list(modification_dict['create_node'])
        for rm_node_name in modification_dict['remove_node']:
            old_node = self.md.nodes.get(rm_node_name)
            if rm_node_name in modification_dict['create_node']:
                # Check if node exists in current model state before removing
                if old_node and old_node in self.model.graph.node:
                    self.model.graph.node.remove(old_node)
                if new_node := modification_dict['create_node'].pop(rm_node_name):
                    self.model.graph.node.append(new_node)

            elif rm_node_name in modification_dict['remove_node']:
                # Check if node exists in current model state before removing
                if old_node and old_node in self.model.graph.node:
                    self.model.graph.node.remove(old_node)
                    self.tag_ops.pop(rm_node_name, None)
                    tag_ops_update.pop(rm_node_name, None)

        for new_node in modification_dict['create_node'].values():
            self.model.graph.node.append(new_node)
        
        # -------------------------------------------------------------------------------------------
        # MODIFY: Activation Tensors /Vinfo/
        # -------------------------------------------------------------------------------------------
        logger.debug(f'Remove/Add value_info {len(modification_dict["remove_vinfo"])}/{len(modification_dict["create_vinfo"])}')
        create_info_names = list(modification_dict['create_vinfo'])
        for key in modification_dict['remove_vinfo']:
            old_vinfo = self.md.vinfo.get(key)
            if key  in modification_dict['create_vinfo']:
                # Check if value_info exists in current model state before removing
                if key in self.md.vinfo and self.md.vinfo[key] in self.model.graph.value_info:
                    self.model.graph.value_info.remove(self.md.vinfo[key])
                vinfo = modification_dict['create_vinfo'][key]
                valid = [x.dim_param == '' for x in vinfo.type.tensor_type.shape.dim]
                if valid != [] and all(valid):
                    self.model.graph.value_info.append(modification_dict['create_vinfo'][key])
                create_info_names.remove(key)
            elif key in modification_dict['remove_vinfo']:
                # Check if value_info exists in current model state before removing
                if key in self.md.vinfo and self.md.vinfo[key] in self.model.graph.value_info:
                    self.model.graph.value_info.remove(self.md.vinfo[key])
        for value_info in create_info_names:
            self.model.graph.value_info.append(modification_dict['create_vinfo'][value_info])

        # -------------------------------------------------------------------------------------------
        # MODIFY: Initializers
        # -------------------------------------------------------------------------------------------
        logger.debug(f"Remove initializers {len(modification_dict['remove_ini'])}")
        for rm_ini_name in modification_dict['remove_ini']:
            # Check if initializer exists in current model state before removing
            rm_ini = self.md.ini.get(rm_ini_name)
            if rm_ini and rm_ini in self.model.graph.initializer:
                self.model.graph.initializer.remove(rm_ini)
                self.md.ini.pop(rm_ini_name)

        logger.debug(f"Add Initializers {len(modification_dict['create_ini'])}")
        # model_inis = {ini.name for ini in self.model.graph.initializer}
        for ini_name, new_ini in modification_dict['create_ini'].items():
            # Check if initializer already exists to avoid duplicates
            if ini_name not in self.md.ini:
                self.model.graph.initializer.append(new_ini)
                self.md.ini[ini_name] = new_ini

        # -------------------------------------------------------------------------------------------
        # MODIFY: Layout Tags
        # -------------------------------------------------------------------------------------------
        logger.debug(f"Add Tags {len(modification_dict['create_tag'])}")
        for node_name, op_tag in modification_dict['create_tag'].items():
            tag_ops_update[node_name] = op_tag

        return tag_ops_update
    
    def _process_chain_nodes(self, mod_subgraph: NCHWSubgraph) -> tuple[dict[ModAction, dict], dict[str, OpTag]]:
        """
        Process all nodes in a subgraph and populate the modification dictionary.
        
        Args:
            mod_subgraph: a NCHW-subgraph whose layout will be transformed into NHWC
        Returns:
            tuple(dict, dict): modifications collected for the subgraph and tag updates.
        """
        actions: tuple[ModAction] = (
            'create_node', 'create_vinfo', 'create_ini', 'create_tag',
            'remove_constant', 'remove_node', 'remove_vinfo', 'remove_ini',
        )
        modification_dict: dict[ModAction, dict] = {action: {} for action in actions}
        tag_ops_update: dict[str, OpTag] = dict()

        nodes = self.md.nodes
        handlers = self.handlers_to_NHWC

        def merge_changes(mod_dict: dict[ModAction, dict], handler_result: dict[ModAction, dict]) -> None:
            for action, updates in handler_result.items():
                if not updates:
                    continue
                bucket = mod_dict[action]
                for key, value in updates.items():
                    existing = bucket.setdefault(key, value)
                    if existing != value:
                        raise ModActionError(f"ModAction('{action}') :: conflicting entries for '{key}'")

        for node_name in mod_subgraph.body:
            node: NodeProto = self.md.get_node(node_name)
            if not node:
                continue

            op_type = node.op_type
            node_tag = self.get_tag(node_name)
            out_edge = node.output[0]
            out_dtype = self.md.get_onnx_dtype(out_edge)
            out_nhwc_shape = self.get_NHWC_shape(out_edge)
            new_out_edge = onnx.helper.make_tensor_value_info(out_edge, out_dtype, out_nhwc_shape)

            if op_type in LAYOUT_AWARE_OPS:
                handler = handlers.get(op_type)
                if handler is None:
                    raise LayoutHandlerError(
                        f"No NHWC handler registered for '{op_type}' op encountered at Node({node_name})"
                    )
                handler_result = handler(node_name)
                if handler_result:
                    merge_changes(modification_dict, handler_result)
                if node_tag:
                    tag_ops_update[node_name] = OpTag(out_nhwc_shape, "NHWC")
                continue

            # layout agnostic path
            if op_type not in LAYOUT_AGNOSTIC_OPS:
                logger.debug(
                    f"NCHW to NHWC handler is undefined for {op_type}({node_name}) node, treating it as layout agnostic",
                )

            if not node_tag or node_tag in {'NHWC', OpTag.NOT_FOUND().tag}:
                continue
            if node_tag != 'NCHW':
                raise LayoutMismatchError(
                    f"Node<{node_tag}>({node_name}) has conflicting memory layout"
                )

            modification_dict['remove_vinfo'][out_edge] = out_edge
            existing = modification_dict['create_vinfo'].get(out_edge)
            if existing and existing != new_out_edge:
                raise ModActionError(
                    f"ModAction('create_vinfo') :: OutTensor({out_edge}) of Node({node_name}) does not match previous definition"
                )
            modification_dict['create_vinfo'][out_edge] = new_out_edge
            tag_ops_update[node_name] = OpTag(out_nhwc_shape, 'NHWC')

            if out_edge in self.tensor_map:
                self.tensor_map[out_edge] |= dict(
                    final_shape=out_nhwc_shape, # == self.md.get_shape(new_out_edge),
                    final_layout='NHWC',           # == tag_ops_update[node_name]
                )
            pass

        if mod_subgraph.leaves:
            removable_nodes = set(mod_subgraph.roots) | set(mod_subgraph.leaves)
            remove_bucket = modification_dict['remove_node']
            create_bucket = modification_dict['create_node']

            for node_name in removable_nodes:
                node = self.md.get_node(node_name)
                if not node:
                    continue
                if (node.op_type != 'Transpose' or
                        len(self.md.input_nodes[node_name]) > 1 or
                        len(self.md.output_nodes[node_name]) != 1):
                    continue

                transpose_node = node
                out_name = transpose_node.output[0]
                out_consumers = list(self.md.get_node_consumers(node_name))
                if not out_consumers:
                    continue

                next_node_name = out_consumers[0]
                next_node = create_bucket.get(next_node_name, nodes.get(next_node_name))
                if next_node is None:
                    continue

                input_node_names = self.md.input_nodes[node_name]
                if (len(input_node_names) == 1 and
                        self.md.nodes[input_node_names[0]].op_type == "DequantizeLinear" and
                        next_node.op_type == "QuantizeLinear" and
                        self.md.output_nodes[next_node.name]):
                    dq_node = self.md.get_node(input_node_names[0])
                    dq_name = dq_node.name
                    dq_proto = create_bucket.get(dq_name, dq_node)

                    # Check if transpose is a leaf and if DQ has upstream QuantizeLinear
                    if node_name in mod_subgraph.leaves:
                        dq_input_nodes = self.md.input_nodes.get(dq_name, [])
                        dq_output_nodes = self.md.output_nodes.get(dq_name, [])
                        if (len(dq_input_nodes) == 1 and len(dq_output_nodes) == 1 and
                                self.md.nodes[dq_input_nodes[0]].op_type == "QuantizeLinear"):
                            # Pattern: upstream_Q → DQ → Transpose (leaf) → Q → downstream
                            # Bypass by making upstream_Q output Q's output tensor
                            up_q_node = self.md.get_node(dq_input_nodes[0])
                            upstream_q_node = create_bucket.get(up_q_node.name, up_q_node)
                            q_output = next_node.output[0]

                            if len(self.md.output_nodes[upstream_q_node.name]) == 1:
                                # Modify upstream_Q to output Q's output tensor
                                upstream_q_modified = copy.deepcopy(upstream_q_node)
                                upstream_q_modified.output[0] = q_output
                                create_bucket[upstream_q_node.name] = upstream_q_modified
                                remove_bucket[upstream_q_node.name] = upstream_q_node.name
                                
                                # Remove the entire chain: DQ → Transpose → Q
                                if len(self.md.output_nodes[dq_proto.name]) == 1 and dq_proto.output[0] not in self.md.outputs:
                                    remove_bucket[dq_proto.name] = dq_proto.name
                                remove_bucket[node_name] = node_name
                                remove_bucket[next_node.name] = next_node.name
                                continue

                    # If transpose is a root: bypass by connecting downstream to DQ's input
                    for final_out in self.md.output_nodes[next_node.name]:
                        downstream = create_bucket.get(final_out, nodes.get(final_out))
                        if downstream is None:
                            continue
                        downstream = copy.deepcopy(downstream)
                        downstream_inputs = list(downstream.input)
                        try:
                            idx = downstream_inputs.index(next_node.output[0])
                        except ValueError:
                            continue
                        downstream_inputs[idx] = dq_proto.input[0]
                        downstream.input[:] = downstream_inputs
                        create_bucket[downstream.name] = downstream
                        remove_bucket[downstream.name] = downstream.name

                    remove_bucket[node_name] = node_name
                    if len(self.md.output_nodes[dq_proto.name]) == 1 and dq_proto.output[0] not in self.md.outputs:
                        remove_bucket[dq_proto.name] = dq_proto.name
                    remove_bucket[next_node.name] = next_node.name
                    continue

                next_node = copy.deepcopy(next_node)
                try:
                    idx = list(next_node.input).index(out_name)
                except ValueError:
                    continue
                next_node.input[idx] = transpose_node.input[0]
                create_bucket[next_node.name] = next_node
                remove_bucket[node_name] = node_name
                remove_bucket[next_node.name] = next_node.name

        return modification_dict, tag_ops_update
                        

    def transform_subgraphs(self, nchw_subgraphs: dict[int, NCHWSubgraph]) -> dict[int, NCHWSubgraph]:
        logger.info(
            f"Start layout transformation of NCHW-subgraphs: "
            f"{len(nchw_subgraphs)} subgraphs will be converted to NHWC layout"
        )

        # Use batching for better performance
        batch_size = min(50, max(10, len(nchw_subgraphs) // 20))  # Adaptive batch size
        logger.info(f"  Using batch size: {batch_size} for subgraph transformation")
        return self._modify_subgraphs_batched(nchw_subgraphs, batch_size)

    def _modify_subgraphs_batched(self, subgraphs_dict: dict[int, NCHWSubgraph], batch_size: int = 50) -> dict[
                                                                                                       int: NCHWSubgraph]:
        # Group subgraphs into batches
        subgraphs = list(subgraphs_dict.values())
        subgraph_batches: list[list[NCHWSubgraph]] = [
            subgraphs[i:i + batch_size] for i in range(0, len(subgraphs), batch_size)
        ]

        tag_ops_mod: dict[str, OpTag] = dict()
        processed_subgraphs = 0
        logger.info(f"  Processing {len(subgraphs)} subgraphs in {len(subgraph_batches)} batches")
        for batch_idx, batch in enumerate(subgraph_batches):
            batch_modifications: dict[int, dict] = {}
            batch_tag_updates = {}
            batch_processed = 0

            logger.info(
                f"  [BATCH] Processing batch {batch_idx + 1}/{len(subgraph_batches)} with {len(batch)} subgraphs"
            )

            # Process entire batch
            for subgraph in batch:
                subgraph_id = subgraph.id

                logger.debug(
                    f"  [PROGRESS] NCHW-Subgraph Transformation: {processed_subgraphs}/{len(subgraphs)} (processing subgraph {subgraph_id})"
                )
                try:
                    modification_dict, tag_ops_update = self._process_chain_nodes(subgraph)

                    has_graph_edits = tag_ops_update or any(modification_dict[action] for action in modification_dict)
                    if has_graph_edits:
                        batch_modifications[subgraph_id] = modification_dict
                        batch_tag_updates.update(tag_ops_update)
                        batch_processed += 1

                    else:
                        pass

                except Exception as e:
                    logger.error(f"Failed to process subgraph {subgraph_id}: {e}")
                    logger.debug(f"Subgraph roots: {subgraph.roots}")
                    logger.debug(f"Subgraph members: {subgraph.body}")
                    logger.debug(f"Subgraph leaves: {subgraph.leaves}")
                    continue
                
                processed_subgraphs += 1
            
            # Apply all modifications in batch if any exist
            if batch_modifications:
                try:
                    consolidated_modifications = self._consolidate_batch_modifications(batch_modifications)
                    applied_updates = self._apply_graph_modifications(consolidated_modifications, batch_tag_updates)
                    
                    if applied_updates:
                        tag_ops_mod.update(applied_updates)
                        self.tag_ops.update(applied_updates)
                        
                        # Update model dict once per batch instead of per subgraph
                        self.md.update_dict(self.model)
                        
                    
                except Exception as e:
                    logger.error(f"Failed to apply batch modifications for batch {batch_idx + 1}: {e}")
                    continue

        # Update output nodes for modified tag ops
        if tag_ops_mod:
            leaf_nodes = self.md.get_leaves()
            for leaf_name, op_tag in tag_ops_mod.items():
                leaf_node = leaf_nodes.get(leaf_name)
                if not leaf_node:
                    continue
                for out in leaf_node.output:
                    if out not in self.md.outputs:
                        continue
                    self.model.graph.output.remove(self.md.outputs[out])
                    dtype_ = self.md.get_onnx_dtype(out)
                    shape_ = self.md.get_shape(out)
                    out_ = onnx.helper.make_tensor_value_info(out, dtype_, shape_)
                    self.model.graph.output.append(out_)

        return subgraphs_dict

    def _consolidate_batch_modifications(self, batch_modifications: dict[int, dict]):
        """
        Consolidate modifications from multiple subgraphs in a batch to reduce overhead.
        
        Args:
            batch_modifications: Dict mapping subgraph_id to modification_dict
            
        Returns:
            dict: Consolidated modification dictionary
        """
        
        # Initialize consolidated modifications
        actions = (
            'create_node', 'create_vinfo', 'create_ini', 'create_tag',
            'remove_constant', 'remove_node', 'remove_vinfo', 'remove_ini',
        )
        consolidated_mods = {action: {} for action in actions}
        
        # Consolidate all modifications across the batch
        for subgraph_id, modification_dict in batch_modifications.items():
            for action in actions:
                if action in modification_dict:
                    # Check for conflicts before merging
                    for key, value in modification_dict[action].items():
                        if key in consolidated_mods[action]:
                            # For most actions, conflicts should be identical (same operation)
                            # For create operations, verify they're the same
                            if action.startswith('create') and consolidated_mods[action][key] != value:
                                logger.warning(f"Batch consolidation conflict for {action}[{key}] in subgraph {subgraph_id}")
                                continue
                        consolidated_mods[action][key] = value
        
        return consolidated_mods

    def Reshape_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}

        node = md.nodes[node_name]
        output0 = node.output[0]
        out_tensor = md.vinfo[output0]
        out_shape   = self.get_shape(out_tensor)
        out_dtype = self.get_onnx_dtype(out_tensor)
        out_dim = len(out_shape)

        # converting output
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)
        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = onnx.helper.make_tensor_value_info(
            output0, out_dtype, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        # Reshape.input = [0:data:act, 1:shape:ini/act]
        shape_edge = node.input[1]
        if shape_edge not in md.ini:
            logger.debug(
                f"Unexpected call of 'Reshape' NWHC-handler at Node = {node.op_type}({node_name})"
            )
            return op_dict

        # FIXME: handler should be able to check if shape_nhwc is already created during handling of other node
        shape_edge_nhwc = node_name + '_shape'
        shape = numpy_helper.to_array(md.ini[shape_edge])
        op_dict['create_ini'][shape_edge_nhwc] = numpy_helper.from_array(np.array(
            out_shape_nhwc, dtype=shape.dtype), shape_edge_nhwc,
        )

        # replace node with its converted version
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        new_node.input[1] = shape_edge_nhwc
        op_dict['create_node'][node_name] = new_node

        return op_dict

    def Squeeze_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}
        
        node = md.nodes[node_name]
        output0 = node.output[0]
        out_tensor = md.vinfo[output0]
        out_shape   = self.get_shape(out_tensor)
        out_dtype = self.get_onnx_dtype(out_tensor)

        # converting output
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)

        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = onnx.helper.make_tensor_value_info(
            output0, out_dtype, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        #axes
        # Squeeze.input = [0:data:act, 1:axes:ini]
        axes_edge  = node.input[1]
        if axes_edge not in md.ini:
            raise KeyError(
                f"{node.op_type}({node_name}): "
                f"input tensor 'axes'={axes_edge} is an unknown initializer"
            )

        # FIXME: handler should be able to check if axes_nhwc is already created during handling of other node
        in_tensor = md.vinfo[node.input[0]]
        in_shape = self.get_shape(in_tensor)
        in_dim = len(in_shape)
        axes_edge_nhwc = f"{axes_edge}_{in_dim}D_NHWC"
        if axes_edge_nhwc not in md.ini:
            axes = numpy_helper.to_array(md.ini[axes_edge])
            axes_nhwc = np.array([
                self.axis_value_to_NHWC(axis, in_dim) for axis in axes
            ], dtype=axes.dtype)
            op_dict['create_ini'][axes_edge_nhwc] = numpy_helper.from_array(
                axes_nhwc, axes_edge_nhwc,
            )

        # replace node with its converted version
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        new_node.input[1] = axes_edge_nhwc
        op_dict['create_node'][node_name] = new_node

        return op_dict

    def Unsqueeze_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}

        node = md.nodes[node_name]
        output0 = node.output[0]
        out_tensor = md.vinfo[output0]
        out_shape   = self.get_shape(out_tensor)
        out_dtype = self.get_onnx_dtype(out_tensor)
        out_dim = len(out_shape)

        # converting output
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)
        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = onnx.helper.make_tensor_value_info(
            output0, out_dtype, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        # Unsqueeze.input = [0:data:act, 1:axes:ini]
        axes_edge  = node.input[1]
        if axes_edge not in md.ini:
            raise KeyError(
                f"{node.op_type}({node_name}): "
                f"input tensor 'axes'={axes_edge} is an unknown initializer"
            )

        # FIXME: handler should be able to check if axes_nhwc is already created during handling of other node
        axes_edge_nhwc = f"{axes_edge}_{out_dim}D_NHWC"
        if axes_edge_nhwc not in md.ini:
            axes = numpy_helper.to_array(md.ini[axes_edge])
            axes_nhwc = np.array([
                self.axis_value_to_NHWC(axis, out_dim) for axis in axes
            ], dtype=axes.dtype)
            op_dict['create_ini'][axes_edge_nhwc] = numpy_helper.from_array(
                axes_nhwc, axes_edge_nhwc,
            )

        # replace node with its converted version
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        new_node.input[1] = axes_edge_nhwc
        op_dict['create_node'][node_name] = new_node
        return op_dict

    def MatMul_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}

        node = md.nodes[node_name]
        _logstr = f"NHWC[MatMul]"
        if len(node.input) != 2:
            logger.error(f"{_logstr} !! MatMul({node_name}) has {len(node.input)} != 2 inputs")
            return op_dict

        logger.debug(
            f"{_logstr} >> MatMul({node.name}) :: A({node.input[0]}) :: B({node.input[1]})"
        )

        out_tensor = md.vinfo[node.output[0]]
        out_shape, out_dtype = self.get_out_shape(node, with_dtype=True)
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)

        def default_MatMul_handler(_node):
            # converting output
            op_dict['remove_vinfo'][out_tensor.name] = out_tensor.name
            op_dict['create_vinfo'][out_tensor.name] = onnx.helper.make_tensor_value_info(
                out_tensor.name, out_dtype, out_shape_nhwc
            )
            if out_tensor.name in self.tensor_map:
                self.tensor_map[out_tensor.name] |= dict(
                    final_shape=out_shape_nhwc,
                    final_layout='NHWC',
                )

            # swap inputs
            out_dim = len(out_shape)
            if out_dim >= 3:
                new_node = copy.deepcopy(_node)
                new_node.input[0] = copy.deepcopy(_node.input[1])
                new_node.input[1] = copy.deepcopy(_node.input[0])
                op_dict['remove_node'][_node.name] = _node.name
                op_dict['create_node'][_node.name] = new_node
                logger.debug(
                    f"{_logstr} << MatMul*({_node.name}){out_shape}"
                    f" :: A*({new_node.input[0]}){md.get_shape(_node.input[1])}"
                    f" :: B*({new_node.input[1]}){md.get_shape(_node.input[0])}"
                )
            return op_dict

        in_node0, *in_nodes = md.input_nodes[node_name]
        if not in_nodes:
            in_ini1 = md.ini[node.input[1]]
            logger.warning(
                f"{_logstr} !! MatMul({node_name}) doesn't have in-node B, "
                f"instead input_B = Initializer({in_ini1.name})"
            )
            # FIXME: unhandled situation?
            return default_MatMul_handler(node)

        in_node1, *_ = in_nodes
        if in_node1 not in md.nodes:
            logger.error(
                f"{_logstr} !! Node not found: in-node B({in_node1}) of MatMul({node.name})"
            )
            return op_dict

        # Get in-nodes and log their properties
        in_node_A = md.get_node(in_node0)
        in_node_B = md.get_node(in_node1)      # for now don't care about act/bias
        in_subgr_A = self.nodes_to_subgraphs.get(in_node_A.name)
        in_subgr_B = self.nodes_to_subgraphs.get(in_node_B.name)

        # Pass if one of the in_nodes is outside NCHW-subgraph
        if in_subgr_A == in_subgr_B:
            return default_MatMul_handler(node)
     
        # Reassign A and B so that: A in subGr, B outside subGr
        in_idx_A, in_idx_B = 0, 1
        if in_subgr_A is None:
                in_node_A, in_node_B = in_node_B, in_node_A
                in_subgr_A, in_subgr_B = in_subgr_B, in_subgr_A
                in_idx_A, in_idx_B = 1, 0

        in_shape_A, in_dtype_A = self.get_out_shape(in_node_A, with_dtype=True)
        in_shape_B, in_dtype_B = self.get_out_shape(in_node_B, with_dtype=True)
        in_layout_A = self.get_tag(in_node_A)
        in_layout_B = self.get_tag(in_node_B)
        logger.debug(
            f"{_logstr}\t {out_shape} <- MatMul({node_name})  <- A{in_shape_A}  <- {in_node_A.op_type}<{in_layout_A}>({in_node_A.name})"
            f" :: ini={md.is_initializer_node(in_node_A)} :: subgr={in_subgr_A}"
        )
        logger.debug(
            f"{_logstr}\t {out_shape} <- MatMul({node_name})  <- B{in_shape_B}  <- {in_node_B.op_type}<{in_layout_B}>({in_node_B.name})"
            f" :: ini={md.is_initializer_node(in_node_B)} :: subgr={in_subgr_B}"
        )

        # Pass if both input nodes are tagged as <NCHW>
        if in_layout_A != "NCHW" or in_layout_B != "NCHW":
            return default_MatMul_handler(node)
        
        # Check if in_node_B is a DequantizeLinear node as we assume most of the cases it's a DQ node
        if in_node_B.op_type == "DequantizeLinear":
            sc_B = md.get_tensor(in_node_B.input[1])  # scale
            zp_B = md.get_tensor(in_node_B.input[2])  # zero_point
            in_dtype_B_float = md.get_onnx_dtype(sc_B)
            in_dtype_B_int = md.get_onnx_dtype(zp_B)
        else:
            act_float_dtype, act_int_dtype = DEFAULT_FLOAT_DTYPE, DEFAULT_INT_DTYPE
            sc_B = numpy_helper.from_array(
                np.array(1).astype(tensor_dtype_to_np_dtype(act_float_dtype)), f"{in_node_B.name}_dummy_qdq_scale"
            )
            zp_B = numpy_helper.from_array(
                np.array(0).astype(tensor_dtype_to_np_dtype(act_int_dtype)), f"{in_node_B.name}_dummy_qdq_zero_point"
            )
            sc_B = self.model.graph.initializer.append(sc_B)
            zp_B = self.model.graph.initializer.append(zp_B)
            in_dtype_B_float = md.get_onnx_dtype(sc_B)
            in_dtype_B_int = md.get_onnx_dtype(zp_B)
            
        
        # Implement the desired transformation pattern:
        # BEFORE: DQ(in_node_B) ---> MatMul
        # AFTER: DQ(in_node_B) --[transpose_input]-> Transpose Node --[transpose_out]-> Q --[q_out]-> DQ --[dq_out]-> MatMul
        
        transpose_input = in_node_B.output[0]  # This is the output of the original DQ node
        transpose_name = f"{in_node_B.name}_transpose_nhwc"
        transpose_out = f"{transpose_name}_out"
        q_name = f"{transpose_name}_q"
        q_out = f"{q_name}_out"
        dq_name = f"{transpose_name}_dq"
        dq_out = f"{dq_name}_out"

        nhwc_perm = NHWC_PERMUTATIONS.get(get_dim(in_shape_B), [])
        in_nhwc_shape_B = self.get_NHWC_shape(in_shape_B)

        # Create Transpose node
        op_dict['create_node'][transpose_name] = make_node(
            'Transpose',
            name=transpose_name,
            inputs=[transpose_input],
            outputs=[transpose_out],
            perm=nhwc_perm,
        )
        op_dict['create_vinfo'][transpose_out] = onnx.helper.make_tensor_value_info(
            transpose_out, in_dtype_B_float, in_nhwc_shape_B
        )
        op_dict['create_tag'][transpose_name] = OpTag(in_nhwc_shape_B, 'NHWC')

        
        # Create Q node after Transpose
        q, q_out = self._make_q_node(
                q_name, sc_B.name, zp_B.name,
                input=transpose_out,
                output_shape=in_nhwc_shape_B,
                output_dtype=in_dtype_B_int,
        )
        # where :  q_out is ValueInfoProto  AND q_out.name == "{q.name}_out"
        
        op_dict['create_node'][q_name] = q
        op_dict['create_vinfo'][q_out.name] = q_out
        op_dict['create_tag'][q_name] = OpTag(in_nhwc_shape_B, 'NHWC')

        # Create DQ node after Q
        dq, dq_out = self._make_dq_node(
                dq_name, sc_B, zp_B,
                input=q_out.name,
                output_shape=in_nhwc_shape_B,
                output_dtype=in_dtype_B_float,
            )
        op_dict['create_node'][dq_name] = dq
        op_dict['create_vinfo'][dq_out.name] = dq_out
        op_dict['create_tag'][dq_name] = OpTag(in_nhwc_shape_B, 'NHWC')
        
        new_matmul_node = copy.deepcopy(node)
        new_matmul_node.input[in_idx_A] = in_node_A.output[0]
        new_matmul_node.input[in_idx_B] = dq_out.name
        op_dict = default_MatMul_handler(new_matmul_node)

        logger.debug(f"{_logstr} :: Created transformation chain: DQ({in_node_B.name}) -> Transpose({transpose_name}) -> Q({q_name}) -> DQ({dq_name}) -> MatMul({node_name})")
        logger.debug(f"{_logstr} :: {out_shape_nhwc} <- MatMul*({node_name}).A  <-{self.get_NHWC_shape(in_shape_A)}--  A*")
        logger.debug(f"{_logstr} :: {out_shape_nhwc} <- MatMul*({node_name}).B  <-{in_nhwc_shape_B}--  Tr*{tuple(nhwc_perm)} <- B*")
        return op_dict

    def Mul_to_NHWC(self, md, node_name) -> dict[ModAction, dict]:
        return self.common_binary_to_NHWC(md, node_name)

    def Div_to_NHWC(self, md, node_name) -> dict[ModAction, dict]:
        return self.common_binary_to_NHWC(md, node_name)

    def Add_to_NHWC(self, md, node_name) -> dict[ModAction, dict]:
        return self.common_binary_to_NHWC(md, node_name)

    def Sub_to_NHWC(self, md, node_name) -> dict[ModAction, dict]:
        return self.common_binary_to_NHWC(md, node_name)

    def common_binary_to_NHWC(self, md, node_name) -> dict[ModAction, dict]:
        op_dict: dict[ModAction, dict] = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'create_tag': {}, 'remove_constant': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}}

        node = md.nodes[node_name]
        _logstr = f"NHWC[{node.op_type.capitalize()}]"

        if len(node.input) != 2:
            logger.error(f"{_logstr} !! BinaryNode({node_name}) has {len(node.input)} != 2 inputs")
            return op_dict

        logger.debug(
            f"{_logstr} >> BiNode({node_name}) :: A({node.input[0]}) :: B({node.input[1]})"
        )

        in_node0, *in_nodes = md.input_nodes[node_name]
        if not in_nodes:
            in_ini1 = md.ini[node.input[1]]
            logger.debug(
                f"{_logstr} !! BinaryNode({node_name}) doesn't have input-node B, "
                f"instead input_B = Initializer({in_ini1.name})"
            )
            act_tensor = md.vinfo[node.input[0]]
            bias_tensor = md.ini[node.input[1]]
            act_shape  = self.get_shape(act_tensor)
            bias_shape = self.get_shape(bias_tensor)
            out_tensor = node.output[0]
            out_shape, out_dtype = self.get_out_shape(node, with_dtype=True)
            out_shape_nhwc = self.get_NHWC_shape(out_tensor)
            logger.debug(f"{_logstr}\t {out_shape} <- BiNode({node_name})  <- Act{act_shape}")
            logger.debug(f"{_logstr}\t {out_shape} <- BiNode({node_name})  <- Ini{bias_shape}")

            # pad scalar and vector bias up to activation tensor
            if (pad_shape := len(act_shape) - len(bias_shape)) > 0:
                padded_bias_shape = [1] * pad_shape + list(bias_shape)
            else:
                padded_bias_shape = list(bias_shape)
            bias_nhwc_shape, nhwc_perm = self.convert_shape_to_NHWC(padded_bias_shape)

            bias_nhwc = node.input[1]+'_nhwc'
            bias_array = numpy_helper.to_array(bias_tensor)
            if len(bias_shape) == 0:
                new_bias_array = bias_array
            else:
                padded_bias_array = bias_array.reshape(padded_bias_shape)
                new_bias_array = np.transpose(padded_bias_array, nhwc_perm)

            logger.debug(f"{_logstr} :: {out_shape_nhwc} <- BiNode*({node_name}).A  <-{self.get_NHWC_shape(act_shape)}--  Act*")
            logger.debug(f"{_logstr} :: {out_shape_nhwc} <- BiNode*({node_name}).B  <-{bias_nhwc_shape}--  Ini*")

            op_dict['create_vinfo'][out_tensor] = onnx.helper.make_tensor_value_info(
                out_tensor, out_dtype, out_shape_nhwc
            )
            op_dict['remove_vinfo'][out_tensor] = out_tensor
            if out_tensor in self.tensor_map:
                self.tensor_map[out_tensor] |= dict(
                    final_shape=out_shape_nhwc,
                    final_layout='NHWC',
                )

            op_dict['create_ini'][bias_nhwc] = numpy_helper.from_array(new_bias_array, bias_nhwc,)
            op_dict['remove_ini'][node.input[1]] = node.input[1]
            op_dict['remove_node'][node_name] = node_name
            new_node = copy.deepcopy(node)
            new_node.input[1] = bias_nhwc
            op_dict['create_node'][node_name] = new_node
            logger.debug(f"{_logstr} << BiNode*({node_name}){out_shape_nhwc}")
            return op_dict

        in_node1, *_ = in_nodes
        if in_node1 not in md.nodes:
            logger.error(
                f"{_logstr} !! Node not found: input-node B({in_node1}) of BinaryNode({node.name})"
            )
            return op_dict

        in_node0 = md.nodes[in_node0]
        in_node1 = md.nodes[in_node1]

        out_tensor = node.output[0]
        out_shape, out_dtype = self.get_out_shape(node, with_dtype=True)
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)

        if in_node0.input[0] in md.ini or in_node1.input[0] in md.ini:
            act = None      # DQ-node with Activation input
            bias = None     # DQ-node with Initializer input
            if in_node0.input[0] in md.ini:
                bias = in_node0.name
                act  = in_node1.name
            elif in_node1.input[0] in md.ini:
                act  = in_node0.name
                bias = in_node1.name

            bias_node = md.nodes[bias]
            act_node = md.nodes[act]
            act_shape  = copy.deepcopy(self.tag_ops[act].shape)
            bias_shape = copy.deepcopy(self.tag_ops[bias].shape)
            bias_dtype = self.get_onnx_dtype(bias_node.output[0])
            logger.debug(f"{_logstr}\t {out_shape} <- BiNode({node_name})  <- Act{act_shape} <-  {act_node.op_type}<{self.tag_ops[act].tag}>({act})")
            logger.debug(f"{_logstr}\t {out_shape} <- BiNode({node_name})  <- Bias{bias_shape} <-  {bias_node.op_type}<{self.tag_ops[bias].tag}>({bias}) <- Ini{self.md.get_shape(bias_node.input[0])}")

            # pad scalar and vector bias up to activation tensor
            if (pad_shape := len(act_shape) - len(bias_shape)) > 0:
                bias_shape = [1] * pad_shape + list(bias_shape)
            bias_nhwc_shape, nhwc_perm = self.convert_shape_to_NHWC(bias_shape)
            
            # ==================================================================================== #
            # Direct initializer shape update WITHOUT Transpose node insertion
            # ==================================================================================== #
            
            bias_dq_ini_name = bias_node.input[0]
            bias_dq_ini = md.ini[bias_dq_ini_name]

            # Get the initializer data
            bias_ini_data, bias_ini_dtype = onnxTensorProto_to_array(bias_dq_ini, transpose=False)

            # Pad if necessary
            if pad_shape > 0:
                new_shape = [1] * pad_shape + list(bias_ini_data.shape)
                bias_ini_data = bias_ini_data.reshape(new_shape)

            # Apply NHWC transpose to the actual data
            bias_ini_data_nhwc = np.transpose(bias_ini_data, axes=nhwc_perm)

            # Create new initializer with transposed data and shape
            new_bias_ini = onnxTensorProto_from_array(
                bias_ini_data_nhwc,
                f"{node.name}_ini_nhwc",  # Give it a new name to avoid conflicts
                og_dtype=bias_ini_dtype,
            )

            # Create a copy of the DQ node with updated input reference
            new_bias_node = copy.deepcopy(bias_node)
            new_bias_node.name = f"{bias_node.name}__{node.name}_ini_nhwc"
            new_bias_node.input[0] = new_bias_ini.name  # Point to the new initializer
            new_bias_node.output[0] = f"{node.name}_ini_input_nhwc"
            
            # Convert n.bias and n.output to NHWC:
            # 1) shield bias-node with NHWC-Transpose
            # 2) convert bias-input tensor to NHWC
            # 3) convert output tensor to NHWC
            # FROM: <out> <- N <- <bias> <- B <- <b_ini>
            # TO  : <out*> <- N <- <bias*> <- Tr(NHWC) <- <bias> <- B <- <b_ini>
       
            # Prepare new tensor value info (for DQ output)
            new_bias_vinfo = onnx.helper.make_tensor_value_info(
                new_bias_node.output[0], bias_dtype, bias_nhwc_shape
            )
            
            new_node = copy.deepcopy(node)
            if in_node0.input[0] in md.ini:
                new_node.input[0] = new_bias_node.output[0]
            else:
                new_node.input[1] = new_bias_node.output[0]
            
            # Update final output tensor info to reflect NHWC layout
            op_dict['create_node'][new_bias_node.name] = new_bias_node
            op_dict['create_node'][node_name] = new_node
            op_dict['create_vinfo'][out_tensor] = onnx.helper.make_tensor_value_info(
                out_tensor, out_dtype, out_shape_nhwc
            )
            op_dict['remove_vinfo'][out_tensor] = out_tensor
            if out_tensor in self.tensor_map:
                self.tensor_map[out_tensor] |= dict(
                    final_shape=out_shape_nhwc,
                    final_layout='NHWC',
                )

            op_dict['create_vinfo'][new_bias_node.output[0]] = new_bias_vinfo

            # Add the new initializer and node, remove old ones
            op_dict['create_ini'][new_bias_ini.name] = new_bias_ini
            op_dict['remove_node'][node_name] = node_name
            op_dict['create_tag'][new_bias_node.name] = OpTag(bias_shape, 'ANY')
            logger.debug(f"{_logstr} << BiNode*({node_name}){out_shape_nhwc}")

            return op_dict
        else: #both inputs are activations
            in_node0_layout = self.tag_ops.get(in_node0.name, OpTag.NOT_FOUND()).tag
            in_node1_layout = self.tag_ops.get(in_node1.name, OpTag.NOT_FOUND()).tag
            in_node0_chains = str(self.nodes_to_subgraphs.get(in_node0.name, "NONE"))
            in_node1_chains = str(self.nodes_to_subgraphs.get(in_node1.name, "NONE"))
            logger.debug(f"{_logstr}\t {out_shape} <- BiNode({node_name})  <- Act(A){self.get_shape(in_node0.output[0])}  <- {in_node0.op_type}<{in_node0_layout}>({in_node0.name})")
            logger.debug(f"{_logstr}\t {out_shape} <- BiNode({node_name})  <- Act(B){self.get_shape(in_node1.output[0])}  <- {in_node1.op_type}<{in_node1_layout}>({in_node1.name})")

            if in_node0_layout == "NCHW" and in_node1_layout == "NCHW":
                if (in_node0_chains == "NONE") ^ (in_node1_chains == "NONE"):
                    if in_node0_chains == "NONE":
                        act = in_node1.name
                        bias = in_node0.name
                    else:
                        act = in_node0.name
                        bias = in_node1.name

                    #Insert transpose above Bias DQ
                    bias_node = md.nodes[bias]
                    act_node = md.nodes[act]
                    act_shape  = copy.deepcopy(self.tag_ops[act].shape)
                    bias_shape = copy.deepcopy(self.tag_ops[bias].shape)
                    bias_dtype = self.get_onnx_dtype(bias_node.output[0])
                    bias_input_dtype = self.get_onnx_dtype(bias_node.input[0])

                    bias_nhwc_shape, nhwc_perm = self.convert_shape_to_NHWC(bias_shape)

                    transpose_name = f"{bias_node.name}_transpose"
                    transpose_out = f"{transpose_name}_out"

                    new_bias_node = copy.deepcopy(bias_node)
                    new_bias_node.input[0] = transpose_out

                    op_dict['create_node'][transpose_name] = make_node(
                        'Transpose',
                        name=transpose_name,
                        inputs=[bias_node.input[0]],
                        outputs=[transpose_out],
                        perm=nhwc_perm,
                    )
                    op_dict['create_node'][new_bias_node.name] = new_bias_node
                    op_dict['create_vinfo'][new_bias_node.input[0]] = onnx.helper.make_tensor_value_info(
                        new_bias_node.input[0], bias_input_dtype, bias_nhwc_shape
                    )

                    op_dict['create_vinfo'][bias_node.output[0]] = onnx.helper.make_tensor_value_info(
                        bias_node.output[0], bias_dtype, bias_nhwc_shape
                    )
                    op_dict['remove_vinfo'][bias_node.output[0]] = bias_node.output[0]

                    op_dict['create_vinfo'][out_tensor] = onnx.helper.make_tensor_value_info(
                        out_tensor, out_dtype, out_shape_nhwc
                    )
                    op_dict['remove_vinfo'][out_tensor] = out_tensor
                    if out_tensor in self.tensor_map:
                        self.tensor_map[out_tensor] |= dict(
                            final_shape=out_shape_nhwc,
                            final_layout='NHWC',
                        )

                    op_dict['remove_node'][bias] = bias
                    op_dict['create_tag'][transpose_name] = OpTag(bias_nhwc_shape, 'NHWC')
                    return op_dict

            op_dict['remove_vinfo'][out_tensor] = out_tensor
            op_dict['create_vinfo'][out_tensor] = onnx.helper.make_tensor_value_info(
                out_tensor, out_dtype, out_shape_nhwc
            )
            logger.debug(f"{_logstr} << BiNode*({node_name}){out_shape_nhwc}")
            return op_dict

    def Concat_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}

        node = md.nodes[node_name]
        output0 = node.output[0]
        out_tensor = md.vinfo[output0]
        out_shape   = self.get_shape(out_tensor)
        out_dim = len(out_shape)
        out_dtype = self.get_onnx_dtype(out_tensor)

        # converting output
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)

        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = onnx.helper.make_tensor_value_info(
            output0, out_dtype, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        # replace node with its converted version
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        attrs = get_attrs(new_node)
        axis = attrs.get('axis', -1)
        new_axis = self.axis_value_to_NHWC(axis, out_dim)
        if 'axis' in attrs:
            for attr in new_node.attribute:
                if attr.name == 'axis':
                    attr.i = new_axis
        else:
            new_node.attribute.append(onnx.helper.make_attribute('axis', new_axis))

        for i, i_act in enumerate(node.input):
            in_node = self.md.get_writer(i_act)
            in_node_tag = self.tag_ops.get(in_node.name, OpTag.NOT_FOUND()).tag
            in_node_chains = str(self.nodes_to_subgraphs.get(in_node.name, "NONE"))
            if in_node_tag == "NCHW" and in_node_chains == "NONE":
                transpose_name = f"{in_node.name}_transpose"
                transpose_out = f"{transpose_name}_out"
                i_shape = self.md.get_shape(i_act)
                i_dtype = self.md.get_onnx_dtype(i_act)
                i_nhwc_shape, i_nhwc_perm = self.convert_shape_to_NHWC(i_shape)
                op_dict['create_node'][transpose_name] = make_node(
                    'Transpose',
                    name=transpose_name,
                    inputs=[in_node.output[0]],
                    outputs=[transpose_out],
                    perm=i_nhwc_perm,
                )
                op_dict['create_vinfo'][transpose_out] = onnx.helper.make_tensor_value_info(
                   transpose_out, i_dtype, i_nhwc_shape
                )
                new_node.input[i] = transpose_out
                op_dict['create_tag'][transpose_name] = OpTag(i_nhwc_shape, 'NHWC')

        self.md.get_node_suppliers(node)
        op_dict['create_node'][node_name] = new_node

        return op_dict

    def Resize_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}

        node = md.nodes[node_name]
        output0 = node.output[0]
        out_tensor = md.vinfo[output0]
        out_shape   = self.get_shape(out_tensor)
        out_dtype = self.get_onnx_dtype(out_tensor)
        out_dim = len(out_shape)

        # converting output
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)
        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = onnx.helper.make_tensor_value_info(
            output0, out_dtype, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        # Resize.input = [0:data:act, 2:scales:ini, 3:sizes:ini]
        scales_edge = node.input[2]
        sizes_edge = node.input[3] if len(node.input) > 3 else None
        if scales_edge not in md.ini and sizes_edge and sizes_edge not in md.ini:
            logger.debug(
                f"Unexpected call of 'Resize' NWHC-handler at Node = {node.op_type}({node_name})"
            )
            return op_dict
        elif scales_edge not in md.ini and sizes_edge and sizes_edge in md.ini:
            sizes_edge_nhwc = node_name + '_sizes'
            if sizes_edge_nhwc not in md.ini:
                sizes = numpy_helper.to_array(md.ini[sizes_edge])
                sizes_nhwc = self.get_NHWC_shape(list(sizes))
                op_dict['create_ini'][sizes_edge_nhwc] = numpy_helper.from_array(np.array(
                    sizes_nhwc, dtype=sizes.dtype), sizes_edge_nhwc,
                )

            # replace node with its converted version
            op_dict['remove_node'][node_name] = node_name
            new_node = copy.deepcopy(node)
            new_node.input[3] = sizes_edge_nhwc
            op_dict['create_node'][node_name] = new_node

            return op_dict

        # FIXME: handler should be able to check if shape_nhwc is already created during handling of other node
        scales_edge_nhwc = node_name + '_scales'
        if scales_edge_nhwc not in md.ini:
            scales = numpy_helper.to_array(md.ini[scales_edge])
            scales_nhwc = self.get_NHWC_shape(list(scales))
            op_dict['create_ini'][scales_edge_nhwc] = numpy_helper.from_array(np.array(
                scales_nhwc, dtype=scales.dtype), scales_edge_nhwc,
            )

        # replace node with its converted version
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        new_node.input[2] = scales_edge_nhwc
        op_dict['create_node'][node_name] = new_node

        return op_dict

    def Slice_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}

        node = md.nodes[node_name]
        output0 = node.output[0]
        out_tensor = md.vinfo[output0]
        out_shape = self.get_shape(out_tensor)
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)
        dtype_ = self.get_onnx_dtype(out_tensor)
        dim_ = len(out_shape)

        # converting output
        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = make_tensor_value_info(
            output0, dtype_, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        # converting initializers
        # Slice.input == [0:data:act, 1:starts:ini, 2:ends:ini, 3:axes:ini, 4:steps:ini]
        axes_edge = node.input[3]
        if axes_edge not in md.ini:
            raise KeyError(
                f"{node.op_type}({node_name}): "
                f"input tensor 'axes'={axes_edge}  is an unknown initializer"
            )

        # FIXME: handler should be able to check if axes_nhwc is already created during handling of other node
        axes_edge_nhwc = f"{axes_edge}_{dim_}D_NHWC"
        if axes_edge_nhwc not in md.ini:
            axes = numpy_helper.to_array(md.ini[axes_edge])
            axes_nhwc = np.array([
                self.axis_value_to_NHWC(axis, dim_)
                for axis in axes
            ], dtype=axes.dtype)
            op_dict['create_ini'][axes_edge_nhwc] = numpy_helper.from_array(
                axes_nhwc, axes_edge_nhwc,
            )

        # replace node with its converted version
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        new_node.input[3] = axes_edge_nhwc
        op_dict['create_node'][node_name] = new_node
        return op_dict

    def Reduce_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}
        node = md.nodes[node_name]
        output0 = node.output[0]
        out_tensor = md.vinfo[output0]
        out_dtype = self.get_onnx_dtype(out_tensor)

        # converting output
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)

        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = onnx.helper.make_tensor_value_info(
            output0, out_dtype, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        # replace node with its converted version
        in_tensor = md.vinfo[node.input[0]]
        in_shape = self.get_shape(in_tensor)
        in_dim = len(in_shape)
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        attrs = get_attrs(new_node)
        axes = attrs.get('axes', [])
        new_axes = []
        for axis in axes:
            new_axes.append(self.axis_value_to_NHWC(axis, in_dim))
        # update the axes attribute of the new node
        for attr in new_node.attribute:
            if attr.name == 'axes':
                del attr.ints[:]  # Clear existing values
                attr.ints.extend(new_axes)  # Add new values

        if len(node.input) == 2:
            # Reduce.input == [0:data:act, 1:axes:ini]
            axes_edge = node.input[1]
            if axes_edge not in md.ini:
                raise KeyError(
                    f"{node.op_type}({node_name}): "
                    f"input tensor 'axes'={axes_edge}  is an unknown initializer"
                )

            # FIXME: handler should be able to check if axes_nhwc is already created during handling of other node
            axes_edge_nhwc = f"{axes_edge}_{in_dim}D_NHWC"
            new_node.input[1] = axes_edge_nhwc
            if axes_edge_nhwc not in md.ini:
                axes = numpy_helper.to_array(md.ini[axes_edge])
                axes_nhwc = np.array([
                    self.axis_value_to_NHWC(axis, in_dim)
                    for axis in axes
                ], dtype=axes.dtype)
                op_dict['create_ini'][axes_edge_nhwc] = numpy_helper.from_array(
                    axes_nhwc, axes_edge_nhwc,
                )

        op_dict['create_node'][node_name] = new_node

        return op_dict

    def axis_value_to_NHWC(self, axis, tensor_rank) -> int:
        axis_ = (axis + tensor_rank) % tensor_rank
        axis_map = NCHW_PERMUTATIONS.get(tensor_rank, [])

        return axis_map[axis_]

    def _axis_attribute_to_NHWC(self, node: NodeProto) -> int:
        input0 = self.md.get_node_activations(node, first=True)
        if not input0:
            raise ValueError(f"axis_to_NHWC: Node({node.name}) doesn't have input")

        rank = len(self.get_shape(input0))
        attrs = get_attrs(node)

        axis = attrs.get('axis', -1)
        axis_ = self.axis_value_to_NHWC(axis, rank)
        return axis_

    def _axes_attribute_to_NHWC(self, node: NodeProto) -> list[int]:
        input0 = self.md.get_node_activations(node, first=True)
        if not input0:
            raise ValueError(f"axis_to_NHWC: Node({node.name}) doesn't have input")

        rank = len(self.get_shape(input0))
        attrs = get_attrs(node)

        axes = attrs.get('axes', [])
        axes_ = [
            self.axis_value_to_NHWC(axis, rank)
            for axis in axes
        ]
        return axes_

    def _axes_initializer_to_NHWC(self, node: NodeProto, axes_input_idx=1) -> onnx.TensorProto:
        input0 = self.md.get_node_activations(node, first=True)
        if not input0:
            raise ValueError(f"axis_to_NHWC: Node({node.name}) doesn't have input")

        rank = len(self.get_shape(input0))
        axes_edge = node.input[axes_input_idx]
        axes_edge_nhwc = f"{axes_edge}_{rank}D_NHWC"
        if axes_edge not in self.md.ini:
            raise KeyError(
                f"{node.op_type}({node.name}): "
                f"input tensor 'axes'={axes_edge}  is an unknown initializer"
            )

        axes_edge_nhwc = f"{axes_edge}_{rank}D_NHWC"
        if axes_ := self.md.ini.get(axes_edge_nhwc):
            return axes_
        axes = numpy_helper.to_array(self.md.ini[axes_edge])
        axes_nhwc = np.array([
            self.axis_value_to_NHWC(axis, rank)
            for axis in axes
        ], dtype=axes.dtype)
        axes_ = numpy_helper.from_array(
            axes_nhwc, axes_edge_nhwc,
        )
        
        return axes_

    def LayerNormalization_to_NHWC(self, md, node_name):
        return self.common_axis_attr_to_NHWC(md, node_name)

    def LpNormalization_to_NHWC(self, md, node_name):
        return self.common_axis_attr_to_NHWC(md, node_name)

    def Softmax_to_NHWC(self, md, node_name):
        return self.common_axis_attr_to_NHWC(md, node_name)

    def Split_to_NHWC(self, md, node_name):
        return self.common_axis_attr_to_NHWC(md, node_name)

    def common_axis_attr_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}

        node = md.nodes[node_name]
        output0 = node.output[0]
        out_tensor = md.vinfo[output0]
        out_shape   = self.get_shape(out_tensor)
        out_dim = len(out_shape)
        out_dtype = self.get_onnx_dtype(out_tensor)

        # converting output
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)

        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = onnx.helper.make_tensor_value_info(
            output0, out_dtype, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        # replace node with its converted version
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        attrs = get_attrs(new_node)
        axis = attrs.get('axis', -1)
        new_axis = self.axis_value_to_NHWC(axis, out_dim)
        if 'axis' in attrs:
            for attr in new_node.attribute:
                if attr.name == 'axis':
                    attr.i = new_axis
        else:
            new_node.attribute.append(onnx.helper.make_attribute('axis', new_axis))

        op_dict['create_node'][node_name] = new_node

        return op_dict

    def Flatten_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}

        node = md.nodes[node_name]
        output0 = node.output[0]
        out_tensor = md.vinfo[output0]
        in_tensor = md.vinfo[node.input[0]]
        in_shape = self.get_shape(in_tensor)
        in_dim = len(in_shape)
        out_dtype = self.get_onnx_dtype(out_tensor)

        # converting output
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)

        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = onnx.helper.make_tensor_value_info(
            output0, out_dtype, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        # replace node with its converted version
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        attrs = get_attrs(new_node)
        axis = attrs.get('axis', -1)
        new_axis = self.axis_value_to_NHWC(axis, in_dim)
        if 'axis' in attrs:
            for attr in new_node.attribute:
                if attr.name == 'axis':
                    attr.i = new_axis
        else:
            new_node.attribute.append(onnx.helper.make_attribute('axis', new_axis))

        op_dict['create_node'][node_name] = new_node

        return op_dict

    def Pad_to_NHWC(self, md, node_name):
        op_dict = {'create_node': {}, 'create_vinfo': {}, 'create_ini': {}, 'remove_node': {}, 'remove_vinfo': {}, 'remove_ini': {}, 'create_tag': {}}

        node = md.nodes[node_name]
        output0 = node.output[0]
        act_tensor = md.get_tensor(node.input[0])
        out_tensor = md.get_tensor(output0)
        out_shape = self.get_shape(out_tensor)
        out_shape_nhwc = self.get_NHWC_shape(out_tensor)
        dtype_ = self.get_onnx_dtype(out_tensor)
        act_rank = len(md.get_shape(act_tensor))
        out_rank = len(out_shape)

        # converting output
        op_dict['remove_vinfo'][output0] = output0
        op_dict['create_vinfo'][output0] = make_tensor_value_info(
            output0, dtype_, out_shape_nhwc
        )
        if output0 in self.tensor_map:
            self.tensor_map[output0] |= dict(
                final_shape=out_shape_nhwc,
                final_layout='NHWC',
            )

        # converting initializers
        # Pad.input == [0:data:act, 1:pads:ini, 2:constant_value:ini, 3:axes:ini]
        pads_edge = node.input[1]
        axes_edge = node.input[3] if len(node.input) > 3 else None
        if pads_edge not in md.ini:
            raise KeyError(
                f"Pad({node_name}): "
                f"input tensor 'pads'={pads_edge}  is an unknown initializer"
            )
        if axes_edge and axes_edge not in md.ini:
            raise KeyError(
                f"Pad({node_name}): "
                f"input tensor 'axes'={axes_edge}  is an unknown initializer"
            )
        elif axes_edge:
            raise ValueError(
                f"Pad({node_name}): "
                f"input tensor 'axes'={axes_edge}  -- UNHANDLED EXCEPTION --"
            )

        pads_edge_nhwc = f"{pads_edge}_{out_rank}D_NHWC"
        if pads_edge_nhwc not in md.ini:
            pads = numpy_helper.to_array(md.ini[pads_edge])
            pads_begin, pads_end = list(pads[0: act_rank]), list(pads[act_rank: 2 * act_rank])
            pads_begin_nhwc = self.get_NHWC_shape(pads_begin)
            pads_end_nhwc = self.get_NHWC_shape(pads_end)
            pads_nhwc = np.array(pads_begin_nhwc + pads_end_nhwc, dtype=pads.dtype)
            op_dict['create_ini'][pads_edge_nhwc] = numpy_helper.from_array(
                pads_nhwc, pads_edge_nhwc,
            )

        # replace node with its converted version
        op_dict['remove_node'][node_name] = node_name
        new_node = copy.deepcopy(node)
        new_node.input[1] = pads_edge_nhwc
        op_dict['create_node'][node_name] = new_node
        return op_dict


def save_nhwc_model(model, model_name, out_dir, save_data=False):
    # save model
    out_model_path = os.path.join(out_dir, f"{model_name}_mod_nhwc.onnx")
    graph_sort(model, 0)
    save_model(model, out_model_path, save_data)


def main(**args):
    model_path = args['model_path']
    model_name = args['model_name']
    out_dir = args['output_dir']
    load_data = int(args.get('load_data', False))
    dbg = args.get('debug', False)

    # load model
    model = onnx.load_model(model_path, load_external_data=load_data)

    print(f'Begin NHWC conversion')
    layout_converter = NHWCLayoutConverter(
        model, f"{model_name}_mod", out_dir,
        external_data=load_data,
        dbg=dbg,
    )

    # PRE-FUSION
    layout_converter.pre_fusion(dbg=dbg)

    # SHIELDING & LAYOUT TAGGING
    layout_converter.layout_tagging_and_shielding(dbg=dbg)
    save_nhwc_model(model, model_name, out_dir, save_data=load_data)

    # CONVERT MODEL TO NHWC LAYOUT
    layout_converter.convert_layout_to_nhwc(dbg=dbg)
    save_nhwc_model(model, model_name, out_dir, save_data=load_data)

    return

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    default_output_dir = os.path.dirname(__file__)

    parser.add_argument("-mp", "--model_path", help="path to onnx model and output destination.Required Field")
    parser.add_argument("-dbg", "--debug", help="Dump dbg log to 'dbg_log.txt'", action="store_true", default=False)
    parser.add_argument('-output', '--output_dir', help="output directory", default=default_output_dir)
    parser.add_argument('-df','--debug_file_name', help="Debug log file name", default="dbg_log.txt")

    args = parser.parse_args()
    if not args.model_path:
        parser.error("Please pass path/to/onnx/model using -mp or --model_path flags.\npython3 parse_onnx_model.py --help\n\t\t\tfor further info.")

    if args.debug:
        #Run with Debug log enabled
        debug_file_path = os.path.join(args.output_dir, args.debug_file_name)
        class DEBUG_VERBOSE(Enum):
            debug = logging.DEBUG
            info  = logging.INFO
            error = logging.ERROR

            @classmethod
            def str2enum(enum_class, string_val):
                if string_val in enum_class.__members__:
                    return enum_class[string_val]
                else:
                    raise ValueError("String not found in str2enum. Str: "+str(string_val))

        verbose  = DEBUG_VERBOSE.str2enum('debug').value
        print(f"Saving debug log as : {debug_file_path}")

        logging.basicConfig(
        filename=debug_file_path,
        filemode='w',
        format='[%(asctime)s,%(msecs)d] [%(levelname)s]: %(message)s',
        datefmt='%M:%H:%S',
        level=verbose
        )

    main(**vars(args))
