import argparse
import os
import json
import subprocess
import shutil
from ml_dtypes import bfloat16
import numpy as np
from collections import defaultdict
from pathlib import Path

np.dtype("bfloat16")


def main(args):
    tiling_json = args.tiling_json
    xrt = args.xrt
    xclbin = args.xclbin
    wgt_folder = args.wgt_folder
    bin_folder = args.bin_folder
    ifm_ofm_folder = args.ifm_ofm_folder
    output_file = args.output
    ops = args.op_list
    output_file = os.path.abspath(output_file)
    if ops != "all":
        op_list = ops.split(',')
        print(op_list)

    with open(tiling_json, 'r') as f:
        tiling_dict = json.load(f)

    bin_dir = os.path.dirname(tiling_json)
    work_dir = os.path.dirname(bin_dir)
    pm_bins_dir = os.path.dirname(work_dir + "/prebuilt/xclbin/")
    files_in_bin_dir = os.listdir(bin_dir)
    for f in files_in_bin_dir:
        if "IR.json" in f:
            ir_name = os.path.join(bin_dir, f)
            break

    with open(ir_name, 'r') as f:
        ir_info = json.load(f)

    full_results = {}
    xrt_path = os.path.abspath(xrt)
    xclbin_path = os.path.abspath(xclbin)
    cwd = os.getcwd()
    for key, value in tiling_dict.items():
        if value['layer_info']['op_type'] == value['layer_info']['orig_op_type'] and value['layer_info']['op_type'] != "Quant" and value['layer_info']['op_type'] != "Dequant" and value['layer_info']['op_type'] != "Transpose":
            continue
        results_dict = {}
        results_dict['total_count'] = value['layer_info']['Frequency']
        results_dict['op_type'] = value['layer_info']['op_type']
        results_dict['inputs'] = value['layer_info']['inputs']
        results_dict['outputs'] = value['layer_info']['outputs']
        pass_dict = defaultdict(list)
        di_fail_dict = defaultdict(dict)
        cmd_fail_dict = defaultdict(list)
        others_dict = defaultdict(list)
        nodenames = value['layer_info']['nodenames']
        if ops != "all" and results_dict['op_type'] not in op_list:
            continue
        try:
            out_shape_list = value['layer_info']['out_ofm_shape']
        except:
            out_shape_list = value['layer_info']['out_act_shape']
        out_shape = 1
        for i in out_shape_list:
            out_shape *= i
        for nodename in nodenames:
            pm_id = 0
            node_dict = ir_info[nodename]
            if "attributes" in node_dict:
                attributes = node_dict["attributes"]
                if "pm_id" in attributes:
                    for i in attributes["pm_id"]:
                        pm_id = i
            else:
                print(f"Error: missing 'attributes'")
            nodename = nodename.replace('.', '_')
            nodename = nodename.replace('/', '_')
            nodename = nodename.replace('#', '_')
            fld_name = os.path.join(wgt_folder, nodename)
            fld_name = os.path.abspath(fld_name)
            print("Starting for ", fld_name)
            if not Path(fld_name).is_dir():
                print("ERROR: Data is not dumped. Directory does not exist: ", fld_name)
                continue
            pm_bin_file = "pm_" + str(pm_id) + ".bin"
            txn_pm_bin_file = "txn_pm_" + str(pm_id) + ".bin"
            pm_bin_path = os.path.join(pm_bins_dir, pm_bin_file)
            txn_pm_bin_path = os.path.join(pm_bins_dir, txn_pm_bin_file)
            if os.path.exists(pm_bin_path):
                shutil.copy2(pm_bin_path, os.path.join(fld_name, pm_bin_file))
            else:
                print("PM bin: ", pm_bin_path, " not found")
            if os.path.exists(txn_pm_bin_path):
                shutil.copy2(txn_pm_bin_path, os.path.join(fld_name, txn_pm_bin_file))
            else:
                print("TXN PM bin: ", txn_pm_bin_path, " not found")
            os.chdir(fld_name)
            if os.path.exists("patch.json"):
                shutil.copy2("patch.json", "ctrl_meta.json")
            out_type = "uint16"
            if "WSilu" in nodename or "WGelu" in nodename or "biasgelu" in nodename or "_Add_qdq" in nodename or "_Mul_qdq" in nodename or "Silu_qdq" in nodename or "Gelu_qdq" in nodename or "GroupNorm" in nodename or "LayerNorm" in nodename:
                if value['layer_info']['attributes']['disable_q'][0] == 1:
                    out_type = "bf16"

            result = subprocess.run([xrt_path,
                                     xclbin_path, "1", "1", out_type, "0", "-id", str(pm_id), "1"],
                                     stdout=subprocess.PIPE).stdout.decode('utf-8')
            print(result)
            if "ERT_CMD_STATE_ERROR" in result:
                cmd_fail_dict['count'] = cmd_fail_dict.get('count', 0) + 1
                cmd_fail_dict['nodes'].append(nodename)
            elif "FAILED" in result:
                di_fail_dict['count'] = di_fail_dict.get('count', 0) + 1
                di_fail_dict['nodes'][nodename] = {}
                max_diff = 0
                max_diff_bfloat = 0
                mismatch_count = 0
                if "Maximum Error = " in result:
                    max_diff = result.split("Maximum Error = ")[1].split("\n")[0]
                    mismatch_count = result.split("Mismatch count = ")[1].split("\n")[0]
                    #print("Max diff = ", max_diff)
                    #print("Mismatch count = ", mismatch_count)
                if "Maximum Error Percentage = " in result:
                    max_percent = result.split("Maximum Error Percentage = ")[1].split("\n")[0]
                    di_fail_dict['nodes'][nodename]['max_diff_percentage'] = max_percent
                if "L2 norm per element = " in result:
                    l2_per_elem = result.split("L2 norm per element = ")[1].split("\n")[0]
                    di_fail_dict['nodes'][nodename]['l2_per_element'] = l2_per_elem
                if out_type != "bf16":
                    di_fail_dict['nodes'][nodename]['max_err_uint16'] = max_diff
                if value['layer_info']['op_type'] != "Transpose":
                    with open("graph_params.json", 'r') as f:
                        qdq_params_dict = json.load(f)
                out_scale_name = 'output_scale'
                if "Quant" in value["layer_info"]["op_type"] or "Dequant" in value["layer_info"]["op_type"]:
                    out_scale_name = 'input_scale'
                if "WSilu" in nodename:
                    out_scale_name = 'out_scale'
                if "MatMul_qdq_silu" in nodename:
                    out_scale_name = 'silu_smax_diffcale'
                if "MHA_2p1" in nodename:
                    out_scale_name = 'sm_out_scale'
                out_scale = 0
                if value['layer_info']['op_type'] != "Transpose":
                    out_scale = qdq_params_dict[out_scale_name]
                if out_type != "bf16":
                    max_diff_bfloat = np.float32(bfloat16(bfloat16(out_scale) *
                                                          bfloat16(int(max_diff))))
                else:
                    max_diff_bfloat = max_diff
                di_fail_dict['nodes'][nodename]['max_diff_bfloat16'] = float(max_diff_bfloat)
                di_fail_dict['nodes'][nodename]['mismatch_count'] = int(mismatch_count)
                di_fail_dict['nodes'][nodename]['mismatch_percentage'] =\
                        round(float(mismatch_count) * 100 / out_shape, 3)
                max_error_percentage = 0
                max_error_percentage_count = 0
                a = result.split('ERROR:')
                for e in a:
                    if not "Error Percentage :" in e:
                        continue
                    err = e.split('Error Percentage : ')[1].split(',')[0]
                    max_error_percentage = \
                            max(max_error_percentage, float(err))
                    if max_error_percentage > 2:
                        max_error_percentage_count = max_error_percentage_count + 1
                di_fail_dict['nodes'][nodename]['max_perc'] =\
                        max_error_percentage
                di_fail_dict['nodes'][nodename]['total_error_perc'] =\
                        round(float(max_error_percentage_count) * 100 / out_shape, 3)
                #if max_error_percentage >= 2:
                #    b = result.split('\n')
                #    err_srch = "Error Percentage : " + str(max_error_percentage)
                #    for l in b:
                #        if err_srch in l:
                #            di_fail_dict['nodes'][nodename]['err_line'] = l
                #            break
                #    di_fail_dict['nodes'][nodename]['qdq_params'] = qdq_params
            elif "PASS" in result:
                pass_dict['count'] = pass_dict.get('count', 0) + 1
                pass_dict['nodes'].append(nodename)
            else:
                others_dict['count'] = others_dict.get('count', 0) + 1
                others_dict['nodes'].append(nodename)
            print("Finished for ", fld_name)

            os.chdir(cwd)
        if bool(pass_dict):
            results_dict['Pass'] = pass_dict
        if bool(di_fail_dict):
            results_dict['DI_fail'] = di_fail_dict
        if bool(cmd_fail_dict):
            results_dict['ERT_CMD_fail'] = cmd_fail_dict
        if bool(others_dict):
            results_dict['Others'] = others_dict
        full_results[key] = results_dict
    with open(output_file, 'w') as f:
        json.dump(full_results, f, indent=2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Json files")
    parser.add_argument("--tiling_json", required=True)
    parser.add_argument("--xclbin", required=True)
    parser.add_argument("--xrt", required=True)
    parser.add_argument("--wgt_folder", default="WAIC_Outputs/DataGen/Consts",
                        required=False)
    parser.add_argument("--output", default="model_data_run_results.json",
                        required=False)
    parser.add_argument("--bin_folder", default="WAIC_Outputs/", required=False)
    parser.add_argument("--ifm_ofm_folder",
                        default="WAIC_Outputs/DataGen/Activations/carf",
                        required=False)
    parser.add_argument("--op_list",
                        default="all",
                        required=False)
    args = parser.parse_args()
    main(args)
