import argparse
import stat
import sys
import os
import json
import subprocess
import shutil
import numpy as np
import tarfile
import ctypes
from OGOAT.src.Ort.run_ort import main as run_ort_main
from OGOAT.src.Ort.onnx_graph_partitioner import main as onnx_graph_partitioner_main
from OGOAT.src.Ort.run_cut_aie_runtime import main as run_cut_aie_runtime_main

# from dataflow.xclbin.xclbin_build import main as xclbin_build
from runner.python.model_aie_runner import main as aie_runner

files_to_copy = [
    "txn.bin",
    "param.bin",
    "ifm.bin",
    "ofm.bin",
    "wgt.bin",
    "ctrl.bin",
    "patch.json",
    "tiling.json",
]


def on_rm_error(func, path, exc_info):
    os.chmod(path, stat.S_IWRITE)
    func(path)


def main(args):
    working_dir = os.getcwd()
    os.chdir(working_dir)
    output_dir = args.output_dir
    file_list = os.listdir(output_dir)
    data_dump = args.data_dump
    node_list = args.node_list
    use_inmem = args.use_inmem
    if args.cpp_fe is None:
        if "DataGen" in file_list:
            file_list.remove("DataGen")
        for d in file_list:
            if not os.path.isdir(os.path.join(output_dir, d)):
                continue
            subdir = os.path.join(output_dir, d)
            subdir_list = os.listdir(subdir)
            for f in subdir_list:
                if f not in files_to_copy and f != d + ".json":
                    if not os.path.isdir(os.path.join(subdir, f)):
                        os.remove(os.path.join(subdir, f))
                    else:
                        shutil.rmtree(os.path.join(subdir, f), onerror=on_rm_error)

        run_ort_args = argparse.Namespace(
            model_name=args.model_path,
            ld=args.load_data,
            all=False,
            idx=0,
            data_dump=data_dump,
            edges="all",
            out_dir=output_dir,
        )
        run_ort_main(vars(run_ort_args))

    # Note, here it is supposed that runtime build is done and corresponding
    # libaraies exist under prebuilt/runtime directory.
    # cmake -B ./runtime_build -S .
    # cmake --build ./runtime_build --config Release
    # cp runtime/Release/* prebuilt/runtime
    build_dir = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "prebuilt", "runtime"
    )

    # import runtime
    orig_sys_path = sys.path
    try:
        sys.path = [build_dir] + orig_sys_path  # put desired directory first
        import waic_runtime
    finally:
        sys.path = orig_sys_path  # undo change of sys.path

    ir_json = ""
    tiling_json = ""
    kernel_list_json = ""
    fused_onnx = ""
    tensor_map_json = ""
    wgt_dir = os.path.join(output_dir, "DataGen", "Consts")
    if not os.path.exists(wgt_dir):
        print(f"ERROR: Const data dir is not exist.")
    for file in file_list:
        if "mod_nhwc_fused_IR.json" in file:
            ir_json = os.path.join(output_dir, file)
        if "tilings.json" in file:
            tiling_json = os.path.join(output_dir, file)
        if "kernel_list.json" in file:
            kernel_list_json = os.path.join(output_dir, file)
        if file.endswith("fused.onnx"):
            fused_onnx = os.path.join(output_dir, file)
        if "tensor_map.json" in file:
            tensor_map_json = os.path.join(output_dir, file)

    with open(tiling_json) as f:
        tiling_dict = json.load(f)
    id_list = list(tiling_dict.keys())

    if use_inmem == "0":
        print(
            "#####################  Formatting weights for all nodes  #####################"
        )
        waic_runtime.wgt_formatting(
            wgt_dir,
            output_dir,
            tiling_json,
            id_list,
            0,
        )

        print("#####################  Copying .bin files  #####################")
        waic_runtime.txn_update_all(
            wgt_dir,
            output_dir,
            tiling_json,
            True,
        )
    # Define xclbin name
    xclbin_fname = "out.xclbin"
    if args.prebuilt_mladf_mha:
        xclbin_fname = "mha.xclbin"

    if args.cpp_fe is None:
        print("#####################  Cut subgraphs  #####################")
        # run_graph_part_args = argparse.Namespace(model_path=fused_onnx,
        #                                   orig_model_path=args.model_path,
        #                                   tiling_json=tiling_json,
        #                                   out_dir=output_dir,
        #                                   bin_check=1)
        # onnx_graph_partitioner_main(run_graph_part_args)

        args = argparse.Namespace(
            model_path=fused_onnx,
            orig_model_path=args.model_path,
            xclbin=os.path.join(output_dir, "cut_graphs", "out.xclbin"),
            out_dir=output_dir,
            data_folder=None,  # ?
            aie_runner=None,
            bin_check=1,
            tiling_json=tiling_json,
            IR_json_file=ir_json,
            tensor_map_json_file=tensor_map_json,
            node_list=node_list,
            exclude_nodes=args.exclude_nodes,
            cpp_fe=args.cpp_fe,
            use_inmem=args.use_inmem,
            disable_fast_pm=args.disable_fast_pm,
            prebuilt_mladf_mha=args.prebuilt_mladf_mha,
            target=args.target,
        )
        run_cut_aie_runtime_main(args)

    # copy out.xclbin to cut_graphs
    source_file = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "prebuilt", "xclbin", xclbin_fname
    )
    dest_path = os.path.join(output_dir, "cut_graphs")
    dest_fpath = os.path.join(dest_path, "out.xclbin")
    os.makedirs(dest_path, exist_ok=True)
    shutil.copy2(source_file, dest_fpath)

    if args.cpp_fe and os.path.exists(args.cpp_fe):
        print("Flexml graph cutting")
        vaiml_shared_lib_path = args.cpp_fe
        # Load the shared library and use it via the c interface
        flexml_lib = ctypes.cdll.LoadLibrary(vaiml_shared_lib_path)

        # Encode the strings into bytes
        mlir_name = os.path.join(args.output_dir, "model.frontend.mlir")
        output_dir = args.output_dir.encode(encoding="utf-8")
        mlir_name = mlir_name.encode(encoding="utf-8")

        serialized_name = os.path.join(args.output_dir, "serialized.txt")
        serialized_name = serialized_name.encode(encoding="utf-8")

        # Build external_data_path (parent of output_dir)
        external_data_path = os.path.dirname(args.output_dir).encode(encoding="utf-8")

        # Build configuration JSON string matching C++ VaimlInterface
        config_dict = {
            "recipes": "DMAC",
            "keep_outputs": False,
            "output_type": "frontend-mlir",
            "prebuilt_mladf_mha": args.prebuilt_mladf_mha,
            "target": args.target if args.target else "",
            "optimize_level": args.optimization_level,
        }
        config_str = json.dumps(config_dict).encode(encoding="utf-8")

        # Create the DMACcompileResult instance needed by the partitioner interface
        flexml_lib.vaiml_dmac_compile_result_from_file.restype = ctypes.POINTER(
            ctypes.c_void_p
        )
        flexml_lib.vaiml_dmac_compile_result_from_file.argtypes = [ctypes.c_char_p]
        compile_result = flexml_lib.vaiml_dmac_compile_result_from_file(serialized_name)

        ## Free the context
        # flexml_lib.vaiml_context_destroy(context)
        # Set the type conversion function that needs to be called for each parameters
        flexml_lib.dmac_partition_graph.argtypes = [
            ctypes.c_char_p,  # outputs_dir
            ctypes.c_char_p,  # external_data_path
            ctypes.c_char_p,  # config
            ctypes.c_void_p,  # DMACCompileResult
        ]

        # Log parameters being passed to vaiml partitioner
        print(f"[DEBUG VAIML Partitioner] Calling dmac_partition_graph with:")
        print(f"[DEBUG VAIML Partitioner]   output_dir: {args.output_dir}")
        print(
            f"[DEBUG VAIML Partitioner]   external_data_path: {os.path.dirname(args.output_dir)}"
        )
        print(f"[DEBUG VAIML Partitioner]   target: {args.target}")
        print(
            f"[DEBUG VAIML Partitioner]   optimization_level: {args.optimization_level}"
        )
        print(
            f"[DEBUG VAIML Partitioner]   config: {json.dumps(config_dict, indent=2)}"
        )
        print(f"[DEBUG VAIML Partitioner]   compile_result: {compile_result}")

        # call the graph cutter
        flexml_lib.dmac_partition_graph(
            output_dir,
            external_data_path,
            config_str,
            compile_result,
        )
        flexml_lib.vaiml_dmac_compile_result_destroy(compile_result)

    # print("#####################  Generate xclbin files  #####################")
    # xclbin_args = argparse.Namespace(overlay=args.overlay,
    #                                  kernel_file=kernel_list_json,
    #                                  output_dir=args.output_dir)
    # xclbin_build(xclbin_args)

    # print("#####################  Generate meta.json file  #####################")
    # runner_args = argparse.Namespace(json_folder=output_dir,
    #                                 model_name=fused_onnx,
    #                                 tiling_name=tiling_json,
    #                                 data_folder=None,
    #                                 file_name="meta")
    # aie_runner(runner_args)

    # model_basefile = os.path.basename(args.model_path)
    # model_basename = os.path.splitext(model_basefile)[0]
    # tar_filename = model_basename + ".tar.bz2"
    # with tarfile.open(tar_filename, 'w:bz2') as tar:
    #    tar.add(output_dir, arcname=os.path.basename(output_dir))
    # print("Package Created : " + tar_filename)


def WAIC_runtime_main():
    parser = argparse.ArgumentParser(
        description="Windows AI Compiler (WAIC) runtime",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )

    default_output_dir = os.path.join(os.path.dirname(__file__), "WAIC_Outputs")
    # Required args
    parser.add_argument(
        "-mp",
        "--model_path",
        required=True,
        help="Path to onnx model (or JSON) and output destination",
    )
    parser.add_argument(
        "-ld",
        "--load_data",
        help="path to additional model data file for large models. Optional Field. Default value = 0",
        default="0",
    )
    parser.add_argument(
        "-output", "--output_dir", help="output directory", default=default_output_dir
    )
    parser.add_argument(
        "-clean",
        "--delete_dir",
        help="delete output directory if it already exists",
        action="store_true",
    )
    parser.add_argument(
        "-dmp",
        "--data_dump",
        help="Data dump option for run_ort. Default value = wgt",
        default="wgt",
    )
    parser.add_argument(
        "-nl",
        "--node_list",
        help="path to file with node names that should\
                go into sub-graphs",
        default=None,
    )
    parser.add_argument(
        "--cpp_fe",
        help="Path to the shared library interface to compile with flexml",
    )
    parser.add_argument(
        "-en",
        "--exclude_nodes",
        help="path to exclude nodes json file",
        default=None,
    )
    parser.add_argument(
        "-prebuilt_mladf_mha",
        "--prebuilt_mladf_mha",
        action="store_true",
        help="Copy xclbin from waic_artifacts dir",
    )
    parser.add_argument(
        "-use_inmem", "--use_inmem", help="Set use_inmem option", default="0"
    )
    parser.add_argument(
        "--disable_fast_pm",
        action="store_true",
        help="To disable fast pm load, Default = False",
        default=False,
    )
    parser.add_argument(
        "--target",
        help="Target for compilation (e.g., procyon). Use a specific fusion seq file according to target",
        default="",
    )
    parser.add_argument(
        "-O",
        "--optimization_level",
        help="Optimization level (0, 1, 2, or 3)",
        type=int,
        default=1,
    )

    args = parser.parse_args()

    args.output_dir = os.path.abspath(args.output_dir)

    # if os.path.exists(args.output_dir):
    # 	if args.delete_dir:
    # 		print(f'Output dir already exist. Deleting it!! {args.output_dir}')
    # 		clear_folder(args.output_dir)
    # 	else:
    # 		print(f'Output dir already exist. It is not cleaned!! {args.output_dir}. ')
    # else:
    # 	os.makedirs(args.output_dir)
    main(args)


if __name__ == "__main__":
    WAIC_runtime_main()
