import sys
import os
import argparse
import json
import shutil


def meta_json_update(
    json_file, data_folder
):
    if data_folder != None:
        full_data_folder = os.path.abspath(data_folder)       

    file_path = os.path.join(json_file)
    with open(file_path, "r", encoding="utf-8") as f2:
        meta = json.load(f2)

    # inputs
    input_tensors = meta["inputs"]    
    for in_tensor in input_tensors:
        node_name = in_tensor["name"]
        # add ref file_name
        if data_folder != None:
            file_in_tensor = node_name
            
            file_name = (
                file_in_tensor.replace("/", "_").replace(".", "_").replace(":", "_") + ".bin"
            )
            full_file_name = os.path.join(full_data_folder, file_name)
            if os.path.isfile(full_file_name):
                in_tensor["file_name"] = full_file_name
            else:
                print("IO file " + file_name + " does not exist")

    # outputs
    output_tensors = meta["outputs"]
    for out_tensor in output_tensors:
        node_name = out_tensor["name"]
        # add ref file_name
        if data_folder != None:
            file_out_tensor = node_name
            file_name = (
                file_out_tensor.replace("/", "_").replace(".", "_").replace(":", "_")
                + ".bin"
            )
            full_file_name = os.path.join(full_data_folder, file_name)
            if os.path.isfile(full_file_name):
                out_tensor["file_name"] = full_file_name
            else:
                print("IO file " + file_name + " does not exist")

    with open(json_file, "w") as f:
        json.dump(meta, f, indent=2)


def main(args):
    out_dir = args.out_dir
    if out_dir == "":
        print("Did not find the out dir")

    data_folder = args.data_folder

    # read context.json
    cut_graphs_fld = os.path.abspath(out_dir)
    context_file_path = os.path.join(cut_graphs_fld, "context.json")
    with open(context_file_path, "r", encoding="utf-8") as f1:
        context_data = json.load(f1)

    # generate prefix_map
    prefix_map = {}
    for key, val in context_data.items():
        new_key = key[:-5]
        prefix_map[new_key] = 0 #val["subgraph_index"]

    # Update meta.json   
    for fname in prefix_map:
        meta_file_path = os.path.join(cut_graphs_fld, fname + ".json")
        meta_json_update(
            meta_file_path,
            data_folder
        )

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Update meta json with model data",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "-out",
        "--out_dir",
        required=True,
        help="Directory to store cut graphs",
    )
    parser.add_argument(
        "-df",
        "--data_folder",
        required=False,
        help="Directory with ref data for input/output",
    )
    args = parser.parse_args()
    main(args)
