import onnx
import onnxruntime as ort
import numpy as np
import sys
import textwrap
import subprocess
from collections import OrderedDict
from onnx import numpy_helper
import os
import argparse
import utils

from config import (
    get_fused_model,
    get_ir_json,
    get_node_json,
    get_input_for_model,
    results_dir,
)

#Initializer dict to hold weights
def construct_initializer_dict(model):
  INTIALIZERS = model.graph.initializer
  initializer_dict = {}
  for initializer in INTIALIZERS:
    if initializer.name not in initializer_dict:
      initializer_dict[initializer.name] = numpy_helper.to_array(initializer)
  return initializer_dict
  
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", required=True)
    parser.add_argument(
        "--idx", type=int, help="against datapoint number",
        default=0, required=False
    )
    parser.add_argument(
        "--edges", type=str, help=textwrap.dedent('''\
        edges to dump
        <text file path> - path to a text file with
                           specified channel names
                           to be extracted'''),
        default="", required=True
    )
    parser.add_argument("--txt_output", action='store_true')

    args = parser.parse_args()

    model_name = args.model_name
    data_idx = int(args.idx)  # Data point number
    first_node = []
    edges = args.edges
    txt_output = args.txt_output

    curr_dir = os.getcwd()
    model_path = os.path.abspath(model_name)

    so = ort.SessionOptions()
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL


    model_orig = onnx.load(model_path)
    for i in model_orig.graph.input:
        first_node.append(i.name)

    data_list, file_name_list = get_input_for_model(model_name, data_idx)

    #Dictionary which contains all weights, weights can be extracted
    # by calling ini_dict['<weights_name>']
    ini_dict = construct_initializer_dict(model_orig)

    channels_file = open(edges, 'r')
    channels_file_lines = channels_file.readlines()
    channels_file_edges = []
    for line in channels_file_lines:
        line = line.replace('\n', '')
        for node in model_orig.graph.node:
            if node.name == line:
                for otp in node.output:
                    model_orig.graph.output.extend([
                        onnx.ValueInfoProto(name=otp)])
                    channels_file_edges.append(otp)
                    break
            if line in node.output:
                    model_orig.graph.output.extend([
                        onnx.ValueInfoProto(name=line)])
                    channels_file_edges.append(line)
                    break


    ort_session_orig = ort.InferenceSession(model_orig.SerializeToString(), providers=["CPUExecutionProvider"], sess_options=so)
    outputs_orig = [x.name for x in ort_session_orig.get_outputs()]
    ort_outs_orig = ort_session_orig.run(outputs_orig, {first_node[i] : data_list[i] for i in range(len(first_node))})
    ort_outs_orig_dict = OrderedDict(zip(outputs_orig, ort_outs_orig))

    #top_folder_name = os.path.join(curr_dir, results_dir)
    #if not os.path.isdir(top_folder_name):
    #    os.mkdir(top_folder_name)
    utils.save_edges(channels_file_edges, ort_outs_orig_dict, curr_dir)

    del ort_session_orig
