import sys
import os
import onnx
import argparse
import json
import math
import re

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from runner_lib import onnx_graph as ogm
from runner_lib import fuse


def update_info(layer, op, key, padded_shape):
    data = json.loads(layer[key].replace("'", '"'))
    shape = []
    dtype = []
    inp_count = 0
    for i in data:
        if "type" not in i:
            print(f"Error: missing 'type' in {key} for {op['name']}")
        elif "shape" not in i:
            print(f"Error: missing 'shape' in {key} for {op['name']}")
        elif "dtype" not in i:
            print(f"Error: missing 'dtype' in {key} for {op['name']}")
        elif "act" == i["type"]:
            if len(padded_shape) != 0:
                for k, v in padded_shape[inp_count].items():
                    if (
                        "input" not in k
                        and "output" not in k
                        and "ifm" not in k
                        and "ofm" not in k
                        and "sin" not in k
                        and "cos" not in k
                    ):
                        continue
                    for s in v["dims"]:
                        shape.append(str(s))
            else:
                for s in i["shape"]:
                    shape.append(str(s))

            dtype.append(i["dtype"])
            inp_count = inp_count + 1
        elif "const" == i["type"]:
            inp_count = inp_count + 1

    ks = "input_shape" if "inputs" == key else "output_shape"
    kt = "input_datatype" if "inputs" == key else "output_datatype"
    op["attrs"][ks] = {"type": "int", "value": []}
    op["attrs"][ks]["value"] = shape
    op["attrs"][kt] = {"type": "str", "value": dtype}


def update_info_noop(tensor_map, op, key):
    ktensor = "in_args" if "inputs" == key else "out_args"
    in_tensor = op[ktensor]
    shape_str = []
    dtype_str = []
    for i in in_tensor:
        shape = tensor_map[i]["shape"]
        dtype = tensor_map[i]["dtype"]    
        for s in shape:
            shape_str.append(str(s))
        dtype_str.append(dtype)
    
    ks = "input_shape" if "inputs" == key else "output_shape"
    kt = "input_datatype" if "inputs" == key else "output_datatype"
    op["attrs"][ks] = {"type": "int", "value": []}
    op["attrs"][ks]["value"] = shape_str
    op["attrs"][kt] = {"type": "str", "value": dtype_str}


def get_inp_out_padded_shapes(padded_shapes):
    padded_shapes_in = []
    padded_shapes_out = []
    if isinstance(padded_shapes, dict):
        for k, v in padded_shapes.items():
            if "input" in k or "ifm" in k or "sin" in k or "cos" in k:
                padded_shapes_in.append({k: v})
            elif "ofm" in k or "output" in k:
                padded_shapes_out.append({k: v})
    else:
        for item in padded_shapes:
            for k, v in item.items():
                if "input" in k or "ifm" in k or "sin" in k or "cos" in k:
                    padded_shapes_in.append(item)
                elif "output" in k or "ofm" in k:
                    padded_shapes_out.append(item)
    return padded_shapes_in, padded_shapes_out


def add_pm_id(op, data):
    idx = []
    pm_id = {"pm_id": {"type": "int", "value": [str(idx)]}}
    if isinstance(data, dict):
        node_dict = data[op["name"]]
        if "attributes" in node_dict:
            attributes = node_dict["attributes"]
            if "pm_id" in attributes:
                for i in attributes["pm_id"]:
                    idx.append(str(i))
                pm_id["pm_id"]["value"] = idx
        else:
            print(f"Error: missing 'attributes'")
        op["attrs"].update(pm_id)
    else:
        print(f"Error: incorrect format")


def add_bkend(op, data):
    bkend_val = []
    bkend = {"bkend": {"type": "str", "value": [""]}}
    if isinstance(data, dict):
        for key, value in data.items():
            if "layer_info" in value:
                layer = value["layer_info"]
                nodenames = {}
                if "nodenames" in layer:
                    nodenames = layer["nodenames"]
                else:
                    print(f"Error: missing 'nodenames'")
                if op["name"] in nodenames:
                    if "bkend_info" in value:
                        bkend_val.append(value["bkend_info"]["bkend"])
                        bkend["bkend"]["value"] = bkend_val
                        op["attrs"].update(bkend)
    else:
        print(f"Error: incorrect format")


def add_conv7x7_special(op, data):
    conv_special_flag_value = []
    conv_special_flag = {"conv_special_flag": {"type": "str", "value": [""]}}
    conv7x7_special = ""
    if isinstance(data, dict):
        for key, value in data.items():
            special_flag = False
            if "layer_info" in value:
                layer = value["layer_info"]
                nodenames = {}
                if "nodenames" in layer:
                    nodenames = layer["nodenames"]
                else:
                    print(f"Error: missing 'nodenames'")
                if op["name"] in nodenames:
                    layer_attr = value["layer_info"]["attributes"]
                    if layer_attr and "conv7x7_special_format" in layer_attr:
                        special_flag = layer_attr["conv7x7_special_format"][0]
                        print(special_flag)
                    if special_flag == True:
                        conv7x7_special = "conv7x7_fold"
                        conv_special_flag_value.append(conv7x7_special)
                        conv_special_flag["conv_special_flag"][
                            "value"
                        ] = conv_special_flag_value
                        op["attrs"].update(conv_special_flag)
    else:
        print(f"Error: incorrect format")


def add_shape_type(op, data, tensor_map, bin_dir):
    if isinstance(data, dict):
        found_op = 0
        for key, value in data.items():
            if value is None or not value:
                continue
            padded_shapes_in = []
            padded_shapes_out = []
            if found_op == 1:
                break
            if "host_layer_padding" in value:
                padded_shapes = value["host_layer_padding"]
                padded_shapes_in, padded_shapes_out = get_inp_out_padded_shapes(
                    padded_shapes
                )
            else:
                small_tiling_json = os.path.join(bin_dir, key, "tiling.json")
                if os.path.isfile(small_tiling_json):
                    with open(small_tiling_json, "r") as f:
                        small_tiling_info = json.load(f)
                    if "host_layer_padding" in small_tiling_info:
                        padded_shapes = small_tiling_info["host_layer_padding"]
                        padded_shapes_in, padded_shapes_out = get_inp_out_padded_shapes(
                            padded_shapes
                        )
            if "layer_info" in value:
                layer = value["layer_info"]
                nodenames = {}
                if "nodenames" in layer:
                    nodenames = layer["nodenames"]
                else:
                    print(f"Error: missing 'nodenames'")
                if op["name"] in nodenames:
                    # print("FOUND: ", op["name"])
                    update_info(layer, op, "inputs", padded_shapes_in)
                    update_info(layer, op, "outputs", padded_shapes_out)
                    found_op = 1
            else:
                print(f"Error: missing 'layer_info'")

        if found_op == 0: # Not in tiling.json
            if "noop" in op["type"] or "runtime" in op["type"]:
                update_info_noop(tensor_map, op, "inputs")
                update_info_noop(tensor_map, op, "outputs")

                # print("FOUND: ", op["name"], op["type"])
            # else:
            # print("NOT FOUND: ", op["name"], op["type"])
    else:
        print(f"Error: incorrect format")


def add_onnx_idx(op):
    idx = 0
    onnx_idx = {"onnx_arg_idx": {"type": "int", "value": [str(idx)]}}
    if "in_args" in op:
        idx += len(op["in_args"])
        onnx_idx["onnx_arg_idx"]["value"].append(str(idx))
    else:
        print(f"Error: missing 'in_args' in {op}")
    if "const_args" in op:
        idx += len(op["const_args"])
        onnx_idx["onnx_arg_idx"]["value"].append(str(idx))
    else:
        print(f"Error: missing 'const_args' in {op}")
    op["attrs"].update(onnx_idx)


def add_format(op, key):
    fmt = ""
    format_ = {key : {"type" : "str", "value" : []}}
    count_ = len(op["in_args"]) if "input_format" == key else len(op["out_args"])
    for i in range(count_):
        format_[key]["value"].append(fmt)
    op["attrs"].update(format_)


def update_nchw_flag(op, tensor_map, nchw_info, key):
    args_key = "in_args" if "input_format" == key else "out_args"
    if args_key in op:
        count_len = len(op[args_key])
        for i in range(count_len):
            tensor_name = op[args_key][i]
            tensor_map[tensor_name]["format"] = ""
            if tensor_name in nchw_info:
                orig_layout = nchw_info[tensor_name]["orig_layout"]
                final_layout = nchw_info[tensor_name]["final_layout"]
                if orig_layout == "NCHW" and final_layout == "NHWC":
                    nchw_flag = "NCHW"
                    op["attrs"][key]["value"][i] = (
                        op["attrs"][key]["value"][i] + "_" + nchw_flag
                    )
                    tensor_map[tensor_name]["format"] = nchw_flag


def add_convertion(op, data, key):
    if "noop" in op["type"] or "runtime" in op["type"]:
        return
    args_key = "in_args" if "input_format" == key else "out_args"
    tensor_list = "inputs" if "input_format" == key else "outputs"
    convert = "_f2bf" if "input_format" == key else "_bf2f"
    if args_key in op:
        count_len = len(op[args_key])
        for i in range(count_len):
            found_node = 0
            for tkey, value in data.items():
                if value is None or not value:
                    continue
                if found_node == 1:
                    break
                if value["layer_info"]["op_type"] == op["type"].replace(
                    "float32", "bfloat16"
                ):
                    if op["name"] in value["layer_info"]["nodenames"]:
                        found_node = 1
                        inp = json.loads(
                            value["layer_info"][tensor_list].replace("'", '"')
                        )[i]
                        if inp["type"] == "act":
                            if "hw_dtype" in inp:
                                if inp["hw_dtype"] == "bfloat16":
                                    op["attrs"][key]["value"][i] = (
                                        op["attrs"][key]["value"][i] + convert
                                    )
                        else:
                            print("Input is not act for node ", op["name"])
            if found_node == 0:
                print("Can't find the key in tiling.json for ", op["name"])


def add_layername(op):
    op_name = op["name"]
    # folder_name = op_name.replace('/', '_').replace('.', '_')
    layername_ = {"layer_name": {"type": "str", "value": [op_name]}}
    op["attrs"].update(layername_)


def add_tkey(op, data):
    if op["type"].find("noop") == -1 and op["type"].find("runtime") == -1:    
        if isinstance(data, dict):
            found_node = 0
            for key, value in data.items():
                if value is None or not value:
                    continue
                if found_node == 1:
                    break

                if value["layer_info"]["op_type"] == op["type"]:
                    if op["name"] in value["layer_info"]["nodenames"]:
                        found_node = 1
                        op["attrs"]["tkey"] = {"type": "str", "value": [key]}

            if found_node == 0:
                # search one more time without checking op.type
                for key, value in data.items():
                    if value is None or not value:
                        continue
                    if found_node == 1:
                        break

                    if op["name"] in value["layer_info"]["nodenames"]:
                        found_node = 1
                        op["attrs"]["tkey"] = {"type": "str", "value": [key]}

            if found_node == 0:
                print("Can't find the key in tiling.json for ", op["name"])

        else:
            print(f"Error: incorrect format")


def update_op_info(op, data, ir_data, nchw_info, tensor_map, bin_dir):
    add_shape_type(op, data, tensor_map, bin_dir)
    add_tkey(op, data)
    add_format(op, "input_format")
    add_format(op, "output_format")
    update_nchw_flag(op, tensor_map, nchw_info, "input_format")
    update_nchw_flag(op, tensor_map, nchw_info, "output_format")
    add_convertion(op, data, "input_format")
    add_convertion(op, data, "output_format")
    add_layername(op)
    add_onnx_idx(op)
    add_pm_id(op, ir_data)
    add_bkend(op, data)
    add_conv7x7_special(op, data)


def main(args):
    model_name = ""
    if args.model_name.endswith(".onnx"):
        model_name = args.model_name
        dir_name = args.json_folder
        tiling_name = args.tiling_name
        data_folder = args.data_folder
        meta_file_name = args.file_name

        bin_dir = os.path.dirname(tiling_name)
        with open(tiling_name, "r") as f:
            tiling_info = json.load(f)

        ir_name = ""
        if ir_name in args:
            ir_name = args.ir_name
        else:
            files_in_bin_dir = os.listdir(bin_dir)
            for f in files_in_bin_dir:
                if "IR.json" in f:
                    ir_name = os.path.join(bin_dir, f)
                    break

        with open(ir_name, "r") as f:
            ir_info = json.load(f)

        tensormap_file_name = ""
        if tensormap_file_name in args:
            tensormap_file_name = args.tensormap_file_name
        else:
            files_in_bin_dir = os.listdir(bin_dir)
            for f in files_in_bin_dir:
                if "tensor_map.json" in f:
                    tensormap_file_name = os.path.join(bin_dir, f)
                    break

        with open(tensormap_file_name, "r") as f:
            nchw_info = json.load(f)

        json_name = os.path.join(dir_name, meta_file_name + ".json")
        os.makedirs(dir_name, exist_ok=True)
        print(model_name)
        onnx_model = onnx.load(model_name)
        meta_info = fuse.prepare_metadata(ogm.ONNXGraph(onnx_model), dir_name)
        op_list = meta_info[0]
        tensor_map = meta_info[2]

        for op in op_list:
            if "ReduceSum" in op["name"]:
                print(op["name"])
            update_op_info(op, tiling_info, ir_info, nchw_info, tensor_map, bin_dir)

        # add input/output filename
        tensors = meta_info[2]
        if data_folder != None:
            full_data_folder = os.path.abspath(data_folder)
            for key, value in tensors.items():
                if (
                    value["packed_buffer_label"] == "in"
                    or value["packed_buffer_label"] == "out"
                ):
                    if key in nchw_info:
                        orig_key = nchw_info[key]["orig_tensor"]
                        file_name = (
                            orig_key.replace("/", "_")
                            .replace(".", "_")
                            .replace(":", "_")
                            + ".bin"
                        )
                        full_file_name = os.path.join(full_data_folder, file_name)
                        if os.path.isfile(full_file_name):
                            value["file_name"] = full_file_name
                        else:
                            print("IO file " + file_name + " does not exist")

        json_str = fuse.save_tensors_to_json(f"{json_name}", *meta_info)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", required=True)
    parser.add_argument("--json_folder", required=True)
    parser.add_argument("--tiling_name", required=True)
    parser.add_argument("--ir_name", required=False, default=None)
    parser.add_argument("--data_folder", required=False)
    parser.add_argument("--file_name", default="meta", required=False)
    parser.add_argument("--tensormap_file_name", required=False, default=None)

    args = parser.parse_args()
    main(args)
