import sys
import os
util_path = (os.path.dirname(os.path.abspath(__file__))+"/../src/L1_fusion/L1_utils/")
sys.path.append(util_path)
import numpy as np
import pickle
from ml_dtypes import bfloat16
import json
import onnx
from onnx import numpy_helper, TensorProto
from onnx.checker import check_model
import onnxruntime as ort

np.dtype("bfloat16")

from onnx.helper import (
    make_model, make_node, make_graph,
    make_tensor_value_info,
    make_empty_tensor_value_info,
    np_dtype_to_tensor_dtype,
    make_tensor_type_proto,
    make_value_info,
    make_tensor)

from utils import construct_dict

def dtype_to_TensorProto(type):
    if type == 'uint16':
        return TensorProto.UINT16
    elif type == 'float32':
        return TensorProto.FLOAT
    else:
        assert 0, 'undefined dtype to TensorProto'

def flatten_list(list2d):
    olist = []
    for x in list2d:
        olist += x
    return olist

def mod_dict_json(dict1):
    for x in dict1:
        x2 = x.replace("/", ".")
        dict1[x].append(x2)
    return dict1

def Reverse(lst):
   new_lst = lst[::-1]
   return new_lst

from extract_graph_config import out_folder, model_name, out_node_names, in_node_names, new_model_name

if not os.path.exists(out_folder):
    os.makedirs(out_folder)

model = onnx.load(model_name)

ini_dict, nodes_dict, in_nodes_dict, out_nodes_dict, value_info_dict, input_names, output_names = construct_dict(model)

ir_version      =model.ir_version
opset_imports   =model.opset_import
producer_name   =model.producer_name
producer_ver    =model.producer_version
metadata_props  =model.metadata_props[:]
graph_name      =model.graph.name

del model #del variable to save on memory

#define graph output
g_inputs = []
g_output = []
g_nodes = {}
n_s_nodes = []
g_node_inputs = []
g_node_outputs = {}

#find nodes for the input and output traces
input_nodes = {}
output_nodes = {}
input_traces  = list(in_node_names.keys())
output_traces = list(out_node_names.keys())
for node in nodes_dict:
    for x in input_traces:
        if x in nodes_dict[node].input:
            if x not in list(input_nodes.keys()):
                input_nodes[x] = [node]
            else:
                input_nodes[x].append(node)
    for x in output_traces:
        if x in nodes_dict[node].output:
            if x not in list(output_nodes.keys()):
                output_nodes[x] = [node]
            else:
                output_nodes[x].append(node)

assert set(list(input_nodes.keys())) == set(input_traces),   f"input trace {(set(input_traces) - set(list(input_nodes.keys())))} not found in graph"
assert set(list(output_nodes.keys())) == set(output_traces), f"output trace {(set(output_traces) - set(list(output_nodes.keys())))} not found in graph"
    
#iterate through to backward search for all connected nodes
dest_nodes = flatten_list(list(output_nodes.values()))
src_nodes  = flatten_list(list(input_nodes.values()))
for x in range (0,10000):
    for s_elem in dest_nodes:
        if (s_elem not in list(g_nodes.keys())):
            g_nodes[s_elem] = nodes_dict[s_elem]
            n_s_nodes = n_s_nodes + in_nodes_dict[s_elem]

            #remove node that is found
            if s_elem in src_nodes: 
                terminate_trace=[]  #pylint fix
                #find src on the found trace
                for key in input_nodes:
                    if s_elem in input_nodes[key]:
                        terminate_trace = key
                #remove the found src node and add remain nodes for search
                for node in in_nodes_dict[s_elem]:
                    if terminate_trace in nodes_dict[node].output:
                        n_s_nodes.remove(node)
    dest_nodes = n_s_nodes
    n_s_nodes = []
assert dest_nodes == [], 'There is still remaining modules in graph cutting'


#connect outputs
for n in out_node_names:
    if not out_node_names[n]:
        g_output.append(make_empty_tensor_value_info(n))
    else:
        g_output.append(make_tensor_value_info(n, out_node_names[n][0], out_node_names[n][1]))

#connect inputs, input type must be defined for onnx graph
for n in in_node_names:
    g_inputs.append(make_tensor_value_info(n, in_node_names[n][0], in_node_names[n][1]))


#fetch all inputs for all g_nodes
for n in g_nodes.values():
    g_node_inputs = g_node_inputs + list(n.input)

#copy all relevant initializers
initializers = []
for n in ini_dict.keys():
    if n in g_node_inputs:
        initializers.append(ini_dict[n])

#fetch outputs for all g_nodes
for n in g_nodes.values():
    for m in n.output:
        g_node_outputs[m] = n.name

value_info = []
for n in g_node_outputs:
    if n in value_info_dict.keys():
        value_info.append(value_info_dict[n])

g_nodes_r= Reverse(list(g_nodes.values()))

graph = make_graph(g_nodes_r, graph_name, g_inputs, g_output, initializers, None, value_info)   

nmodel = make_model(graph, ir_version=ir_version, opset_imports=opset_imports)
nmodel.producer_name=producer_name
nmodel.producer_version=producer_ver
for m in metadata_props:
    nmodel.metadata_props.append(m)

onnx.save(nmodel, out_folder+new_model_name)

with open(out_folder+'input_port_param.json', "w") as outfile0:
    json.dump(mod_dict_json(in_node_names), outfile0)

with open(out_folder+'output_port_param.json', "w") as outfile1: 
    json.dump(mod_dict_json(out_node_names), outfile1)
