"""Partitions the onnx model into subgraphs and optimizes them."""
##
#  Copyright (C) 2023 – 2024 Advanced Micro Devices, Inc.
##
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
##
#  http://www.apache.org/licenses/LICENSE-2.0
##
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
##
import os
import argparse
import json
import re
import copy
from collections import defaultdict
from typing import Union, Any
import onnx
import graph_partitioner.onnx_graph as ogm
from graph_partitioner import partitioner
from utils.utils_common import log, is_subgraph_debug

RUNTIME_NODES = ["concat_runtime",
                 "Split_runtime",
                 "gather_runtime",
                 "slice_runtime"]
NPU_NODES_NO_BIN = ["Transpose"]


def replace_slashes(line):
    """
    Replaces slashes in the input string with underscores.
    """
    out = re.sub(r'[^\w\s-]', '_', line)
    return out


def create_idx_node_map(onnx_model):
    """
    Creates a mapping from node index to node name.
    """
    return {i: node.name for i, node in enumerate(onnx_model.graph.node)}


def create_node_idx_map(onnx_model):
    """
    Creates a mapping from node name to node index.
    """
    return {node.name: i for i, node in enumerate(onnx_model.graph.node)}


def get_initializers_list(onnx_model):
    """
    Returns a list of initializer names from the ONNX model.
    """
    return [init.name for init in onnx_model.graph.initializer]


def reverse_map(d):
    """
    Reverses the mapping of a dictionary.
    """
    return {val: key for key, val in d.items()}


def create_adj_list(onnx_model, node_idx_map):
    """
    Creates an adjacency list from the ONNX model.
    """
    ogm_graph = ogm.ONNXGraph(onnx_model)

    input_tensors = ogm_graph.getPrimaryInputs()
    stack = []
    visited_nodes = set()
    adj_list = {}
    for tensor in input_tensors:
        child_ops = ogm_graph.getSuccOp(tensor.name)
        for op in child_ops:
            stack.append(op)
            visited_nodes.add(op)

    while stack:
        op = stack.pop()
        idx = node_idx_map[op]
        next_ops = ogm_graph.getSuccOp(op)
        next_op_ids = [node_idx_map[next_op] for next_op in next_ops]
        adj_list[idx] = next_op_ids

        for next_op in next_ops:
            if next_op not in visited_nodes:
                stack.append(next_op)
                visited_nodes.add(next_op)

    return adj_list


def create_adj_list2(onnx_model, node_idx_map):
    """
    Creates an adjacency list from the ONNX model.
    """
    ogm_graph = ogm.ONNXGraph(onnx_model)
    adj_list = {}
    for node in onnx_model.graph.node:
        node_id = node_idx_map[node.name]
        next_nodes = ogm_graph.getSuccOp(node.name)
        next_node_ids = [node_idx_map[name] for name, op in next_nodes.items()]
        adj_list[node_id] = next_node_ids

    return adj_list


def get_undirected_adj_list(adj_list, idx_node_map):
    """
    Creates an undirected adjacency list from the directed adjacency list.
    """
    new_adj_list = {}
    for idx in idx_node_map:
        new_adj_list[idx] = []
    for node in adj_list:
        for next_node in adj_list[node]:
            new_adj_list[node].append(next_node)
            new_adj_list[next_node].append(node)
    return new_adj_list


# pylint: disable=R1724
def label_node_property(onnx_model, node_idx_map, supported_ops, excluded_op_names):
    """
    Labels each node in the ONNX model AIE or CPU.
    """
    res = {}

    for node in onnx_model.graph.node:
        idx = node_idx_map[node.name]
        # Exclude if already marked excluded.
        if node.name in excluded_op_names:
            res[idx] = "CPU"
            continue
        else:  # and node.domain == "com.amd":
            res[idx] = "AIE"
            continue
        if node.op_type in RUNTIME_NODES:
            res[idx] = "AIE"
        elif "Concat" in node.op_type and "Concat_qdq_uint16" in supported_ops:
            res[idx] = "AIE"
        else:
            res[idx] = "CPU"

    return res
# pylint: enable=R1724


def label_nodes_by_list(
    onnx_model, node_idx_map, skip_nodes
):
    """
    Labels each node in the ONNX model based on a list of skipped nodes.
    """
    res = {}

    for node in onnx_model.graph.node:
        idx = node_idx_map[node.name]
        # Exclude if already marked excluded.
        if (
            node.name in skip_nodes
            and skip_nodes[node.name] == node.op_type
        ):
            res[idx] = "CPU"
        else:
            res[idx] = "AIE"

    return res


def update_skip_nodes(onnx_model, skip_nodes, node_list):
    """
    Updates the skip nodes based on a list of node names.
    """
    assert node_list
    supported_ops = []
    with open(node_list, 'r', encoding="utf-8") as f:
        for line in f:
            supported_ops.append(line.strip())

    for node in onnx_model.graph.node:
        if node.name not in supported_ops:
            skip_nodes[node.name] = node.op_type


def append_to_skip_nodes(onnx_model, skip_nodes, node_list):
    """
    Appends nodes to the skip list based on a list of node names.
    """
    for node in onnx_model.graph.node:
        if node.name in node_list:
            skip_nodes[node.name] = node.op_type
    return skip_nodes


def partition_onnx_model(onnx_model, idx_node_map,
                         supported_ops, skip_nodes,
                         node_list):  # filename with nodenames to be cut
    """
    Partitions the ONNX model based on skip nodes and supported optypes.
    """
    excluded_op_names = {
        "1024_DequantizeLinear": "Add",  # Final Add in mxgan
        "/Gather_output_0_DequantizeLinear": "MatMul",  # Final MatMul in mxpzi, shape 1x768
        "input_1_QuantizeLinear": "QuantOP",
        "input_2_QuantizeLinear": "QuantOP",
        "output_1_DequantizeLinear": "DeQuantOP",
        # "/up_blocks.0/Concat_1_Concat2_qdq_uint16": "Concat2_qdq_uint16", #PSD5
    }
    node_idx_map = reverse_map(idx_node_map)
    if node_list is not None:
        update_skip_nodes(onnx_model, skip_nodes, node_list)
        node_labels = label_nodes_by_list(onnx_model,
                                          node_idx_map,
                                          skip_nodes)
    else:
        excluded_op_names.update(skip_nodes)
        node_labels = label_node_property(onnx_model,
                                          node_idx_map,
                                          supported_ops,
                                          excluded_op_names)
    adj_list = create_adj_list2(onnx_model, node_idx_map)
    subgraphs = partitioner.partition_graph(adj_list, node_labels)
    cluster = partitioner.subgraph_labels_to_clusters(subgraphs)
    return subgraphs, cluster, node_labels


def get_cpu_cluster_nodes(onnx_model, idx_node_map, cpu_clusters):
    """
    Gets the CPU subgraphs from the ONNX model.
    """
    io_dict = {}
    for label, node_ids in cpu_clusters.items():
        dct = []
        for node_id in node_ids:
            nodename = idx_node_map[node_id]
            for node in onnx_model.graph.node:
                if node.name == nodename:
                    dct.append(node.name.rsplit("/", 1)[-1])
        io_dict[label] = dct

    return io_dict


def get_graph_channels(model):
    """
    Gets the channel names from the ONNX model.
    """
    init_list = get_initializers_list(model)
    channel_list = set()
    for node in model.graph.node:
        for inp in node.input:
            if inp not in init_list:
                channel_list.add(inp)
        for out in node.output:
            if out not in init_list:
                channel_list.add(out)
    return list(channel_list)


def get_graph_in_channels(model):
    """
    Gets the channel names from the ONNX model.
    """
    init_list = get_initializers_list(model)
    channel_list = []
    for node in model.graph.node:
        for inp in node.input:
            if inp not in init_list:
                channel_list.append(inp)
    for node in model.graph.node:
        for op in node.output:
            if op not in init_list and op not in channel_list:
                channel_list.append(op)
    return list(channel_list)


def get_input_idx(model, node_idx_map):
    """
    Gets the tensor to tensor input node mapping.
    """
    dct = {}
    for node in model.graph.node:
        dct[node.output[0]] = node_idx_map[node.name]
    return dct


def remove_redundant_noop_nodes(onnx_model,
                                filtered_inputs,
                                filtered_outputs,
                                orig_channels_list,
                                init_list):
    """
    Removes redundant noop nodes from the ONNX model.
    """
    cut_nodes = {}
    all_ctr = []  # chennels to remove from sugbraph
    for out in filtered_outputs:
        if out in orig_channels_list:
            continue
        tmp_channels = [out]
        ctr = []
        for tmp_channel in tmp_channels:
            if tmp_channel in filtered_inputs:
                filtered_inputs.remove(tmp_channel)
                ctr.append(tmp_channel)
                continue
            if tmp_channel in orig_channels_list:
                continue
            tensor_map = get_node_from_out_channel(onnx_model)
            node = tensor_map.get(tmp_channel, None)
            if "noop" not in node.op_type and \
                    node.op_type not in NPU_NODES_NO_BIN and\
                    node.op_type not in RUNTIME_NODES:
                continue
            cut_nodes[node.name] = node.op_type
            for i in node.input:
                if i not in init_list:
                    tmp_channels.append(i)   # pylint: disable=W4701
            ctr.append(tmp_channel)
        tmp_channels = list(set(tmp_channels))
        ctr = list(set(ctr))
        tmp_channels = [i for i in tmp_channels if i not in ctr]
        all_in_graph = True
        for i in tmp_channels:
            if i not in orig_channels_list:
                all_in_graph = False
                break
        if not all_in_graph:
            continue
        all_ctr.extend(ctr)
        for i in tmp_channels:
            if i not in filtered_outputs:
                filtered_outputs.append(i)
    aftercut_outputs = [i for i in filtered_outputs if i not in all_ctr]

    all_ctr = []  # chennels to remove from sugbraph
    for inp in filtered_inputs:
        if inp in orig_channels_list:
            continue
        tmp_channels = [inp]
        ctr = []
        for tmp_channel in tmp_channels:
            if tmp_channel in filtered_outputs:
                filtered_outputs.remove(tmp_channel)
                ctr.append(tmp_channel)
                continue
            if tmp_channel in orig_channels_list:
                continue
            tensor_map = get_nodes_from_in_channel(onnx_model)
            nodes = tensor_map.get(tmp_channel, [])
            rm = True
            for node in nodes:
                if "noop" not in node.op_type and \
                        node.op_type not in NPU_NODES_NO_BIN and\
                        node.op_type not in RUNTIME_NODES:
                    rm = False
                    continue
                cut_nodes[node.name] = node.op_type
                tmp_channels.extend(node.output)
            if not rm:
                continue
            ctr.append(tmp_channel)
        tmp_channels = list(set(tmp_channels))
        ctr = list(set(ctr))
        tmp_channels = [i for i in tmp_channels if i not in ctr]
        all_in_graph = True
        for i in tmp_channels:
            if i not in orig_channels_list:
                all_in_graph = False
                break
        if not all_in_graph:
            continue
        all_ctr.extend(ctr)
        for i in tmp_channels:
            if i not in filtered_inputs:
                filtered_inputs.append(i)
    aftercut_inputs = [i for i in filtered_inputs if i not in all_ctr]

    return aftercut_inputs, aftercut_outputs, cut_nodes


def make_list_unique(lst):
    """
    Makes a list unique by removing duplicates.
    """
    seen = set()
    result = []
    for item in lst:
        if item not in seen:
            seen.add(item)
            result.append(item)
    return result


def check_bot(onnx_model,
              orig_channels_list,
              adj_list,
              idx_node_map,
              sub_subgraph_1,
              sub_subgraph_2,
              top_node_idx,
              bot_node_idx,
              channel_to_check):
    """
    Checks the bottom nodes of the subgraph.
    """
    sg_1 = copy.deepcopy(sub_subgraph_1)
    sg_2 = copy.deepcopy(sub_subgraph_2)
    while channel_to_check not in orig_channels_list:
        # going_down
        if len(adj_list[bot_node_idx]) > 1 or bot_node_idx not in sg_2:
            break
        up_count = 0
        for _, val in adj_list.items():
            if bot_node_idx in val:
                up_count += 1
        if up_count > 1:
            break
        sg_2.remove(bot_node_idx)
        sg_1.append(bot_node_idx)
        top_node_idx = bot_node_idx
        bot_node_idx = adj_list[bot_node_idx][0]
        top_node_nm = idx_node_map[top_node_idx]
        bot_node_nm = idx_node_map[bot_node_idx]
        top_node = None
        bot_node = None
        for node in onnx_model.graph.node:
            if node.name == top_node_nm:
                top_node = node
            if node.name == bot_node_nm:
                bot_node = node
        assert top_node.output[0] in bot_node.input
        channel_to_check = top_node.output[0]
    return channel_to_check, sg_1, sg_2


def check_top(onnx_model,
              orig_channels_list,
              adj_list,
              idx_node_map,
              sub_subgraph_1,
              sub_subgraph_2,
              top_node_idx,
              bot_node_idx,
              channel_to_check):
    """
    Checks the top nodes of the subgraph.
    """
    sg_1 = copy.deepcopy(sub_subgraph_1)
    sg_2 = copy.deepcopy(sub_subgraph_2)
    while channel_to_check not in orig_channels_list:
        # going_up
        up_count = 0
        up_node = -1
        for i, val in adj_list.items():
            if top_node_idx in val:
                up_node = i
                up_count += 1
        if up_count > 1 or top_node_idx not in sg_1:
            break

        sg_1.remove(top_node_idx)
        sg_2.append(top_node_idx)
        bot_node_idx = top_node_idx
        top_node_idx = up_node
        top_node_nm = idx_node_map[top_node_idx]
        bot_node_nm = idx_node_map[bot_node_idx]
        top_node = None
        bot_node = None
        for node in onnx_model.graph.node:
            if node.name == top_node_nm:
                top_node = node
            if node.name == bot_node_nm:
                bot_node = node
        assert top_node.output[0] in bot_node.input
        channel_to_check = top_node.output[0]
    return channel_to_check, sg_1, sg_2


def refine_subgraphs(onnx_model,
                     orig_channels_list,
                     adj_list,
                     idx_node_map,
                     sub_subgraph_1,
                     sub_subgraph_2):
    """
    Refines the subgraphs by checking the input and output tensors.
    """
    top_node_name = ""
    bot_node_name = ""
    top_node_ix = -1
    bot_node_ix = -1
    top_node = None
    bot_node = None
    for node_1 in sub_subgraph_1:
        for node_2 in sub_subgraph_2:
            if node_2 in adj_list[node_1]:
                top_node_name = idx_node_map[node_1]
                bot_node_name = idx_node_map[node_2]
                top_node_ix = node_1
                bot_node_ix = node_2
    if top_node_name == "" and bot_node_name == "":
        return sub_subgraph_1, sub_subgraph_2
    for node in onnx_model.graph.node:
        if node.name == top_node_name:
            top_node = node
        if node.name == bot_node_name:
            bot_node = node
    assert top_node.output[0] in bot_node.input
    channel_to_check = top_node.output[0]
    if channel_to_check in orig_channels_list:
        return sub_subgraph_1, sub_subgraph_2
    channel_to_check, sg_1, sg_2 = check_bot(onnx_model,
                                             orig_channels_list,
                                             adj_list,
                                             idx_node_map,
                                             sub_subgraph_1,
                                             sub_subgraph_2,
                                             top_node_ix,
                                             bot_node_ix,
                                             channel_to_check)
    if channel_to_check in orig_channels_list \
            and len(sg_1) != 0 and len(sg_2) != 0:
        return sg_1, sg_2
    channel_to_check, sg_1, sg_2 = check_top(onnx_model,
                                             orig_channels_list,
                                             adj_list,
                                             idx_node_map,
                                             sub_subgraph_1,
                                             sub_subgraph_2,
                                             top_node_ix,
                                             bot_node_ix,
                                             channel_to_check)
    if channel_to_check in orig_channels_list\
            and len(sg_1) != 0 and len(sg_2) != 0:
        return sg_1, sg_2
    return sub_subgraph_1, sub_subgraph_2


def handle_quant_inputs(filtered_inputs, orig_channels_list, fused_tensor_to_innode_map, orig_tensor_to_outnodes_map):
    """Handle subgraph input"""
    new_inputs = []
    for i in filtered_inputs:
        if i not in orig_channels_list:
            node = fused_tensor_to_innode_map.get(i, None)
            if not node.op_type == "Quant_float32xuint16":
                continue
            q_input = node.input[0]
            if q_input not in orig_channels_list:
                continue
            orig_nodes = orig_tensor_to_outnodes_map.get(q_input, [])
            if len(orig_nodes) != 1:
                continue
            if not orig_nodes[0].op_type == "QuantizeLinear":
                continue
            i = orig_nodes[0].output[0]
        new_inputs.append(i)
    return new_inputs


def handle_dequant_outputs(filtered_outputs, orig_channels_list, fused_tensor_to_outnodes_map, orig_tensor_to_innode_map):
    """Handle subgraph output"""
    new_inputs = []
    for i in filtered_outputs:
        if i not in orig_channels_list:
            nodes = fused_tensor_to_outnodes_map.get(i, [])
            if len(nodes) != 1:
                continue
            if not nodes[0].op_type == "Dequant_uint16xfloat32":
                continue
            q_output = nodes[0].output[0]
            if q_output not in orig_channels_list:
                continue
            orig_node = orig_tensor_to_innode_map.get(q_output, None)
            if not orig_node.op_type == "DequantizeLinear":
                continue
            i = orig_node.input[0]
        new_inputs.append(i)
    return new_inputs


# pylint: disable=W0640
def get_cluster_inputs_outputs(orig_model, onnx_model,
                               idx_node_map, aie_clusters, tensor_map_dict):
    """Get input and output tensors of each NPU cluster"""
    orig_channels_list = get_graph_channels(orig_model)
    fused_channels_list = get_graph_in_channels(onnx_model)
    node_idx_map = reverse_map(idx_node_map)
    tensor_input_idx_map = get_input_idx(onnx_model, node_idx_map)
    adj_list = create_adj_list2(onnx_model, node_idx_map)
    io_dict = {}
    removed_nodes = {}
    graph_inputs = []
    graph_outputs = []
    fused_tensor_to_innode_map = get_node_from_out_channel(onnx_model)
    fused_tensor_to_outnodes_map = get_nodes_from_in_channel(onnx_model)
    for graph_input in onnx_model.graph.input:
        graph_inputs.append(graph_input.name)
    for output in onnx_model.graph.output:
        graph_outputs.append(output.name)
    for label, node_ids in aie_clusters.items():
        def detect_and_remove_loops(label, node_ids):
            dct = {}
            init_list = get_initializers_list(onnx_model)
            cut_nodes = {}

            def get_filtered_inputs_outputs(node_ids, cut_nodes):
                inputs = []
                outputs = []
                filtered_inputs = []
                filtered_outputs = []
                for node_id in node_ids:
                    nodename = idx_node_map[node_id]
                    for node in onnx_model.graph.node:
                        if node.name == nodename:
                            inputs.extend(node.input)
                            outputs.extend(node.output)
                for i in inputs:
                    if i not in outputs and i not in init_list and i != '':
                        filtered_inputs.append(i)
                for o in outputs:
                    if o == '':
                        continue
                    subgraph_out_count = 0
                    for i in inputs:
                        if o == i:
                            subgraph_out_count = subgraph_out_count + 1
                    full_fused_graph_out_count = 0
                    for i in fused_channels_list:
                        if o == i:
                            full_fused_graph_out_count = full_fused_graph_out_count + 1
                    if subgraph_out_count < full_fused_graph_out_count:
                        filtered_outputs.append(o)

                filtered_inputs = make_list_unique(filtered_inputs)
                filtered_outputs = make_list_unique(filtered_outputs)
                flag = True
                for out_ in filtered_outputs:
                    out_node = fused_tensor_to_innode_map.get(out_, None)  # getting the output node associated with a output
                    if ("noop" in out_node.op_type or "runtime" in out_node.op_type or out_ not in tensor_map_dict):
                        if node_idx_map[out_node.name] in node_ids:
                            flag = False
                            node_ids.remove(node_idx_map[out_node.name])
                            cut_nodes[out_node.name] = out_node.op_type
                for in_ in filtered_inputs:
                    in_nodes = fused_tensor_to_outnodes_map.get(in_, [])  # getting the input nodes associated with a input
                    for in_node in in_nodes:
                        if ("noop" in in_node.op_type or "runtime" in in_node.op_type or in_ not in tensor_map_dict):
                            if node_idx_map[in_node.name] in node_ids:
                                flag = False
                                node_ids.remove(node_idx_map[in_node.name])
                                cut_nodes[in_node.name] = in_node.op_type
                if flag:
                    return (filtered_inputs, filtered_outputs)
                return get_filtered_inputs_outputs(node_ids, cut_nodes)

            aftercut_inputs, aftercut_outputs = get_filtered_inputs_outputs(node_ids, cut_nodes)
            if not node_ids:  # empty npu subgraph
                removed_nodes.update(cut_nodes)
                return
            removed_nodes.update(cut_nodes)
            aftercut_inputs = make_list_unique(aftercut_inputs)
            aftercut_outputs = make_list_unique(aftercut_outputs)

            # detect loop in subgraph
            non_model_inputs = [input for input in aftercut_inputs if input not in graph_inputs]
            non_model_outputs = [output for output in aftercut_outputs if output not in graph_outputs]
            # check if there exists a path from any of the non_model_outputs to any of non_model_inputs in onnx_model
            target_nodes = [tensor_input_idx_map[subgraph_input]
                            for subgraph_input in non_model_inputs
                            if subgraph_input in tensor_input_idx_map]
            loop_present = False
            for subgraph_output in non_model_outputs:
                start_node = tensor_input_idx_map[subgraph_output]
                reachable_target_nodes = []
                visited = set()
                queue = []

                def bfs(init_node):
                    queue.append(init_node)
                    while len(queue) > 0:
                        node = queue.pop(0)
                        if node in visited:
                            continue
                        visited.add(node)
                        if node in target_nodes:
                            reachable_target_nodes.append(node)
                        for next_node in adj_list.get(node, []):
                            queue.append(next_node)
                bfs(start_node)
                if len(reachable_target_nodes) != 0:
                    loop_present = True
                    log(f"Subgraph_{label} has a loop")
                    node_1 = start_node
                    node_2 = reachable_target_nodes[0]
                    break
            if loop_present:
                sub_subgraph_1 = []
                sub_subgraph_2 = []
                # code to cut the subgraph
                # getting all the nodes dependent on node_2
                visited = set()
                queue = []

                def bfs_2(init_node):
                    queue.append(init_node)
                    while len(queue) > 0:
                        node = queue.pop(0)
                        if node in visited:
                            continue
                        visited.add(node)
                        if node in node_ids:
                            sub_subgraph_2.append(node)
                        for next_node in adj_list.get(node, []):
                            queue.append(next_node)
                bfs_2(node_2)
                # build undirected version of the graph
                undirected_adj_list = get_undirected_adj_list(adj_list, idx_node_map)
                # remove nodes dependent on node_2 from the undirected graph
                for node in sub_subgraph_2:
                    undirected_adj_list.pop(node, None)
                for next_nodes in undirected_adj_list.values():
                    for node in sub_subgraph_2:
                        if node in next_nodes:
                            next_nodes.remove(node)
                # getting nodes connected to node_1 in the updated undirected graph
                visited = set()
                queue = []

                def bfs_1(init_node):
                    queue.append(init_node)
                    while len(queue) > 0:
                        node = queue.pop(0)
                        if node in visited:
                            continue
                        visited.add(node)
                        if node in node_ids:
                            sub_subgraph_1.append(node)
                        for next_node in undirected_adj_list.get(node, []):
                            queue.append(next_node)
                bfs_1(node_1)
                for node in node_ids:
                    if node not in sub_subgraph_1 and node not in sub_subgraph_2:
                        sub_subgraph_2.append(node)
                sub_subgraph_1, sub_subgraph_2 = refine_subgraphs(onnx_model, orig_channels_list, adj_list, idx_node_map, sub_subgraph_1, sub_subgraph_2)
                sub_subgraph_1 = sorted(sub_subgraph_1)
                sub_subgraph_2 = sorted(sub_subgraph_2)

                if "_" in label:
                    detect_and_remove_loops(label+'1', sub_subgraph_1)
                    detect_and_remove_loops(label+'2', sub_subgraph_2)
                else:
                    detect_and_remove_loops(label+'_1', sub_subgraph_1)
                    detect_and_remove_loops(label+'_2', sub_subgraph_2)
            else:
                # removed_nodes.update(cut_nodes)
                if aftercut_inputs and aftercut_outputs:
                    dct['inputs'] = aftercut_inputs
                    dct['outputs'] = aftercut_outputs
                    dct['ids'] = node_ids
                    io_dict[label] = dct
        detect_and_remove_loops(label, node_ids)
    cluster_map = defaultdict(list)
    final_io_dict = {}
    for label, dct in io_dict.items():
        cluster_map[label.split('_')[0]].append(label)
    for label, dct in io_dict.items():
        new_label = label.split('_')[0]+'_'+str(cluster_map[label.split('_')[0]].index(label))
        final_io_dict[new_label] = dct

    return final_io_dict, removed_nodes
# pylint: disable=W0640


def cut_subgraph(onnx_model_path, io_dict, type_name, fld, save_subgraphs):
    """
    Cut the subgraph from the original ONNX model.
    """
    _ = type_name
    folder_name = os.path.join(fld, "cut_graphs")
    os.makedirs(folder_name, exist_ok=True)
    io_node_dict = {}
    failed_nodes = []

    if is_subgraph_debug():
        onnx_model = onnx.load(onnx_model_path, load_external_data=False)

    for label, io in io_dict.items():
        input_names = io['inputs']
        output_names = io['outputs']
        subgraph_path = onnx_model_path[:-5] + "_cluster_" + label + ".onnx"
        subgraph_path = os.path.basename(subgraph_path)
        full_subgraph_path = os.path.join(folder_name, subgraph_path)

        if is_subgraph_debug():
            extractor = onnx.utils.Extractor(onnx_model)
            extracted = extractor.extract_model(input_names, output_names)
            onnx.save(extracted, full_subgraph_path.replace(".onnx", "_debug.onnx"))

        if save_subgraphs:
            try:
                onnx.utils.extract_model(onnx_model_path,
                                         full_subgraph_path,
                                         input_names,
                                         output_names,
                                         check_model=False)
            except onnx.checker.ValidationError as e:
                log(f"""Cut subgraph '{full_subgraph_path}' failed
                      the model check""")
                print(f"Error: {e}")
                failed_nodes.extend(io['ids'])
            except Exception as e:     # pylint: disable=W0718
                print(f"""Unexpected error occured in cutting of subgraph
                      {full_subgraph_path}""")
                print(f"Error: {e}")
                failed_nodes.extend(io['ids'])
        io_node_dict[subgraph_path] = {}
        io_node_dict[subgraph_path]['inputs'] = io['inputs']
        io_node_dict[subgraph_path]['outputs'] = io['outputs']

    return io_node_dict, failed_nodes


def get_fuse_op_list(tilings_json: Union[str, dict[str, Any]]):
    """
    Get the list of supported optypes and skip nodes from the tilings JSON file.
    """
    # Decide whether we received a path or a dict
    if isinstance(tilings_json, str):
        with open(tilings_json, "r", encoding="utf-8") as f1:
            tiling_dict: dict[str, Any] = json.load(f1)
    elif isinstance(tilings_json, dict):
        tiling_dict = tilings_json
    else:
        raise TypeError(
            f"json_path must be str or dict[str, Any], not {type(tilings_json)}"
        )
    op_list = []
    skip_nodes = {}
    node_layerid_map = {}
    for key, val in tiling_dict.items():
        node_layerid_map[val['name']] = int(key)
        if val is None:
            continue
        if val["is_compilable"]:
            if val['op'] not in op_list:
                op_list.append(val['op'])  # supported_ops are custom op types
        else:
            skip_nodes[val['name']] = val["op"]
    return op_list, skip_nodes, node_layerid_map


def get_node_from_out_channel(onnx_model):
    """
    Get the node from the original model by its output channel.
    """
    tensor_to_innode_map = {}
    ini_list = get_initializers_list(onnx_model)
    for node in onnx_model.graph.node:
        for out_ch in node.output:
            if out_ch not in ini_list:
                tensor_to_innode_map[out_ch] = node
    return tensor_to_innode_map


def get_nodes_from_in_channel(onnx_model):
    """
    Get the nodes from the original model by its input channel.
    """
    tensor_to_outnodes_map = defaultdict(list)
    ini_list = get_initializers_list(onnx_model)
    for node in onnx_model.graph.node:
        for in_ch in node.input:
            if in_ch not in ini_list:
                tensor_to_outnodes_map[in_ch].append(node)
    return tensor_to_outnodes_map


def get_node_by_name(model: onnx.ModelProto, node_name: str):
    """
    Get the node from the original model by its name.
    """
    for node in model.graph.node:
        if node.name == node_name:
            return node
    return None


def main(onnx_model_path,
         orig_model_path,
         out_fld,
         tilings_json: Union[str, dict[str, Any]],
         tensor_map_json: str = None,
         save_subgraphs=False,
         node_list: str = None):
    '''Callable API'''
    op_list, skip_nodes, node_layerid_map = get_fuse_op_list(tilings_json)

    onnx_model = onnx.load(onnx_model_path, load_external_data=False)
    orig_model = onnx.load(orig_model_path, load_external_data=False)

    if tensor_map_json:
        with open(tensor_map_json, 'r', encoding="utf-8") as f:
            full_tensor_map = json.load(f)
        tensor_map_dict = {}
        for tensor in full_tensor_map:
            tensor_map_dict[tensor] = full_tensor_map[tensor]["orig_tensor"]
    else:
        tensor_map_dict = {}
        print("Did not find the tensor map json file")

    orig_channels_list = get_graph_channels(orig_model)
    for orig_tensor in orig_channels_list:
        if orig_tensor not in tensor_map_dict:
            tensor_map_dict[orig_tensor] = orig_tensor

    idx_node_map = create_idx_node_map(onnx_model)
    _, subgraph_node_cluster, target_label = partition_onnx_model(onnx_model, idx_node_map, op_list, skip_nodes, node_list)
    log("cluster:", subgraph_node_cluster)

    # ff = 0
    aie_clusters = {}
    cpu_clusters = {}
    for subgraph_label, subgraph_nodes in subgraph_node_cluster.items():
        if subgraph_nodes and target_label[subgraph_nodes[0]] == "CPU":
            cpu_clusters[subgraph_label] = subgraph_nodes
            continue
        aie_clusters[str(subgraph_label)] = subgraph_nodes
        # ff = ff + 1
        # for node_id in subgraph_nodes:
        #     log(f"{subgraph_label} : {idx_node_map[node_id]} : {target_label[subgraph_nodes[0]]}")

    aie_io_dict, removed_nodes = get_cluster_inputs_outputs(orig_model,
                                                            onnx_model,
                                                            idx_node_map,
                                                            aie_clusters,
                                                            tensor_map_dict)
    log(removed_nodes)
    skip_nodes.update(removed_nodes)
    io_node_dict, failed_ids = cut_subgraph(onnx_model_path, aie_io_dict, "aie", out_fld, save_subgraphs)

    folder_name = os.path.join(out_fld, "cut_graphs")
    remove_subgraphs = []
    fused_tensor_to_innode_map = get_node_from_out_channel(onnx_model)
    fused_tensor_to_outnodes_map = get_nodes_from_in_channel(onnx_model)
    new_io_node_dict = {}
    new_io = {}
    for subgraph_path, io in io_node_dict.items():
        label = subgraph_path[:-5].split('_')[-2]+'_'+subgraph_path[:-5].split('_')[-1]
        subgraph_nodes = [idx_node_map[idx] for idx in aie_io_dict[label]['ids']]
        input_nodes = []
        output_nodes = []
        for subgraph_input in io['inputs']:
            input_candidates = fused_tensor_to_outnodes_map[subgraph_input]
            input_nodes.append([input_candidate.name for input_candidate in input_candidates if input_candidate.name in subgraph_nodes][0])
        for subgraph_output in io['outputs']:
            output_candidates = [fused_tensor_to_innode_map[subgraph_output]]
            output_nodes.append([output_candidate.name for output_candidate in output_candidates if output_candidate.name in subgraph_nodes][0])
        new_io['inputs'] = io['inputs']
        new_io['outputs'] = io['outputs']
        io['inputs'] = [tensor_map_dict[input_] for input_ in io['inputs'] if input_ in tensor_map_dict]
        io['outputs'] = [tensor_map_dict[output_] for output_ in io['outputs'] if output_ in tensor_map_dict]
        if (not io['inputs']) or (not io['outputs']) or (not set(io['inputs']).isdisjoint(set(io['outputs']))):
            # if (len(io['inputs']) != len(new_io['inputs'])) or (len(io['outputs']) != len(new_io['outputs'])):
            if os.path.exists(os.path.join(folder_name, subgraph_path)):
                os.remove(os.path.join(folder_name, subgraph_path))
            remove_subgraphs.append(subgraph_path)
            skip_nodes = append_to_skip_nodes(onnx_model, skip_nodes, subgraph_nodes)
        else:
            new_io_node_dict[subgraph_path] = {}
            new_io_node_dict[subgraph_path]['inputs'] = {}
            new_io_node_dict[subgraph_path]['outputs'] = {}
            for i in range(len(new_io['inputs'])):
                new_io_node_dict[subgraph_path]['inputs'][new_io['inputs'][i]] = input_nodes[i]
            for i in range(len(new_io['outputs'])):
                new_io_node_dict[subgraph_path]['outputs'][new_io['outputs'][i]] = output_nodes[i]

    for subgraph in remove_subgraphs:
        io_node_dict.pop(subgraph)
    with open(os.path.join(folder_name, "context.json"), 'w', encoding="utf-8") as f:
        json.dump(new_io_node_dict, f, indent=2)
    with open(os.path.join(folder_name, "context_info.json"), 'w', encoding="utf-8") as f:
        json.dump(io_node_dict, f, indent=2)
    print("Total #subgraphs (CPU+NPU) : ", len(subgraph_node_cluster))
    print("NPU #subgraphs : ", len(aie_clusters))
    print("CPU #subgraphs : ", len(cpu_clusters))
    print("Total #subgraphs after loop removal (CPU+NPU) : ", len(cpu_clusters) + len(io_node_dict))
    print("NPU #subgraphs after loop removal : ", len(io_node_dict))
    print("CPU #subgraphs after loop removal : ", len(cpu_clusters))
    print("#subgraphs removed after loop removal because they are noop, runtime op and doesn't have input/output tensors : ", len(remove_subgraphs))

    failed_nodes = [idx_node_map[ix] for ix in failed_ids]
    skip_nodes = append_to_skip_nodes(onnx_model,
                                      skip_nodes, failed_nodes)
    with open(os.path.join(out_fld, 'skipped_nodes.json'), 'w', encoding="utf-8") as file:
        json.dump(skip_nodes, file, indent=2)

    partitioner_output = []
    subgraph_id_to_name_map = {}
    count = 0
    subgraph_ops, subgraph_ios = {}, {}
    for subgraph_path in io_node_dict:
        label = subgraph_path[:-5].split('_')[-2]+'_'+subgraph_path[:-5].split('_')[-1]
        subgraph_layerids = [node_layerid_map[idx_node_map[idx]] for idx in aie_io_dict[label]['ids']]
        partitioner_output.append(subgraph_layerids)
        subgraph_id_to_name_map[count] = subgraph_path
        subgraph_ops[count] = [idx_node_map[idx] for idx in aie_io_dict[label]['ids']]
        subgraph_ios[count] = (aie_io_dict[label]['inputs'], aie_io_dict[label]['outputs'])
        count = count+1

    return partitioner_output, subgraph_id_to_name_map, subgraph_ops, subgraph_ios


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Automated cutter of onnx graph into NPU and CPU chunks",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        required=True,
        help="Path to fused onnx model",
    )
    parser.add_argument(
        "-omp",
        "--orig_model_path",
        required=True,
        help="Path to original onnx model",
    )
    parser.add_argument(
        "--tiling_json",
        required=False,
        help="Path to onnx model",
    )
    parser.add_argument(
        "--IR_json_file",
        required=False,
        help="Path to fused onnx layer descriptions",
    )
    parser.add_argument(
        "--out_dir",
        required=False,
        help="Directory to store cut graphs",
    )
    parser.add_argument(
        "--bin_check",
        required=False,
        type=int,
        default=1,
        help="check for generated bin files",
    )
    parser.add_argument(
        "-nl",
        "--node_list",
        help="path to file with node names that \
                should go into sub-graphs",
        default=None,
    )
    parser.add_argument(
        "-en",
        "--exclude_nodes",
        required=False,
        help="Path to JSON file with nodes to exclude",
        default=None,
    )
    args = parser.parse_args()
    main(args.model_path, args.orig_model_path, args.out_dir, args.IR_json_file)
