#!/usr/bin/env python3

'''
This script is to estimate the e2e latency of model given the log file
and output dir.

The log file obtained from the run with -t build_run -profile_perf, and
<output_dir>/*_mod_nhwc_fused_IR_unique_nodes.json must be consistent
with the log file.

The outout file is out.csv.
TODO:
  Currently, this script only process one log file. However, some nodes
may need to run serveral times. As of this writing, this script currently
cannot combine multiple log files. We may need to fix the problem as to
why we cannot get consistent result from HW.

'''

import glob
import os
import sys
import json
import re
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Batch process")
    parser.add_argument("--WAIC_output", type=str, help="The WAIC output dir", required=True)
    parser.add_argument("-l", "--log", type=str, help="the log file", required=True)
    args = parser.parse_args()

    return args

def abort(msg):
    print(msg)
    sys.exit(1);

def find_unique_node_json_file(output_dir):
    all_match = glob.glob(os.path.join(output_dir, "*_mod_nhwc_fused_IR_unique_nodes.json"))
    if len(all_match) == 0:
        abort(f"{output_dir}/*_mod_nhwc_fused_IR_unique_nodes.json does not exist")

    if len(all_match) > 1:
        abort(f"{output_dir}/*_mod_nhwc_fused_IR_unique_nodes.json have multiple match")

    return all_match[0]

def get_node_freq(output_dir):
    unique_node_f = find_unique_node_json_file(output_dir)
    node_freq = {}
    with open(unique_node_f) as f:
        j = json.load(f)
        for k, v in j.items():
            node_freq[k] = int(v["Frequency"])
    return node_freq

'''
hint: The table look like this

+-------------+---------------+-----------------+----------------+-------------------+-------------+-------------+--------------------------+--------------+-----------+
|    Shape    | Maximum Error | Maximum Error % |    L2 norm     | L2 norm p/element |  RMS error  |  RMA error  | Average Relative Error % | XRT time(us) | Pass/Fail |
+-------------+---------------+-----------------+----------------+-------------------+-------------+-------------+--------------------------+--------------+-----------+
|    Add_0    |     31970     |    100.000000   | 3193261.250000 |     20.789461     | 8147.771973 | 2288.969727 |         7.256557         |    89.77     |    Fail   |
|    Add_1    |       13      |     0.041785    |   538.806091   |      0.004385     |   1.537064  |   0.944743  |         0.002999         |    89.77     |    Pass   |
...
+-------------+---------------+-----------------+----------------+-------------------+-------------+-------------+--------------------------+--------------+-----------+
'''

def get_xrt_time(logfile):
    start_pattern = re.compile("Average Relative Error.*XRT time")

    found_xrt_row = False
    row_num = 0
    run_succ = {}
    run_failed = {}

    with open(logfile) as f:
        for line in f:
            # The trailing '^M' (i.e. \r) may be put in a separate line
            # ignore this case
            line = line.strip()
            if line == "":
              continue

            if not found_xrt_row:
                match_title = start_pattern.search(line)
                if bool(match_title):
                    found_xrt_row = True
                    continue

            if not found_xrt_row:
                continue

            # process occurrence of '+-----+ ... +' line
            if line[0] == "+":
                if row_num == 0:
                    # first occurrence shows up right after title
                    continue

                # the 2nd occurrence is at the end of the table bottom
                break

            # process table data
            row_num += 1
            line = line.replace("|", "").strip()
            line = re.sub(' +', ' ', line)
            fields = line.split(' ')

            layer = fields[0]
            xrt_time = fields[-2]
            if fields[-1] == "Pass":
                run_succ[layer] = xrt_time
            else:
                run_failed[layer] = xrt_time

    return run_succ, run_failed

def get_hw_timer(logfile):
    succ_pattern = "record_timer_ts record_timer_ts_"
    fail_pattern = "Failed to collect record timer result for record_timer_ts_"

    succ_dict = {}
    fail_dict = {}

    # log entry examples:
    # record_timer_ts record_timer_ts_Dequant_9.json                                        : 10.72us
    # Failed to collect record timer result for record_timer_ts_Dequant_7.json
    with open(logfile) as f:
        for line in f:
            line = line.strip()
            if line == "":
                continue

            if line.startswith(succ_pattern):
                line = line[len(succ_pattern):]
                line = re.sub('.json +: +', ' ', line)
                line = line[:-len("us")] # remove the trailing us
                fields = line.split(' ')
                succ_dict[fields[0]] = float(fields[1])
                continue

            if line.startswith(fail_pattern):
                line = line[len(fail_pattern):]
                line = line[:-len(".json")] # remove the trailing '.json'
                fail_dict[line] = True

    return succ_dict, fail_dict

def emit_result(args, node_freq, xrt_succ, xrt_fail, timer_succ, timer_fail):
    all_nodes = set(node_freq.keys())
    xrt_node_set = set(xrt_succ.keys()) | set(xrt_fail.keys())

    if xrt_node_set not in all_nodes:
        ir_file = find_unique_node_json_file(args.WAIC_output)
        print("It looks like the log file is not consistent with the output folder")
        print(f"following nodes are not defined in {ir_file}")
        for n in xrt_node_set - all_nodes:
            print(f"{n},", end="")
        print("")

    weighted_timer = 0.0
    succ_num = 0
    with open("out.csv", "w") as f:
        for k, v in node_freq.items():
            timer_status = 'na'
            if k in timer_fail:
                timer_status = 'fail'
            if k in timer_succ:
                timer_status = 'pass'
                succ_num += 1
                weighted_timer += node_freq[k] * timer_succ[k]

            xrt_status = 'na'
            xrt_time = 0.0
            if k in xrt_succ:
                xrt_status = 'pass'
                xrt_time = xrt_succ[k]
            if k in xrt_fail:
                xrt_status = 'fail'
                xrt_time = xrt_fail[k]

            f.writelines(f"{k} {v} {timer_succ.get(k, 0)} {timer_status} {xrt_time} {xrt_status}\n")

    print("the result is written to out.csv with 6 fields")
    print(" - layer name")
    print(" - layer frequency")
    print(" - HW timer in us, 0 if it was not successufl")
    print(" - HW run status: tf: failed, ts: succ, na: no data")
    print(" - XRT time in us")
    print(" - XRT status: xf: failed, ts: succ, na: no data")
    print(f"{succ_num} out of {len(node_freq)} was run successfully, weighted time (us): {weighted_timer}")

def main():
    args = parse_args();

    # step 1: get node frequency
    node_freq = get_node_freq(args.WAIC_output)

    # step 2: get XRT time
    xrt_succ, xrt_fail = get_xrt_time(args.log)

    # step 3: get hw runtime
    timer_succ, timer_fail = get_hw_timer(args.log)

    # step 4: emit result
    emit_result(args, node_freq, xrt_succ, xrt_fail, timer_succ, timer_fail)

main()
