"""
This script builds an operator:
"""

import importlib
import json
import multiprocessing as mp
import os
import shutil
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from glob import glob
from typing import Any, Dict, List, Union
from tabulate import tabulate

# pylint: disable-next=W0611
import buildscripts  # noqa: F401
from buildscripts.common import OperatorsRegistry
import graph.L2L3_allocator as alloc
from cert_sim.build_aiebu_json import CfgItem
import graph_partitioner.onnx_graph_partitioner as ogp
from graph.allocation_types import SubgraphAllocationMode
from dmacompiler import BackEnd
from utils.build_cli import make_build_parser, parse_args_with_cfg
from utils.fuse_alloc_ir_json import update_alloc_json
from utils.get_subgraph_bo_size import build_subgraph_report_async
from utils.utils_common import log, ReadBins
from utils.build_utils import (
    parse_json_to_dict_with_op, merge_dicts,
    copy_pdi_for_win, unified_di_sim,
    capture_prints_to_file, out_dir_name_from_dict,
    save_intermediate_bin, generate_ctrl_elf,
    build_data_bins, compile_backend,
    capture_logging, prep_read_bins,
    create_output_folder, clean_overlay,
    create_fused_package, values_for_keys,
    __copy_read_bins, gen_model_elf,
    get_os_core_count, default_L3_mappings,
    mark_block_compilable, is_chained_di,
    validate_paths, generate_subgraph_nodelists,
    set_datatype, get_skip_ops, map_subgraphs_to_json_keys,
)
from utils.map_unique_layer_ids import process_unique_alloc

CURRDIR = os.environ.get("AIE4_ROOT_DIR")
HOSTDIR = os.path.join(CURRDIR, 'host')


def get_supported_ops(raw_json: dir, selected_ids: List[int], skip_operators: str = None) -> List[int]:
    '''Filter unsupported ops'''
    supported_ops = set(OperatorsRegistry.get_operators())
    if skip_operators:
        supported_ops = supported_ops - set(skip_operators.split(','))
    log(f"Supported Operator: {supported_ops}")
    _remove_ids = []
    for bid in selected_ids:
        block = raw_json.get(bid)
        if block is None:
            log(f"[WARN] Block ID '{bid}' not found in JSON. Skipping.")
            _remove_ids.append(bid)
            continue

        attrs, operator = parse_json_to_dict_with_op(block)

        is_compilable = bool(attrs.get("is_compilable", False))
        if not is_compilable:
            log(f"[WARN] Skipping block {bid}: operator not offloaded '{operator}'.")
            _remove_ids.append(bid)
            continue

        if not operator or operator not in supported_ops:
            log(f"[WARN] Skipping block {bid}: unsupported or missing operator '{operator}'.")
            _remove_ids.append(bid)
            continue

    if _remove_ids:
        log(f"[INFO] Following Layers are unsupported operator: {_remove_ids}")
        selected_ids = [bid for bid in selected_ids if bid not in _remove_ids]
        log(
            f"[INFO] After filtering unsupported blocks, compiling: {selected_ids}")

    return selected_ids


def process_layer_id_str(block_id: Union[str, List[int]], raw_json: dict[str, Any], skip_operators: str = None) -> List[int]:
    '''Handle List[int], Common seperated str and Range of Layer IDs'''
    # Case 1: list of ints
    if isinstance(block_id, list) and all(isinstance(bid, int) for bid in block_id):
        selected_ids = [str(bid) for bid in block_id]

    # Case 2: string
    elif isinstance(block_id, str):
        block_id = block_id.strip()
        if "-" in block_id:  # range
            start, end = map(int, block_id.split("-"))
            selected_ids = [str(i) for i in range(start, end + 1)]
        else:  # comma-separated
            selected_ids = [bid.strip()
                            for bid in block_id.split(",") if bid.strip()]

    else:
        raise ValueError("block_id must be a string or a list of ints")
    log(f"[INFO] Compiling selected blocks: {selected_ids}")

    return get_supported_ops(raw_json, selected_ids, skip_operators)


def get_combined_kernel_from_ops(operators: list[str]):
    """Get Combined Kernel List and Includes from list[operator]"""
    combined_kernel_names_list = []
    combined_kernel_includes: list[str] = ["super.hh"]
    for operator in operators:
        op_cfg = OperatorsRegistry.get_operator(operator)
        combined_kernel_names_list.append(op_cfg["kernel_names"])
        for inc in op_cfg["kernel_includes"]:
            if inc not in combined_kernel_includes:
                combined_kernel_includes.append(inc)

    # Merge kernel_names dictionaries from all relevant ops
    combined_kernel_names = merge_dicts(combined_kernel_names_list)

    return combined_kernel_includes, combined_kernel_names


def get_combined_kernel_list(buildscript: str):
    """Get Combined Kernel List and Includes from buildscript"""
    operators = OperatorsRegistry.get_operators_by_build_script(buildscript)
    return get_combined_kernel_from_ops(operators)


def generate_combined_pdi(raw_json: dict[str, Any],
                          is_gen_pdi: bool,
                          outfolder: str,
                          device: str,
                          skip_operators: str = None) -> tuple[List[str], Dict[str, Any], str]:
    '''Collect kernel metadata'''
    operators = set()
    combined_kernel_names_list = []
    combined_kernel_includes: list[str] = ["super.hh"]
    selected_ids = get_supported_ops(
        raw_json, list(raw_json.keys()), skip_operators)
    for bid in selected_ids:
        block = raw_json.get(bid)
        attrs, operator = parse_json_to_dict_with_op(block)
        operators.add(operator)

        combined_kernel_names_list.append(OperatorsRegistry.get_kernel_names(operator, attrs))
        for inc in OperatorsRegistry.get_kernel_includes(operator, attrs):
            if inc not in combined_kernel_includes:
                combined_kernel_includes.append(inc)

    # Merge kernel_names dictionaries from all relevant ops
    combined_kernel_names = merge_dicts(combined_kernel_names_list)

    # PDI pre-compile
    if os.name == "nt" or not is_gen_pdi:
        pdi_dir = copy_pdi_for_win(combined_kernel_names, combined_kernel_includes)
    else:
        compile_operator(
            "pdi", {"input": [0]}, args_target="sim",
            out_folder=outfolder,
            combined_kernel_names=combined_kernel_names,
            combined_kernel_includes=combined_kernel_includes,
            device=device, dump_vcd=False,
        )
        pdi_dir = os.path.join(outfolder, "op_pdi_shape_input_0", "Work", "ps", "asm")

    return combined_kernel_includes, combined_kernel_names, pdi_dir


@dataclass(frozen=True)
class SubgraphPdiInfo:
    """
    Per-subgraph PDI mapping entry.

    This dataclass struct is to store:
    - pdi_path: filesystem path to the selected prebuilt/generated PDI
    - combined_kernel_includes: include headers needed by the combined kernel
    - combined_kernel_names: wrapper names included in the combined kernel
    """
    pdi_path: str
    combined_kernel_includes: List[str]
    combined_kernel_names: List[str]


@dataclass
class SubgraphTask:
    """Configuration to compile a single subgraph"""
    idx: int
    subgraph: Any
    alloc_json: Dict[str, Any]
    outfolder: str
    device: str
    read_model_data: bool
    model_data_path: str
    build_pdi: bool
    combined_kernel_mapping: SubgraphPdiInfo


def _compile_json_task(task: SubgraphTask) -> tuple[int, str, str]:
    """
    Worker function that runs compile_json() for a single subgraph.
    Executed in a separate process to isolate os.chdir() side effects.
    """
    subgraph_name = f"subgraph_{task.idx}"
    fused_dir = compile_json(
        task.alloc_json,
        args_target="dataflow_cert",
        outfolder=task.outfolder,
        block_id=task.subgraph,
        device=task.device,
        build_pdi=task.build_pdi,
        subgraph_suffix=subgraph_name,
        read_model_data=task.read_model_data,
        model_data_path=task.model_data_path,
        raise_on_failure=True,
        combined_kernel_mapping=task.combined_kernel_mapping,
    )
    return task.idx, subgraph_name, fused_dir


def _compile_subgraphs_parallel(
    subgraphs: Dict[int, list[int]],
    alloc_json: dict[str, Any],
    outfolder: str,
    device: str,
    read_model_data: bool,
    model_data_path: str,
    max_workers: int = 2,
    override_build_pdi: bool = True,
    pdi_mapping: Dict[int, SubgraphPdiInfo] = None,
) -> list[CfgItem]:
    """
    Compile subgraph 0 first (sequentially, build_pdi=True),
    then run remaining subgraphs in parallel (build_pdi=False).
    """
    def validate_artifacts(path: str) -> None:
        """
        Ensure that required compilation artifacts exist in the folder.
        Raises RuntimeError if any are missing.
        """
        required = ["control.asm", "wgt.bin", "param.bin"]
        missing = [f for f in required if not os.path.exists(os.path.join(path, f))]
        if missing:
            raise RuntimeError(f"Missing expected artifacts: {', '.join(missing)} in {path}")

    # Run subgraph 0 sequentially (build_pdi=True)
    log("[INFO] Compiling subgraph_0 (build_pdi=True)")
    name0 = "subgraph_0"
    path0 = compile_json(
        alloc_json,
        args_target="dataflow_cert",
        outfolder=outfolder,
        block_id=subgraphs[0],
        device=device,
        build_pdi=True and override_build_pdi,
        subgraph_suffix=name0,
        read_model_data=read_model_data,
        model_data_path=model_data_path,
        raise_on_failure=True,
        combined_kernel_mapping=pdi_mapping[0],
    )
    merged: list[CfgItem] = [CfgItem(id=name0, path=path0)]

    # Compile remaining subgraphs in parallel (build_pdi=False)
    if len(subgraphs) == 1:
        return merged

    log(
        f"[INFO] Launching parallel compile for {len(subgraphs) - 1} subgraphs...")
    ctx = mp.get_context("spawn")
    tasks: List[SubgraphTask] = [
        SubgraphTask(
            idx=i,
            subgraph=subgraphs[i],
            alloc_json=alloc_json,
            outfolder=outfolder,
            device=device,
            read_model_data=read_model_data,
            model_data_path=model_data_path,
            build_pdi=False,
            combined_kernel_mapping=pdi_mapping[i],
        )
        for i in range(1, len(subgraphs))
    ]

    is_failure: bool = False
    failed_subgraphs: list[tuple[int, str]] = []
    results: list[tuple[int, str, str]] = []
    with ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) as pool:
        futures = {pool.submit(_compile_json_task, t): t.idx for t in tasks}
        for fut in as_completed(futures):
            idx = futures[fut]
            try:
                res = fut.result()   # (idx, name, path)
                _, _, path = res
                validate_artifacts(path)
                results.append(res)
            except Exception as e:  # pylint: disable=broad-exception-caught
                log(f"[ERROR] Subgraph {idx} failed to compile: {e}")
                results.append((idx, f"subgraph_{idx}", "<FAILED>"))
                failed_subgraphs.append((idx, str(e)))
                is_failure = True

    # Raise Error if one of the Subgraph Failed
    if is_failure:
        msg_lines = ["One or more subgraphs failed compilation:"]
        for idx, err in failed_subgraphs:
            msg_lines.append(f"  - Subgraph {idx}: {err}")
        raise RuntimeError("\n".join(msg_lines))

    # Sort and merge results
    results.sort(key=lambda x: x[0])
    merged.extend(CfgItem(id=name, path=path) for _, name, path in results)
    return merged


def compile_model(
    fused_model_path: str,
    ir_json_path: str,
    model_data_path: str = "/path/to/model/data",
    unfused_model_path: str = "",
    args_target: str = "dataflow",
    outfolder: str = "Output",
    both_L2L3: bool = False,
    block_id: str = None,
    device: str = "mds",
    read_model_data: bool = False,
    only_graph_gen: bool = False,
    build_pdi: bool = True,
    node_list: str = None,
    unique_nodes_path: str = None,
    tensor_map_json: str = None,
    skip_operators: str = None,
    num_workers: int = None,
    gen_node_list: bool = False,
    save_cut_graphs: bool = False,
    include_operators: str = None,
    set_qdq_fp16: bool = True,
) -> None:
    """
    Compile Tiling JSON for a ONNX graph and compile JSON to hw_package.
    """
    # Validate if all the paths provided exist
    validate_paths(
        locals(),
        keys=[
            "fused_model_path",
            "ir_json_path",
            "model_data_path",
            "unfused_model_path",
            "node_list",
            "unique_nodes_path",
            "tensor_map_json",
        ],
    )

    # Set Env Variable for QDQ Data Type
    set_datatype(set_qdq_fp16)

    # Use include_operators to populate include_operators
    if include_operators is not None:
        skip_operators = get_skip_ops(unique_nodes_path, include_operators)

    # Fallback mechanism for null path
    if not model_data_path:
        model_data_path = "/path/to/model/data"

    # Call Allocator to generate JSON from ONNX
    allocator_json_name = f"{os.path.basename(fused_model_path)}_alloc.json"
    allocator_json_path = os.path.join(outfolder, allocator_json_name)
    if os.path.exists(allocator_json_path):
        log(
            f"[INFO] Skipping Allocator Stage (already exists: {allocator_json_path})")
    else:
        log("[INFO] Allocator Stage Started")
        allocator_args = alloc.Command(
            model_path=fused_model_path,
            fusion_json_path=allocator_json_path,
            both_l2l3=both_L2L3,
            c64=True
        )
        alloc.main(allocator_args)
        update_alloc_json(alloc_json_path=allocator_json_path,
                          ir_json_path=ir_json_path)
        log("[INFO] Allocator Stage Completed")
        log(f"[INFO] Path of generated JSON: {allocator_json_path}")

    # Load allocator JSON ONCE and reuse
    with open(allocator_json_path, "r", encoding="utf-8") as f:
        alloc_json: dict[str, Any] = json.load(f)

    is_cert_backend = args_target == "cert"

    # Compile the genenrated JSON in `dataflow` mode for partitioner
    unique_nodes_map = process_unique_alloc(
        unique_nodes_path, allocator_json_path, None, True)

    print("\nUnique nodes in the graph:\n")
    print(tabulate(
        [[k, ", ".join(v)] for k, v in unique_nodes_map.items()],
        headers=["Unique Node", "Node Instances"],
        tablefmt="github"
        ))

    for node in unique_nodes_map:
        compile_json(alloc_json, "dataflow", outfolder,
                     block_id=node, device=device,
                     unique_nodes=unique_nodes_map,
                     skip_operators=skip_operators)

    if block_id:
        combined_kernel_includes, combined_kernel_names, pdi_path = generate_combined_pdi(alloc_json,
                                                                                          is_cert_backend and build_pdi,
                                                                                          outfolder, device, skip_operators)
        layers = process_layer_id_str(block_id, alloc_json, skip_operators)
        for layer in layers:
            default_L3_mappings(alloc_json, layer)
            suffix = "layer_" + layer
            compile_json(alloc_json, args_target, outfolder, layer,
                         device=device, build_pdi=False, read_model_data=read_model_data,
                         model_data_path=model_data_path, subgraph_suffix=suffix,
                         copy_ifm_ofm=True, combined_kernel_mapping=SubgraphPdiInfo(pdi_path,
                                                                                    combined_kernel_includes,
                                                                                    combined_kernel_names))

    else:
        if is_cert_backend:
            log("[INFO] Partitioner Stage Started")
            subgraphs, subgraph_id_to_name_map, subgraph_to_name, subgraph_ios = ogp.main(fused_model_path,
                                                                                          unfused_model_path,
                                                                                          outfolder,
                                                                                          alloc_json,
                                                                                          tensor_map_json,
                                                                                          save_cut_graphs,
                                                                                          node_list=node_list)

            combined_kernel_includes, combined_kernel_names, pdi_path = generate_combined_pdi(
                alloc_json, is_cert_backend and build_pdi, outfolder, device, skip_operators)

            sg_pdi_mapping: Dict[int, SubgraphPdiInfo] = {}

            for sg_idx in range(len(subgraphs)):
                sg_pdi_mapping[sg_idx] = SubgraphPdiInfo(
                    pdi_path=pdi_path,
                    combined_kernel_includes=combined_kernel_includes,
                    combined_kernel_names=combined_kernel_names,
                )

            log("[INFO] Subgraph Level Allocator Stage Started")
            allocator_args = alloc.CommandSubgraph(
                model_path=fused_model_path,
                fusion_json_path=allocator_json_path,
                both_l2l3=both_L2L3,
                c64=True,
                fusion_json=alloc_json,
                subgraph_ops=subgraph_to_name,
                subgraph_ios=subgraph_ios,
                allocation_mode=SubgraphAllocationMode.CONTINUOUS
            )
            alloctor_topo_order: Dict[int, list[str]] = alloc.update_allocations(allocator_args)
            subgraph_topo_order: Dict[int, list[int]] = map_subgraphs_to_json_keys(allocator_json_path, alloctor_topo_order)
            update_alloc_json(alloc_json_path=allocator_json_path, ir_json_path=ir_json_path)

            sg_table = [[idx, layers] for idx, layers in subgraph_topo_order.items()]
            print(f"\n{len(subgraph_topo_order)} NPU Subgraphs:\n")
            print("\n" + tabulate(sg_table, headers=["Subgraph Idx", "Layer IDs"], tablefmt="github") + "\n")
            log("[INFO] Partitioner Stage Completed")

            # Generate per-subgraph nodelists and exit early.
            # Used for subgraph-level testing and validation.
            if gen_node_list:
                generate_subgraph_nodelists(
                    outfolder=outfolder,
                    subgraphs=subgraph_topo_order,
                    alloc_json=alloc_json,
                )

            # Explicit early exit after alloc.json generation
            if only_graph_gen:
                raise SystemExit(0)

            with open(allocator_json_path, "r", encoding="utf-8") as f:
                alloc_json_updated: dict[str, Any] = json.load(f)
            build_subgraph_report_async(
                allocator_json_path, alloc_json_updated, subgraph_topo_order, subgraph_id_to_name_map)
            log("[INFO] Subgraph Level Allocator Stage Completed")

            # Compile the genenrated JSON for CERT Compilation
            workers = num_workers if num_workers else min(32, get_os_core_count())
            merged_elf_cfg = _compile_subgraphs_parallel(
                subgraphs=subgraph_topo_order,
                alloc_json=alloc_json_updated,
                outfolder=outfolder,
                device=device,
                read_model_data=read_model_data,
                model_data_path=model_data_path,
                max_workers=workers,
                override_build_pdi=build_pdi,
                pdi_mapping=sg_pdi_mapping,
            )

            gen_model_elf(merged_elf_cfg, outfolder)
            num_hw_package = len(glob(os.path.join(outfolder, "fused_hw_package_subgraph*")))
            if num_hw_package != len(subgraphs):
                print(f"[WARN] Number of HW_Package ({num_hw_package}) is not equal to Number of Subgraphs ({len(subgraphs)}).")


def compile_json(
    json_path: Union[str, dict[str, Any]],
    args_target: str,
    outfolder: str,
    block_id: Union[str, List[int]] = None,
    ifm_path: list[str] = [],
    wgt_path: str = "",
    device: str = "mds",
    build_pdi: bool = True,
    subgraph_suffix: str = None,
    read_model_data: bool = False,
    model_data_path: str = "/path/to/model/data",
    copy_ifm_ofm: bool = False,
    unique_nodes: dict = None,
    skip_operators: str = None,
    raise_on_failure: bool = False,
    combined_kernel_mapping: SubgraphPdiInfo = None,
) -> None:
    """
    Parse a graph JSON and compile one or more blocks.
    Functionalities:
    - If `block_id` is provided (comma-separated), only those blocks are compiled.
      In this case:
        * The first selected block gets `load_input_from_ddr = True`.
        * The last selected block gets `store_output_to_ddr = True`.
    - If `block_id` is not provided, all blocks in the JSON are compiled in file order.
    - Kernel metadata (names/includes) is aggregated across the chosen set of blocks
      and passed to each `compile_operator` call.
    """
    enable_chained_di = is_chained_di()
    # Decide whether we received a path or a dict
    if isinstance(json_path, str):
        with open(json_path, "r", encoding="utf-8") as f:
            raw_json: dict[str, Any] = json.load(f)
    elif isinstance(json_path, dict):
        raw_json = json_path
    else:
        raise TypeError(
            f"json_path must be str or dict[str, Any], not {type(json_path)}"
        )

    is_cert_backend = args_target in ["cert", "dataflow_cert"]
    is_frontend = args_target in ["dataflow", "dataflow_cert"]

    # Resolve which block IDs to compile (preserve order)
    if block_id:
        selected_ids = process_layer_id_str(block_id, raw_json)
    else:
        selected_ids = get_supported_ops(
            raw_json, list(raw_json.keys()), skip_operators)
        log(f"[INFO] Compiling all {len(selected_ids)} blocks from {json_path}...")

    # NOTE: When we want to compile PDI for each subgraph because of PM limitation we need a change here
    combined_pdi_path = None
    if combined_kernel_mapping is not None:
        combined_kernel_includes = combined_kernel_mapping.combined_kernel_includes
        combined_kernel_names = combined_kernel_mapping.combined_kernel_names
        combined_pdi_path = combined_kernel_mapping.pdi_path
    else:
        combined_kernel_includes, combined_kernel_names, _ = generate_combined_pdi(
            raw_json, is_cert_backend and build_pdi, outfolder, device, skip_operators)

    if args_target in ["sim", "cert_sim"]:
        assert len(
            selected_ids) == 1, "Only one layer can be compile in AIESIM or CERT_SIM mode at once"

    # Compile each selected block (in order)
    shim_prm_offset = 0
    shim_wgt_offset = 0
    # Log shim offsets in CSV
    shim_log = []
    op_output_dirs: Dict[int, str] = {}

    # only set DDR flags when a subset is specified
    subset_mode = bool(block_id)
    # keep order, skip missing
    valid_ids = [bid for bid in selected_ids if bid in raw_json]

    # Handle first read bins args
    first_read_bins = ReadBins(read_ifm=0, read_wgt=0)
    if ifm_path and len(ifm_path) > 0:
        first_read_bins.read_ifm = 1
    if wgt_path:
        first_read_bins.read_wgt = 1
    log(f"[INFO] {first_read_bins}")
    intermediate_bin_dir = os.path.join(
        CURRDIR, outfolder, "intermediate_bins")

    failed_blocks: list[tuple[str, str, str]] = []  # (bid, op, err)
    for idx, bid in enumerate(valid_ids):
        block = raw_json[bid]
        config_dict, operator = parse_json_to_dict_with_op(block)
        if not operator:
            log(f"[WARN] Skipping block {bid}: missing 'op' field.")
            continue

        # Determine prev_layer_names
        prev_layer_names = []
        read_bins = ReadBins(read_ifm=0, read_wgt=0)
        if enable_chained_di:
            if "input_name" in config_dict:
                prev_layer_names.append(
                    config_dict["input_name"].replace("/", "_"))
            else:
                i = 0
                key = f"input{i}_name"
                while key in config_dict:
                    prev_layer_names.append(config_dict[key].replace("/", "_"))
                    i += 1
                    key = f"input{i}_name"

            # Determine read_bins
            read_bins = ReadBins(read_ifm=1, read_wgt=0)
            if subset_mode:
                if idx == 0:
                    read_bins = first_read_bins
                    __copy_read_bins(
                        config_dict, intermediate_bin_dir, ifm_path, wgt_path)
            else:
                try:
                    if int(bid) == 0:
                        read_bins = first_read_bins
                        __copy_read_bins(
                            config_dict, intermediate_bin_dir, ifm_path, wgt_path)
                except ValueError:
                    pass  # non-integer bid, ignore special case
        if not enable_chained_di or is_frontend:
            read_bins = ReadBins(read_ifm=0, read_wgt=0)

        # Special handling only for subset mode
        if subset_mode:
            if idx == 0:
                config_dict["load_input_from_ddr"] = True
            if idx == len(valid_ids) - 1:
                config_dict["store_output_to_ddr"] = True

        log(f"\n[INFO] Compiling block {bid}: operator = {operator}")
        try:
            shim_prm_offset_new, shim_wgt_offset_new, op_folder_path = compile_operator(
                operator,
                config_dict,
                args_target=args_target,
                out_folder=outfolder,
                json_mode=True,
                json_block_id=bid,
                shim_prm_offset=shim_prm_offset,
                shim_wgt_offset=shim_wgt_offset,
                combined_kernel_names=combined_kernel_names,
                combined_kernel_includes=combined_kernel_includes,
                read_bins=read_bins,
                prev_layer_names=prev_layer_names,
                device=device,
                read_model_data=read_model_data,
                model_data_path=model_data_path,
            )
            shim_log.append((bid, operator, shim_prm_offset, shim_wgt_offset))
            shim_prm_offset, shim_wgt_offset = shim_prm_offset_new, shim_wgt_offset_new
            op_output_dirs[int(bid)] = op_folder_path
            # Mark this block as compilable in the raw_json
            block_ids = unique_nodes.get(bid, [bid]) if unique_nodes else [bid]
            mark_block_compilable(raw_json, block_ids, True)
        except Exception as e:  # pylint: disable=broad-exception-caught
            err = str(e)
            failed_blocks.append((str(bid), str(operator), err))
            print(f"[ERR] Failed to compile block {bid} ({operator}): {e}")
            print(traceback.format_exc())
            # Mark this block as not compilable in the raw_json
            block_ids = unique_nodes.get(bid, [bid]) if unique_nodes else [bid]
            mark_block_compilable(raw_json, block_ids, False)

    # If anything failed and caller wants to catch error, raise
    if failed_blocks and raise_on_failure:
        msg_lines = ["compile_json: one or more blocks failed:"]
        for bid, op, err in failed_blocks:
            msg_lines.append(f"  - block {bid} ({op}): {err}")
        raise RuntimeError("\n".join(msg_lines))

    if not op_output_dirs:
        log("[WARN] No blocks compiled; nothing to fuse.")
        return None

    # Metadata
    log(f"[INFO] Operator Compilation Successful for the {len(op_output_dirs)} shapes:")
    for op_name in op_output_dirs.values():
        log(f'\t{op_name.split("/")[-1]}')

    if is_cert_backend:
        subgraph_fused_dir = create_fused_package(values_for_keys(op_output_dirs),
                                                  outfolder, subgraph_suffix,
                                                  shim_log, enable_chained_di or copy_ifm_ofm,
                                                  valid_ids,
                                                  combined_pdi_path=combined_pdi_path)
        return subgraph_fused_dir
    return None


# Function to run individual operators
def compile_operator(operator: str,
                     shape: Dict,
                     args_target: str = "dataflow",
                     out_folder: str = "Output",
                     gen_standalone_pdi: bool = False,
                     json_mode: bool = False,
                     json_block_id: int = None,
                     shim_prm_offset: int = 0,
                     shim_wgt_offset: int = 0,
                     combined_kernel_names: dict = None,
                     combined_kernel_includes: list[str] = None,
                     read_bins: ReadBins = ReadBins(read_ifm=0, read_wgt=0),
                     prev_layer_names: list[str] = [],
                     device: str = "mds",
                     read_model_data: bool = False,
                     model_data_path: str = "/path/to/model/data",
                     gen_op_elf: bool = False,
                     dump_vcd: bool = True,) -> None:
    """Compile and build an operator, supports CLI or JSON config."""
    # Backend
    backend_mapping = {"dataflow_cert": -2, "dataflow": -
                       1, "cert_sim": 0, "sim": 0, "cert": 2}
    target = backend_mapping[args_target]
    if target == -1:
        backend = BackEnd.Adf
    elif target == -2:
        backend = BackEnd.CertAsm
    else:
        backend = BackEnd(target)
    is_cert_sim = args_target == "cert_sim"
    supported_ops = OperatorsRegistry.get_operators()
    log("[INFO] List of Supported Operators: ", set(supported_ops))
    if operator not in supported_ops:
        log(f"[ERR] {operator} is not supported.\n"
            "       Please make sure the following three files are present:\n"
            "       1. dataflow/{operator}.py\n"
            "       2. host/{operator}.cpp\n"
            "       3. buildscripts/build_{operator}.py")
        return None

    op_config = OperatorsRegistry.get_operator(operator)
    build_script = op_config["build_script"]
    kernel_names = combined_kernel_names if combined_kernel_names is not None else \
        OperatorsRegistry.get_kernel_names(operator, shape)
    kernel_includes = combined_kernel_includes if combined_kernel_includes is not None else \
        OperatorsRegistry.get_kernel_includes(operator, shape)
    testbench = op_config["testbench"][0]

    # To optionally generate PDI when testing standalone operators
    if gen_standalone_pdi:
        if os.name == "nt":
            copy_pdi_for_win(kernel_names, kernel_includes)
        else:
            compile_operator("pdi", {"input": [0]}, args_target="sim",
                             out_folder=out_folder,
                             combined_kernel_names=kernel_names,
                             combined_kernel_includes=kernel_includes,
                             device=device, dump_vcd=False)

    try:
        op_folder = create_output_folder(
            operator, shape, out_folder, json_mode, json_block_id)
        op_folder_path = os.path.join(CURRDIR, op_folder)

        # Intermediate bin staging folder (outside the op’s working dir)
        intermediate_bin_dir = os.path.join(
            CURRDIR, out_folder, "intermediate_bins")

        os.chdir(op_folder)
        clean_overlay(backend=backend)
        op_buildscript = importlib.import_module(
            f"buildscripts.{build_script[:-3]}")

        # Pprep read_bins before build (derive current layer/output name)
        curr_layer_name = None
        if isinstance(shape, dict) and "output_name" in shape:
            curr_layer_name = str(shape["output_name"]).replace("/", "_")
        # Call helper only if we have a layer name (keeps behavior aligned with PR #2)
        enable_chained_di = is_chained_di()
        if enable_chained_di and curr_layer_name is not None:
            prep_read_bins(curr_layer_name, prev_layer_names,
                           intermediate_bin_dir, read_bins)

        log_file = os.path.join(os.getcwd(), "dataflow_tiling.txt")
        if hasattr(op_buildscript, "get_op") and callable(op_buildscript.get_op):
            op_impl = op_buildscript.get_op()
            with capture_logging(log_file):
                compile_for_backend = [BackEnd.CertAsm,
                                       BackEnd.Adf] if is_cert_sim else [backend]
                for backend_name in compile_for_backend:
                    layer_file_name = 'aie4_dma.cpp' if (
                        backend_name == BackEnd.CertAsm) else 'dma.hpp'
                    shim_prm_offset, shim_wgt_offset = op_impl.run_op(
                        shape,
                        backend_name,
                        kernel_names,
                        kernel_includes,
                        shim_prm_offset,
                        shim_wgt_offset,
                        layer_file_name=layer_file_name,
                        read_bins=read_bins,
                        read_model_data=read_model_data,
                        model_data_path=model_data_path
                    )

            if target >= 0:
                compile_backend(backend, testbench, os.getcwd(),
                                is_cert_sim, device=device,
                                dump_vcd=dump_vcd, is_standalone_op=True)
            elif args_target == "dataflow_cert":
                build_data_bins(os.path.join(HOSTDIR, testbench), os.getcwd())
            if gen_op_elf:
                work_aie4_path = os.path.join(op_folder_path, "Work_AIE4")
                cfg = [CfgItem(id="aie4_models", path=work_aie4_path)]
                generate_ctrl_elf(cfg, work_aie4_path)

        # Persist intermediate bins after compile
        if enable_chained_di and curr_layer_name is not None and backend == BackEnd.CertAsm:
            save_intermediate_bin(
                os.getcwd(), intermediate_bin_dir, curr_layer_name)

    except Exception:   # pylint: disable=W0718
        print(f'[ERR] Operator failed to compile for ->\n'
              f'      Operator: {operator}\n'
              f'      Shape: {out_dir_name_from_dict(shape)}')
        print(traceback.format_exc())
        raise

    finally:
        os.chdir(CURRDIR)

    if json_mode:
        return shim_prm_offset, shim_wgt_offset, op_folder_path
    return None


def main() -> None:
    '''Entry point of the build; constitutes of arg-parser'''
    parser = make_build_parser()
    args = parse_args_with_cfg(parser)

    # Output Dir
    outfolder = args.outfolder
    if args.clean:
        if os.path.exists(outfolder):
            log(
                f"[INFO] Cleaning up: Removing {outfolder} directory and its contents...")
            shutil.rmtree(outfolder, ignore_errors=True)
        else:
            log(f"[INFO] {outfolder} directory does not exist. Skipping cleanup.")
    os.makedirs(outfolder, exist_ok=True)

    # Set env variable before any logging happens
    os.environ["LOG_ENABLED"] = str(args.enable_log)
    os.environ["CHAINED_DI"] = str(args.chain_di)
    os.environ["ML_TIMER_LOG_LEVEL"] = str(args.ml_timeline)
    os.environ["SAVE_SUBGRAPH_JSON"] = str(args.save_cut_graphs)
    if args.cont_alloc:
        os.environ["AIE4_ALLOCATOR_MODE"] = "CONTINUOUS"

    # Read first block IFM,WGT bin paths
    if args.ifm_path:
        ifm_path = [read_bin.strip() for read_bin in args.ifm_path.split(",")]
    else:
        ifm_path = []

    # Dispatch based on input source
    if args.json:
        capture_prints_to_file(compile_json,
                               args.json, args.target, outfolder,
                               block_id=args.layer_ids, ifm_path=ifm_path,
                               wgt_path=args.wgt_path, device=args.device,
                               build_pdi=args.build_pdi,
                               filename=f"{outfolder}/build_json_log.txt")
    elif args.fused_model:
        capture_prints_to_file(compile_model,
                               args.fused_model, args.ir_json,
                               args.model_data_path, args.unfused_model,
                               args.target, outfolder, args.both_L2L3,
                               block_id=args.layer_ids, device=args.device,
                               read_model_data=args.read_model_data, only_graph_gen=args.skip,
                               build_pdi=args.build_pdi, node_list=args.node_list,
                               unique_nodes_path=args.unique_nodes,
                               tensor_map_json=args.tensor_map,
                               skip_operators=args.skip_op, num_workers=args.num_workers,
                               gen_node_list=args.gen_node_list,
                               save_cut_graphs=args.save_cut_graphs,
                               include_operators=args.include_op,
                               set_qdq_fp16=args.is_qdq_fp16,
                               filename=f"{outfolder}/build_model_log.txt")
    elif args.build_sim:
        unified_di_sim(args.build_sim, device=args.device)
    else:
        log("[ERR] No input source provided. Exiting.")


if __name__ == "__main__":
    main()
