# fmt: on
import argparse
import csv
import ctypes
import glob
import json
import logging
import multiprocessing
import ntpath
import os
import shutil
import subprocess
import sys
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
from enum import Enum
from pathlib import Path

import yaml

from dataflow.xclbin.auto_xclbin import gen_xclbin
from HW_requirements.collect_pm_ids_script import collect_pm_ids
from dataflow.xclbin.auto_xclbin import gen_xclbin

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from config_loader import waic_config

## WAIC Components


# TODO: Using sys.path.append is an anti-pattern as it
# can cause confusing name shadowing, etc.
# Unfortunately, changing this here has knock-on effects for
# running other entry points of the code.
REPO_ROOT = os.path.abspath(os.path.dirname(__file__))
l1_fusion_path = os.path.join(REPO_ROOT, "OGOAT", "src", "L1_fusion")
l2_fusion_path = os.path.join(REPO_ROOT, "OGOAT", "src", "L2_fusion")
tiler_path = os.path.join(REPO_ROOT, "OGOAT", "src", "Tiler")
scheduler_path = os.path.join(REPO_ROOT, "OGOAT", "src", "Scheduling_Engine")
misc_tools_path = os.path.join(REPO_ROOT, "OGOAT", "misc_tools")
dmacompiler_path = os.path.join(REPO_ROOT, "dmacompiler")
build_sys_path = os.path.join(REPO_ROOT, "dataflow")
qhw4_path = os.path.join(REPO_ROOT, "aie4_models")
HW_req = f"{REPO_ROOT}/HW_requirements//"
error_status_filename = "error_status.csv"
IS_WINDOWS = True if sys.platform.startswith("win") else False

sys.path.append(l1_fusion_path)
sys.path.append(l2_fusion_path)
sys.path.append(tiler_path)
sys.path.append(scheduler_path)
sys.path.append(misc_tools_path)
sys.path.append(dmacompiler_path)
sys.path.append(build_sys_path)
sys.path.append(HW_req)

import OGOAT.src.L1_fusion.main as L1_fusion
import OGOAT.src.Scheduling_Engine.main as Scheduler

# import OGOAT.src.Tiler.run_tiler as Tiler
# from dataflow.bilinearresize.bilinear_resize_build import run_bilinear_resize_op
# from dataflow.concat.concat_build import run_concat_op
# from dataflow.conv.conv_build import run_conv_op
# from dataflow.dataflow_common import (
#     build_sim_overlay,
#     clean_overlay,
#     disable_fast_pm_backend,
#     process_simulation_results,
# )
# from dataflow.depthtospace.depthtospace_build import run_depthtospace_op
# from dataflow.max_avg_pool.pooling_build import run_pooling_op
# from dataflow.mha.mha_build import run_qkt_sm_common
# from dataflow.q_dq.q_dq_build import run_q_dq_op
# from dataflow.resize.resize_build import run_resize_op
# from dataflow.slice.slice_build import run_slice_qdq_op
# from dataflow.slice_neg.slice_neg_build import run_slice_neg_op
# from dataflow.transpose.transpose_build import run_transpose_op
# from dmacompiler import BackEnd, set_dev_gen, DevGen
from OGOAT.misc_tools.flow_summary import main as layer_summary
from OGOAT.src.Scheduling_Engine.main import (
    init_args_Scheduler,
    run_Scheduler_single_op,
)
from OGOAT.src.Tiler.device import Device
from OGOAT.src.Tiler.kernel import Kernel, KernInfoFromYaml

# from OGOAT.src.Tiler.run_tiler import init_args_Tiler, run_Tiler_single_op


# def directive(ident: str, val: int, backend: BackEnd) -> str:
#     if backend == BackEnd.Adf:
#         return f'--Xpreproc="-D{ident}={val}"'
#     else:
#         return f"-D{ident}={val}"


def is_json_file(filename: str) -> bool:
    return filename.lower().endswith(".json")


def is_ir_json_file(filename: str) -> bool:
    return filename.lower().endswith("ir_unique_nodes.json")


def is_onnx_file(filename: str) -> bool:
    return filename.lower().endswith(".onnx")


def extract_fields(file_name: str):
    if os.path.exists(file_name):
        with open(file_name, "r") as f:
            data = json.load(f)
        return data
    else:
        return None


def clear_folder(folder: str):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print("Failed to delete %s. Reason: %s" % (file_path, e))


def extract_compile_info(file_path: str, backend):
    with open(file_path, "r") as f:
        data = json.load(f)
    host_files = data["testbench_args"]["HOST_NAME"]
    compile_flags_dict = data["testbench_args"]["COMPILE_FLAGS"]
    compile_flags = []
    for key, val in compile_flags_dict.items():
        compile_flags.append(directive(key, val, backend))
    return host_files, compile_flags


def run_build_and_sim(
    test_dir,
    combine_kernels,
    bypass_tiler,
    layers,
    build_txn,
    backend,
    args,
):
    if args.disable_fast_pm:
        disable_fast_pm_backend()
    logging.info(f"Run AIECompiler and AIESim")
    df_dir = [f.name for f in os.scandir(test_dir) if f.is_dir()]
    op_list = df_dir if layers == "all" else [layers]

    op_type_to_function = {
        "Conv": run_conv_op,
        "Transpose": run_transpose_op,
        "Concat": run_concat_op,
        "Resize": run_resize_op,
        "Slice": run_slice_qdq_op,
        "Slice_neg": run_slice_neg_op,
        "DepthToSpace": run_depthtospace_op,
        "MHA": run_qkt_sm_common,
        "Quant": run_q_dq_op,
        "Dequant": run_q_dq_op,
        "BilinearResize": run_bilinear_resize_op,
        "MaxPool": run_pooling_op,
    }
    if combine_kernels.lower() in ["true", "1"]:

        kernel_file = [
            f
            for f in glob.glob(os.path.join(test_dir, "*.*"), recursive=True)
            if f.endswith("kernel_list.json")
        ][0]
        with open(kernel_file) as f:
            kernel_dict = json.load(f)
    else:
        kernel_dict = {}
        kernel_file = 0

    for op in op_list:
        op_json_dir = os.path.join(test_dir, str(op))
        op_json_file = os.path.join(op_json_dir, f"{op}.json")
        data = extract_fields(op_json_file)
        if data is None:
            sys.stderr.write(f"error: {op_json_file} does not exist or is corrupted\n")
            continue
        op_type = data["layer_info"]["orig_op_type"]
        if (
            op_type == "Resize"
            and data["layer_info"]["attributes"]["mode"][0] == "linear"
        ):
            op_type = "BilinearResize"
        if op_type not in bypass_tiler:
            try:
                os.chdir(
                    op_json_dir
                )  # because build system only work on current dir (subprocess)
                _name, compile_flags = extract_compile_info(op_json_file, backend)
                compile_flags.append(directive("TXN_MODE", backend.value, backend))
                final_src_path = os.path.dirname(f"{REPO_ROOT}/{_name}")
                host_name = os.path.basename(_name)

                source_dir = f"{final_src_path}"
                destination_dir = op_json_dir
                # Ensure the destination directory exists
                # os.makedirs(destination_dir, exist_ok=True)
                # Copy all .cpp, .hpp, and .h files
                for file_name in glob.glob(
                    os.path.join(source_dir, "**", "*.*"), recursive=True
                ):
                    if file_name.endswith(
                        (".cpp", ".hpp", ".h")
                    ) and not file_name.endswith(
                        ("dma.hpp", "graph.hpp", "super.cc", "super.hh")
                    ):
                        shutil.copy(file_name, destination_dir)

                if args.frontend_only is False:
                    build_sim_overlay(
                        backend,
                        host_name,
                        compile_flags,
                        args.dump_waves,
                        args.kernel_debug,
                    )
                    if backend == backend.Adf:
                        results_list = [""]
                        simtime_list = [0.0]
                        process_simulation_results(
                            "AIESimulator.log", 0, results_list, simtime_list
                        )
                        print(
                            f"{op} result: {results_list[0]}, SIM TIME: {simtime_list[0]}"
                        )
                        with open(
                            os.path.join("..", "sim_result.log"), "a"
                        ) as sim_log_file:
                            sim_log_file.write(
                                f"{op} result: {results_list[0]}, SIM TIME: {simtime_list[0]}\n"
                            )
                update_csv_row(
                    test_dir + "/" + error_status_filename,
                    condition_key="Op",
                    condition_value=str(op),
                    update_key="Sim",
                    update_value="pass",
                )
            except Exception as e:
                assert (
                    not args.assert_on_error
                ), f"Unable to run non-OGOAT OPs in sim mode"
                print(f"Error running build sim for {op}, {e}")
                update_csv_row(
                    test_dir + "/" + error_status_filename,
                    condition_key="Op",
                    condition_value=str(op),
                    update_key="Sim",
                    update_value="fail",
                )
        else:
            try:
                if op_type not in ["Conv", "MHA"]:
                    os.chdir(op_json_dir)
                    op_dir_type = op_type
                    if op_type.startswith(("Quant", "Dequant")):
                        op_dir_type = "q_dq"

                    if os.name == "nt":  # Windows
                        copy_scripts_cmd = (
                            f"xcopy /Y {build_sys_path}\\dataflow_common.py . & "
                            f"xcopy /Y /S /I {build_sys_path}\\{op_dir_type.lower()}\\* ."
                        )
                    else:  # Linux
                        copy_scripts_cmd = (
                            f"cp {build_sys_path}/dataflow_common.py . && "
                            f"cp -r {build_sys_path}/{op_dir_type.lower()}/* ."
                        )

                    print(copy_scripts_cmd)
                    os.system(copy_scripts_cmd)

                elif op_type in ["MHA"]:
                    os.chdir(op_json_dir)

                    if os.name == "nt":  # Windows
                        if "MHA_2p1" in data["layer_info"]["op_type"]:
                            copy_scripts_cmd = (
                                f"xcopy /Y {build_sys_path}\\dataflow_common.py . & "
                                f"xcopy /Y {build_sys_path}\\mha\\mha_2p1\\overlay.py . & "
                                f"xcopy /Y {build_sys_path}\\mha\\mha_build.py ."
                            )
                        else:
                            copy_scripts_cmd = (
                                f"xcopy /Y {build_sys_path}\\dataflow_common.py . & "
                                f"xcopy /Y {build_sys_path}\\mha\\mini_mha\\overlay.py . & "
                                f"xcopy /Y {build_sys_path}\\mha\\mha_build.py ."
                            )
                    else:  # Linux
                        if "MHA_2p1" in data["layer_info"]["op_type"]:
                            copy_scripts_cmd = (
                                f"cp {build_sys_path}/dataflow_common.py . && "
                                f"cp {build_sys_path}/mha/mha_2p1/overlay.py . && "
                                f"cp {build_sys_path}/mha/mha_build.py ."
                            )
                        else:
                            copy_scripts_cmd = (
                                f"cp {build_sys_path}/dataflow_common.py . && "
                                f"cp {build_sys_path}/mha/mini_mha/overlay.py . && "
                                f"cp {build_sys_path}/mha/mha_build.py ."
                            )

                    os.system(copy_scripts_cmd)
                else:
                    op_json_dir = f"{build_sys_path}/{op_type.lower()}/"
                txn_mode = 1 if (build_txn == "all" or build_txn == op) else 0
                op_type_to_function[op_type](
                    op_json_file, op_json_dir, txn_mode, kernel_dict, args.frontend_only
                )
                update_csv_row(
                    test_dir + "/" + error_status_filename,
                    condition_key="Op",
                    condition_value=str(op),
                    update_key="Sim",
                    update_value="pass",
                )
            except Exception as e:
                assert (
                    not args.assert_on_error
                ), f"Unable to run non-OGOAT OPs in txn mode"
                print(f"Error running build sim for {op}, {e}")
                print(traceback.format_exc())
                update_csv_row(
                    test_dir + "/" + error_status_filename,
                    condition_key="Op",
                    condition_value=str(op),
                    update_key="Sim",
                    update_value="fail",
                )


def run_cmd_line(cmd_line, curr_dir):
    try:
        result = subprocess.run(
            cmd_line,
            cwd=curr_dir,
            check=True,
            capture_output=True,
            text=True,
            shell=False,
        )
        print("Subprocess output:")
        print(result.stdout)
    except subprocess.CalledProcessError as e:
        # Handle the error and assert failure
        print(f"Error: {e}")
        print("Script stderr:")
        print(e.stderr)


def run_Backend(test_dir, bypass_tiler, layers, BackEnd):
    print(f"Run DMA Compiler backend")
    df_dir = [f.name for f in os.scandir(test_dir) if f.is_dir()]
    op_list = df_dir if layers == "all" else [layers]
    for op in op_list:
        print(f"Generating compiled graph for layer id: {op}")
        op_dir = os.path.join(test_dir, str(op))
        op_json_file = os.path.join(op_dir, f"{op}.json")
        data = extract_fields(op_json_file)
        if data is None:
            sys.stderr.write(f"error: {op_json_file} does not exist or is corrupted\n")
            continue

        if data["layer_info"]["orig_op_type"] not in bypass_tiler:
            os.chdir(
                op_dir
            )  # because build system only work on current dir (subprocess)
            op_dataflow_file = os.path.join(op_dir, "data_flow.py")
            if os.path.exists(op_dataflow_file):
                cmd_line = [
                    sys.executable,
                    op_dataflow_file,
                    "-b",
                    str(BackEnd).split(".")[1],
                ]
                run_cmd_line(cmd_line, op_dir)


def run_TS_single_op(
    args,
    Tiler_arg,
    all_kern_info,
    layer_dict,
    layer_id,
    device,
    output_dir,
    bypass_scheduler,
    test_dir,
    kernel_file,
    op,
    mode,
):
    waic_config.mode = mode
    b = run_Tiler_single_op(
        Tiler_arg, all_kern_info, layer_dict, layer_id, device, output_dir
    )
    a = run_Scheduler_single_op(args, bypass_scheduler, test_dir, kernel_file, op)
    return a + b


def run_Tiler_and_Scheduler(args, test_dir, model_name, bypass_scheduler):

    Tiler_arg = init_args_Tiler(args, test_dir, model_name)
    test_dir, kernel_file = init_args_Scheduler(args)
    parent_dir = os.path.dirname(os.path.abspath(__file__))
    mode = waic_config.mode
    with open(os.path.join(parent_dir, "OGOAT/Collaterals/device.yaml")) as f:
        devices_dict = yaml.safe_load(f)

    supported_devices = list(devices_dict.keys())
    with open(Tiler_arg["ir_json"]) as f:
        mdict = json.load(f)

    device = None
    if Tiler_arg["device"] in supported_devices:
        device = Device(Tiler_arg["device"])
    else:
        raise Exception("invalid device selection")

    output_dir = os.path.dirname(os.path.abspath(Tiler_arg["ir_json"]))
    output_name = (
        os.path.splitext(os.path.basename(Tiler_arg["ir_json"]))[0][:-26]
        + "_tilings.json"
    )

    output_path = os.path.join(output_dir, output_name)

    all_tilings = {}
    all_kern_info = KernInfoFromYaml()

    # Use ProcessPoolExecutor with user's -j setting or automatic CPU core detection
    num_layers = len(mdict)
    max_workers = args.j
    if args.j == 1:
        print(
            f"Running Tiler and Scheduler sequentially (1 worker) for {num_layers} layers"
        )
    elif args.j:
        print(
            f"Running Tiler and Scheduler with {args.j} workers for {num_layers} layers"
        )
    else:
        print(
            f"Running Tiler and Scheduler with automatic worker count for {num_layers} layers"
        )

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        future_to_task = {
            executor.submit(
                run_TS_single_op,
                args,
                Tiler_arg.copy(),
                all_kern_info,
                layer_dict,
                layer_id,
                device,
                output_dir,
                bypass_scheduler,
                test_dir,
                kernel_file,
                layer_id,
                mode,
            ): (layer_id, layer_dict)
            for layer_id, layer_dict in mdict.items()
        }
        for future in as_completed(future_to_task):
            try:
                result = future.result()
                if result is None:
                    print("Warning: Task returned None result")
                    continue

                tiler_res = "Nan"
                if len(result) > 4:
                    all_tilings[result[3]] = result[4]
                    tiler_res = result[5]
                else:
                    tiler_res = result[3]
                append_csv_row(
                    args.output_dir + "/" + error_status_filename,
                    new_row={
                        "Op": result[0],
                        "Op type": result[1],
                        "Scheduler": result[2],
                        "Tiler": tiler_res,
                    },
                )
            except Exception as e:
                print(f"Exception in parallel task: {e}")
                continue

    # load tiling files from all layer together and dump it to output_path
    merged_data = {}
    # Merge Json data from all layers
    for key, file_path in all_tilings.items():
        if os.path.exists(file_path):
            with open(file_path) as file:
                merged_data[key] = json.load(file)
        else:
            print(f"File not found: {file_path}")
    with open(output_path, "w") as f:
        json.dump(merged_data, f, indent=2)


def run_Scheduler(args, bypass_scheduler):
    print(f"Run Scheduler Stage")
    test_dir, kernel_file = init_args_Scheduler(args)
    df_dir = [f.name for f in os.scandir(test_dir) if f.is_dir()]
    print("...............", df_dir)

    # Use ProcessPoolExecutor with user's -j setting or automatic CPU core detection
    num_ops = len(df_dir)
    max_workers = args.j
    if args.j == 1:
        print(f"Running Scheduler sequentially (1 worker) for {num_ops} operations")
    elif args.j:
        print(f"Running Scheduler with {args.j} workers for {num_ops} operations")
    else:
        print(f"Running Scheduler with automatic worker count for {num_ops} operations")

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        future_to_task = {
            executor.submit(
                run_Scheduler_single_op,
                args,
                bypass_scheduler,
                test_dir,
                kernel_file,
                op,
            ): op
            for op in df_dir
        }
        for future in as_completed(future_to_task):
            try:
                result = future.result()
                if result is None:
                    print("Warning: Scheduler task returned None result")
                    continue

                append_csv_row(
                    args.output_dir + "/" + error_status_filename,
                    new_row={
                        "Op": result[0],
                        "Op type": result[1],
                        "Scheduler": result[2],
                    },
                )
            except Exception as e:
                print(f"Exception in scheduler parallel task: {e}")
                continue


def run_Tiler(args, test_dir, model_name):
    print(f"Run Tiler Stage")
    Tiler_arg = init_args_Tiler(args, test_dir, model_name)
    try:
        Tiler.main(Tiler_arg.copy())
        update_csv_row(
            args.output_dir + "/" + error_status_filename,
            condition_key="Model name",
            condition_value=model_name,
            update_key="Tiler",
            update_value="pass",
        )
    except:
        assert not args.assert_on_error, f"Unable to run tiler"
        print("Unable to run tiler")
        print(traceback.format_exc())
        update_csv_row(
            args.output_dir + "/" + error_status_filename,
            condition_key="Model name",
            condition_value=model_name,
            update_key="Tiler",
            update_value="fail",
        )
        pass


def run_L2():
    pass


def run_L1(args):
    # L1 fusion
    logging.info(f"Run L1 Fusion Stage")
    if args.cpp_fe:
        L1_fusion.run_VAIML_fusion(args)
        return

    L1_fusion_arg = vars(args).copy()
    L1_fusion_arg["shape_infer_method"] = "onnx_shape_infer"
    L1_fusion_arg["input_names"] = ""
    L1_fusion_arg["input_dims"] = ""
    L1_fusion_arg["assign_new_dtypes"] = "0"
    L1_fusion_arg["low_precision_act_dtype"] = "uint16"
    L1_fusion_arg["high_precision_act_dtype"] = "uint16"
    L1_fusion_arg["low_precision_wgt_dtype"] = "uint8"
    L1_fusion_arg["high_precision_wgt_dtype"] = "uint16"
    L1_fusion_arg["skip_step"] = args.skip_step
    L1_fusion_arg["fast_pm"] = not args.disable_fast_pm
    L1_fusion_arg["shape_inference_outputs"] = args.shape_inference_outputs
    L1_fusion_arg["no_dtype_downcast"] = args.no_dtype_downcast
    L1_fusion_arg["no_dtype_freeze"] = args.no_dtype_freeze
    L1_fusion_arg["prebuilt_mladf_mha"] = args.prebuilt_mladf_mha
    L1_fusion_arg["assign_pmid_before_partition"] = args.assign_pmid_before_partition

    if args.fusion_seq is not None:
        L1_fusion_arg["fusion_seq"] = args.fusion_seq
    if args.target is not None:
        L1_fusion_arg["target"] = args.target

    L1_fusion.main(L1_fusion_arg.copy())
    try:
        mod_ort_file = os.path.join(
            args.output_dir,
            os.path.basename(args.model_path.replace(".onnx", "_mod_ort.onnx")),
        )
        if os.path.exists(mod_ort_file):
            os.remove(mod_ort_file)
        update_csv_row(
            args.output_dir + "/" + error_status_filename,
            condition_key="Model name",
            condition_value=os.path.basename(args.model_path),
            update_key="L1 fusion",
            update_value="pass",
        )
    except:
        assert not args.assert_on_error, f"Unable to find _mod_ort.onnx file"
        print("System cannot find _mod_ort.onnx file")
        update_csv_row(
            args.output_dir + "/" + error_status_filename,
            condition_key="Model name",
            condition_value=os.path.basename(args.model_path),
            update_key="L1 fusion",
            update_value="fail",
        )


def run_cpp_ME(args):
    print("Skip Python ME and run CPP ME")
    # Path to CPP ME DLL
    cpp_me_lib_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)),
        "prebuilt",
        "run_waic.dll" if IS_WINDOWS else "librun_waic.so",
    )
    if not os.path.exists(cpp_me_lib_path):
        raise RuntimeError(
            f"Could not find CPP ME library with path: {cpp_me_lib_path}"
        )
    model_path = args.model_path[:]
    model_name = ntpath.split(model_path)[-1]
    model_name = os.path.splitext(model_name)[0]
    ir_json_path = os.path.join(
        args.output_dir, model_name + "_mod_nhwc_fused_IR_unique_nodes.json"
    )
    waic_args = [
        "run_waic",
        "--ir_json",
        ir_json_path,
        "-df",
        "--continue_on_error",
        "--output_dir",
        args.output_dir,
    ]

    if args.prebuilt_mladf_mha:
        waic_args += ["--prebuilt_mladf_mha"]

    if args.target is not None and args.target == "procyon":
        waic_args += ["--target", "procyon"]

    if args.mode == "dev":
        waic_args.append("--verify_txn_on_HW")
        waic_args.append("--verbose")
        waic_args.append("DEBUG")
    # Convert to ctypes format
    argc = len(waic_args)
    argv = (ctypes.c_char_p * argc)(*(arg.encode() for arg in waic_args))

    # Load the shared library and use it via the c interface
    waic_me = ctypes.CDLL(cpp_me_lib_path)
    result = waic_me.run_waic(argc, argv)
    print("CPP ME completed with status:", result)


def create_error_status_csv(output_dir, error_status_filename, model_name):
    csv_file = os.path.join(output_dir, error_status_filename)
    headers = [
        "Model name",
        "L1 fusion",
        "Op",
        "Op type",
        "Tiler",
        "Scheduler",
        "Sim",
        "HW",
        "Iteration TIme (us)",
        "L2 norm",
    ]
    rows = [{"Model name": model_name}]
    with open(csv_file, mode="w", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=headers)
        writer.writeheader()
        writer.writerows(rows)


def update_csv_row(file_path, condition_key, condition_value, update_key, update_value):
    rows = []
    with open(file_path, mode="r") as file:
        reader = csv.DictReader(file)
        rows = list(reader)

    for row in rows:
        if row.get(condition_key) == condition_value:
            row[update_key] = update_value

    with open(file_path, mode="w", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=rows[0].keys())
        writer.writeheader()
        writer.writerows(rows)


def append_csv_row(file_path, new_row):
    rows = []
    with open(file_path, mode="r") as file:
        reader = csv.DictReader(file)
        rows = list(reader)

    rows.append(new_row)

    with open(file_path, mode="w", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=rows[0].keys())
        writer.writeheader()
        writer.writerows(rows)


def run_qhw4(args):
    from OGOAT.src.Ort.run_ort import main as run_ort_main
    from dmacompiler import BackEnd, set_dev_gen, DevGen

    sys.path.append(qhw4_path)
    from aie4_models import build_aie4
    from aie4_models.utils.run_meta_runtime import main as run_meta_runtime_main

    set_dev_gen(DevGen.Aie4)
    run_ort_args = argparse.Namespace(
        model_name=args.model_path,
        ld=0,
        all=False,
        idx=0,
        data_dump=args.data_dump,
        edges="all",
        out_dir=args.output_dir,
    )
    run_ort_main(vars(run_ort_args))
    model_path = args.model_path[:]
    model_full_path = Path(args.model_path).expanduser().resolve()
    model_name = ntpath.split(model_path)[-1]
    model_name = os.path.splitext(model_name)[0]
    fused_path = os.path.join(args.output_dir, model_name + "_mod_nhwc_fused.onnx")
    ir_json_path = os.path.join(args.output_dir, model_name + "_mod_nhwc_fused_IR.json")
    unique_json_path = os.path.join(
        args.output_dir, model_name + "_mod_nhwc_fused_IR_unique_nodes.json"
    )
    tensor_map_path = os.path.join(args.output_dir, model_name + "_tensor_map.json")
    data_dir = os.path.join(args.output_dir, "DataGen", "Consts")
    print("Running qhw4 compile_model...")
    build_aie4.compile_model(
        fused_path,
        ir_json_path,
        data_dir,
        model_full_path,
        "cert",
        args.output_dir,
        read_model_data=True,
        unique_nodes_path=unique_json_path,
        tensor_map_json=tensor_map_path,
        skip_operators=args.aie4_skip_op,
        num_workers=args.aie4_num_workers,
        include_operators=args.aie4_include_op,
        block_id=args.aie4_layer_ids,
        set_qdq_fp16=args.aie4_is_qdq_fp16
    )
    print("qhw4 compile_model is successfully done!")
    tiling_path = os.path.join(
        args.output_dir, model_name + "_mod_nhwc_fused.onnx_alloc.json"
    )
    cut_graphs_path = os.path.join(args.output_dir, "cut_graphs")
    elf_path = os.path.join(args.output_dir, "model_elf", "control.elf")
    run_meta_runtime_args = argparse.Namespace(
        tiling_json=tiling_path,
        IR_json=ir_json_path,
        hw_data_path=args.output_dir,
        out_dir=cut_graphs_path,
        elf=elf_path,
        data_folder="",
        tensor_map=tensor_map_path,
    )
    print("Running run_meta_runtime...")
    run_meta_runtime_main(run_meta_runtime_args)
    print("run_meta_runtime is successfully done!")


def main(args):
    logging.info(args)
    output_dir = args.output_dir

    model_path = args.model_path[:]
    model_name = ntpath.split(model_path)[-1]

    create_error_status_csv(output_dir, error_status_filename, model_name)

    # bypass_tiler = Tiler.TilerConfig().bypass_tiler
    bypass_scheduler = Scheduler.SchedulerConfig().bypass_scheduler
    args.old_fusion_flow = True if args.test == "build_run" else False

    if is_json_file(args.model_path):
        if is_ir_json_file(args.model_path):

            # copy file to outputs directory
            destination_path = os.path.join(output_dir, model_name)
            if not os.path.exists(destination_path):
                shutil.copy(args.model_path, destination_path)
            else:
                print(
                    f"File {model_name} already exists in {output_dir}, skipping copy."
                )
            # L2 fusion

            run_L2()
            run_Tiler_and_Scheduler(args, output_dir, model_name, bypass_scheduler)
        else:
            # NOTE: Temp method to bypass the first three stages of WAIC
            logging.info("Running from scheduler stage and beyond")
            dummy_folder = os.path.join(output_dir, "dummy_node")
            if os.path.exists(dummy_folder):
                shutil.rmtree(dummy_folder)  # delete the history nodes
                os.makedirs(dummy_folder)
            shutil.copy(
                args.model_path,
                os.path.join(output_dir, "dummy_node", "dummy_node.json"),
            )
            run_Scheduler(args, bypass_scheduler)
    else:
        logging.info("Running full E2E WAIC stack")
        if not is_onnx_file(args.model_path):
            raise ValueError(
                f"Invalid file type. Expects either onnx or json (tiler output) as input to WAIC. {args.model_path}: "
            )
        run_L1(args)
        if args.qhw4_runner:
            print("QHW4 flow...")
            run_qhw4(args)
            return
        if "2+" in args.skip_step:
            print("Termination after L1 fusion requested.")
            return
        if args.cpp_me:
            run_cpp_ME(args)
            return
        else:
            run_L2()
            run_Tiler_and_Scheduler(args, output_dir, model_name, bypass_scheduler)

    if not is_json_file(args.model_path):
        layer_summary(args, bypass_scheduler)

    if args.lsf:
        Run_Single_build_and_sim_lsf(args)
    else:
        if args.build_txn == "all":
            # DMA Compiler
            run_Backend(output_dir, bypass_tiler, args.build_txn, BackEnd.TxnHostPatch)
            # build txn
            run_build_and_sim(
                output_dir,
                args.combine_kernels,
                bypass_tiler,
                args.build_txn,
                args.build_txn,
                BackEnd.TxnHostPatch,
                args,
            )
        if args.run_sim == "all":
            # DMA Compiler
            run_Backend(output_dir, bypass_tiler, args.run_sim, BackEnd.Adf)
            # run sim
            run_build_and_sim(
                output_dir,
                args.combine_kernels,
                bypass_tiler,
                args.run_sim,
                args.build_txn,
                BackEnd.Adf,
                args,
            )

    print(f"Results dumped in : {output_dir}")

    if args.test == "build_run":
        from HW_requirements.test_script import HW_test

        HW_test(
            output_dir,
            HW_req,
            xclbin=1,
            xclbin_path="",
            overlay=args.overlay,
            use_bsub=not args.local,
            host=args.HW_IP,
            perf_testing=args.perf_testing,
            golden_io=args.golden_io,
            rename=args.rename,
            profile_perf=args.profile_perf,
            rel_err_pc=args.rel_err_pc,
            disable_fast_pm=args.disable_fast_pm,
        )
        op_file = glob.glob(args.output_dir + "/output*.json")
        if op_file:
            with open(op_file[0], "r") as file:
                for i in json.load(file):
                    update_csv_row(
                        args.output_dir + "/" + error_status_filename,
                        condition_key="Op",
                        condition_value=i["Shape"],
                        update_key="HW",
                        update_value=i["Pass or Fail"],
                    )
                    update_csv_row(
                        args.output_dir + "/" + error_status_filename,
                        condition_key="Op",
                        condition_value=i["Shape"],
                        update_key="Iteration TIme (us)",
                        update_value=i["iterations time(us)"],
                    )
                    update_csv_row(
                        args.output_dir + "/" + error_status_filename,
                        condition_key="Op",
                        condition_value=i["Shape"],
                        update_key="L2 norm",
                        update_value=i["L2 norm"],
                    )
    elif args.test == "run":
        from HW_requirements.test_script import HW_test

        HW_test(
            output_dir,
            HW_req,
            xclbin=0,
            xclbin_path=args.output_dir,
            overlay=args.overlay,
            use_bsub=not args.local,
            host=args.HW_IP,
            perf_testing=args.perf_testing,
            golden_io=args.golden_io,
            rename=args.rename,
            profile_perf=args.profile_perf,
            rel_err_pc=args.rel_err_pc,
            disable_fast_pm=args.disable_fast_pm,
        )
        op_file = glob.glob(args.output_dir + "/output*.json")
        if op_file:
            with open(op_file[0], "r") as file:
                for i in json.load(file):
                    update_csv_row(
                        args.output_dir + "/" + error_status_filename,
                        condition_key="Op",
                        condition_value=i["Shape"],
                        update_key="HW",
                        update_value=i["Pass or Fail"],
                    )
                    update_csv_row(
                        args.output_dir + "/" + error_status_filename,
                        condition_key="Op",
                        condition_value=i["Shape"],
                        update_key="Iteration TIme (us)",
                        update_value=i["iterations time(us)"],
                    )
                    update_csv_row(
                        args.output_dir + "/" + error_status_filename,
                        condition_key="Op",
                        condition_value=i["Shape"],
                        update_key="L2 norm",
                        update_value=i["L2 norm"],
                    )


def Run_Single_build_and_sim(args):
    logging.info(args)
    bypass_tiler = [
        "Conv",
        "Concat",
        "Transpose",
        "Slice",
        "Slice_qdq",
        "Slice_neg",
        "Resize",
        "MHA",
        "DepthToSpace",
        "Quant",
        "Dequant",
        "BilinearResize",
        "MaxPool",
    ]
    if args.build_txn != "none":
        # DMA Compiler
        run_Backend(args.output_dir, bypass_tiler, args.build_txn, BackEnd.TxnHostPatch)
        # build txn
        run_build_and_sim(
            args.output_dir,
            args.combine_kernels,
            bypass_tiler,
            args.build_txn,
            args.build_txn,
            BackEnd.TxnHostPatch,
            args,
        )
        if args.test == "build_run":
            print(f"Run test on HW")
            from HW_requirements.test_script import HW_test

            HW_test(
                args.output_dir,
                HW_req,
                xclbin=1,
                xclbin_path="",
                overlay=args.overlay,
                use_bsub=not args.local,
                host=args.HW_IP,
                perf_testing=args.perf_testing,
                golden_io=args.golden_io,
                rename=args.rename,
                profile_perf=args.profile_perf,
                rel_err_pc=args.rel_err_pc,
                disable_fast_pm=args.disable_fast_pm,
            )
            op_file = glob.glob(args.output_dir + "/output*.json")
            if op_file:
                with open(op_file[0], "r") as file:
                    for i in json.load(file):
                        update_csv_row(
                            args.output_dir + "/" + error_status_filename,
                            condition_key="Op",
                            condition_value=i["Shape"],
                            update_key="HW",
                            update_value=i["Pass or Fail"],
                        )
                        update_csv_row(
                            args.output_dir + "/" + error_status_filename,
                            condition_key="Op",
                            condition_value=i["Shape"],
                            update_key="Iteration TIme (us)",
                            update_value=i["iterations time(us)"],
                        )
                        update_csv_row(
                            args.output_dir + "/" + error_status_filename,
                            condition_key="Op",
                            condition_value=i["Shape"],
                            update_key="L2 norm",
                            update_value=i["L2 norm"],
                        )
        elif args.test == "run":
            from HW_requirements.test_script import HW_test

            HW_test(
                args.output_dir,
                HW_req,
                xclbin=0,
                xclbin_path=args.output_dir,
                overlay=args.overlay,
                use_bsub=not args.local,
                host=args.HW_IP,
                perf_testing=args.perf_testing,
                golden_io=args.golden_io,
                rename=args.rename,
                profile_perf=args.profile_perf,
                rel_err_pc=args.rel_err_pc,
                disable_fast_pm=args.disable_fast_pm,
            )
            op_file = glob.glob(args.output_dir + "/output*.json")
            if op_file:
                with open(op_file[0], "r") as file:
                    for i in json.load(file):
                        update_csv_row(
                            args.output_dir + "/" + error_status_filename,
                            condition_key="Op",
                            condition_value=i["Shape"],
                            update_key="HW",
                            update_value=i["Pass or Fail"],
                        )
                        update_csv_row(
                            args.output_dir + "/" + error_status_filename,
                            condition_key="Op",
                            condition_value=i["Shape"],
                            update_key="Iteration TIme (us)",
                            update_value=i["iterations time(us)"],
                        )
                        update_csv_row(
                            args.output_dir + "/" + error_status_filename,
                            condition_key="Op",
                            condition_value=i["Shape"],
                            update_key="L2 norm",
                            update_value=i["L2 norm"],
                        )
    if args.run_sim != "none":
        # DMA Compiler
        run_Backend(args.output_dir, bypass_tiler, args.run_sim, BackEnd.Adf)
        # run sim
        run_build_and_sim(
            args.output_dir,
            args.combine_kernels,
            bypass_tiler,
            args.run_sim,
            args.build_txn,
            BackEnd.Adf,
            args,
        )


def Run_Single_build_and_sim_lsf(args):
    if args.run_sim != "none":
        mode = "-sim"
        arg_val = args.run_sim
    elif args.build_txn != "none":
        mode = "-txn"
        arg_val = args.build_txn
    else:
        raise RuntimeError("At least -sim or -txn should be provided.")

    if arg_val.isnumeric():
        layer_id_list = [arg_val]
    elif arg_val.replace(",", "").isnumeric():
        layer_id_list = arg_val.split(",")
    else:
        ir_json = []
        for file in os.listdir(args.output_dir):
            if is_ir_json_file(file):
                ir_json.append(file)
        if len(ir_json) == 1:
            with open(os.path.join(args.output_dir, ir_json[0]), "r") as f:
                ir_data = json.load(f)
        else:
            raise Exception(f"Multiple IR json file are found in the output directory")

        layer_id_list = []
        if arg_val == "all":
            for key, val in ir_data.items():
                if os.path.exists(os.path.join(args.output_dir, key)):
                    layer_id_list.append(key)
        else:
            for key, val in ir_data.items():
                if arg_val in val["op_type"] and os.path.exists(
                    os.path.join(args.output_dir, key)
                ):
                    layer_id_list.append(key)

    for layer in layer_id_list:
        os.system(
            f'bsub -R "select[osdistro=rhel && (osver=ws8)]" -R "rusage[mem=16384]" python WAIC.py -o {args.overlay} -mp dummy.onnx -output {args.output_dir} --combine_kernels {args.combine_kernels} {mode} {layer}'
        )


class BuildRunManager:
    """
    Manages the build and run workflows of WAIC.

    This class introduces a streamlined workflow for WAIC by clearly separating
    the build and run stages.
    This flow is triggered when `-mode release` option in WAIC.py is passed

    Build stage (triggered with `-t build` option in WAIC.py):
        1. L1 Fusion
        2. L2 Fusion
        3. Tiler
        4. Scheduler
        5. Backend
        6. Build & Simulation
        7. Xclbin generation

    Run stage (triggered with `-t hw_run` option in WAIC.py):
        1. No Copy: Executes on a Windows machine where all required artifacts are
           already present in the output directory. This mode simply runs the test without copying files.
        2. With Copy (requires `HW_IP` parameter passed to WAIC.py): The legacy flow
           that copies necessary files to the target IP and then executes the hardware run remotely.
    """

    def __init__(self, input_args: argparse.Namespace):
        self.args = input_args
        # TODO: Have universal logger with file dump option available
        # Path variables
        self.project_dir = Path(__file__).resolve().parent
        self.ogoat_src_dir = self.project_dir / "OGOAT" / "src"
        self.l1_fusion_dir = self.ogoat_src_dir / "L1_fusion"
        self.l2_fusion_dir = self.ogoat_src_dir / "L2_fusion"
        self.tiler_dir = self.ogoat_src_dir / "Tiler"
        self.scheduler_dir = self.ogoat_src_dir / "Scheduling_Engine"
        self.dolphin_script = (
            self.project_dir / "HW_requirements" / "dolphin_test_ver4.py"
        )
        self.xrt_exe = (
            self.project_dir
            / "HW_requirements"
            / "xrt_flow_test_patch_datatype_debug.exe"
        )
        self.output_json_file = Path(self.args.output_dir) / "output.json"

        # Args
        self.output_dir = self.args.output_dir
        self.model_path = self.args.model_path
        self.model_name = ntpath.basename(self.model_path)
        self.overlay = self.args.overlay

        self.scheduler = Scheduler.SchedulerConfig()
        # self.tiler = Tiler.TilerConfig()

        create_error_status_csv(self.output_dir, error_status_filename, self.model_name)

    def execute(self) -> int:
        # TODO: Enhance the args validation
        try:
            if self.args.test == "build":
                return self._build()
            elif self.args.test == "hw_run":
                return self._run()
            else:
                print("Unknown flow fix to handle")
                return 2
        except Exception as e:
            print(f"An error occurred in test execution {e}")
            traceback.print_exc()
            return 2

    def _evaluate_results(self) -> bool:
        key = "Pass or Fail"
        try:
            with open(self.output_json_file, "r") as file:
                data = json.load(file)
            results = [entry[key] for entry in data if key in entry]
            return False if any(r.lower() != "pass" for r in results) else True
        except FileNotFoundError:
            print("Output file not generated")
            return False
        except Exception as e:
            print(f"Unexpected error {e}")
            return False

    def _build(self) -> int:
        # Step 1 L1 and L2 Fusion
        if not is_ir_json_file(self.model_path):
            self.args.build_bins_flow = False
            run_L1(self.args)
        run_L2()
        # Step 2 Tiler and Scheduler
        run_Tiler_and_Scheduler(
            self.args, self.output_dir, self.model_name, self.scheduler.bypass_scheduler
        )
        # Step 3 DMA Compiler
        run_Backend(
            self.output_dir,
            self.tiler.bypass_tiler,
            self.args.build_txn,
            BackEnd.TxnHostPatch,
        )
        # Step 4 Build TXN
        run_build_and_sim(
            self.output_dir,
            self.args.combine_kernels,
            self.tiler.bypass_tiler,
            self.args.build_txn,
            self.args.build_txn,
            BackEnd.TxnHostPatch,
            self.args,
        )
        # Step 5 XCL bin generation
        # This entire flow is deprecated with new prebuild check for CI
        xcl_exit_code = 2
        # xcl_exit_code = gen_xclbin(
        #     self.overlay,
        #     self.output_dir,
        #     self.args.local,
        #     is_ir_json_file(self.model_path),
        # )
        return xcl_exit_code

    def _run(self):
        # Flow with traditional copy
        if self.args == "HW_IP":
            from HW_requirements.test_script import HW_test

            HW_test(
                self.output_dir,
                HW_req,
                xclbin=0,
                xclbin_path=self.args.output_dir,
                overlay=self.args.overlay,
                use_bsub=not self.args.local,
                host=self.args.HW_IP,
                perf_testing=self.args.perf_testing,
                golden_io=self.args.golden_io,
                rename=self.args.rename,
                profile_perf=self.args.profile_perf,
                rel_err_pc=self.args.rel_err_pc,
                disable_fast_pm=self.args.disable_fast_pm,
            )
        else:
            log_file_path = os.path.join(self.output_dir, "hw_run.log")
            shutil.copy(self.dolphin_script, self.output_dir)
            shutil.copy(self.xrt_exe, self.output_dir)
            collect_pm_ids(self.output_dir, self.output_dir)
            command = f"python dolphin_test_ver4.py"
            if self.args.perf_testing:
                command += " --perf_testing"
            if self.args.profile_perf:
                command += " --profile_perf"
            if self.args.rel_err_pc:
                command += " --rel_err_pc"
            with open(log_file_path, "w") as log_file:
                result = subprocess.run(
                    ["powershell", "-Command", command],
                    cwd=self.output_dir,
                    stdout=log_file,
                    stderr=subprocess.STDOUT,
                    text=True,
                )
            # TODO: Redirect via logger
            with open(log_file_path, "r") as log_file:
                log_contents = log_file.read()
                print("Output:")
                print(log_contents)

            if result.returncode != 0:
                print(f"Error: Dolphin run failed with return code {result.returncode}")
                return result.returncode

            passed = self._evaluate_results()
            if not passed:
                return 1
            return 0


def _str2bool(x: str) -> bool:
    """Handle CLI boolean. Replicated from aie4 to have same cli"""
    v = str(x).lower()
    if v in ("true", "1", "yes", "y", "t"):
        return True
    if v in ("false", "0", "no", "n", "f"):
        return False
    raise argparse.ArgumentTypeError("Expected true/false, 1/0 or yes/no")


def waic_main_func():
    parser = argparse.ArgumentParser(
        description="Windows AI Compiler (WAIC) - build and run MLOPs on AIESim",
        usage='use "%(prog)s --help" for more info',
        formatter_class=argparse.RawTextHelpFormatter,
    )

    default_output_dir = os.path.join(os.path.dirname(__file__), "WAIC_Outputs")
    # Required args
    parser.add_argument(
        "-mp",
        "--model_path",
        required=True,
        help="Path to onnx model (or JSON) and output destination",
    )
    parser.add_argument(
        "-to_i8",
        "--int4_to_int8",
        help="path to additional model data file for large models. Optional Field. Default value = 0",
        default="0",
    )
    parser.add_argument(
        "-o",
        "--overlay",
        default="8x4",
        choices=["4x4", "8x4"],
        help="Name of overlay to run",
    )
    # [Optional] => HW Args
    parser.add_argument(
        "-clean",
        "--delete_dir",
        help="delete output directory if it already exists",
        action="store_true",
    )
    parser.add_argument(
        "-ck", "--combine_kernels", help="Use combine kernel file", default="False"
    )
    parser.add_argument(
        "-txn",
        "--build_txn",
        default="none",
        help="Generate bin for each OP, 'all', 'none', '<layer number>'",
    )
    parser.add_argument(
        "-sim",
        "--run_sim",
        default="none",
        help="run sim for each Op, 'all', 'none', '<layer number>'",
    )
    parser.add_argument(
        "-t",
        "--test",
        choices=["build_run", "run", "none", "build", "hw_run"],
        help="Run the test flow on HW if new binaries are generated",
    )
    parser.add_argument(
        "-d",
        "--device",
        help="Name of device to run; e.g., strix or phoenix etc. Default = 'strix'",
        choices=["strix", "med", "swv"],
        default="strix",
    )
    parser.add_argument(
        "-skip",
        "--skip_step",
        nargs="+",
        default=[],
        action="extend",
        help="Skip WAIC step, none, 1.0: skip L1 shape inference, 2+: skip everything after L1 fusion",
    )
    parser.add_argument(
        "--cpp_me",
        help="call CPP ME instead of Python ME",
        action="store_true",
    )
    parser.add_argument(
        "--fusion_seq",
        help="Force a specific fusion sequence file to be used instead of the default one",
    )
    parser.add_argument(
        "--target",
        help="Use a specific fusion seq file according to target",
    )
    parser.add_argument(
        "-O",
        "--optimization_level",
        help="Set an optimization level",
        choices=["0", "1", "2", "3"],
        default="1",
    )
    parser.add_argument("-lsf", "--lsf", help="use lsf", action="store_true")
    parser.add_argument(
        "-HW_IP", "--HW_IP", help="Set HW IP address", default="10.228.45.202"
    )
    parser.add_argument(
        "--perf_testing", action="store_true", help="Enable performance testing mode."
    )
    parser.add_argument(
        "-golden_io",
        "--golden_io",
        nargs="*",
        help=(
            "Enable golden IO testing mode. "
            "Specify subfolders (e.g., 'conv', 'psmu', 'mha'). "
            "Include 'update' to replace golden files using DES -> SRC."
            "If no subfolders are given, all available subfolders will be used."
        ),
    )

    parser.add_argument(
        "-rename",
        "--rename",
        help="Rename layers folder in WAIC_Outputs.",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "-profile_perf",
        "--profile_perf",
        help="xrt recort_timer profiling",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "-rel_err_pc",
        "--rel_err_pc",
        help="Use average relative error for HW test",
        action="store_true",
        default=False,
    )
    # [Optional] => DBG log related args
    parser.add_argument(
        "-bfm",
        "--tiler_bfm",
        help="Use tiler bfm instead of actual tiler",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "-bfm_mode",
        "--tiler_bfm_mode",
        choices=["M4K1N8", "M1K1N32"],
        help="Tensor Split",
        default="M4K1N8",
    )
    parser.add_argument(
        "-vcd",
        "--dump_waves",
        help="Dump vcd trace from AIESIM run",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "-kdbg",
        "--kernel_debug",
        help="Enable kernel debug print and large program memory",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--frontend_only",
        help="Runs Front end only (till DMA Compiler stage)",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--cpp_fe",
        help="Path to the shared library interface to compile with flexml",
    )
    parser.add_argument(
        "-dbg",
        "--debug",
        help="Dump dbg log to 'dbg_log.txt'",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "-df", "--debug_file_name", help="Debug log file name", default="dbg_log.txt"
    )
    parser.add_argument(
        "-v",
        "--verbose",
        choices=["debug", "info", "error"],
        help="Verbosity for debug logs",
        default="debug",
    )
    parser.add_argument(
        "--call_DMAC",
        help="Call DMAC directly for OGOAT OPs instead of dumping .py files",
        action="store_true",
        default=False,
    )
    # [Optional] => Profile and call stack graph generation
    parser.add_argument(
        "-p",
        "--profile",
        help="Profile auto scheduler",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "-pf",
        "--profile_graph_name",
        help="Profile graph file name",
        default="dbg_call_graph.png",
    )
    parser.add_argument(
        "-output", "--output_dir", help="output directory", default=default_output_dir
    )
    parser.add_argument(
        "--local",
        action="store_true",
        help="Don't use bsub to build for HW on LSF cluster, build on local machine",
    )
    parser.add_argument(
        "-shape_params",
        "--in_shape_params",
        required=False,
        type=str,
        help="Dynamic shape parameters for inputs as a JSON string. Optional Field. Default value = '{}'",
        default="{}",
    )
    parser.add_argument(
        "--disable_fast_pm",
        action="store_true",
        help="To disable fast pm load, Default = False",
        default=False,
    )
    parser.add_argument(
        "--fixed_input_values",
        required=False,
        type=str,
        help="Fixed input values to the neural network. JSON syntax: input name -> value. Optional Field. Default value = '{}'",
        default="{}",
    )
    parser.add_argument(
        "--default_shape_params_values",
        required=False,
        type=str,
        help=f"YML file specifying default shape parameters and graph input values. Default: {L1_fusion.default_shape_params_values}",
        default=L1_fusion.default_shape_params_values,
    )
    parser.add_argument(
        "--assert_on_error",
        help="Error out if there's an assertion",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "-j",
        type=int,
        help="Number of workers for parallel Tiler, Scheduler execution (default: auto-detect CPU cores, use -j 1 for sequential execution)",
        default=None,  # Let ProcessPoolExecutor auto-detect by default
    )
    parser.add_argument(
        "--qdq_optimization",
        type=int,
        choices=[0, 1],
        default=0,
        help="Enable QDQ optimization at end of L1 fusion. Default is 0 (disabled).",
    )
    parser.add_argument(
        "--qdq_int16_cleanup",
        type=int,
        choices=[0, 1],
        default=1,
        help="Enable int16 QDQ cleanup pass for mixed precision models. Removes int16 QDQ pairs adjacent to int8 QDQ. Default is 1 (enabled).",
    )
    parser.add_argument(
        "-m",
        "--mode",
        choices=["dev", "release"],
        default="release",
        help="dev mode will generate all 6 waic bins [ ctrl.bin, ifm.bin, ofm.bin, param.bin, txn.bin, wgt.bin ] and release mode will generate 3 bins [ctrl.bin, param.bin, txn.bin ]",
    )
    parser.add_argument(
        "-infer_batch",
        "--shape_inference_outputs",
        default=3000,
        type=int,
        help="max batch size during onnx runtime inferencing",
    )
    parser.add_argument(
        "-no_dtype_downcast",
        "--no_dtype_downcast",
        action="store_true",
        help="Disable dtype downcasting during L1 fusion",
    )
    parser.add_argument(
        "-no_dtype_freeze",
        "--no_dtype_freeze",
        action="store_true",
        help="Disable dtype freeze during L1 fusion",
    )

    parser.add_argument(
        "-prebuilt_mladf_mha",
        "--prebuilt_mladf_mha",
        action="store_true",
        help="Copy the prebuilt bins from artifacts dir for MHA_3p0_1col_Transpose_qdq_uint16xuint16xuint16 node",
    )
    parser.add_argument(
        "-pre_pm_assign",
        "--assign_pmid_before_partition",
        action="store_true",
        default=False,
        help="Enable static PM bin allocation before graph partitioning in L1 fusion. Default is False (PM bin allocation done per-subgraph in WAIC_runtime.py)",
    )
    aie4_group = parser.add_argument_group("aie4_options")
    aie4_group.add_argument(
        "--qhw4_runner",
        help="Enable qhw4 flow",
        action="store_true",
        default=False,
    )
    aie4_group.add_argument(
        "-dmp",
        "--data_dump",
        help="Data dump option for run_ort. Default value = wgt",
        default="wgt",
    )
    aie4_group.add_argument(
        "-workers",
        "--aie4_num_workers",
        type=int,
        default=None,
        help="AIE4 number of workers for parallel subgraph compilation.",
    )
    aie4_group.add_argument(
        "-include_op",
        "--aie4_include_op",
        type=str,
        default=None,
        help="AIE4 Comma Separated List of Operators that should be included while compiling.",
    )
    aie4_group.add_argument(
        "-skip_op",
        "--aie4_skip_op",
        type=str,
        default=None,
        help="AIE4 Comma Separated List of Operators that should be skipped while compiling.",
    )
    aie4_group.add_argument(
        "--aie4_layer_ids",
        type=str,
        default=None,
        help="AIE4 Key of block in JSON to compile. Compiles all blocks if not set.",
    )
    aie4_group.add_argument(
        "-fp16",
        "--aie4_is_qdq_fp16",
        type=_str2bool,
        default=True,
        help="AIE4 QDQ datatpye is FP16 or BF16? (Default -> True (QDQ DType is BF16))",
    )

    args = parser.parse_args()

    waic_config.mode = args.mode
    if args.dump_waves:
        assert args.run_sim, f"dump_waves should only be enabled with sim"
    if args.frontend_only:
        assert (
            args.run_sim == "all"
        ), f"frontend_only expected to enabled only for sim mode. run_sim: {args.run_sim}"

    # Since some stage chdir into directories we need the root output dir path to not be relative
    args.output_dir = os.path.abspath(args.output_dir)

    if os.path.exists(args.output_dir):
        if args.delete_dir:
            print(f"Output dir already exist. Deleting it!! {args.output_dir}")
            clear_folder(args.output_dir)
        else:
            print(f"Output dir already exist. It is not cleaned!! {args.output_dir}. ")
    else:
        os.makedirs(args.output_dir)

    if args.debug:
        # Run with Debug log enabled
        debug_file_path = os.path.join(args.output_dir, args.debug_file_name)

        class DEBUG_VERBOSE(Enum):
            debug = logging.DEBUG
            info = logging.INFO
            error = logging.ERROR

            @classmethod
            def str2enum(enum_class, string_val):
                if string_val in enum_class.__members__:
                    return enum_class[string_val]
                else:
                    raise ValueError(
                        "String not found in str2enum. Str: " + str(string_val)
                    )

        verbose = DEBUG_VERBOSE.str2enum(args.verbose).value
        print(f"Saving debug log as : {debug_file_path}")

        logging.basicConfig(
            filename=debug_file_path,
            filemode="w",
            format="[%(asctime)s,%(msecs)d] [%(levelname)s]: %(message)s",
            datefmt="%M:%H:%S",
            level=verbose,
        )
    elif 0:  # direct all logging to stdout
        root = logging.getLogger()
        root.setLevel(logging.DEBUG)

        handler = logging.StreamHandler(sys.stdout)
        handler.setLevel(logging.DEBUG)
        formatter = logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        )
        handler.setFormatter(formatter)
        root.addHandler(handler)

    if (args.build_txn != "none" and args.build_txn != "all") or (
        args.run_sim != "none" and args.run_sim != "all"
    ):
        if args.lsf:
            Run_Single_build_and_sim_lsf(args)
        else:
            Run_Single_build_and_sim(args)

    elif args.profile:
        ## Profiling
        from pycallgraph2 import Config, GlobbingFilter, PyCallGraph
        from pycallgraph2.output import GraphvizOutput

        print(
            "Saving profile graph as as :", os.getcwd() + "/" + args.profile_graph_name
        )
        config = Config()
        config.trace_filter = GlobbingFilter(
            exclude=[
                "pycallgraph.*",
                "custom_dict.*",
            ]
        )
        graphviz = GraphvizOutput(output_file=args.profile_graph_name)
        with PyCallGraph(output=graphviz, config=config):
            main(args)
    else:
        # Run without profiling
        if args.test == "hw_run":
            build_run_manager = BuildRunManager(args)
            exit_code = build_run_manager.execute()
            sys.exit(exit_code)
        else:
            main(args)


if __name__ == "__main__":
    waic_main_func()
