import numpy as np
import shutil
import argparse
import json
import os

op_details = {
        "Add_qdq_EleWise_uint16xuint16xuint16" : 
        {"folder_name" : "elwadd", "overlay" : True,
         "inputs" : {1 : "a", 2 : "a"}, "output" : "acc"},
        "Add_qdq_BroadCast_uint16xuint16xuint16" : 
        {"folder_name_const" : "act_const_matvecadd", 
         "folder_name_act" : "matvecadd","overlay" : True,
         "inputs" : {1 : "a", 2 : "a"}, "output" : "acc"},
        "Mul_qdq_EleWise_uint16xuint16xuint16" : 
        {"folder_name" : "elwmul_qdq", "overlay" : True,
         "inputs" : {1 : "a", 2 : "a"}, "output" : "acc"},
        "Conv_qdq_bias_uint16xuint8xuint16" : 
        {"folder_name" : "iconv", "overlay" : False,
         "inputs" : {1 : "a"}, "wgts" : {1 : "w"}, "output" : "acc"},
        "GroupNormalization_qdq_uint16xuint16xuint16" : 
        {"folder_name" : "gpn", "overlay" : True,
         "inputs" : {1 : "a"}, "output" : "acc"},
        "LayerNormalization_qdq_uint16xuint8xuint16" : 
        {"folder_name" : "lrn", "overlay" : True,
         "inputs" : {1 : "a"}, "output" : "acc"},
        "MHA_2p1_qdq_uint16xuint16xuint16" : 
        {},
        "MHA_3p0_1col_qdq_uint16xuint16xuint16" : 
        #{"folder_name" : "mhapsr", "overlay" : True,
        # "inputs" : {1 : "a", 2 "w"}, "output" : "acc"},
        {},
        "MatMul_qdq_bias_uint16xuint8xuint16" : 
        {"folder_name" : "gemm", "overlay" : True,
         "inputs" : {1 : "a"}, "wgts" : {1 : "w"}, "output" : "acc"},
        "MatMul_qdq_bias_uint16xint8xuint16" : 
        {"folder_name" : "gemm", "overlay" : True,
         "inputs" : {1 : "a"}, "wgts" : {1 : "w"}, "output" : "acc"},
        "MatMul_qdq_Unsqueeze_bias_uint16xuint8xuint16" : 
        {"folder_name" : "gemm", "overlay" : True,
         "inputs" : {1 : "a"}, "wgts" : {1 : "w"}, "output" : "acc"},
        "MatMul_qdq_uint16xuint8xuint16" : 
        {"folder_name" : "gemm", "overlay" : True,
         "inputs" : {1 : "a"}, "wgts" : {1 : "w"}, "output" : "acc"},
        "MatMul_qdq_actxact_uint16xuint16xuint16" : 
        {"folder_name" : "act_act_matmul_qdq", "overlay" : True,
         "inputs" : {1 : "a"}, "wgts" : {1 : "w"}, "output" : "acc"},
        "MatMul_qdq_biasgelu_uint16xuint8xuint16" : 
        {"folder_name" : "gemmgelu", "overlay" : False,
         "inputs" : {1 : "a"}, "wgts" : {1 : "w"}, "output" : "acc"},
        "MatMul_qdq_biasgelu_uint16xint8xuint16" : 
        {"folder_name" : "gemmgelu", "overlay" : False,
         "inputs" : {1 : "a"}, "wgts" : {1 : "w"}, "output" : "acc"},
        "Silu_qdq_uint16xuint16" : 
        {"folder_name" : "silu_qdq", "overlay" : True,
         "inputs" : {1 : "a"}, "output" : "acc"},
        "Gelu_qdq_uint16xuint16" : 
        {"folder_name" : "gelu_qdq", "overlay" : True,
         "inputs" : {1 : "a"}, "output" : "acc"},
        "Resize_qdq_uint16xuint16" : 
        {},
        "Slice_qdq_uint16xuint16" : 
        {"folder_name" : "slice", "overlay" : True,
         "inputs" : {1 : "a"}},
        "Softmax_qdq_uint16xuint16" : 
        {"folder_name" : "softmax_qdq", "overlay" : True,
         "inputs" : {1 : "a"}, "output" : "acc"},
        "LpNormalization_qdq_uint16xuint16" : 
        {"folder_name" : "l2_norm", "overlay" : True,
         "inputs" : {1 : "a"}, "output" : "acc"},
        "RoPE_qdq_uint16xuint16" : 
        {"folder_name" : "RoPE_const", "overlay" : False,
         "inputs" : {1 : "a"}, "output" : "acc"},
}

special_ops = ["RoPE_qdq_uint16xuint16",
               "Add_qdq_EleWise_uint16xuint16xuint16",
               "Add_qdq_BroadCast_uint16xuint16xuint16"]

def copy_and_rename(args):
    tiling_json = args.tiling_json
    input_dir = ""
    output_dir = ""
    if args.input_dir is not None:
        input_dir = args.input_dir
    else:
        input_dir = os.path.dirname(args.tiling_json)
    if args.output_dir is not None:
        output_dir = args.output_dir
    else:
        output_dir = os.path.dirname(args.tiling_json)
        output_dir = os.path.join(output_dir, "copy_folder")

    os.makedirs(output_dir, exist_ok=True)

    with open(tiling_json) as f:
        tiling_dict = json.load(f)

    for key, val in tiling_dict.items():
        if val['layer_info']['op_type'] ==\
                val['layer_info']['orig_op_type']:
            continue
        op_type = val['layer_info']['op_type']
        if op_type not in op_details:
            print(op_type, " INFO DOES NOT EXIST")
            continue
        if not op_details[op_type]:
            print(op_type, "info is empty")
            continue
        overlay = ""
        if op_details[op_type]['overlay']:
            overlay = "_" + val['overlay_info']['overlay']
        folder_name = ""
        if op_type == "Add_qdq_BroadCast_uint16xuint16xuint16":
            i = val['layer_info']['inputs']
            i = i.split("{")[2].split(",")[0].split(" ")[1].split("'")[1]
            if  i == "act":
                folder_name = op_details[op_type]['folder_name_act']
            elif i == "const":
                folder_name = op_details[op_type]['folder_name_const']
            else:
                print("Wrong input type for add broadcast")
        else:
            folder_name = op_details[op_type]['folder_name']
        file_template = ""
        appendix = ""
        if "host_layer_padding" in val:
            padding_info = val['host_layer_padding']
            val_keys = val['layer_info'].keys()
            in_ix = 1
            wgt_ix = 1
            for item in padding_info:
                for k, v in item.items():
                    if k == "sin" or k == "cos":
                        continue
                    datatype = ""
                    if "ifm" in k:
                        if not op_details[op_type]['inputs']:
                            continue
                        appendix += op_details[op_type]['inputs'][in_ix]
                        in_ix += 1
                    elif "wgt" in k:
                        if not op_details[op_type]['wgts']:
                            continue
                        appendix += op_details[op_type]['wgts'][wgt_ix]
                        wgt_ix += 1
                    elif "ofm" in k:
                        if not op_details[op_type]['output']:
                            continue
                        appendix += op_details[op_type]['output']
                    else:
                        print("ERROR KEY: ", key)
                    for val_key in val_keys:
                        if k in val_key and "datatype" in val_key:
                            datatype = val['layer_info'][val_key]
                            break
                    if datatype == "uint16" or datatype == "int16":
                        appendix += "16"
                    elif datatype == "uint8" or datatype == "int8":
                        appendix += "8"
                    elif datatype == "bfloat16":
                        appendix += "bf16"
                    else:
                        print("ERROR datatype: ", key, " ", 
                              k, " ", datatype)
            if op_type not in special_ops:
                last_dim = 0
                for item in padding_info[:-1]:
                    for k, v in item.items():
                        dims = v['dims']
                        for ix, d in enumerate(dims):
                            if d == last_dim and ix == 0:
                                continue
                            else:
                                appendix += "_" + str(d)
                        last_dim = dims[-1]
            elif op_type == "RoPE_qdq_uint16xuint16" or\
                 op_type == "Add_qdq_EleWise_uint16xuint16xuint16":
                o = val['layer_info']['outputs']
                o = o.split("[")[2].split("]")[0].split(", ")
                appendix += "_" + o[-2] + "_" + o[-1]
            elif op_type == "Add_qdq_BroadCast_uint16xuint16xuint16":
                item = padding_info[-1]
                for k, v in item.items():
                    dims = v['dims']
                    for d in dims:
                        appendix += "_" + str(d)
        else:
            print("OP without host_layer_padding: ", op_type)
            if op_type == "Slice_qdq_uint16xuint16":
                input_shape = val['layer_info']['in_act_shape']
                in_datatype = val['layer_info']['in_datatype']
                appendix += op_details[op_type]['inputs'][1]
                if in_datatype == "uint16" or in_datatype == "int16":
                    appendix += "16"
                elif in_datatype == "uint8" or in_datatype == "int8":
                    appendix += "8"
                elif in_datatype == "bfloat16":
                    appendix += "bf16"
                else:
                    print("ERROR datatype: ", key, " ", 
                          k, " ", in_datatype)
                appendix += "_" + str(input_shape[1] * input_shape[2]) +\
                            "_" + str(input_shape[3])
            elif "Conv" in op_type:
                input_shape = val['layer_info']['in_act_shape']
                wgt_shape = val['layer_info']['in_wgt_shape']
                output_shape = val['layer_info']['out_act_shape']
                in_datatype = val['layer_info']['in_datatype']
                wgt_datatype = val['layer_info']['wgt_datatype']
                out_datatype = val['layer_info']['out_datatype']
                l = [op_details[op_type]['inputs'][1],
                     op_details[op_type]['wgts'][1], 
                     op_details[op_type]['output']]
                datatypes = [in_datatype, wgt_datatype, out_datatype]
                assert(len(l) == len(datatypes))
                for i in range(len(l)):
                    appendix += l[i]
                    if datatypes[i] == "uint16" or\
                            datatypes[i] == "int16":
                        appendix += "16"
                    elif datatypes[i] == "uint8" or\
                            datatypes[i] == "int8":
                        appendix += "8"
                    elif datatypes[i] == "bfloat16":
                        appendix += "bf16"
                    else:
                        print("ERROR datatype: ", key, " ", 
                              k, " ", datatypes[i])
                for i in range(1, len(input_shape)):
                    appendix += "_" + str(input_shape[-i])
                for i in range(1, len(output_shape)):
                    appendix += "_" + str(output_shape[-i])
                appendix += "_" + str(wgt_shape[0]) +\
                            "_" + str(wgt_shape[1])
            else:
                print("Unhandled op: ", op_type)
                continue

        file_template = folder_name + overlay + "_" + appendix
        subfolder_name = os.path.join(output_dir, folder_name)
        input_folder_template = os.path.join(input_dir, key)
        output_file_template = os.path.join(subfolder_name,
                                 file_template)
        os.makedirs(subfolder_name, exist_ok=True)
        shutil.copy(os.path.join(input_folder_template, "txn.bin"),
                    output_file_template + ".bin")
        shutil.copy(os.path.join(input_folder_template, "ctrl.bin"),
                    output_file_template + "_ctrl.bin")
        shutil.copy(os.path.join(input_folder_template, "param.bin"),
                    output_file_template + "_param.bin")
        shutil.copy(
            os.path.join(input_folder_template, "ctrl_meta.json"),
            output_file_template + "_ctrl_meta.json")
        #print("Input Folder_name: ", input_folder_template)
        #print("File Template: ", output_file_template)




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tiling_json", required=True)
    parser.add_argument("--input_dir", required=False)
    parser.add_argument("--output_dir", required=False)

    args = parser.parse_args()
    copy_and_rename(args)
