"""This is a Python script to provide artifacts for aie runner."""

import os
import argparse
import json
import shutil


def meta_json_gen(
    json_file, tiling_data, ir_data, context_data, data_folder, tensor_map_dict
):
    """
    This function is to generate meta json file.
    """
    if data_folder is not None:
        full_data_folder = os.path.abspath(data_folder)
    meta = {}
    meta["meta_major_version"] = 2
    meta["meta_minor_version"] = 0
    meta["device"] = "QHW4"

    # inputs
    meta["inputs"] = []
    input_tensors = context_data["inputs"]
    for in_tensor in input_tensors:
        node_name = input_tensors[in_tensor]
        # find onnx related info from ir_data
        input_dict = {}
        input_dict["name"] = in_tensor
        for input_ in ir_data[node_name]["inputs"]:
            if input_["name"] == in_tensor:
                input_dict["onnx_shape"] = input_["shape"]
                input_dict["onnx_dtype"] = input_["dtype"]
                input_param_name = input_["param_name"]
                break

        # find the format from tensor map
        input_dict["onnx_format"] = ""
        if in_tensor in tensor_map_dict:
            # update tensor name based on the original name
            input_dict["name"] = tensor_map_dict[in_tensor]["orig_tensor"]
            orig_layout = tensor_map_dict[in_tensor]["orig_layout"]
            final_layout = tensor_map_dict[in_tensor]["final_layout"]
            if orig_layout == "NCHW" and final_layout == "NHWC":
                input_dict["onnx_format"] = "NCHW"

        # find hw related info from tiling_data
        for key, val in tiling_data.items():
            if val["name"] == node_name:
                node_idx = key
                break

        tiling_data_node = tiling_data[node_idx]

        for key, val in tiling_data_node.items():
            if val == in_tensor:
                input_name = key
                break

        if "const_idx" in tiling_data_node:
            input_dict["L3_alloc"] = tiling_data_node["L3"]["ifm"]
        elif "input_name" in tiling_data_node:
            input_dict["L3_alloc"] = tiling_data_node["L3"]["ifm"]
        else:
            parts = input_name.split("_")
            idx_str = parts[0][5:]
            input_dict["L3_alloc"] = tiling_data_node["L3"]["ifm" + idx_str]

        if "padded_input" in tiling_data_node:
            input_dict["hw_shape"] = tiling_data_node["padded_input"]
        else:
            parts = input_name.split("_")
            idx_str = parts[0][5:]
            input_dict["hw_shape"] = tiling_data_node["padded_input" + idx_str]

        input_dict["hw_dtype"] = tiling_data_node["in_dtype_" + input_param_name]
        input_dict["hw_format"] = "NHWC"

        # add ref file_name
        if data_folder is not None:
            if in_tensor in tensor_map_dict:
                file_in_tensor = tensor_map_dict[in_tensor]["orig_tensor"]
            else:
                file_in_tensor = in_tensor

            file_name = (
                file_in_tensor.replace("/", "_").replace(".", "_").replace(":", "_") + ".bin"
            )
            full_file_name = os.path.join(full_data_folder, file_name)
            if os.path.isfile(full_file_name):
                input_dict["file_name"] = full_file_name
            else:
                print("IO file " + file_name + " does not exist")
        meta["inputs"].append(input_dict)

    # outputs
    meta["outputs"] = []
    output_tensors = context_data["outputs"]
    for out_tensor in output_tensors:
        node_name = output_tensors[out_tensor]
        # find onnx related info from ir_data
        output_dict = {}
        output_dict["name"] = out_tensor
        for output_ in ir_data[node_name]["outputs"]:
            if output_["name"] == out_tensor:
                output_dict["onnx_shape"] = output_["shape"]
                output_dict["onnx_dtype"] = output_["dtype"]
                output_param_name = output_["param_name"]
                break

        # find the format from tensor map
        output_dict["onnx_format"] = ""
        if out_tensor in tensor_map_dict:
            # update tensor name based on the original name
            output_dict["name"] = tensor_map_dict[out_tensor]["orig_tensor"]
            orig_layout = tensor_map_dict[out_tensor]["orig_layout"]
            final_layout = tensor_map_dict[out_tensor]["final_layout"]
            if orig_layout == "NCHW" and final_layout == "NHWC":
                output_dict["onnx_format"] = "NCHW"

        # find hw related info from tiling_data
        for key, val in tiling_data.items():
            if val["name"] == node_name:
                node_idx = key
                break

        tiling_data_node = tiling_data[node_idx]
        for key, val in tiling_data_node.items():
            if val == out_tensor:
                out_name = key
                break

        if "output_name" in tiling_data_node:
            output_dict["L3_alloc"] = tiling_data_node["L3"]["ofm"]
        else:
            # not tested
            parts = out_name.split("_")
            idx_str = parts[0][6:]
            output_dict["L3_alloc"] = tiling_data_node["L3"]["ofm" + idx_str]

        if "padded_output" in tiling_data_node:
            output_dict["hw_shape"] = tiling_data_node["padded_output"]
        else:
            # not tested
            parts = out_name.split("_")
            idx_str = parts[0][6:]
            output_dict["hw_shape"] = tiling_data_node["padded_output" + idx_str]

        output_dict["hw_dtype"] = tiling_data_node[
            "out_dtype_" + output_param_name
        ]
        output_dict["hw_format"] = "NHWC"

        # add ref file_name
        if data_folder is not None:
            if out_tensor in tensor_map_dict:
                file_out_tensor = tensor_map_dict[out_tensor]["orig_tensor"]
            else:
                file_out_tensor = out_tensor
            file_name = (
                file_out_tensor.replace("/", "_").replace(".", "_").replace(":", "_")
                + ".bin"
            )
            full_file_name = os.path.join(full_data_folder, file_name)
            if os.path.isfile(full_file_name):
                output_dict["file_name"] = full_file_name
            else:
                print("IO file " + file_name + " does not exist")
        meta["outputs"].append(output_dict)

    meta["subgraph_index"] = context_data["subgraph_index"]
    meta["bo_sizes"] = context_data["bo_sizes"]

    with open(json_file, "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)


def main(args):
    """
    This is the main function to prepare artifacts for aie runner.
    """
    tilings_json = args.tiling_json
    if tilings_json == "":
        print("Did not find the tilings json")

    IR_json = args.IR_json
    if IR_json == "None":
        print("Did not find the fused IR json")

    tensor_map_json = args.tensor_map
    if tensor_map_json == "None":
        print("Did not find the tensor map json")

    hw_data_path = args.hw_data_path
    if hw_data_path == "":
        print("Did not find the hw data path")

    out_dir = args.out_dir
    if out_dir == "":
        print("Did not find the out dir")

    elf_file = args.elf
    if elf_file == "":
        print("Did not find the elf file")

    data_folder = args.data_folder

    # read context.json
    cut_graphs_fld = os.path.abspath(out_dir)
    context_file_path = os.path.join(cut_graphs_fld, "context.json")
    with open(context_file_path, "r", encoding="utf-8") as f1:
        context_data = json.load(f1)

    # generate prefix_map
    prefix_map = {}
    for key, val in context_data.items():
        new_key = key[:-5]
        prefix_map[new_key] = val["subgraph_index"]

    # generate config.json
    config_file_path = os.path.join(cut_graphs_fld, "config.json")
    config = {}
    config["HWbin_path"] = ""
    config["Tilings_json"] = ""
    config["Compile"] = 0
    config["Runtime"] = 1
    # use the filename from cut_graphs
    config["xclbin"] = os.path.join(cut_graphs_fld, os.path.basename(args.elf))
    config["Cache_dir"] = cut_graphs_fld
    config["prebuilt_bin_dir"] = ""
    debug_cfg = {}
    debug_cfg['dump_data'] = int(os.environ.get("AIE4_SUBGRAPH_DUMP_DATA", 0))
    debug_cfg['enable_trace'] = 0
    config['Debug_cfg'] = debug_cfg
    for fname in prefix_map:
        sub_config = {}
        sub_config["meta_json"] = os.path.abspath(
            os.path.join(cut_graphs_fld, fname + ".json")
        )
        config[fname] = sub_config
    with open(config_file_path, "w", encoding="utf-8") as f:
        json.dump(config, f, indent=2)

    # copy elf
    source_file = os.path.abspath(elf_file)
    dest_path = cut_graphs_fld
    os.makedirs(dest_path, exist_ok=True)
    shutil.copy2(source_file, dest_path)

    # copy/rename artifacts
    dest_path = cut_graphs_fld
    for fname, sub_index in prefix_map.items():
        source_path = os.path.join(
            os.path.abspath(hw_data_path),
            "fused_hw_package_subgraph_" + str(sub_index),
        )
        source_file = os.path.abspath(os.path.join(source_path, "param.bin"))
        dest_file = os.path.abspath(
            os.path.join(dest_path, fname + "super_instr_bo_fname.bin")
        )
        shutil.copy2(source_file, dest_file)
        source_file = os.path.abspath(os.path.join(source_path, "wgt.bin"))
        dest_file = os.path.abspath(
            os.path.join(dest_path, fname + "const_bo_fname.bin")
        )
        shutil.copy2(source_file, dest_file)

    # generate meta.json
    tiling_file_path = os.path.join(tilings_json)
    with open(tiling_file_path, "r", encoding="utf-8") as f2:
        tiling_data = json.load(f2)

    IR_file_path = os.path.join(IR_json)
    with open(IR_file_path, "r", encoding="utf-8") as f3:
        IR_data = json.load(f3)

    tensor_map_dict = {}
    if tensor_map_json is not None:
        if os.path.exists(tensor_map_json):
            with open(tensor_map_json, "r", encoding="utf-8") as f1:
                tensor_map_dict = json.load(f1)
        else:
            raise TypeError(
                "tensor map json path is not found."
            )

    for fname in prefix_map:
        meta_file_path = os.path.join(cut_graphs_fld, fname + ".json")
        meta_json_gen(
            meta_file_path,
            tiling_data,
            IR_data,
            context_data[fname + ".onnx"],
            data_folder,
            tensor_map_dict
        )

    # update context.json
    contextfull_file_path = os.path.join(cut_graphs_fld, "context_full.json")
    context_file_path = os.path.join(cut_graphs_fld, "context.json")
    os.rename(context_file_path, contextfull_file_path)
    if os.path.exists(context_file_path):
        os.remove(context_file_path)
        print("remove successfully")
    contextinfo_file_path = os.path.join(cut_graphs_fld, "context_info.json")
    os.rename(contextinfo_file_path, context_file_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Prepare for aie runner for the artifacts",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "-t",
        "--tiling_json",
        required=True,
        help="Path to alloc json",
    )
    parser.add_argument(
        "-ir",
        "--IR_json",
        required=True,
        help="Path to IR json",
    )
    parser.add_argument(
        "-elf",
        "--elf",
        required=True,
        help="Full path of elf file",
    )
    parser.add_argument(
        "-out",
        "--out_dir",
        required=True,
        help="Directory to store cut graphs",
    )
    parser.add_argument(
        "-hw",
        "--hw_data_path",
        required=True,
        help="Directory with hw data for each subgraph",
    )
    parser.add_argument(
        "-df",
        "--data_folder",
        required=False,
        help="Directory with ref data for input/output",
    )
    parser.add_argument(
        "-tmap",
        "--tensor_map",
        required=True,
        help="Path to the tensor map json",
    )
    args_in = parser.parse_args()
    main(args_in)
