import os
import glob
import re
import yaml
import json
import numpy
import argparse
import logging
import onnx
from collections import defaultdict
from onnx.helper import make_attribute
from onnx import AttributeProto
from typing import Callable, Iterable

from kernel_func_list import kernel_func_list

DATAFLOW_OPS = [
    "Concat",
    "Transpose",
    "Reshape",
    "Slice",
    #"Gather",
    #"GatherElements",
    "DepthToSpace",
    "Split",
]

def get_kernel_dict():
    kernel_dict = {}
    parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    yaml_files = glob.glob(os.path.join(parent_dir,'Collaterals/*_kernel_metadata.yaml'))

    for file in yaml_files:
        with open(file) as f:
            kd = yaml.safe_load(f)
            kernel_dict.update(kd)
    return kernel_dict

def get_kernel_size(yaml_file_name):
    kernel_size = {}
    parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    yaml_file = os.path.join(parent_dir,'Collaterals',yaml_file_name)
    with open(yaml_file) as f:
        ks = yaml.safe_load(f)
        kernel_size.update(ks)
    return kernel_size

def calc_kernel_size(kernel, kernel_size_tree, current_subkernels_list):
    if kernel in current_subkernels_list:
        return 0
    
    kernel_size = kernel_size_tree['size']
    if kernel_size_tree['subkernels']:
        for sub_kernel in kernel_size_tree['subkernels']:
            kernel_size = kernel_size + calc_kernel_size(sub_kernel, kernel_size_tree['subkernels'][sub_kernel], current_subkernels_list)
    return kernel_size

def get_kernel_calls(kernel, kernel_size_tree):
    kernel_calls = [kernel]
    if kernel_size_tree['subkernels']:
        for sub_kernel in kernel_size_tree['subkernels']:
            kernel_calls.extend(get_kernel_calls(sub_kernel, kernel_size_tree['subkernels'][sub_kernel]))
    return kernel_calls

def dataflow_op_kernel_names(tilings_json_path):
    node_to_kernels = {}

    with open(tilings_json_path, "r") as f:
        tilings_data = json.load(f)

    # Iterate through all operations in tilings.json
    for op_id, op_data in tilings_data.items():
        if not isinstance(op_data, dict):
            continue

        # Read kernel_names list
        kernel_names = op_data.get("kernel_names", [])

        # Read nodenames from layer_info
        layer_info = op_data.get("layer_info", {})
        nodenames = layer_info.get("nodenames", [])

        # Add same kernel_names list for every node_name in nodenames
        if kernel_names and nodenames:
            for node_name in nodenames:
                node_to_kernels[node_name] = kernel_names

    return node_to_kernels

def get_subgraph_pm_size(subgraph_model, kernel_dict, kernel_size_dict, dataflow_ops_kernel_map, current_subgraph_kernels, current_subgraph_kernel_calls):
    kernel_list = current_subgraph_kernels.copy()
    kernel_call_list = current_subgraph_kernel_calls.copy()
    subgraph_pm_size = 0
    for node in subgraph_model.graph.node:
        node_op = node.op_type

        if node_op not in kernel_dict:
            continue
        
        if node.name in dataflow_ops_kernel_map:
            node_kernels = dataflow_ops_kernel_map[node.name]
        else:
            node_kernels = kernel_dict[node_op]['kernel_path']['kernel_list']
        
        for node_kernel in node_kernels:
            if node_kernel not in kernel_size_dict:
                raise Exception("Kernel PM size not found in the metadata file")
            if node_kernel and node_kernel not in kernel_list:
                kernel_list.append(node_kernel)
                subgraph_pm_size = subgraph_pm_size + calc_kernel_size(node_kernel, kernel_size_dict[node_kernel], kernel_call_list)
                kernel_call_list.extend(get_kernel_calls(node_kernel, kernel_size_dict[node_kernel]))

    return subgraph_pm_size

def partition_graph(model, model_path, subgraphs_folder_path, tiling_json_file, fast_pm_enable):
    current_subgraph_id = 0
    all_subgraph_kernels_list = []
    all_subgraph_kernel_includes_list = []
    current_subgraph_kernels = []
    current_subgraph_kernel_includes = []
    current_subgraph_kernel_calls = []
    kernel_dict = get_kernel_dict()
    dataflow_ops_kernel_map = dataflow_op_kernel_names(tiling_json_file) if tiling_json_file else {}

    conv_kernel_flag = False
    for node in model.graph.node:
        if node.op_type not in kernel_dict:
            continue
        node_kernels = kernel_dict[node.op_type]['kernel_path']['kernel_list']
        if 'run_conv_a16w8_qdq' in node_kernels:
            conv_kernel_flag = True
            break
    if conv_kernel_flag:
        kernel_size_dict = get_kernel_size('kernels_size.yaml')
    else:
        kernel_size_dict = get_kernel_size('kernels_size_without_conv.yaml')
    
    Unknown_pm_size = 400
    threshold = 16384 - Unknown_pm_size
    current_subgraph_size = kernel_size_dict['_waic_main_init']+kernel_size_dict['_main']+kernel_size_dict['BufferPort']+kernel_size_dict['super_kernel_loop']

    gpn_nodes_in_model = False

    subgraphs = [f for f in os.listdir(subgraphs_folder_path) if f.endswith(".onnx")]
    ordered_subgraphs = sorted(subgraphs, key=lambda f: tuple(map(int, re.findall(r'\d+', f)[-2:])))

    for subgraph_path in ordered_subgraphs:
        subgraph_model = onnx.load(os.path.join(subgraphs_folder_path, subgraph_path))
        next_subgraph_pm_size = get_subgraph_pm_size(subgraph_model, kernel_dict, kernel_size_dict, dataflow_ops_kernel_map, current_subgraph_kernels, current_subgraph_kernel_calls)
        if current_subgraph_size+next_subgraph_pm_size > threshold and current_subgraph_kernels:
            current_subgraph_id = current_subgraph_id+1
            all_subgraph_kernels_list.append(current_subgraph_kernels)
            all_subgraph_kernel_includes_list.append(current_subgraph_kernel_includes)
            current_subgraph_kernels = []
            current_subgraph_kernel_calls = []
            current_subgraph_kernel_includes = []
            current_subgraph_size = kernel_size_dict['_waic_main_init']+kernel_size_dict['_main']+kernel_size_dict['BufferPort']+kernel_size_dict['super_kernel_loop']

        for node in subgraph_model.graph.node:
            node_op = node.op_type
            if 'GroupNormalization' in node_op:
                gpn_nodes_in_model = True

            if node_op not in kernel_dict:
                if "noop" in node_op or "runtime" in node_op:
                    node.attribute.append(make_attribute("pm_id",current_subgraph_id))
                elif node_op.split("_")[0] in DATAFLOW_OPS:
                    node.attribute.append(make_attribute("pm_id",current_subgraph_id))
                continue
            
            if node.name in dataflow_ops_kernel_map:
                node_kernels = dataflow_ops_kernel_map[node.name]
            else:
                node_kernels = kernel_dict[node_op]['kernel_path']['kernel_list']
            node_kernel_includes = kernel_dict[node_op]['kernel_path']['kernel_include']
            new_kernels = []
            new_kernel_includes = []
            new_kernel_calls = []
            new_kernels_size = 0
            for node_kernel in node_kernels:
                if node_kernel not in kernel_size_dict:
                    raise Exception("Kernel PM size not found in the metadata file")
                if node_kernel and node_kernel not in current_subgraph_kernels:
                    new_kernels.append(node_kernel)
                    new_kernels_size = new_kernels_size + calc_kernel_size(node_kernel, kernel_size_dict[node_kernel], current_subgraph_kernel_calls+new_kernel_calls)
                    new_kernel_calls.extend(get_kernel_calls(node_kernel, kernel_size_dict[node_kernel]))

            for node_kernel_include in node_kernel_includes:
                if node_kernel_include and node_kernel_include not in current_subgraph_kernel_includes:
                    new_kernel_includes.append(node_kernel_include)

            if current_subgraph_size+new_kernels_size > threshold:
                current_subgraph_id = current_subgraph_id+1
                all_subgraph_kernels_list.append(current_subgraph_kernels)
                all_subgraph_kernel_includes_list.append(current_subgraph_kernel_includes)
                #print(current_subgraph_kernels)
                #print(current_subgraph_size)
                current_subgraph_kernels = []
                current_subgraph_kernel_calls = []
                current_subgraph_kernel_includes = []
                current_subgraph_size = kernel_size_dict['_waic_main_init']+kernel_size_dict['_main']+kernel_size_dict['BufferPort']+kernel_size_dict['super_kernel_loop']
                new_kernels = []
                new_kernel_includes = []
                new_kernel_calls = []
                new_kernels_size = 0
                for node_kernel in node_kernels:
                    if node_kernel and node_kernel not in current_subgraph_kernels:
                        new_kernels.append(node_kernel)
                        new_kernels_size = new_kernels_size + calc_kernel_size(node_kernel, kernel_size_dict[node_kernel], current_subgraph_kernel_calls+new_kernel_calls)
                        new_kernel_calls.extend(get_kernel_calls(node_kernel, kernel_size_dict[node_kernel]))

                for node_kernel_include in node_kernel_includes:
                    if node_kernel_include and node_kernel_include not in current_subgraph_kernel_includes:
                        new_kernel_includes.append(node_kernel_include)
                #print(new_kernels, new_kernels_size)

            current_subgraph_kernels.extend(new_kernels)
            current_subgraph_kernel_calls.extend(new_kernel_calls)
            current_subgraph_kernel_includes.extend(new_kernel_includes)
            current_subgraph_size = current_subgraph_size + new_kernels_size
            node.attribute.append(make_attribute("pm_id",current_subgraph_id))
    all_subgraph_kernels_list.append(current_subgraph_kernels)
    all_subgraph_kernel_includes_list.append(current_subgraph_kernel_includes)
    #print(current_subgraph_kernels)
    #print(current_subgraph_size)

    unique_pm_id_map = {}
    unique_kernels_set = []
    seen = set()
    for kernel_list in all_subgraph_kernels_list:
        if frozenset(kernel_list) not in seen:
            seen.add(frozenset(kernel_list))
            unique_kernels_set.append(frozenset(kernel_list))
    minimum_pm_kernels_list = []
    minimum_pm_id_list = []

    for i,kernel_set1 in enumerate(unique_kernels_set):
        is_subset = False
        for j,kernel_set2 in enumerate(unique_kernels_set):
            if kernel_set1<kernel_set2:
                is_subset = True
        if not is_subset:
            minimum_pm_kernels_list.append(list(kernel_set1))

    for i in range(len(minimum_pm_kernels_list)):
        for j in range(len(all_subgraph_kernels_list)):
            if set(minimum_pm_kernels_list[i]) == set(all_subgraph_kernels_list[j]):
                minimum_pm_id_list.append(j)
                break
        
    for i in range(len(all_subgraph_kernels_list)):
        for j in range(len(minimum_pm_kernels_list)):
            if set(all_subgraph_kernels_list[i]) <= set(minimum_pm_kernels_list[j]):
                unique_pm_id_map[i] = minimum_pm_id_list[j]
                break

    for node in model.graph.node:
        for attr in node.attribute:
            if attr.name == "pm_id":
                attr.i = minimum_pm_id_list.index(unique_pm_id_map[attr.i])
    
    global_kernel_list = list(set(kernel for sublist in all_subgraph_kernels_list for kernel in sublist))
    unique_kernels_list = []
    unique_kernel_includes_list = []

    for i in range(len(minimum_pm_id_list)):
        unique_kernels_list.append(all_subgraph_kernels_list[minimum_pm_id_list[i]])
        unique_kernel_includes_list.append(all_subgraph_kernel_includes_list[minimum_pm_id_list[i]])
    
    json_file = {}
    json_file_name = model_path[:-5]+'_IR_kernel_list.json'
    for i in range(len(unique_kernels_list)):
        json_file["pm_"+str(i)] = {}
        json_file["pm_"+str(i)]['kernel_list'] = []
        kernel_names = {}
        for s in unique_kernels_list[i]:
            try:
                kidx = kernel_func_list.index(s)
                kernel_names[s] = kidx
            except ValueError:
                print(f"Error: '{s}' not found in the kernel func list!")

        json_file["pm_"+str(i)]['kernel_list'] = dict(sorted(kernel_names.items(), key=lambda item: item[1]))
                
        json_file["pm_"+str(i)]['kernel_include'] = unique_kernel_includes_list[i]

    #add remaining fields
    json_file["group_norm_in_model"] = gpn_nodes_in_model
    json_file["disable_fast_pm"] = not fast_pm_enable

    with open(json_file_name, "w") as outfile:
        json.dump(json_file, outfile, indent=4)
    #print(current_subgraph_size)
    
def main(args):
    model_path = args['model_path']
    load_data = int(args['load_data'])
    subgraphs_folder_path = args['subgraphs_path']
    tiling_json = args['tiling_json']
    fast_pm_enable = True

    #load model
    model = onnx.load_model(model_path, load_external_data=load_data)

    partition_graph(model, model_path, subgraphs_folder_path, tiling_json, fast_pm_enable)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--debug", help="Print lots of debugging statements", action="store_const", dest="loglevel", const=logging.DEBUG)
    parser.add_argument("-mp", "--model_path", help="path to fused model.Required Field")
    parser.add_argument("-ld", "--load_data", help="path to additional model data file for large models. Optional Field. Default value = 0", default="0")
    parser.add_argument("-sp", "--subgraphs_path", help="path to subgraphs folder.Required Field")
    parser.add_argument('-tj', "--tiling_json", help="path to tiling json file. Required Field")

    args = parser.parse_args()
    if not args.model_path:
        parser.error("Please pass path/to/fused/model using -mp or --model_path flags.\npython3 graph_partitioning.py --help\n\t\t\tfor further info.")
    logging.basicConfig(level=args.loglevel)
    logging.debug("Debug mode is enabled!")
    if not args.tiling_json:
        parser.error("Please pass path/to/tiling/json using -tj or --tiling_json flags.\npython3 graph_partitioning.py --help\n\t\t\tfor further info.")
    if not args.subgraphs_path:
        parser.error("Please pass path/to/subgraphs/folder using -sp or --subgraphs_path flags.\npython3 graph_partitioning.py --help\n\t\t\tfor further info.")
    main(vars(args))

