#import pickle
import os
import sys
from collections import OrderedDict
import subprocess
import argparse
import textwrap
import numpy as np
import onnx
import onnxruntime as ort

#REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
#l1_fusion_path = os.path.join(REPO_ROOT, "OGOAT", "src", "L1_fusion")
#sys.path.append(l1_fusion_path)
#sys.path.append(REPO_ROOT)

from extract_channels import extract_ifm_ofm_channels, extract_const_channels
from utils_ort import (
    append_outputs_to_model,
    append_outputs_to_fused_model,
    save_ifm_ofm,
    save_wgt_bias,
    save_wgt_bias_from_pickle
)

from config import (
    get_fused_model,
    get_ir_json,
    get_node_json,
    get_input_for_model,
    get_outs_from_graph,
    get_input_nodes,
    results_dir,
)

from onnx import numpy_helper

from datetime import datetime
from sys import exit

def remove_tmp_files():
    if os.path.exists('new_mod_fused.onnx'):
        os.remove('new_mod_fused.onnx')
    if os.path.exists('new_mod_fused.onnx.data'):
        os.remove('new_mod_fused.onnx.data')
        
    if os.path.exists('new_model.onnx'):
        os.remove('new_model.onnx')
    if os.path.exists('new_model.onnx.data'):
        os.remove('new_model.onnx.data')

def run_ortsession(so, model, ir_json, node_json,
                   nodes = "unique", modified = False):
    weights_qdq_dict = []
    ifm_ofm_dict = []
    if modified:
        ifm_ofm_dict = extract_ifm_ofm_channels(ir_json, node_json, all_nodes = nodes)
        weights_qdq_dict = extract_const_channels(ir_json)
        append_outputs_to_fused_model(model, ifm_ofm_dict)
        append_outputs_to_fused_model(model, weights_qdq_dict)
    else:
        append_outputs_to_model(model)
    new_model_path = "new_model"
    onnx.save_model(model, new_model_path + ".onnx", save_as_external_data=True, all_tensors_to_one_file=True, location=new_model_path + '.onnx.data')
    ort_session = ort.InferenceSession(
        new_model_path + '.onnx', so, providers=["CPUExecutionProvider"]
    )
    outputs = [x.name for x in ort_session.get_outputs()]

    return ort_session, outputs, ifm_ofm_dict, weights_qdq_dict

def save_and_get_intermediate_outputs(model, outputs, ort_outs, ifm_ofm_dict, prefix, output_dir):
    curr_dir = os.getcwd()
    ort_outs_dict = OrderedDict(zip(outputs, ort_outs))
    top_folder_name = os.path.join(curr_dir, output_dir, 'DataGen')
    top_folder_name = os.path.join(top_folder_name, 'Activations')
    if not os.path.isdir(top_folder_name):
        os.makedirs(top_folder_name)
    save_ifm_ofm(model, ifm_ofm_dict, ort_outs_dict, top_folder_name, prefix)

def get_network_l2_norm_msft(ort_outs_mod, msft_out):
    l2 = []
    for i in range(len(msft_out)):
        l2.append(np.linalg.norm(ort_outs_mod[i].astype(np.float32) - 
                                 msft_out[i].astype(np.float32)))
    return l2


def get_network_l2_norm_orig(ort_outs_mod, ort_outs_orig, count = 1):
    l2 = []
    for i in range(count):
        l2.append(np.linalg.norm(ort_outs_mod[i].astype(np.float32) - 
                                 ort_outs_orig[i].astype(np.float32)))
    return l2


def ort_single_input(data_list, args, so, output_dir):
    orig_model_name = args['model_name']
    data_idx = int(args['idx'])  # Data point number
    runall_data = args['all']  # Run against all data points
    output_dir = args['out_dir']
    save_data = True
    if runall_data:
        save_data = False
    model_name = get_fused_model(orig_model_name, output_dir)
    data_folder = os.path.dirname(model_name)
    ir_json = get_ir_json(orig_model_name, output_dir)
    node_json = get_node_json(orig_model_name, output_dir)
    msft_out = data_list[1000:]
    edges = args['edges']
    load_data = args['ld']
    out_flag = args['data_dump']
    outputs = out_flag.split(',')
    save_orig = False
    save_mod = ""
    if "all" in outputs or "ort" in outputs:
        save_orig = True
    if "const" in outputs:
        save_mod = "const"
    ort_outs_orig = None
    ort_outs_mod  = None
    if save_orig:
        ort_outs_orig = generate_ORT_data(orig_model_name,
                                          node_json, ir_json,
                                          data_list, output_dir, nodes = edges, save = True)

    l2_norm_msft_against_orig = None
    l2_norm_against_msft = None
    l2_norm_against_orig = None
    if ort_outs_orig:
        l2_norm_msft_against_orig = \
                get_network_l2_norm_orig(msft_out,
                                         ort_outs_orig,
                                         len(msft_out))
    if  ort_outs_mod:
        l2_norm_against_msft = \
                get_network_l2_norm_msft(ort_outs_mod, msft_out)
    if ort_outs_orig and ort_outs_mod:
        l2_norm_against_orig = \
                get_network_l2_norm_orig(ort_outs_mod, 
                                         ort_outs_orig, 
                                         len(msft_out))

    return l2_norm_msft_against_orig, l2_norm_against_msft, l2_norm_against_orig


def save_results(partial, result_lines, model_name, output_dir):
    current_datetime = datetime.now()
    model_base = model_name[:-5]
    formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")
    res_filename = f"{model_base}_results_{formatted_datetime}.txt"
    if partial:
        res_filename = f"partial_{res_filename}"
    with open(os.path.join(output_dir, res_filename), "w") as f:
        f.write("\n".join(result_lines))
    print(f"Results saved to {res_filename}")
    print("Total number of files processed: ", len(result_lines))


def generate_ORT_data(orig_model_name,
                      node_json,
                      ir_json,
                      data_list, output_dir,
                      nodes = "unique",
                      save = True):
    so = ort.SessionOptions()
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    if data_list is None:
        print("Invalid: No data")
    model_orig = onnx.load(orig_model_name)  # Original model

    print("Run orig model...")
    ort_session_orig, outputs_orig, ifm_ofm_dict_orig, weights_qdq_dict_orig = \
            run_ortsession(so, model_orig, ir_json, node_json, nodes, False)
    input_nodes = get_input_nodes(model_orig)
    ort_outs_orig = get_outs_from_graph(ort_session_orig,
                                        outputs_orig,
                                        data_list,
                                        input_nodes)
    if save:
        var = "ort"
        save_and_get_intermediate_outputs(
            model_orig, outputs_orig, ort_outs_orig, ifm_ofm_dict_orig, var, output_dir
        )
    del ort_session_orig

    return ort_outs_orig


def construct_initializer_dict(model):
    INTIALIZERS = model.graph.initializer
    initializer_dict = {}
    for initializer in INTIALIZERS:
        if initializer.name not in initializer_dict:
            initializer_dict[initializer.name] = numpy_helper.to_array(initializer)
    return initializer_dict

def main(args):
    remove_tmp_files()
    model_name = args['model_name']  # Model to test
    data_idx = int(args['idx'])  # Data point number
    data_folder = os.path.dirname(model_name)
    runall_data = args['all']  # Run against all data points
    output_dir = args["out_dir"]
    out_flag = args['data_dump']
    outputs = out_flag.split(',')
    save_mod = ""
    if "wgt" in outputs or "all" in outputs:
        orig_model_name = args['model_name']
        model_name = get_fused_model(orig_model_name, output_dir)
        data_folder = os.path.dirname(model_name)
        ir_json = get_ir_json(orig_model_name, output_dir)
        node_json = get_node_json(orig_model_name, output_dir)
        model_name = get_fused_model(orig_model_name, output_dir)
        weights_qdq_dict = extract_const_channels(ir_json)
        model_mod = onnx.load(model_name)  # Modified model from config
        ini_dict = construct_initializer_dict(model_mod)
        curr_dir = os.getcwd()
#        ort_outs_dict = OrderedDict(zip(outputs, ort_outs))
        top_folder_name = os.path.join(curr_dir, output_dir, 'DataGen')
        top_folder_name = os.path.join(top_folder_name, 'Consts')
        if not os.path.isdir(top_folder_name):
            os.makedirs(top_folder_name)
        save_wgt_bias_from_pickle(weights_qdq_dict, ini_dict, top_folder_name)
        print("Saved constants without running fused model")        
    if "ort" in outputs or "all" in outputs:
        model_name = args['model_name']  # Model to test
        data_folder = os.path.dirname(model_name)
        so = ort.SessionOptions()
        so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        if not runall_data:
            data_list, file_name_list = get_input_for_model(model_name, data_idx)
            l2_norm_msft_against_orig, l2_norm_against_msft, \
                    l2_norm_against_orig = ort_single_input(data_list, args, so, output_dir)
            print(
                "Orig Model L2 Norm with MSFT is:",
                l2_norm_msft_against_orig,
            )
            print(
                "Fused Model L2 Norm with MSFT is:",
                l2_norm_against_msft,
            )
            print(
                "Fused Model L2 Norm with Orig Model is:",
                l2_norm_against_orig,
            )
            remove_tmp_files()
        else:
            print("Running against all data points for Model", model_name)
            result_lines = []
            num_data_points = len(os.listdir(os.path.join(data_folder, "msft_output")))
            sum_l2_norm_against_msft = 0.0
            count = 0
            for idx in range(num_data_points):
                try:
                    data_list, file_name_list = get_input_for_model(model_name, idx)
                    l2_norm_msft_against_orig, l2_norm_against_msft, \
                            l2_norm_against_orig = ort_single_input(data_list, args, so, output_dir)
                    dpl = " ".join(file_name_list)
                    result_line = f"Data point: {dpl, idx}"
                    print(
                        "Orig Model L2 Norm with MSFT is:",
                        l2_norm_msft_against_orig,
                    )
                    print(
                        "Fused Model L2 Norm with MSFT is:",
                        l2_norm_against_msft,
                    )
                    sum_l2_norm_against_msft += l2_norm_against_msft
                    count += 1
                    print(
                        "Average L2 Norm against MSFT so far: ",
                        sum_l2_norm_against_msft / count,
                    )
                    result_line = result_line.join(f", L2 Norm against MSFT output: {l2_norm_against_msft}")
                    print(
                        "Fused Model L2 Norm with Orig Model is:",
                        l2_norm_against_orig,
                    )
                    result_line = result_line.join(f", L2 Norm against Orig Model output: {l2_norm_against_orig}")
                    result_lines.append(result_line)
                except Exception as e:
                    print(
                        f"Script Failed. An unexpected error occurred: {e} \nPartial results saved to txt file"
                    )
                    save_results(True, result_lines, model_name, output_dir)
                    exit()
                except KeyboardInterrupt:
                    print("Script Cancelled. Partial results saved to txt file")
                    save_results(True, result_lines, model_name, output_dir)
                    sys.exit(0)

            save_results(False, result_lines, model_name, output_dir)
            remove_tmp_files()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="run_ort.py, run the original and fused graph and dump the specified data",
                                  usage='use "%(prog)s --help" for more info',
                                  formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("--model_name", required=True)
    parser.add_argument(
        "--ld", type=int, help="load data for large models", default=0, required=False
    )
    parser.add_argument(
        "--idx", type=int, help="against datapoint number", default=0, required=False
    )
    parser.add_argument(
        "--edges",
        type=str,
        help=textwrap.dedent('''\
        edges to dump
        all - extract all ifm/ofm-s from fused graph
        unique - extract ifm/ofm from unique nodes
        fused - extract ifm/ofm from fused nodes only
        fused_unique - extract ifm/ofm from unique
                       fused nodes only
        <text file path> - path to a text file with
                           specified channel names
                           to be extracted'''),
        default="all", required=False
    )
    parser.add_argument(
        "--all",
        type=bool,
        help="run against all datapoints",
        default=False,
        required=False,
    )
    parser.add_argument(
        "--data_dump",
        type=str,
        help=textwrap.dedent('''\
        data to be dumped,
        all   - dump all specified edges and all consts
        ort   - dump all spcified edges from original model
        wgt - dump all const data
        multiple options can be specified separated by ',' comma'''),
        default="all",
        required=False
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        help=textwrap.dedent('''\
        output directory, default is from config.py'''),
        default=results_dir,
        required=True
    )
    args = parser.parse_args()

    if not args.model_name:
        parser.error("Please pass model path with --model_name flags.")

    main(vars(args))
