
import os
import argparse
import json
from pathlib import Path
import ntpath
import shutil
import csv
import subprocess
import re
import tempfile

import build_aie4
from utils.run_meta_runtime import main as run_meta_runtime_main
from update_meta_runtime import main as update_meta_runtime_main

def clear_folder(folder: str):
    print("Deleting aie4 artifacts!!!")
    for i in os.listdir(folder):
        item = os.path.join(folder, i)
        try:
            if os.path.isdir(item) and (i.startswith("op_") or i.startswith("fused_") or
                                  i == "model_elf" or
                                  i == "cut_graphs"):
                shutil.rmtree(item)
            elif os.path.isfile(item) and i.endswith("_mod_nhwc_fused.onnx_alloc.json"):
                os.unlink(item)
        except Exception as e:
            print("Failed to delete %s. Reason: %s" % (item, e))

def run_aie4(fe_dir, output_dir, ort_dir, model_name, model_full_path, unique_json_path, nodelist_path, is_qdq_fp16: bool = True):
    fused_path = os.path.join(fe_dir, model_name + "_mod_nhwc_fused.onnx")
    ir_json_path = os.path.join(fe_dir, model_name + "_mod_nhwc_fused_IR.json")
    tensor_map_path = os.path.join(fe_dir, model_name + "_tensor_map.json")
    data_dir = os.path.join(fe_dir, "DataGen", "Consts")
    print("Running aie4 compile_model...")
    build_aie4.compile_model(
        fused_path, ir_json_path, data_dir, model_full_path, "cert",
        output_dir, read_model_data=True,
        unique_nodes_path=unique_json_path, tensor_map_json=tensor_map_path,
        node_list=nodelist_path, set_qdq_fp16=is_qdq_fp16
    )
    print("aie4 compile_model is successfully done!")
    tiling_path = os.path.join(output_dir, model_name + "_mod_nhwc_fused.onnx_alloc.json")
    cut_graphs_path = os.path.join(output_dir, "cut_graphs")
    elf_path = os.path.join(output_dir, "model_elf", "control.elf")
    if os.path.exists(elf_path):
        run_meta_runtime_args = argparse.Namespace(
            tiling_json=tiling_path,
            IR_json=ir_json_path,
            hw_data_path=output_dir,
            out_dir=cut_graphs_path,
            elf=elf_path,
            data_folder=ort_dir,
            tensor_map=tensor_map_path
        )
        print("Running run_meta_runtime...")
        run_meta_runtime_main(run_meta_runtime_args)
        print("run_meta_runtime is successfully done!")
        return True
    else:
        print("No control.elf, may be beacuse it is CPU op")
        return False

def get_nodes_list(args, unique_json_path, out_dir):
    all_nodenames = []
    filtered_nodes = []
    project_dir = Path(__file__).resolve().parent
    nodelist_path = project_dir / "nodelist.txt"
    #use passed nodelist.txt
    if args.nodelist:
        nodelist_path = Path(args.nodelist).expanduser().resolve()
        with open(nodelist_path, "r") as n:
            filtered_nodes = [line.strip() for line in n]
    elif args.sglist: #generate nodelist from given subgraphs
        sglist_path = Path(args.sglist).expanduser().resolve()
        filtered_subgraphs = []
        with open(sglist_path, "r") as n:
            filtered_subgraphs = [line.strip() for line in n]
        context_path = os.path.join(out_dir, "cut_graphs", "context_full.json")
        if not os.path.isfile(context_path):
            print("Failed to open context_full.json: ", context_path)
            return filtered_nodes
        for sg_name in filtered_subgraphs:
            found = 0
            sg_onnx = sg_name + ".onnx"
            with open(context_path, "r", encoding="utf-8") as f1:
                context_full_data = json.load(f1)
                for key, value in context_full_data.items():
                    if key == sg_onnx:
                        found = 1
                        if "nodelist" in value:
                            filtered_nodes.extend(value["nodelist"])
                        break
            if found == 0:
                print("Cannot find subgraph in context_full.json: ", sg_name)
    else: #generate nodelist.txt
        with open(unique_json_path, "r", encoding="utf-8") as f1:
            unique_data = json.load(f1)
            for key, value in unique_data.items():
                if "nodenames" in value:
                    if args.single:
                        all_nodenames.append(value["nodenames"][0])
                    else:
                        all_nodenames.extend(value["nodenames"])

        #filter out skipped nodes if passed
        if args.skipped is not None:
            skipped_file_path = os.path.join(args.skipped)
            with open(skipped_file_path, "r", encoding="utf-8") as f2:
                skipped_data = json.load(f2)
            filtered_nodes = [n for n in all_nodenames if n not in skipped_data]
        else:
            filtered_nodes = all_nodenames[:]
        with open(nodelist_path, "w") as n:
            for name in filtered_nodes:
                n.write(name + "\n")
    return filtered_nodes

def run_aie4_for_each_op(args, fe_dir, out_dir, ort_dir, model_name, model_full_path, unique_json_path, report_path, runner_exe):
    nodelist = get_nodes_list(args, unique_json_path, out_dir)
    with open(report_path, "a", newline="") as f:
        writer = csv.writer(f)
        for n in nodelist:
            clear_folder(out_dir)
            ft, tmp_path = tempfile.mkstemp(text=True)
            os.close(ft)
            is_npu = False
            try:
                with open(tmp_path, "w") as tmp:
                    tmp.write(n)
                is_npu = run_aie4(fe_dir, out_dir, ort_dir, model_name, model_full_path,
                                  unique_json_path, tmp_path, is_qdq_fp16=args.is_qdq_fp16)
            except Exception as e:
                print(e)
                writer.writerow([n,"","", "", "", f"{e}"])
            finally:
                if os.path.exists(tmp_path):
                    os.remove(tmp_path)
            if args.check_wgt:
                check_wgt([n], model_name, out_dir, report_path)
            if is_npu:
                run_aie_runner(args, out_dir, runner_exe, report_path, n)
            else:
                writer.writerow([n,"","", "", "", "missing control.elf, may be CPU op"])

def generate_report(result, report_path, op=None):
    rows = []
    subgraph = None
    max_diffs = []
    l2_norms_el = []
    l2_norms = []
    err_counts = []
    next_is_name = True
    for line in result.stdout.splitlines():
        line = line.strip()
        if not line:
            continue
        if line.startswith("="):
            next_is_name = True
            for diff, l2, l2_per, err in zip(max_diffs, l2_norms, l2_norms_el, err_counts):
                if op is not None:
                    subgraph = op
                rows.append([subgraph, diff, l2, l2_per, err, ""])
            max_diffs = []
            l2_norms_el = []
            l2_norms = []
            err_counts = []
            continue
        if next_is_name:
            subgraph = line
            max_diffs = []
            l2_norms_el = []
            l2_norms = []
            err_counts = []
            next_is_name = False
            continue

        if line.startswith("max_diff is"):
            val = line[len("max_diff is"):].strip()
            max_diffs.append(val)
        if line.startswith("L2_norm per element is"):
            val = line[len("L2_norm per element is"):].strip()
            l2_norms_el.append(val)
        if line.startswith("L2_norm is"):
            val = line[len("L2_norm is"):].strip()
            l2_norms.append(val)
        if line.startswith("Error Count is"):
            val = line[len("Error Count is"):].strip()
            err_counts.append(val)

    for diff, l2, l2_per, err in zip(max_diffs, l2_norms, l2_norms_el, err_counts):
        if op is not None:
            subgraph = op
        rows.append([subgraph, diff, l2, l2_per, err, ""])

    with open(report_path, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(rows)

def run_aie_runner(args, out_dir, runner_exe, report_path, op=None):
    print("Running aie_runner...")
    try:
        if not os.path.exists(runner_exe):
            raise FileNotFoundError(f"Runner is not found: {runner_exe}")
        config_path =  os.path.join(out_dir, "cut_graphs", "config.json")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"config.json file is not found: {config_path}")
        sub = ""
        if args.per_subgraph:
            data ={}
            with open(config_path, "r") as conf:
                data = json.load(conf)
            exclude_keys = ["xclbin",  "HWbin_path", "Tilings_json", "Cache_dir", "prebuilt_bin_dir",
                            "Compile", "Runtime",    "Compile_cfg",  "Debug_cfg"]
            for i in data:
                if i in exclude_keys:
                    continue
                print(f"Run subgraph: {i}")
                sub = i
                result = subprocess.run([runner_exe] + [config_path] + ['1'] + [i], check=True,
                                        stdout=subprocess.PIPE,
                                        stderr=subprocess.STDOUT,
                                        text=True)
                print(f"Collecting max_diff for {i}")
                #with open(f"{i}.log", "w") as f:
                #    f.write(result.stdout)
                generate_report(result, report_path, op)
        else:
            result = subprocess.run([runner_exe] + [config_path], check=True,
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.STDOUT,
                                    text=True)
            print(f"Runner done, collecting max_diff")
            #with open("runner.log", "w") as f:
            #    f.write(result.stdout)
            generate_report(result, report_path, op)
        print("Reporting done.")
    except FileNotFoundError as e:
        print(f"Runner crashed for: {sub}")
        print(e)
        with open(report_path, "a") as r:
            if op:
                sub = op
            writer = csv.writer(r)
            writer.writerow([sub, "", "", "", "", f"Runner crashed:\n {e}"])
    except Exception as e:
        print(f"Runner crashed for: {sub}")
        print(e)
        with open(report_path, "a") as r:
            if op:
                sub = op
            writer = csv.writer(r)
            writer.writerow([sub, "", "", "", "", f"Runner crashed:\n {e}"])

def check_wgt(filtered_nodes, model_name, folder, report_path):
    print("Start wgt checking....")
    alloc_path = os.path.join(folder, model_name + "_mod_nhwc_fused.onnx_alloc.json")
    with open(alloc_path, "r") as alloc:
        data = json.load(alloc)
    missing_wgt = []
    keys = []
    if filtered_nodes is not None:
        for name in filtered_nodes:
            for k, v in  data.items():
                if v.get("name") == name:
                    print(f"Found {k} : {name}")
                    keys.append(k)
                    break
        for key in keys:
            for i in os.listdir(folder):
                item = os.path.join(folder, i)
                if (os.path.isdir(item) and i.startswith("op_") and not i.startswith("op_pdi_shape_input")
                                        and i.endswith(f"_layer_id_{key}")):
                    if not os.path.isfile(os.path.join(item, "wgt.bin")):
                        missing_wgt.append(i)
    else:
        for i in os.listdir(folder):
            item = os.path.join(folder, i)
            if (os.path.isdir(item) and i.startswith("op_") and not i.startswith("op_pdi_shape_input")):
                if not os.path.isfile(os.path.join(item, "wgt.bin")):
                    missing_wgt.append(i)
    with open(report_path, "a", newline="") as f:
        writer = csv.writer(f)
        for i in missing_wgt:
            writer.writerow([i, "", "", "", "", "Missing wgt.bin"])
    print("End wgt checking!")

def main(args):
    out_dir = os.path.join(os.path.dirname(__file__), "Outputs")
    runner_exe = Path(args.runner).expanduser()
    if args.output_dir is not None:
        out_dir = Path(args.output_dir).expanduser()
    os.makedirs(out_dir, exist_ok=True)
    fe_dir = None
    if args.fe_dir is not None:
        fe_dir = Path(args.fe_dir).expanduser()
    ort_dir = None
    if args.ort_dir is not None:
        ort_dir = Path(args.ort_dir).expanduser()
    if args.clean:
        clear_folder(out_dir)
    model_path = args.model_path[:]
    model_name = ntpath.split(model_path)[-1]
    model_name = os.path.splitext(model_name)[0]
    model_full_path = Path(args.model_path).expanduser()
    unique_json_path = os.path.join(fe_dir, model_name + "_mod_nhwc_fused_IR_unique_nodes.json")
    project_dir = Path(__file__).resolve().parent
    report_path = project_dir / "report.csv"
    with open(report_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Op Name", "max_diff", "L2_norm", "L2_norm per element", "Error Count", "Note"])
    if args.subgraph or args.per_subgraph:
        if args.skip_compiling:
            is_npu = 1
            cut_graphs_path = os.path.join(out_dir, "cut_graphs")
            update_meta_runtime_args = argparse.Namespace(
            out_dir=cut_graphs_path,
            data_folder=ort_dir
            )
            print("Updating run_meta_runtime...")
            update_meta_runtime_main(update_meta_runtime_args)
            print("update_meta_runtime is successfully done!")
        else:
            nodelist_path = None
            if args.nodelist:
                nodelist_path = Path(args.nodelist).expanduser().resolve()
            is_npu = run_aie4(fe_dir, out_dir, ort_dir, model_name, model_full_path,
                          unique_json_path, nodelist_path, is_qdq_fp16=args.is_qdq_fp16)
        if args.check_wgt:
            nodelist = None
            if args.nodelist:
                nodelist_path = Path(args.nodelist).expanduser().resolve()
                with open(nodelist_path, "r") as n:
                    nodelist = [line.strip() for line in n]
            check_wgt(nodelist, model_name, out_dir, report_path)
        if args.skip_running:
            print(f"Skipping execution (--skip_running). Artifacts in: {out_dir}")
        elif is_npu:
            run_aie_runner(args, out_dir, runner_exe, report_path)
    else:
        run_aie4_for_each_op(args, fe_dir, out_dir, ort_dir, model_name, model_full_path,
                             unique_json_path, report_path, runner_exe)
    print(f"Results are reported to {report_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="For AIE4 flow op-level debugging",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        required=True,
        help="Path to onnx model",
    )
    parser.add_argument(
        "-fe", "--fe_dir", help="FE output directory", default=None
    )
    parser.add_argument(
        "-ort", "--ort_dir", help="ORT data directory", default=None
    )
    parser.add_argument(
        "-output", "--output_dir", help="output directory", default=None
    )
    default_runner = os.path.join(os.path.dirname(__file__), "waic_runner.exe")
    parser.add_argument(
        "-r", "--runner", help="aie_runner exe path",
        required=True,
    )
    parser.add_argument(
        "--skipped",
        required=False,
        help="Path to skipped nodes list",
    )
    parser.add_argument(
        "--single",
        help="Select single instance for an op",
        action="store_true",
    )
    parser.add_argument(
        "--subgraph",
        help="Enable subgrpah level reporting",
        action="store_true",
    )
    parser.add_argument(
        "--per_subgraph",
        help="Run per-subgraph separately",
        action="store_true",
    )
    parser.add_argument(
        "--skip_compiling",
        help="Skip compiling",
        action="store_true",
    )
    parser.add_argument(
        "--skip_running",
        help="Skip running (compile only)",
        action="store_true",
    )
    parser.add_argument(
        "--nodelist",
        required=False,
        help="Path to nodelist to use, instead of generation"
    )
    parser.add_argument(
        "--sglist",
        required=False,
        help="Path to nodelist to use, instead of generation"
    )
    parser.add_argument(
        "-wgt",
        "--check_wgt",
        help="Check wgt.bin existence",
        action="store_true",
    )
    parser.add_argument(
        "--clean",
        help="Delete artifacts generated by aie4",
        action="store_true",
    )
    parser.add_argument(
        "--ml-timeline", "--ml_timeline",
        action="store_true",
        help="Enable ML Timeline profiling."
    )
    parser.add_argument(
        "-fp16", "--is_qdq_fp16",
        type=bool, default=True,
        help="QDQ datatpye is FP16 or BF16? (Default -> False (QDQ DType is BF16))"
        )

    args = parser.parse_args()

    # Set env variable for toggling ML Timeline
    os.environ["ML_TIMER_LOG_LEVEL"] = "1" if args.ml_timeline else "0"

    if args.skip_compiling and not (args.subgraph or args.per_subgraph):
        parser.error("--skip_compiling is not supported with op level debug.\
                     \nPlease use it with --subgraph or per_subgraph option only.")
    if args.skip_running and not args.subgraph:
        parser.error("--skip_running is not supported with op level debug.\
                     \nPlease use it with --subgraph option only.")
    if args.skip_compiling and args.skip_running:
        parser.error("--skip_compiling and --skip_running are mutually exclusive.")
    if (args.subgraph or args.per_subgraph) and (args.sglist or args.nodelist):
        parser.error("Cannot run subgraph and op level report with single run.\
                     \nPlease use --subgraph/--per_subgraph options separately from --nodelist/--sglist options.")
    main(args)
