
import os
import argparse
import json
from pathlib import Path
import ntpath
import shutil
import csv
import subprocess
import re
import tempfile

import build
from WAIC_runtime import main as run_waic_runtime_main
#from utils.run_meta_runtime import main as run_meta_runtime_main
#from update_meta_runtime import main as update_meta_runtime_main

def clear_folder(folder: str):
    print("Deleting aie2p artifacts!!!")
    for i in os.listdir(folder):
        item = os.path.join(folder, i)
        try:
            if os.path.isdir(item) and i == "cut_graphs":
                shutil.rmtree(item)
        except Exception as e:
            print("Failed to delete %s. Reason: %s" % (item, e))

def get_op_type(n, ir_json_path):
    op_type=""
    with open(ir_json_path, "r", encoding="utf-8") as f1:
        ir_data = json.load(f1)
        for key, value in ir_data.items():
            if key == n:
                if "op_type" in value:
                    op_type = value["op_type"]
    return op_type

def get_parent_and_children_nodes(n, ir_json_path):
    all_nodenames = []
    with open(ir_json_path, "r", encoding="utf-8") as f1:
        ir_data = json.load(f1)
        for key, value in ir_data.items():
            if key == n:
                if "op_type" in value:
                    if "runtime" in value["op_type"] or "noop" in value["op_type"]:
                        print("Found runtime op in IR json")
                        if "children_names" in value:
                            all_nodenames.extend(value["children_names"])
                        if "parent_names" in value:
                            all_nodenames.extend(value["parent_names"])
    return all_nodenames

def get_nodes_list(args, unique_json_path):
    all_nodenames = []
    filtered_nodes = []
    project_dir = Path(__file__).resolve().parent
    nodelist_path = project_dir / "nodelist.txt"
    #use passed nodelist.txt
    if args.nodelist:
        nodelist_path = Path(args.nodelist).expanduser().resolve()
        with open(nodelist_path, "r") as n:
            filtered_nodes = [line.strip() for line in n]
    else: #generate nodelist.txt
        with open(unique_json_path, "r", encoding="utf-8") as f1:
            unique_data = json.load(f1)
            for key, value in unique_data.items():
                if "nodenames" in value:
                    if args.single:
                        all_nodenames.append(value["nodenames"][0])
                    else:
                        all_nodenames.extend(value["nodenames"])

        #filter out skipped nodes if passed
        if args.skipped is not None:
            skipped_file_path = os.path.join(args.skipped)
            with open(skipped_file_path, "r", encoding="utf-8") as f2:
                skipped_data = json.load(f2)
            filtered_nodes = [n for n in all_nodenames if n not in skipped_data]
        else:
            filtered_nodes = all_nodenames[:]
        with open(nodelist_path, "w") as n:
            for name in filtered_nodes:
                n.write(name + "\n")
    return filtered_nodes

def run_aie2p_for_each_op(args, out_dir, model_path, ir_json_path, unique_json_path, report_path, runner_exe):
    nodelist = get_nodes_list(args, unique_json_path)
    with open(report_path, "a", newline="") as f:
        writer = csv.writer(f)
        for n in nodelist:
            clear_folder(out_dir)
            ft, tmp_path = tempfile.mkstemp(text=True)
            os.close(ft)
            is_npu = False
            try:
                with open(tmp_path, "w") as tmp:
                    tmp.write(n)
                    neighbor_nodes = get_parent_and_children_nodes(n, ir_json_path)
                    for i in neighbor_nodes:
                        tmp.write("\n")
                        tmp.write(i)

                run_waic_runtime_args = argparse.Namespace(
                    model_path=model_path,
                    output_dir=out_dir,
                    node_list=tmp_path,
                    use_inmem=0,
                    cpp_fe=None,
                    load_data=0,
                    data_dump="wgt",
                    prebuilt_mladf_mha=0,
                    disable_fast_pm=0,
                    target="",
                    exclude_nodes=None
                )
                result = run_waic_runtime_main(run_waic_runtime_args)
                subgraph_path =  os.path.join(out_dir, "cut_graphs")
                has_onnx = any(f.endswith(".onnx") for f in os.listdir(subgraph_path))
                if has_onnx:
                    is_npu = True
            finally:
                if os.path.exists(tmp_path):
                    os.remove(tmp_path)
            op_type = get_op_type(n, ir_json_path)
            if is_npu:
                run_aie_runner(out_dir, runner_exe, report_path, n, op_type)
            else:
                writer.writerow([op_type, n,"","", "", "", "", "missing *.onnx, treated as CPU op"])

def run_aie_runner(out_dir, runner_exe, report_path, op, op_type):
    print("Running aie_runner...")
    try:
        if not os.path.exists(runner_exe): 
            raise FileNotFoundError(f"Runner is not found: {runner_exe}")
        config_path =  os.path.join(out_dir, "cut_graphs", "config.json")
        if not os.path.exists(config_path): 
            raise FileNotFoundError(f"config.json file is not found: {config_path}")
        result = subprocess.run([runner_exe] + [config_path], check=True,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.STDOUT,
                                text=True)
        print("Runner done, collecting max_diff...")
        #with open("runner.log", "w") as f:
        #    f.write(result.stdout)
        rows = []
        subgraph = None
        max_diffs = []
        l2_norms_el = []
        l2_norms = []
        err_counts = []
        number_of_ops = []
        next_is_name = True
        for line in result.stdout.splitlines():
            line = line.strip()
            if not line:
                continue
            if line.startswith("="):
                next_is_name = True
                for diff, l2, l2_per, err, num in zip(max_diffs, l2_norms, l2_norms_el, err_counts, number_of_ops):
                    if op is not None:
                        subgraph = op
                    rows.append([op_type, subgraph, diff, l2, l2_per, err, num, ""])
                max_diffs = []
                l2_norms_el = []
                l2_norms = []
                err_counts = []
                number_of_ops = []
                continue
            if next_is_name:
                subgraph = line
                max_diffs = []
                l2_norms_el = []
                l2_norms = []
                err_counts = []
                number_of_ops = []
                next_is_name = False
                continue

            if line.startswith("max_diff is"):
                val = line[len("max_diff is"):].strip()
                max_diffs.append(val)
            if line.startswith("L2_norm per element is"):
                val = line[len("L2_norm per element is"):].strip()
                l2_norms_el.append(val)
            if line.startswith("L2_norm is"):
                val = line[len("L2_norm is"):].strip()
                l2_norms.append(val)
            if line.startswith("Error Count is"):
                val = line[len("Error Count is"):].strip()
                err_counts.append(val)
            if line.startswith("Total number of Ops"):
                val = line[len("Total number of Ops : "):].strip()
                number_of_ops.append(val)

        for diff, l2, l2_per, err, num in zip(max_diffs, l2_norms, l2_norms_el, err_counts, number_of_ops):
            if op is not None:
                subgraph = op
            rows.append([op_type, subgraph, diff, l2, l2_per, err, num, ""])

        with open(report_path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerows(rows)
        print("Reporting done.")
    except FileNotFoundError as e:
        print(e)
    except Exception as e:
        print(e)

def main(args):
    out_dir = os.path.join(os.path.dirname(__file__), "..\\WAIC_Outputs")
    runner_exe = Path(args.runner).expanduser()
    if args.output_dir is not None:
        out_dir = Path(args.output_dir).expanduser()
    os.makedirs(out_dir, exist_ok=True)
    if args.clean:
        clear_folder(out_dir)
    model_path = args.model_path[:]
    run_waic_runtime_args = argparse.Namespace(
        model_path=model_path,
        output_dir=out_dir,
        node_list=None,
        use_inmem=0,
        cpp_fe=None,
        load_data=0,
        data_dump="ort",
        prebuilt_mladf_mha=0,
        disable_fast_pm=0,
        target="",
        exclude_nodes=None
    )
    print("Running waic_runtime to dump ort data ...")
    run_waic_runtime_main(run_waic_runtime_args)
    model_name = ntpath.split(model_path)[-1]
    model_name = os.path.splitext(model_name)[0]
    ir_json_path = os.path.join(out_dir, model_name + "_mod_nhwc_fused_IR.json")
    unique_json_path = os.path.join(out_dir, model_name + "_mod_nhwc_fused_IR_unique_nodes.json")
    project_dir = Path(__file__).resolve().parent
    report_path = project_dir / "report.csv"
    with open(report_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Op Type", "Op Name", "max_diff", "L2_norm", "L2_norm per element", "Error Count", "Number of ops", "Note"])
#    with open(report_path, "r") as f:
#        print(f.read())
    run_aie2p_for_each_op(args, out_dir, model_path, ir_json_path, unique_json_path, report_path, runner_exe)
    print(f"Results are reported to {report_path}")

    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="For AIE2P flow op-level debugging",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        required=True,
        help="Path to onnx model",
    )
    parser.add_argument(
        "-output", "--output_dir", help="output directory", default=None
    )
    default_runner = os.path.join(os.path.dirname(__file__), "waic_runner.exe")
    parser.add_argument(
        "-r", "--runner", help="aie_runner exe path",
        required=True,
    )
    parser.add_argument(
        "--skipped",
        required=False,
        help="Path to skipped nodes list",
    )
    parser.add_argument(
        "--single",
        help="Select single instance for an op",
        action="store_true",
    )
    parser.add_argument(
        "--subgraph",
        help="Enable subgrpah level reporting",
        action="store_true",
    )
    parser.add_argument(
        "--nodelist",
        required=False,
        help="Path to nodelist to use, instead of generation"
    )
    parser.add_argument(
        "--clean",
        help="Delete artifacts generated by aie2p",
        action="store_true",
    )
    args = parser.parse_args()
    main(args)
