##
##  Copyright (C) 2023 – 2024 Advanced Micro Devices, Inc.
##
##  Licensed under the Apache License, Version 2.0 (the "License");
##  you may not use this file except in compliance with the License.
##  You may obtain a copy of the License at
##
##  http://www.apache.org/licenses/LICENSE-2.0
##
##  Unless required by applicable law or agreed to in writing, software
##  distributed under the License is distributed on an "AS IS" BASIS,
##  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
##  See the License for the specific language governing permissions and
##  limitations under the License.
##
import sys
import os
import onnx
import OGOAT.src.Ort.onnx_graph as ogm
import OGOAT.src.Ort.partitioner as partitioner
import argparse
import json
import re
import copy
from collections import defaultdict

RUNTIME_NODES = ["Concat_runtime", "Split_runtime", "Gather_runtime", "Slice_runtime"]
NPU_NODES_NO_BIN = ["Transpose"]


def replace_slashes(line):
    out = re.sub(r"[^\w\s-]", "_", line)
    return out


def create_slice_groups(onnx_model):
    slice_groups = []
    slice_inputs = []
    for i, node in enumerate(onnx_model.graph.node):
        found = False
        for g in slice_groups:
            if i in g:
                found = True
        if found:
            continue
        if "Slice_qdq" in node.op_type:
            group = []
            input_channel = node.input[0]
            for j, nd in enumerate(onnx_model.graph.node):
                if node.op_type == nd.op_type:
                    inp_channel = nd.input[0]
                    if input_channel == inp_channel:
                        slice_inputs.append(input_channel)
                        group.append(j)
            slice_groups.append(group)
    return slice_groups, list(set(slice_inputs))


def create_idx_node_map(onnx_model):
    return {i: node.name for i, node in enumerate(onnx_model.graph.node)}


def create_node_idx_map(onnx_model):
    return {node.name: i for i, node in enumerate(onnx_model.graph.node)}


def get_initializers_list(onnx_model):
    return [init.name for init in onnx_model.graph.initializer]


def reverse_map(d):
    return {val: key for key, val in d.items()}


def create_adj_list(onnx_model, node_idx_map):
    ogm_graph = ogm.ONNXGraph(onnx_model)

    input_tensors = ogm_graph.getPrimaryInputs()
    stack = []
    visited_nodes = set()
    adj_list = {}
    for tensor in input_tensors:
        child_ops = ogm_graph.getSuccOp(tensor.name)
        for op in child_ops.keys():
            stack.append(op)
            visited_nodes.add(op)

    while stack:
        op = stack.pop()
        idx = node_idx_map[op]
        next_ops = ogm_graph.getSuccOp(op)
        next_op_ids = [node_idx_map[next_op] for next_op in next_ops.keys()]
        adj_list[idx] = next_op_ids

        for next_op in next_ops.keys():
            if next_op not in visited_nodes:
                stack.append(next_op)
                visited_nodes.add(next_op)

    return adj_list


def create_adj_list2(onnx_model, node_idx_map):
    ogm_graph = ogm.ONNXGraph(onnx_model)
    adj_list = {}
    for node in onnx_model.graph.node:
        node_id = node_idx_map[node.name]
        next_nodes = ogm_graph.getSuccOp(node.name)
        next_node_ids = [node_idx_map[name] for name, op in next_nodes.items()]
        adj_list[node_id] = next_node_ids

    return adj_list


def get_undirected_adj_list(adj_list, idx_node_map):
    new_adj_list = {}
    for idx in idx_node_map:
        new_adj_list[idx] = []
    for node in adj_list:
        for next_node in adj_list[node]:
            new_adj_list[node].append(next_node)
            new_adj_list[next_node].append(node)
    return new_adj_list


def label_node_property(onnx_model, node_idx_map, supported_ops, excluded_op_names):
    res = {}

    for node in onnx_model.graph.node:
        idx = node_idx_map[node.name]
        # Exclude if already marked excluded.
        if node.name in excluded_op_names:
            res[idx] = "CPU"
        else:
            res[idx] = "AIE"

    return res


def label_nodes_by_list(onnx_model, node_idx_map, skip_nodes):
    res = {}

    for node in onnx_model.graph.node:
        idx = node_idx_map[node.name]
        # Exclude if already marked excluded.
        if node.name in skip_nodes:
            res[idx] = "CPU"
        else:
            res[idx] = "AIE"

    return res


def update_skip_nodes(onnx_model, skip_nodes, node_list):
    assert node_list
    supported_ops = []
    with open(node_list, "r") as f:
        for line in f:
            supported_ops.append(line.strip())

    for node in onnx_model.graph.node:
        if node.name not in supported_ops:
            skip_nodes[node.name] = node.op_type


def append_to_skip_nodes(onnx_model, skip_nodes, node_list):
    for node in onnx_model.graph.node:
        if node.name in node_list:
            skip_nodes[node.name] = node.op_type
    return skip_nodes


def update_slice_labels(node_labels, slice_groups):
    fixed_groups = []
    for group in slice_groups:
        if group in fixed_groups:
            continue
        for idx in group:
            if node_labels[idx] == "CPU":
                for index in group:
                    node_labels[index] = "CPU"
                fixed_groups.append(group)
                break
    return node_labels


def partition_onnx_model(
    onnx_model, idx_node_map, supported_ops, skip_nodes, node_list, slice_groups
):  # filename with nodenames to
    # be cut
    excluded_op_names = {
        "1024_DequantizeLinear": "Add",  # Final Add in mxgan
        "/Gather_output_0_DequantizeLinear": "MatMul",  # Final MatMul in mxpzi, shape 1x768
        "input_1_QuantizeLinear": "QuantOP",
        "input_2_QuantizeLinear": "QuantOP",
        "output_1_DequantizeLinear": "DeQuantOP",
        # "/up_blocks.0/Concat_1_Concat2_qdq_uint16": "Concat2_qdq_uint16", #PSD5
        # Mul_3 shape for PSG0 old and new models
        "/encoder/layers.0/deformable_layer/self_attn/Mul_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/deformable_layer/self_attn/Mul_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/Mul_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.3/deformable_layer/self_attn/Mul_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.4/deformable_layer/self_attn/Mul_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.5/deformable_layer/self_attn/Mul_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        # Mul_4
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_2_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_3_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_4_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_5_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_7_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_12_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_13_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_14_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_1_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_3_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_5_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_6_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_8_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_10_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_15_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_2_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_1_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_3_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_4_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_5_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_6_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_7_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_8_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_9_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_10_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_11_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_12_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_13_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_14_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_15_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Mul_16_Mul_qdq_BroadCast_uint16xuint16xuint16": "Mul_qdq_BroadCast_uint16xuint16xuint16",
        # Mul_7
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_1_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_6_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_8_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_9_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_10_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_11_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_15_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.0/deformable_layer/self_attn/attn/Mul_16_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_2_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_4_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_7_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_9_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_11_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_12_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_13_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_14_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Mul_16_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_2_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_1_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_3_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_4_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_5_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_6_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_7_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_8_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_9_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_10_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_11_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_12_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_13_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_14_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_15_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Mul_16_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_2_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_1_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_3_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_4_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_5_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_6_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_7_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_8_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_9_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_10_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_11_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_12_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_13_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_14_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_15_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Mul_16_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_2_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_1_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_3_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_4_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_5_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_6_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_7_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_8_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_9_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_10_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_11_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_12_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_13_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_14_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_15_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Mul_16_Mul_qdq_BroadCast_uint8xuint8xuint8": "Mul_qdq_BroadCast_uint8xuint8xuint8",
        # Div_0
        "/encoder/layers.0/Div_Div_qdq_BroadCast_uint16xuint16xuint16": "Div_qdq_BroadCast_uint16xuint16xuint16",
        # Div_1
        "/encoder/Div_2_Div_qdq_BroadCast_uint16xuint16xuint16": "Div_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/Div_3_Div_qdq_BroadCast_uint16xuint16xuint16": "Div_qdq_BroadCast_uint16xuint16xuint16",
        # Div_2
        "/encoder/Div_5_Div_qdq_BroadCast_uint16xuint16xuint16": "Div_qdq_BroadCast_uint16xuint16xuint16",
        # Div_3
        "/encoder/Div_Div_qdq_BroadCast_uint16xuint16xuint16": "Div_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/Div_1_Div_qdq_BroadCast_uint16xuint16xuint16": "Div_qdq_BroadCast_uint16xuint16xuint16",
        # Div_4
        "/encoder/Div_6_Div_qdq_BroadCast_uint8xuint8xuint8": "Div_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/Div_7_Div_qdq_BroadCast_uint8xuint8xuint8": "Div_qdq_BroadCast_uint8xuint8xuint8",
        # Div_5
        "/encoder/Div_4_Div_qdq_BroadCast_uint8xuint8xuint8": "Div_qdq_BroadCast_uint8xuint8xuint8",
        # Sub_0
        "/encoder/layers.0/fusion_layer/attn/Sub_Sub_qdq_BroadCast_uint16xuint16xuint16": "Sub_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/fusion_layer/attn/Sub_Sub_qdq_BroadCast_uint16xuint16xuint16": "Sub_qdq_BroadCast_uint16xuint16xuint16",
        # Sub_1
        "/encoder/layers.0/fusion_layer/attn/Sub_1_Sub_qdq_BroadCast_uint16xuint16xuint16": "Sub_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.1/fusion_layer/attn/Sub_1_Sub_qdq_BroadCast_uint16xuint16xuint16": "Sub_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.2/fusion_layer/attn/Sub_1_Sub_qdq_BroadCast_uint16xuint16xuint16": "Sub_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.3/fusion_layer/attn/Sub_1_Sub_qdq_BroadCast_uint16xuint16xuint16": "Sub_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.4/fusion_layer/attn/Sub_1_Sub_qdq_BroadCast_uint16xuint16xuint16": "Sub_qdq_BroadCast_uint16xuint16xuint16",
        "/encoder/layers.5/fusion_layer/attn/Sub_1_Sub_qdq_BroadCast_uint16xuint16xuint16": "Sub_qdq_BroadCast_uint16xuint16xuint16",
        # Sub_3
        "/encoder/layers.2/fusion_layer/attn/Sub_Sub_qdq_BroadCast_uint8xuint8xuint8": "Sub_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.3/fusion_layer/attn/Sub_Sub_qdq_BroadCast_uint8xuint8xuint8": "Sub_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.4/fusion_layer/attn/Sub_Sub_qdq_BroadCast_uint8xuint8xuint8": "Sub_qdq_BroadCast_uint8xuint8xuint8",
        "/encoder/layers.5/fusion_layer/attn/Sub_Sub_qdq_BroadCast_uint8xuint8xuint8": "Sub_qdq_BroadCast_uint8xuint8xuint8",
        # Transpose_0
        "/encoder/layers.1/deformable_layer/self_attn/attn/Transpose_3_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_1
        "/encoder/layers.1/deformable_layer/self_attn/attn/Transpose_2_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_2
        "/encoder/layers.1/fusion_layer/attn/Transpose_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_3
        "/encoder/layers.1/deformable_layer/self_attn/Transpose_1_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/Transpose_1_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_4
        "/encoder/layers.2/deformable_layer/self_attn/Transpose_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        "/encoder/layers.3/deformable_layer/self_attn/Transpose_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        "/encoder/layers.4/deformable_layer/self_attn/Transpose_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        "/encoder/layers.5/deformable_layer/self_attn/Transpose_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_5
        "/encoder/layers.0/deformable_layer/self_attn/attn/Transpose_4_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        "/encoder/layers.1/deformable_layer/self_attn/attn/Transpose_4_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Transpose_4_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Transpose_4_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_6
        "/encoder/layers.1/deformable_layer/self_attn/attn/Transpose_1_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_7
        "/encoder/layers.0/text_enhancer_layer/self_attn/Transpose_3_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_8
        "/encoder/layers.2/fusion_layer/attn/Transpose_3_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_9
        "/encoder/layers.1/deformable_layer/self_attn/attn/Transpose_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_10
        "/encoder/layers.1/fusion_layer/attn/Transpose_5_Transpose_qdq_uint16xuint16": "Transpose_qdq_uint16xuint16",
        # Transpose_11
        "/encoder/layers.0/deformable_layer/self_attn/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_12
        "/encoder/layers.0/deformable_layer/self_attn/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_13
        "/encoder/layers.0/fusion_layer/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.0/fusion_layer/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/fusion_layer/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/fusion_layer/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/fusion_layer/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/fusion_layer/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/fusion_layer/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/fusion_layer/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/fusion_layer/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/fusion_layer/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/fusion_layer/attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_14
        "/encoder/layers.0/deformable_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_15
        "/encoder/layers.0/deformable_layer/self_attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/deformable_layer/self_attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_16
        "/encoder/layers.3/deformable_layer/self_attn/attn/Transpose_4_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Transpose_4_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_17
        "/encoder/layers.0/deformable_layer/self_attn/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_18
        "/encoder/layers.0/fusion_layer/attn/Transpose_6_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/fusion_layer/attn/Transpose_6_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/fusion_layer/attn/Transpose_6_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/fusion_layer/attn/Transpose_6_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/fusion_layer/attn/Transpose_6_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/fusion_layer/attn/Transpose_6_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_19
        "/encoder/layers.0/fusion_layer/attn/Transpose_7_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/fusion_layer/attn/Transpose_7_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/fusion_layer/attn/Transpose_7_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/fusion_layer/attn/Transpose_7_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/fusion_layer/attn/Transpose_7_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/fusion_layer/attn/Transpose_7_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_20
        "/encoder/layers.1/text_enhancer_layer/self_attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/text_enhancer_layer/self_attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/text_enhancer_layer/self_attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/text_enhancer_layer/self_attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/text_enhancer_layer/self_attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_21
        "/encoder/layers.0/fusion_layer/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.0/fusion_layer/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/fusion_layer/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/fusion_layer/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/fusion_layer/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/fusion_layer/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/fusion_layer/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/fusion_layer/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/fusion_layer/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/fusion_layer/attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/fusion_layer/attn/Transpose_3_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_22
        "/encoder/layers.0/text_enhancer_layer/self_attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.0/text_enhancer_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/text_enhancer_layer/self_attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/text_enhancer_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/text_enhancer_layer/self_attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/text_enhancer_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/text_enhancer_layer/self_attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/text_enhancer_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/text_enhancer_layer/self_attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/text_enhancer_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/text_enhancer_layer/self_attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/text_enhancer_layer/self_attn/Transpose_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_23
        "/encoder/layers.0/text_enhancer_layer/self_attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/text_enhancer_layer/self_attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/text_enhancer_layer/self_attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/text_enhancer_layer/self_attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/text_enhancer_layer/self_attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/text_enhancer_layer/self_attn/Transpose_2_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_24
        "/encoder/layers.0/deformable_layer/self_attn/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/deformable_layer/self_attn/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/deformable_layer/self_attn/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/deformable_layer/self_attn/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/deformable_layer/self_attn/attn/Transpose_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_25
        "/encoder/layers.0/fusion_layer/attn/Transpose_5_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/fusion_layer/attn/Transpose_5_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/fusion_layer/attn/Transpose_5_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/fusion_layer/attn/Transpose_5_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/fusion_layer/attn/Transpose_5_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # Transpose_26
        "/encoder/layers.0/fusion_layer/attn/Transpose_4_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.1/fusion_layer/attn/Transpose_4_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.2/fusion_layer/attn/Transpose_4_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.3/fusion_layer/attn/Transpose_4_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.4/fusion_layer/attn/Transpose_4_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/encoder/layers.5/fusion_layer/attn/Transpose_4_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # swin2sr_a8w8
        "/conv_after_body/Conv_Conv_qdq_bias_int8xint8xint8": "Conv_qdq_bias_int8xint8xint8",
        "/Resize_Resize_qdq_int8xint8": "Resize_qdq_int8xint8",
        "/Resize_1_Resize_qdq_int8xint8": "Resize_qdq_int8xint8",
        # sam2 encoder
        "/image_encoder/trunk/patch_embed/proj/Conv_out_transpose_nhwc_0_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        "/image_encoder/trunk/patch_embed/proj/Conv_in_transpose_nchw_1_Transpose_qdq_uint8xuint8": "Transpose_qdq_uint8xuint8",
        # sam2 decoder
        "/Add_output_0_QuantizeLinear_Quant_float32xuint8": "Quant_float32xuint8",
        "/Add_output_0_DequantizeLinear_Dequant_uint8xfloat32": "Dequant_uint8xfloat32",
        "/Constant_4_output_0_DequantizeLinear_Dequant_int8xfloat32": "Dequant_uint8xfloat32",
        "/Concat": "concat_runtime",
        "/Concat_output_0_QuantizeLinear_Quant_float32xuint8": "Quant_float32xuint8",
        "/Concat_output_0_DequantizeLinear/duplicated_Dequant_uint8xfloat32": "Dequant_uint8xfloat32",
        "/Mul_6_output_0_QuantizeLinear_Quant_float32xuint8": "Quant_float32xuint8",
        "/Sub_output_0_/MatMul_smooth_mul_Mul_qdq_BroadCast_uint8xint8xuint8": "Mul_qdq_BroadCast_uint8xint8xuint8",
        "/MatMul_MatMul_qdq_uint8xint8xuint8": "MatMul_qdq_uint8xint8xuint8",
        "/Mul_7_output_0_DequantizeLinear_Dequant_uint8xfloat32": "Dequant_uint8xfloat32",
        "/Mul_7_output_0_DequantizeLinear/duplicated_Dequant_uint8xfloat32": "Dequant_uint8xfloat32",
        "low_res_masks_QuantizeLinear_Quant_float32xuint8": "Quant_float32xuint8",
        "/Reshape_5_output_0_DequantizeLinear_Dequant_uint8xfloat32": "Dequant_uint8xfloat32",
        "/Reshape_5_output_0_DequantizeLinear/duplicated_Dequant_uint8xfloat32": "Dequant_uint8xfloat32",
        "/Sin_output_0_QuantizeLinear_Quant_float32xuint8": "Quant_float32xuint8",
        "/Cos_output_0_QuantizeLinear_Quant_float32xuint8": "Quant_float32xuint8",
        "/Concat_Concat2_qdq_uint8": "Concat_qdq_uint8",  # newly created concat because of converting concat_runtime to concat_qdq_uint8
        "/Concat_1_Concat2_qdq_uint8": "Concat_qdq_uint8",
        "/Sub_1_Sub_qdq_BroadCast_float32xuint8xuint8": "Sub_qdq_BroadCast_bfloat16xuint8xuint8",
        # accuracy dropping to complete 0 without these.
        "/Reshape_5_output_0_DequantizeLinear/duplicated_Dequant_uint8xfloat32": "Dequant_uint8xbfloat16",
        "low_res_masks_QuantizeLinear_Quant_float32xuint8": "Quant_bfloat16xuint8",
        # PSO11
        "/Softmax_1_Softmax_qdq_uint16xuint16": "Softmax_qdq_uint16xuint16",
        "/sep_cls_head.2/MatMul/MatMulAddFusion_MatMul_qdq_bias_uint16xuint8xuint16": "MatMul_qdq_bias_uint16xuint8xuint16",
        "/Sub_Sub_qdq_BroadCast_uint16xuint16xuint16": "Sub_qdq_BroadCast_uint16xuint16xuint16",
        "/Div_1_Div_qdq_BroadCast_uint16xuint16xuint16": "Div_qdq_BroadCast_uint16xuint16xuint16",
    }
    node_idx_map = reverse_map(idx_node_map)
    if node_list is not None:
        update_skip_nodes(onnx_model, skip_nodes, node_list)
        node_labels = label_nodes_by_list(onnx_model, node_idx_map, skip_nodes)
    else:
        excluded_op_names.update(skip_nodes)
        node_labels = label_node_property(
            onnx_model, node_idx_map, supported_ops, excluded_op_names
        )
    node_labels = update_slice_labels(node_labels, slice_groups)
    adj_list = create_adj_list2(onnx_model, node_idx_map)
    subgraphs = partitioner.partition_graph(adj_list, node_labels)
    cluster = partitioner.subgraph_labels_to_clusters(subgraphs)
    return subgraphs, cluster, node_labels


def get_cpu_cluster_nodes(onnx_model, idx_node_map, cpu_clusters):
    io_dict = {}
    for label, node_ids in cpu_clusters.items():
        inputs = set()
        outputs = set()
        dct = []
        init_list = get_initializers_list(onnx_model)
        for node_id in node_ids:
            nodename = idx_node_map[node_id]
            for node in onnx_model.graph.node:
                if node.name == nodename:
                    dct.append(node.name.rsplit("/", 1)[-1])
        io_dict[label] = dct

    return io_dict


def get_graph_channels(model):
    init_list = get_initializers_list(model)
    channel_list = set()
    for node in model.graph.node:
        for inp in node.input:
            if inp not in init_list:
                channel_list.add(inp)
        for out in node.output:
            if out not in init_list:
                channel_list.add(out)
    return list(channel_list)


def get_graph_in_channels(model):
    init_list = get_initializers_list(model)
    channel_list = []
    for node in model.graph.node:
        for inp in node.input:
            if inp not in init_list:
                channel_list.append(inp)
    for node in model.graph.node:
        for op in node.output:
            if op not in init_list and op not in channel_list:
                channel_list.append(op)
    return list(channel_list)


def get_input_idx(model, node_idx_map):
    dct = {}
    for node in model.graph.node:
        dct[node.output[0]] = node_idx_map[node.name]
    return dct


def remove_redundant_noop_nodes(
    onnx_model, filtered_inputs, filtered_outputs, orig_channels_list, init_list
):
    cut_nodes = {}
    all_ctr = []  # chennels to remove from sugbraph
    for out in filtered_outputs:
        if out in orig_channels_list:
            continue
        tmp_channels = [out]
        ctr = []
        for tmp_channel in tmp_channels:
            if tmp_channel in filtered_inputs:
                filtered_inputs.remove(tmp_channel)
                ctr.append(tmp_channel)
                continue
            if tmp_channel in orig_channels_list:
                continue
            tensor_map = get_node_from_out_channel(onnx_model)
            node = tensor_map.get(tmp_channel, None)
            if (
                not "noop" in node.op_type
                and node.op_type not in NPU_NODES_NO_BIN
                and node.op_type not in RUNTIME_NODES
            ):
                continue
            cut_nodes[node.name] = node.op_type
            for i in node.input:
                if not i in init_list:
                    tmp_channels.append(i)
            ctr.append(tmp_channel)
        tmp_channels = list(set(tmp_channels))
        ctr = list(set(ctr))
        tmp_channels = [i for i in tmp_channels if i not in ctr]
        all_in_graph = True
        for i in tmp_channels:
            if i not in orig_channels_list:
                all_in_graph = False
                break
        if not all_in_graph:
            continue
        all_ctr.extend(ctr)
        for i in tmp_channels:
            if i not in filtered_outputs:
                filtered_outputs.append(i)
    aftercut_outputs = [i for i in filtered_outputs if i not in all_ctr]

    all_ctr = []  # chennels to remove from sugbraph
    for inp in filtered_inputs:
        if inp in orig_channels_list:
            continue
        tmp_channels = [inp]
        ctr = []
        for tmp_channel in tmp_channels:
            if tmp_channel in filtered_outputs:
                filtered_outputs.remove(tmp_channel)
                ctr.append(tmp_channel)
                continue
            if tmp_channel in orig_channels_list:
                continue
            tensor_map = get_nodes_from_in_channel(onnx_model)
            nodes = tensor_map.get(tmp_channel, [])
            rm = True
            for node in nodes:
                if (
                    not "noop" in node.op_type
                    and node.op_type not in NPU_NODES_NO_BIN
                    and node.op_type not in RUNTIME_NODES
                ):
                    rm = False
                    continue
                cut_nodes[node.name] = node.op_type
                tmp_channels.extend(node.output)
            if not rm:
                continue
            ctr.append(tmp_channel)
        tmp_channels = list(set(tmp_channels))
        ctr = list(set(ctr))
        tmp_channels = [i for i in tmp_channels if i not in ctr]
        all_in_graph = True
        for i in tmp_channels:
            if i not in orig_channels_list:
                all_in_graph = False
                break
        if not all_in_graph:
            continue
        all_ctr.extend(ctr)
        for i in tmp_channels:
            if i not in filtered_inputs:
                filtered_inputs.append(i)
    aftercut_inputs = [i for i in filtered_inputs if i not in all_ctr]

    return aftercut_inputs, aftercut_outputs, cut_nodes


def make_list_unique(lst):
    seen = set()
    result = []
    for item in lst:
        if item not in seen:
            seen.add(item)
            result.append(item)
    return result


def check_bot(
    onnx_model,
    orig_channels_list,
    adj_list,
    idx_node_map,
    sub_subgraph_1,
    sub_subgraph_2,
    top_node_idx,
    bot_node_idx,
    channel_to_check,
    protected_channels,
):
    sg_1 = copy.deepcopy(sub_subgraph_1)
    sg_2 = copy.deepcopy(sub_subgraph_2)
    while (
        channel_to_check not in orig_channels_list
        or channel_to_check in protected_channels
    ):
        # going_down
        if len(adj_list[bot_node_idx]) > 1 or bot_node_idx not in sg_2:
            break
        up_count = 0
        up_node = -1
        for i, val in adj_list.items():
            if bot_node_idx in val:
                up_node = i
                up_count += 1
        if up_count > 1:
            break
        sg_2.remove(bot_node_idx)
        sg_1.append(bot_node_idx)
        top_node_idx = bot_node_idx
        bot_node_idx = adj_list[bot_node_idx][0]
        top_node_nm = idx_node_map[top_node_idx]
        bot_node_nm = idx_node_map[bot_node_idx]
        top_node = None
        bot_node = None
        for node in onnx_model.graph.node:
            if node.name == top_node_nm:
                top_node = node
            if node.name == bot_node_nm:
                bot_node = node
        assert top_node.output[0] in bot_node.input
        channel_to_check = top_node.output[0]
    return channel_to_check, sg_1, sg_2


def check_top(
    onnx_model,
    orig_channels_list,
    adj_list,
    idx_node_map,
    sub_subgraph_1,
    sub_subgraph_2,
    top_node_idx,
    bot_node_idx,
    channel_to_check,
    protected_channels,
):
    sg_1 = copy.deepcopy(sub_subgraph_1)
    sg_2 = copy.deepcopy(sub_subgraph_2)
    while (
        channel_to_check not in orig_channels_list
        or channel_to_check in protected_channels
    ):
        # going_up
        up_count = 0
        up_node = -1
        for i, val in adj_list.items():
            if top_node_idx in val:
                up_node = i
                up_count += 1
        if up_count > 1 or top_node_idx not in sg_1:
            break

        sg_1.remove(top_node_idx)
        sg_2.append(top_node_idx)
        bot_node_idx = top_node_idx
        top_node_idx = up_node
        top_node_nm = idx_node_map[top_node_idx]
        bot_node_nm = idx_node_map[bot_node_idx]
        top_node = None
        bot_node = None
        for node in onnx_model.graph.node:
            if node.name == top_node_nm:
                top_node = node
            if node.name == bot_node_nm:
                bot_node = node
        assert top_node.output[0] in bot_node.input
        channel_to_check = top_node.output[0]
    return channel_to_check, sg_1, sg_2


def refine_subgraphs(
    onnx_model,
    orig_channels_list,
    adj_list,
    idx_node_map,
    sub_subgraph_1,
    sub_subgraph_2,
    protected_channels,
):
    top_node_name = ""
    bot_node_name = ""
    top_node_ix = -1
    bot_node_ix = -1
    top_node = None
    bot_node = None
    for node_1 in sub_subgraph_1:
        for node_2 in sub_subgraph_2:
            if node_2 in adj_list[node_1]:
                top_node_name = idx_node_map[node_1]
                bot_node_name = idx_node_map[node_2]
                top_node_ix = node_1
                bot_node_ix = node_2
    if top_node_name == "" and bot_node_name == "":
        return sub_subgraph_1, sub_subgraph_2
    for node in onnx_model.graph.node:
        if node.name == top_node_name:
            top_node = node
        if node.name == bot_node_name:
            bot_node = node
    assert top_node.output[0] in bot_node.input
    channel_to_check = top_node.output[0]
    if (
        channel_to_check in orig_channels_list
        and channel_to_check not in protected_channels
    ):
        return sub_subgraph_1, sub_subgraph_2
    channel_to_check, sg_1, sg_2 = check_bot(
        onnx_model,
        orig_channels_list,
        adj_list,
        idx_node_map,
        sub_subgraph_1,
        sub_subgraph_2,
        top_node_ix,
        bot_node_ix,
        channel_to_check,
        protected_channels,
    )
    if (
        channel_to_check in orig_channels_list
        and channel_to_check not in protected_channels
        and len(sg_1) != 0
        and len(sg_2) != 0
    ):
        return sg_1, sg_2
    channel_to_check, sg_1, sg_2 = check_top(
        onnx_model,
        orig_channels_list,
        adj_list,
        idx_node_map,
        sub_subgraph_1,
        sub_subgraph_2,
        top_node_ix,
        bot_node_ix,
        channel_to_check,
        protected_channels,
    )
    if (
        channel_to_check in orig_channels_list
        and channel_to_check not in protected_channels
        and len(sg_1) != 0
        and len(sg_2) != 0
    ):
        return sg_1, sg_2
    return sub_subgraph_1, sub_subgraph_2


def handle_quant_inputs(
    filtered_inputs,
    orig_channels_list,
    fused_tensor_to_innode_map,
    orig_tensor_to_outnodes_map,
):
    new_inputs = []
    for i in filtered_inputs:
        if i not in orig_channels_list:
            node = fused_tensor_to_innode_map.get(i, None)
            if not node.op_type == "Quant_float32xuint16":
                continue
            q_input = node.input[0]
            if q_input not in orig_channels_list:
                continue
            orig_nodes = orig_tensor_to_outnodes_map.get(q_input, [])
            if not len(orig_nodes) == 1:
                continue
            if not orig_nodes[0].op_type == "QuantizeLinear":
                continue
            i = orig_nodes[0].output[0]
        new_inputs.append(i)
    return new_inputs


def handle_dequant_outputs(
    filtered_outputs,
    orig_channels_list,
    fused_tensor_to_outnodes_map,
    orig_tensor_to_innode_map,
):
    new_inputs = []
    for i in filtered_outputs:
        if i not in orig_channels_list:
            nodes = fused_tensor_to_outnodes_map.get(i, [])
            if len(nodes) != 1:
                continue
            if not nodes[0].op_type == "Dequant_uint16xfloat32":
                continue
            q_output = nodes[0].output[0]
            if q_output not in orig_channels_list:
                continue
            orig_node = orig_tensor_to_innode_map.get(q_output, None)
            if not orig_node.op_type == "DequantizeLinear":
                continue
            i = orig_node.input[0]
        new_inputs.append(i)
    return new_inputs


def get_cluster_inputs_outputs(
    orig_model,
    onnx_model,
    idx_node_map,
    aie_clusters,
    tensor_map_dict,
    protected_channels,
):
    orig_channels_list = get_graph_channels(orig_model)
    fused_channels_list = get_graph_in_channels(onnx_model)
    node_idx_map = reverse_map(idx_node_map)
    tensor_input_idx_map = get_input_idx(onnx_model, node_idx_map)
    adj_list = create_adj_list2(onnx_model, node_idx_map)
    io_dict = {}
    removed_nodes = {}
    graph_inputs = []
    graph_outputs = []
    fused_tensor_to_innode_map = get_node_from_out_channel(onnx_model)
    fused_tensor_to_outnodes_map = get_nodes_from_in_channel(onnx_model)
    for input in onnx_model.graph.input:
        graph_inputs.append(input.name)
    for output in onnx_model.graph.output:
        graph_outputs.append(output.name)
    for label, node_ids in aie_clusters.items():

        def detect_and_remove_loops(label, node_ids):
            dct = {}
            init_list = get_initializers_list(onnx_model)
            cut_nodes = {}

            def get_filtered_inputs_outputs(node_ids, cut_nodes):
                inputs = []
                outputs = []
                filtered_inputs = []
                filtered_outputs = []
                for node_id in node_ids:
                    nodename = idx_node_map[node_id]
                    for node in onnx_model.graph.node:
                        if node.name == nodename:
                            inputs.extend(node.input)
                            outputs.extend(node.output)
                for i in inputs:
                    if i not in outputs and i not in init_list and i != "":
                        filtered_inputs.append(i)
                for o in outputs:
                    if o == "":
                        continue
                    elif o in graph_outputs:
                        filtered_outputs.append(o)
                        continue
                    subgraph_out_count = 0
                    for i in inputs:
                        if o == i:
                            subgraph_out_count = subgraph_out_count + 1
                    full_fused_graph_out_count = 0
                    for i in fused_channels_list:
                        if o == i:
                            full_fused_graph_out_count = full_fused_graph_out_count + 1
                    if subgraph_out_count < full_fused_graph_out_count:
                        filtered_outputs.append(o)

                filtered_inputs = make_list_unique(filtered_inputs)
                filtered_outputs = make_list_unique(filtered_outputs)
                flag = True
                for out_ in filtered_outputs:
                    out_node = fused_tensor_to_innode_map.get(
                        out_, None
                    )  # getting the output node associated with a output
                    if (("noop" in out_node.op_type
                        or "runtime" in out_node.op_type
                        or out_ not in tensor_map_dict
                         ) and (out_ not in graph_outputs)) or out_node.op_type.startswith("Dequant_"):
                        if node_idx_map[out_node.name] in node_ids:
                            flag = False
                            node_ids.remove(node_idx_map[out_node.name])
                            cut_nodes[out_node.name] = out_node.op_type
                for in_ in filtered_inputs:
                    in_nodes = fused_tensor_to_outnodes_map.get(
                        in_, []
                    )  # getting the input nodes associated with a input
                    for in_node in in_nodes:
                        if (("noop" in in_node.op_type
                            or "runtime" in in_node.op_type
                            or in_ not in tensor_map_dict
                             ) and (in_ not in graph_inputs)) or in_node.op_type.startswith("Quant_"):
                            if node_idx_map[in_node.name] in node_ids:
                                flag = False
                                node_ids.remove(node_idx_map[in_node.name])
                                cut_nodes[in_node.name] = in_node.op_type
                if flag:
                    return (filtered_inputs, filtered_outputs)
                else:
                    return get_filtered_inputs_outputs(node_ids, cut_nodes)

            aftercut_inputs, aftercut_outputs = get_filtered_inputs_outputs(
                node_ids, cut_nodes
            )
            if not node_ids:  # empty npu subgraph
                removed_nodes.update(cut_nodes)
                return
            else:
                removed_nodes.update(cut_nodes)
            aftercut_inputs = make_list_unique(aftercut_inputs)
            aftercut_outputs = make_list_unique(aftercut_outputs)

            # detect loop in subgraph
            non_model_inputs = [
                input for input in aftercut_inputs if input not in graph_inputs
            ]
            non_model_outputs = [
                output for output in aftercut_outputs if output not in graph_outputs
            ]
            # check if there exists a path from any of the non_model_outputs to any of non_model_inputs in onnx_model
            target_nodes = [
                tensor_input_idx_map[subgraph_input]
                for subgraph_input in non_model_inputs
                if subgraph_input in tensor_input_idx_map
            ]
            loop_present = False
            for subgraph_output in non_model_outputs:
                start_node = tensor_input_idx_map[subgraph_output]
                reachable_target_nodes = []
                visited = set()
                queue = []

                def bfs(init_node):
                    queue.append(init_node)
                    while len(queue) > 0:
                        node = queue.pop(0)
                        if node in visited:
                            continue
                        visited.add(node)
                        if node in target_nodes:
                            reachable_target_nodes.append(node)
                        for next_node in adj_list.get(node, []):
                            queue.append(next_node)

                bfs(start_node)
                if len(reachable_target_nodes) != 0:
                    loop_present = True
                    print(f"Subgraph_{label} has a loop")
                    node_1 = start_node
                    node_2 = reachable_target_nodes[0]
                    break
            if loop_present:
                sub_subgraph_1 = []
                sub_subgraph_2 = []
                # code to cut the subgraph
                # getting all the nodes dependent on node_2
                visited = set()
                queue = []

                def bfs_2(init_node):
                    queue.append(init_node)
                    while len(queue) > 0:
                        node = queue.pop(0)
                        if node in visited:
                            continue
                        visited.add(node)
                        if node in node_ids:
                            sub_subgraph_2.append(node)
                        for next_node in adj_list.get(node, []):
                            queue.append(next_node)

                bfs_2(node_2)
                # build undirected version of the graph
                undirected_adj_list = get_undirected_adj_list(adj_list, idx_node_map)
                # remove nodes dependent on node_2 from the undirected graph
                for node in sub_subgraph_2:
                    undirected_adj_list.pop(node, None)
                for next_nodes in undirected_adj_list.values():
                    for node in sub_subgraph_2:
                        if node in next_nodes:
                            next_nodes.remove(node)
                # getting nodes connected to node_1 in the updated undirected graph
                visited = set()
                queue = []

                def bfs_1(init_node):
                    queue.append(init_node)
                    while len(queue) > 0:
                        node = queue.pop(0)
                        if node in visited:
                            continue
                        visited.add(node)
                        if node in node_ids:
                            sub_subgraph_1.append(node)
                        for next_node in undirected_adj_list.get(node, []):
                            queue.append(next_node)

                bfs_1(node_1)
                for node in node_ids:
                    if node not in sub_subgraph_1 and node not in sub_subgraph_2:
                        sub_subgraph_2.append(node)
                sub_subgraph_1, sub_subgraph_2 = refine_subgraphs(
                    onnx_model,
                    orig_channels_list,
                    adj_list,
                    idx_node_map,
                    sub_subgraph_1,
                    sub_subgraph_2,
                    protected_channels,
                )

                if "_" in label:
                    detect_and_remove_loops(label + "1", sub_subgraph_1)
                    detect_and_remove_loops(label + "2", sub_subgraph_2)
                else:
                    detect_and_remove_loops(label + "_1", sub_subgraph_1)
                    detect_and_remove_loops(label + "_2", sub_subgraph_2)
                # if "loop" in label:
                #    parts = label.rsplit('_', 1)
                #    if parts[1].isdigit():
                #        last_number = int(parts[1])
                #        last_number += 1
                #    modified_label = f"{parts[0]}_{last_number}"
                #    detect_and_remove_loops(modified_label, sub_subgraph_1)
                #    last_number += 1
                #    modified_label = f"{parts[0]}_{last_number}"
                #    detect_and_remove_loops(modified_label, sub_subgraph_2)
                # else:
                #    detect_and_remove_loops(label+'_1_loop_0', sub_subgraph_1)
                #    detect_and_remove_loops(label+'_2_loop_0', sub_subgraph_2)
            else:
                # removed_nodes.update(cut_nodes)
                if aftercut_inputs and aftercut_outputs:
                    dct["inputs"] = aftercut_inputs
                    dct["outputs"] = aftercut_outputs
                    dct["ids"] = node_ids
                    io_dict[label] = dct

        detect_and_remove_loops(label, node_ids)

    cluster_map = defaultdict(list)
    final_io_dict = {}
    for label, dct in io_dict.items():
        cluster_map[label.split("_")[0]].append(label)
    for label, dct in io_dict.items():
        new_label = (
            label.split("_")[0]
            + "_"
            + str(cluster_map[label.split("_")[0]].index(label))
        )
        final_io_dict[new_label] = dct

    return final_io_dict, removed_nodes


def check_and_fix_onnx_graph_name(
    model_path: str, default_name: str = "fixed_onnx_graph"
):
    """
    Checks if the ONNX model's graph has a name. If not, it assigns a default name
    and overwrites the original file with the corrected model.
    Args:
        model_path (str): The file path to the ONNX model.
        default_name (str): The name to assign if the current name is missing or empty.
    """
    if not os.path.exists(model_path):
        print(f"Error: Model file not found at {model_path}")
        return

    try:
        # 1. Load the model
        model = onnx.load(model_path)
        graph = model.graph

        # Flag to track if the model was modified
        modified = False

        # 2. Check and Update Name
        if not graph.name:
            graph.name = default_name
            modified = True
            # print(f"Graph name was missing. Setting name to: '{default_name}'")
        # else:
        # print(f"Graph already has a name: '{graph.name}'. No changes made.")

        # 3. Overwrite the file if modified
        if modified:
            # We use onnx.save to serialize and overwrite the original file
            onnx.save(model, model_path)
            # print(f"Successfully overwrote the fixed model to: {model_path}")

            # Optional: Run the checker to confirm
            try:
                onnx.checker.check_model(model)
                # print("ONNX Checker passed successfully on the fixed model.")
            except Exception as e:
                print(
                    f"Warning: Model checker still failed after fixing the name. Error: {e}"
                )

    except Exception as e:
        print(f"An error occurred while processing the ONNX file: {e}")


def cut_subgraph(onnx_model_path, io_dict, type_name, fld):
    folder_name = os.path.join(fld, "cut_graphs")
    os.makedirs(folder_name, exist_ok=True)
    io_node_dict = {}
    failed_nodes = []
    for label, io in io_dict.items():
        input_names = io["inputs"]
        output_names = io["outputs"]
        subgraph_path = onnx_model_path[:-5] + "_cluster_" + label + ".onnx"
        subgraph_path = os.path.basename(subgraph_path)
        full_subgraph_path = os.path.join(folder_name, subgraph_path)
        try:
            check_and_fix_onnx_graph_name(onnx_model_path)
            onnx.utils.extract_model(
                onnx_model_path,
                full_subgraph_path,
                input_names,
                output_names,
                infer_shapes=False,
            )
        except onnx.checker.ValidationError as e:
            print(
                f"""Cut subgraph '{full_subgraph_path}' failed 
                  the model check"""
            )
            print(f"Error: {e}")
            failed_nodes.extend(io["ids"])
        except Exception as e:
            print(
                f"""Unexpected error occured in cutting of subgraph 
                  {full_subgraph_path}"""
            )
            print(f"Error: {e}")
            failed_nodes.extend(io["ids"])
        io_node_dict[subgraph_path] = {}
        io_node_dict[subgraph_path]["inputs"] = io["inputs"]
        io_node_dict[subgraph_path]["outputs"] = io["outputs"]

    return io_node_dict, failed_nodes


def get_fuse_op_list(tilings_json, IR_json, use_inmem, bin_check=False):

    dir_name = os.path.dirname(IR_json)
    with open(tilings_json, "r") as f1:
        tiling_dict = json.load(f1)
    op_list = []
    generated_bins = []
    # print("===input IR json:",IR_json)
    for key, val in tiling_dict.items():
        if val is None or not val:
            continue
        # if val['layer_info']['op_type'] !=\
        #        val['layer_info']['orig_op_type']:
        op_list.append(
            val["layer_info"]["op_type"]
        )  # supported_ops are custom op types
        if use_inmem != "0":
            unique_shape_dir = os.path.join(dir_name, key)
            if not os.path.isdir(
                unique_shape_dir
            ):  # The unique shape dir does not exist -> CPU op
                continue
            if (
                os.path.isfile(os.path.join(unique_shape_dir, "txn.bin"))
                and os.path.isfile(os.path.join(unique_shape_dir, "param.bin"))
                and os.path.isfile(os.path.join(unique_shape_dir, "ctrl.bin"))
                and (
                    os.path.isfile(os.path.join(unique_shape_dir, "ctrl_meta.json"))
                    or os.path.isfile(os.path.join(unique_shape_dir, "patch.json"))
                )
            ):
                for n in val["layer_info"]["nodenames"]:
                    generated_bins.append(n)

    with open(IR_json, "r") as f:
        IR_dict = json.load(f)
    skip_nodes = {}
    for node, val in IR_dict.items():
        skip_id = False
        bin_folder = replace_slashes(val["node_name"])
        unique_lyr_subdir = os.path.join(dir_name, "DataGen", "Consts", bin_folder)
        if (
            "noop" in val["op_type"]
            or val["op_type"] in RUNTIME_NODES
            or val["op_type"] in NPU_NODES_NO_BIN
        ):
            skip_id = False  # NPU op
            op_list.append(val["op_type"])  # supported_ops are custom op types
            continue
        if bin_check:
            if use_inmem == "0":
                if not os.path.isdir(
                    unique_lyr_subdir
                ):  # The fused pattern dir does not exist -> CPU op
                    skip_id = True
                    print("=== Node", node, "treated as CPU op since bins are missing")
                else:
                    if (
                        not os.path.isfile(os.path.join(unique_lyr_subdir, "txn.bin"))
                        or not os.path.isfile(
                            os.path.join(unique_lyr_subdir, "param.bin")
                        )
                        or not os.path.isfile(
                            os.path.join(unique_lyr_subdir, "ctrl.bin")
                        )
                        or (
                            not os.path.isfile(
                                os.path.join(unique_lyr_subdir, "ctrl_meta.json")
                            )
                            and not os.path.isfile(
                                os.path.join(unique_lyr_subdir, "patch.json")
                            )
                        )
                    ):
                        print(
                            "\x1b[1;47;41m",
                            "Skipping node",
                            node,
                            ":",
                            val["op_type"],
                            "\x1b[0m",
                        )
                        print(
                            "\x1b[1;47;41m",
                            "Misssing one of generated bin files",
                            "\x1b[0m",
                        )
                        skip_id = True  # CPU op

                    elif not os.path.isfile(os.path.join(unique_lyr_subdir, "wgt.bin")):
                        print(
                            "\x1b[1;47;41m",
                            "Skipping node",
                            node,
                            ":",
                            val["op_type"],
                            "\x1b[0m",
                        )
                        print("\x1b[1;47;41m", "Misssing formatted wgt.bin", "\x1b[0m")
                        skip_id = True  # CPU op
            else:
                if not val["node_name"] in generated_bins:
                    skip_id = True
                    print("=== Node", node, "treated as CPU op since bins are missing")

        if skip_id == True:
            # print("Skipping the node",node)
            skip_nodes[node] = val[
                "op_type"
            ]  # check: custom op type (we would get from tilings.json) = original op type (we would get from IR.json) for such nodes
            # else:
            #     print(node,"with ID: ", key,"treated as NPU op")

    return op_list, skip_nodes


def get_node_from_out_channel(onnx_model):
    tensor_to_innode_map = {}
    ini_list = get_initializers_list(onnx_model)
    for node in onnx_model.graph.node:
        for out_ch in node.output:
            if out_ch not in ini_list:
                tensor_to_innode_map[out_ch] = node
    return tensor_to_innode_map


def get_nodes_from_in_channel(onnx_model):
    tensor_to_outnodes_map = defaultdict(list)
    ini_list = get_initializers_list(onnx_model)
    for node in onnx_model.graph.node:
        for in_ch in node.input:
            if in_ch not in ini_list:
                tensor_to_outnodes_map[in_ch].append(node)
    return tensor_to_outnodes_map


def get_node_by_name(model: onnx.ModelProto, node_name: str):
    for node in model.graph.node:
        if node.name == node_name:
            return node
    return None


def filter_input_nodes(orig_model, in_nodes, ex_outs, out_list):
    filtered_in_nodes = []
    for in_node in in_nodes:
        node_list = [get_node_by_name(orig_model, in_node)]
        found = False
        while not found:
            next_node_list = []
            for node in node_list:
                if node.output[0] in out_list:
                    filtered_in_nodes.append(in_node)
                    found = True
                    break
                if node.output[0] in ex_outs:
                    found = True
                    break
                for n in orig_model.graph.node:
                    if node.output[0] in n.input:
                        next_node_list.append(n)
            node_list = next_node_list

    return filtered_in_nodes


def validate_subgraph_io_counts(subgraph_path, expected_io, folder_name):
    """
    Validate that the number of inputs and outputs in the extracted subgraph
    matches the expected counts from the io variable.

    Args:
        subgraph_path (str): Path to the subgraph ONNX file
        expected_io (dict): Dictionary containing 'inputs' and 'outputs' lists
        folder_name (str): Directory containing the cut graphs

    Returns:
        bool: True if counts match, False otherwise
    """
    try:
        full_subgraph_path = os.path.join(folder_name, subgraph_path)
        if not os.path.exists(full_subgraph_path):
            return False

        # Load the subgraph
        subgraph_model = onnx.load(full_subgraph_path)

        # Count actual inputs and outputs in the subgraph
        subgraph_input_count = len(subgraph_model.graph.input)
        subgraph_output_count = len(subgraph_model.graph.output)

        # Count expected inputs and outputs
        io_input_count = len(expected_io["inputs"])
        io_output_count = len(expected_io["outputs"])

        # Compare counts
        inputs_match = subgraph_input_count == io_input_count
        outputs_match = subgraph_output_count == io_output_count

        if not inputs_match or not outputs_match:
            print(f"Subgraph {subgraph_path} I/O count mismatch:")
            print(f"  IO inputs: {io_input_count}, Subgraph: {subgraph_input_count}")
            print(f"  IO outputs: {io_output_count}, Subgraph: {subgraph_output_count}")
            return False

        return True

    except Exception as e:
        print(f"Error validating subgraph {subgraph_path}: {e}")
        return False


def get_io_subgraph_nodes(orig_model, aie_io_dict):
    all_outputs = []
    for key, val in aie_io_dict.items():
        all_outputs.extend(val["outputs"])
    for key, val in aie_io_dict.items():
        input_nodes = []
        output_nodes = []
        out_list = val["outputs"]
        in_list = val["inputs"]
        ex_outs = [out for out in all_outputs if out not in out_list]
        for out_ch in out_list:
            tensor_map = get_node_from_out_channel(orig_model)
            out_node = tensor_map.get(out_ch, None)
            if out_node:
                output_nodes.append(out_node.name)
        for in_ch in in_list:
            tensor_map = get_nodes_from_in_channel(orig_model)
            in_nodes = tensor_map.get(in_ch, [])
            in_nodes = [i for i in in_nodes if i is not None]
            in_node_names = [i.name for i in in_nodes]
            if len(in_nodes) == 1:
                input_nodes.extend(in_node_names)
            else:
                # print("+++ multiple input case +++")
                i_nodes = filter_input_nodes(
                    orig_model, in_node_names, ex_outs, out_list
                )
                input_nodes.extend(i_nodes)
        val["input_nodes"] = input_nodes
        val["output_nodes"] = output_nodes


def merge_slice_clusters(subgraph_node_cluster, slice_groups):
    merged_keys = []
    for key, val in subgraph_node_cluster.items():
        if key in merged_keys:
            continue
        for i in val:
            selected_group = []
            for group in slice_groups:
                if i in group:
                    selected_group = group
            for j in selected_group:
                for k, v in subgraph_node_cluster.items():
                    if k in merged_keys or k == key:
                        continue
                    if j in v and j != i:
                        subgraph_node_cluster[key].extend(subgraph_node_cluster[k])
                        merged_keys.append(k)
    for k in merged_keys:
        del subgraph_node_cluster[k]


def main(args):
    onnx_model_path = args.model_path
    orig_model_path = args.orig_model_path
    bin_check = args.bin_check
    if args.out_dir is not None:
        out_fld = args.out_dir
    else:
        out_fld = os.path.dirname(onnx_model_path)
    tilings_json = ""
    if args.tiling_json is not None:
        tilings_json = args.tiling_json
    else:
        fld = os.path.dirname(onnx_model_path)
        file_list = os.listdir(fld)
        for f in file_list:
            if "tilings" in f:
                tilings_json = os.path.join(fld, f)
                break
    if tilings_json == "None":
        print("Did not find the tilings json")
    IR_json = ""
    if args.IR_json_file is not None:
        IR_json = args.IR_json_file
    else:
        fld = os.path.dirname(onnx_model_path)
        file_list = os.listdir(fld)
        for f in file_list:
            if "IR.json" in f:
                IR_json = os.path.join(fld, f)
                break
    if IR_json == "None":
        print("Did not find the fused IR json")

    tensor_map_json = ""
    if args.tensor_map_json_file is not None:
        tensor_map_json = args.tensor_map_json_file
    else:
        fld = os.path.dirname(onnx_model_path)
        file_list = os.listdir(fld)
        for f in file_list:
            if "tensor_map.json" in f:
                tensor_map_json = os.path.join(fld, f)
                break
    if tensor_map_json == "None":
        print("Did not find the tensor map json file")

    with open(tensor_map_json, "r") as f:
        full_tensor_map = json.load(f)

    tensor_map_dict = {}
    for tensor in full_tensor_map:
        tensor_map_dict[tensor] = full_tensor_map[tensor]["orig_tensor"]

    node_list = args.node_list

    op_list, skip_nodes = get_fuse_op_list(
        tilings_json, IR_json, args.use_inmem, bin_check
    )

    if args.exclude_nodes is not None:
        with open(args.exclude_nodes, "r") as f:
            exclude_nodes = json.load(f)
        skip_nodes.update(exclude_nodes)

    onnx_model = onnx.load(onnx_model_path)
    orig_model = onnx.load(orig_model_path)

    orig_channels_list = get_graph_channels(orig_model)
    for orig_tensor in orig_channels_list:
        if orig_tensor not in tensor_map_dict:
            tensor_map_dict[orig_tensor] = orig_tensor
    tensor_map_dict["/image_encoder/trunk/Add_1_QuantizeLinear_Output"] = (
        "/image_encoder/trunk/Add_1_output_0_QuantizeLinear"
    )

    idx_node_map = create_idx_node_map(onnx_model)
    slice_groups, slice_inputs = create_slice_groups(onnx_model)
    node_subgraphs_label, subgraph_node_cluster, target_label = partition_onnx_model(
        onnx_model, idx_node_map, op_list, skip_nodes, node_list, slice_groups
    )
    print("cluster:", subgraph_node_cluster)
    merge_slice_clusters(subgraph_node_cluster, slice_groups)

    # ff = 0
    aie_clusters = {}
    cpu_clusters = {}
    for subgraph_label, subgraph_nodes in subgraph_node_cluster.items():
        if subgraph_nodes and target_label[subgraph_nodes[0]] == "CPU":
            cpu_clusters[subgraph_label] = subgraph_nodes
            continue
        aie_clusters[str(subgraph_label)] = subgraph_nodes
        # ff = ff + 1
        # for node_id in subgraph_nodes:
        #     print(f"{subgraph_label} : {idx_node_map[node_id]} : {target_label[subgraph_nodes[0]]}")

    aie_io_dict, removed_nodes = get_cluster_inputs_outputs(
        orig_model,
        onnx_model,
        idx_node_map,
        aie_clusters,
        tensor_map_dict,
        slice_inputs,
    )
    print(removed_nodes)
    cpu_node_dict = get_cpu_cluster_nodes(onnx_model, idx_node_map, cpu_clusters)
    skip_nodes.update(removed_nodes)
    io_node_dict, failed_ids = cut_subgraph(
        onnx_model_path, aie_io_dict, "aie", out_fld
    )

    folder_name = os.path.join(out_fld, "cut_graphs")
    remove_subgraphs = []
    for subgraph_path, io in io_node_dict.items():
        io["inputs"] = [
            tensor_map_dict[input_]
            for input_ in io["inputs"]
            if input_ in tensor_map_dict
        ]
        io["outputs"] = [
            tensor_map_dict[output_]
            for output_ in io["outputs"]
            if output_ in tensor_map_dict
        ]
        if (
            (not io["inputs"])
            or (not io["outputs"])
            or (not set(io["inputs"]).isdisjoint(set(io["outputs"])))
        ):
            if os.path.exists(os.path.join(folder_name, subgraph_path)):
                os.remove(os.path.join(folder_name, subgraph_path))
                remove_subgraphs.append(subgraph_path)
                label = (
                    subgraph_path[:-5].split("_")[-2]
                    + "_"
                    + subgraph_path[:-5].split("_")[-1]
                )
                remove_nodes = [idx_node_map[idx] for idx in aie_io_dict[label]["ids"]]
                skip_nodes = append_to_skip_nodes(onnx_model, skip_nodes, remove_nodes)

        # Validate that I/O counts match between subgraph and expected values
        # Only check if not already marked for removal
        elif subgraph_path not in remove_subgraphs and not validate_subgraph_io_counts(
            subgraph_path, io, folder_name
        ):
            if os.path.exists(os.path.join(folder_name, subgraph_path)):
                os.remove(os.path.join(folder_name, subgraph_path))
                remove_subgraphs.append(subgraph_path)
                label = (
                    subgraph_path[:-5].split("_")[-2]
                    + "_"
                    + subgraph_path[:-5].split("_")[-1]
                )
                remove_nodes = [idx_node_map[idx] for idx in aie_io_dict[label]["ids"]]
                skip_nodes = append_to_skip_nodes(onnx_model, skip_nodes, remove_nodes)
    for subgraph in remove_subgraphs:
        io_node_dict.pop(subgraph)
    with open(os.path.join(folder_name, "context.json"), "w") as f:
        json.dump(io_node_dict, f, indent=2)

    print("Total #subgraphs (CPU+NPU) : ", len(subgraph_node_cluster))
    print("NPU #subgraphs : ", len(aie_clusters))
    print("CPU #subgraphs : ", len(cpu_clusters))
    print(
        "Total #subgraphs after loop removal (CPU+NPU) : ",
        len(cpu_clusters) + len(io_node_dict),
    )
    print("NPU #subgraphs after loop removal : ", len(io_node_dict))
    print("CPU #subgraphs after loop removal : ", len(cpu_clusters))
    print(
        "#subgraphs removed after loop removal because they are noop, runtime op and doesn't have input/output tensors : ",
        len(remove_subgraphs),
    )

    failed_nodes = [idx_node_map[ix] for ix in failed_ids]
    skip_nodes = append_to_skip_nodes(onnx_model, skip_nodes, failed_nodes)
    with open(os.path.join(out_fld, "skipped_nodes.json"), "w") as file:
        json.dump(skip_nodes, file, indent=2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Automated cutter of onnx graph into NPU and CPU chunks",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        required=True,
        help="Path to fused onnx model",
    )
    parser.add_argument(
        "-omp",
        "--orig_model_path",
        required=True,
        help="Path to original onnx model",
    )
    parser.add_argument(
        "--tiling_json",
        required=False,
        help="Path to onnx model",
    )
    parser.add_argument(
        "--IR_json_file",
        required=False,
        help="Path to fused onnx layer descriptions",
    )
    parser.add_argument(
        "--out_dir",
        required=False,
        help="Directory to store cut graphs",
    )
    parser.add_argument(
        "--bin_check",
        required=False,
        type=int,
        default=1,
        help="check for generated bin files",
    )
    parser.add_argument(
        "-nl",
        "--node_list",
        help="path to file with node names that \
                should go into sub-graphs",
        default=None,
    )
    parser.add_argument(
        "-en",
        "--exclude_nodes",
        required=False,
        help="Path to JSON file with nodes to exclude",
        default=None,
    )
    args = parser.parse_args()
    main(args)
