import sys
import os
import subprocess
import argparse
import json
import OGOAT.src.Ort.onnx_graph_partitioner as ogp
from OGOAT.src.Ort.cut_graph_check import main as cut_graph_check
import runner.python.model_aie_runner as json_gen


def main(args):
    # aie_runner = ""
    # if args.aie_runner is not None:
    #     aie_runner = args.aie_runner
    # else:
    #     runner_fld = os.path.join(os.getcwd(), "runner")
    #     for f in os.listdir(runner_fld):
    #         if "aie_runtime.exe" in f:
    #             aie_runner = os.path.join(runner_fld, f)
    # if aie_runner == "":
    #     print("Could not find aie_runner.exe")
    # aie_runner = os.path.abspath(aie_runner)
    tilings_json = ""
    if args.tiling_json is not None:
        tilings_json = args.tiling_json
    else:
        fld = os.path.dirname(args.model_path)
        file_list = os.listdir(fld)
        for f in file_list:
            if "tilings" in f:
                tilings_json = os.path.join(fld, f)
                break
    if tilings_json == "":
        print("Did not find the tilings json")
    IR_json = ""
    if args.IR_json_file is not None:
        IR_json = args.IR_json_file
    else:
        fld = os.path.dirname(args.model_path)
        IR_json = os.path.join(fld, "model_mod_nhwc_fused_IR.json")
    if IR_json == "None":
        print("Did not find the fused IR json")

    tensor_map_json = ""
    if args.tensor_map_json_file is not None:
        tensor_map_json = args.tensor_map_json_file
    if not tensor_map_json:
        print("Did not find the tensor map json")

    data_folder = ""
    if args.data_folder is not None:
        tilings_json = args.data_folder
    else:
        fld = os.path.dirname(args.model_path)
        file_list = os.listdir(fld)
        for f in file_list:
            if "DataGen" in f:
                data_folder = os.path.join(fld, f)
                break
    if data_folder == "":
        print("Did not find the data_folder")
    data_folder = os.path.abspath(data_folder)
    act_folder = os.path.join(data_folder, "Activations", "ort")
    ogp.main(args)

    # Apply PM bin allocation to each cut subgraph
    print(
        "#####################  Applying PM bin allocation to subgraphs  #####################"
    )
    if args.out_dir is not None:
        fld = args.out_dir
    else:
        fld = os.path.dirname(args.model_path)

    cut_graphs_fld = os.path.join(fld, "cut_graphs")
    if os.path.exists(cut_graphs_fld):
        import onnx
        from OGOAT.src.L1_fusion.static_pm_bin_selection import (
            static_pm_partition_graph,
        )

        subgraph_files = [f for f in os.listdir(cut_graphs_fld) if f.endswith(".onnx")]
        print(f"Found {len(subgraph_files)} subgraphs to process")

        for subgraph_file in subgraph_files:
            subgraph_path = os.path.join(cut_graphs_fld, subgraph_file)
            print(f"  Processing subgraph: {subgraph_file}")

            try:
                subgraph_model = onnx.load(subgraph_path)

                # Apply static PM partition to this subgraph
                # Use parameters from args if available
                fast_pm_enable = not getattr(args, "disable_fast_pm", False)
                prebuilt_mladf_mha = getattr(args, "prebuilt_mladf_mha", False)
                is_target_procyon = getattr(args, "target", None) == "procyon"
                IR_json_file = getattr(args, "IR_json_file", None)
                tiling_json_file = getattr(args, "tiling_json", None)

                static_pm_partition_graph(
                    subgraph_model,
                    subgraph_path,
                    fast_pm_enable,
                    prebuilt_mladf_mha,
                    is_target_procyon,
                    tiling_json_file,
                    IR_json_file,
                    save_model=False,
                )
                print(f"    PM bin allocation complete for {subgraph_file}")
            except Exception as e:
                print(
                    f"    Warning: Could not apply PM bin allocation to {subgraph_file}: {e}"
                )
    else:
        print(f"Warning: cut_graphs directory not found at {cut_graphs_fld}")

    if args.out_dir is not None:
        fld = args.out_dir
    else:
        fld = os.path.dirname(args.model_path)
    cut_graphs_fld = os.path.join(fld, "cut_graphs")
    abs_model_path = os.path.abspath(args.model_path)
    config_file_path = os.path.join(cut_graphs_fld, "config.json")
    config = {}
    config["HWbin_path"] = os.path.abspath(fld)  # os.path.join(data_folder, "Consts")
    config["Tilings_json"] = os.path.abspath(tilings_json)
    config["Compile"] = 1
    config["Runtime"] = 1
    config["xclbin"] = os.path.abspath(args.xclbin)
    config["Cache_dir"] = os.path.abspath(os.path.join(cut_graphs_fld, "_Cache"))
    config["prebuilt_bin_dir"] = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))),
        "prebuilt",
        "xclbin",
    )
    compile_cfg = {}
    compile_cfg["profile"] = 0
    compile_cfg["eager_mode"] = 0
    compile_cfg["optimize_scratch"] = 1
    compile_cfg["enable_preemption"] = 1
    compile_cfg["enable_fast_pm"] = 1
    compile_cfg["use_inmem"] = 1
    config["Compile_cfg"] = compile_cfg
    debug_cfg = {}
    debug_cfg['dump_data'] = 0
    debug_cfg['enable_trace'] = 0
    debug_cfg['is_profiling'] = 1
    config['Debug_cfg'] = debug_cfg
    files = os.listdir(cut_graphs_fld)
    for f in files:
        subgraph_base = ""
        if not f.endswith(".onnx"):
            continue
        else:
            subgraph_base = f
        subgraph = os.path.join(cut_graphs_fld, subgraph_base)
        meta_json_args = argparse.Namespace(
            model_name=os.path.abspath(subgraph),
            data_folder=os.path.abspath(act_folder),
            json_folder=os.path.abspath(cut_graphs_fld),
            tiling_name=os.path.abspath(tilings_json),
            ir_name=os.path.abspath(IR_json),
            file_name=subgraph_base[:-5],
            tensormap_file_name=None,
        )
        json_gen.main(meta_json_args)
        fname = subgraph_base[:-5]
        sub_config = {}
        sub_config["meta_json"] = os.path.abspath(
            os.path.join(cut_graphs_fld, fname + ".json")
        )
        config[fname] = sub_config
    with open(config_file_path, "w") as f:
        json.dump(config, f, indent=2)
    # for config in config_json_list:
    #    result = subprocess.run([aie_runner, config],
    #                             stdout=subprocess.PIPE).stdout.decode('utf-8')
    #    print(result)
    cut_graph_check(args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Automated cutter of onnx graph into NPU and CPU chunks",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        required=True,
        help="Path to fused onnx model",
    )
    parser.add_argument(
        "-omp",
        "--orig_model_path",
        required=False,
        help="Path to original onnx model",
    )
    parser.add_argument(
        "--tiling_json",
        required=False,
        help="Path to onnx model",
    )
    parser.add_argument(
        "--IR_json_file",
        required=False,
        help="Path to fused onnx layer descriptions",
    )
    parser.add_argument(
        "--xclbin",
        required=False,
        help="Path to xclbin file",
    )
    parser.add_argument(
        "--out_dir",
        required=False,
        help="Directory to store cut graphs",
    )
    parser.add_argument(
        "--data_folder", required=False, help="Directory with dumped bin files"
    )
    parser.add_argument("--aie_runner", required=False, help="aie_runner path")
    parser.add_argument(
        "--bin_check",
        required=False,
        type=int,
        default=1,
        help="check for generated bin files",
    )
    parser.add_argument(
        "-nl",
        "--node_list",
        help="path to file with node names that \
                should go into sub-graphs",
        default=None,
    )
    args = parser.parse_args()
    main(args)
