import os
import re
import sys
import subprocess
import shutil
from typing import Iterable, Set, List
import shlex

from build_aie4 import compile_operator
# pylint: disable-next=W0611
import buildscripts  # noqa: F401
from utils.build_utils import set_datatype, is_qdq_fp16
from utils.unique_pdi_variants import pdi_variants as PDI_VARIANTS


OUTPUT_DIR = os.path.join(os.getcwd(), "Output", "op_pdi_shape_input_0")
ASM_SOURCE_DIR = os.path.join(OUTPUT_DIR, "Work", "ps", "asm")

ASM_FILES: List[str] = [
    "aie_asm_elfs.asm",
    "aie_asm_init.asm",
    "aie_asm_enable.asm",
    "pdi.asm",
]


def abort_with_error(message: str) -> None:
    """
    Print an error message to stderr and terminate the program with exit code 1.
    """
    RED = "\033[91m"
    RESET = "\033[0m"
    raise RuntimeError(f"{RED}Error: {message}{RESET}")


def run_command(cmd: Iterable[str], cwd: str | None = None) -> None:
    """
    Execute a shell command, raising an exception if it fails.

    Parameters
    ----------
    cmd:
        Command and arguments as an iterable (e.g. list of strings).
    cwd:
        Optional working directory to run the command in. Defaults to the
        current working directory.
    """
    cwd_to_use = cwd or os.getcwd()
    cmd_str = " ".join(shlex.quote(arg) for arg in cmd)
    print(f"[RUN] {cmd_str} (cwd={cwd_to_use})")
    subprocess.run(list(cmd), cwd=cwd_to_use, check=True)


def stage_and_commit_prebuilt_pdi(repo_root: str, pdi_variant: str) -> None:
    """
    Stage and commit the updated prebuilt PDI directory.

    Parameters
    ----------
    repo_root:
        Path to the git repository root.
    pdi_variant:
        Name of the PDI variant (e.g. 'conv_a16', 'matmul_a16w4').
    """
    dtype_prefix = "fp16" if is_qdq_fp16() else "bf16"
    prebuilt_dir = f"prebuilt/{dtype_prefix}_pdi_with_{pdi_variant}"
    commit_message = f"Prebuilt PDI updated for PDI with {pdi_variant} with data type {dtype_prefix}"

    run_command(["git", "add", prebuilt_dir], cwd=repo_root)
    run_command(["git", "commit", "-m", commit_message], cwd=repo_root)


def copy_asm_files(source_dir: str, dest_dir: str) -> None:
    """
    Copy the generated ASM files from source_dir to dest_dir.

    Parameters
    ----------
    source_dir:
        Directory where the generated .asm files are located.
    dest_dir:
        Destination directory to copy files into.

    Raises
    ------
    SystemExit:
        If any required .asm file is missing.
    """
    os.makedirs(dest_dir, exist_ok=True)
    print(f"Copying ASM files from {source_dir} to {dest_dir}")

    missing_any = False

    for filename in ASM_FILES:
        src = os.path.join(source_dir, filename)
        dst = os.path.join(dest_dir, filename)

        if os.path.isfile(src):
            print(f"Copying {src} → {dst}")
            shutil.copy2(src, dst)
        else:
            print(f"Missing file: {src}", file=sys.stderr)
            missing_any = True

    if missing_any:
        abort_with_error("One or more .asm files were missing. Aborting.")

    print(f"Copied ASM files to {dest_dir}")


def extract_aiecompiler_info(log_path: str) -> tuple[str, str]:
    """
    Extract AIE compiler header and PRIME build ID from a log file.
    """

    header_lines = []
    prime_build = ""

    with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
        lines = f.readlines()

    # Extract header block
    for line in lines:
        if line.strip().startswith("INFO:"):
            break
        if (
            "AI Engine Compiler" in line
            or "Version" in line
            or "Copyright" in line
        ):
            header_lines.append(line.rstrip())

    header_block = "\n".join(header_lines)

    # Extract PRIME build
    for line in lines:
        if "Cmd Line :" in line and "/proj/primebuilds/" in line:
            prime_build = (
                line.split("/proj/primebuilds/", 1)[1]
                .split(os.sep, 1)[0]
            )
            break

    return header_block, prime_build


def extract_pm_size(output_dir: str, filename: str = "0_0.calltree") -> int:
    """
    Read a calltree file from `dir_path` and extract the integer value from the
    "func desc" column for the function name `_main_no_exit_init`.

    For typical calltree rows like:
        0 768 0 0 168 28502 _main_no_exit_init
    this returns 28502.
    """
    dir_path = os.path.join(output_dir, "Work", "aie", "0_0", "Release")
    path = os.path.join(dir_path, filename)
    if not os.path.isfile(path):
        raise FileNotFoundError(path)

    target = "_main_no_exit_init"
    in_table = False

    # Regex explanation:
    # ^\s*(\|\s*)*  -> optional leading pipes
    # (\d+)\s+      -> stack desc
    # (\d+)\s+      -> stack level
    # (\d+)\s+      -> stack level
    # (\d+)\s+      -> call level
    # (\d+)\s+      -> func stack desc
    # (\d+)\s+      -> func desc   <-- WE WANT THIS
    # (\S+)         -> function name
    row_re = re.compile(
        r"^\s*(\|\s*)*(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\S+)"
    )

    with open(path, "r", encoding="utf-8", errors="replace") as f:
        for line in f:
            if not in_table:
                if "Call tree stack and functions sizes" in line:
                    in_table = True
                continue

            m = row_re.match(line)
            if not m:
                continue

            func_desc = int(m.group(7))
            func_name = m.group(8)

            if func_name == target:
                return func_desc

    raise ValueError(f"{target} not found in calltree table")


def save_dict_as_md(data: dict, dest_dir: str, output_dir: str, filename: str = "kernel.md") -> None:
    """
    Save a dictionary as a GitHub-style markdown file using os/shutil only.

    Parameters
    ----------
    data : dict
        Dictionary to write into markdown.
    dest_dir : str
        Destination directory. Created if not present.
    filename : str
        Markdown filename. Default: 'kernel.md'.

    Returns
    -------
    str
        Full path to the written markdown file.
    """
    # Create directory if needed
    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir)

    file_path = os.path.join(dest_dir, filename)

    # Extract info from log
    aie_compiler_header, prime_build = extract_aiecompiler_info(os.path.join(output_dir, "AIECompiler.log"))

    # Build markdown content
    lines = []
    lines.append("# Kernel Metadata")

    # TA Build
    lines.append("## TA Build")
    lines.append("```")
    lines.append(f"{prime_build}")
    lines.append("```")
    lines.append("")

    # AIE Compiler
    lines.append("## AIE Compiler")
    lines.append("```")
    lines.extend(aie_compiler_header.strip().splitlines())
    lines.append("```")
    lines.append("")

    # Kernel Metadata
    lines.append("## Kernels Included")
    lines.append("```")
    for key, value in sorted(data.items(), key=lambda kv: kv[1]):
        lines.append(f"{key}: {value}")
    lines.append("```")

    # Kernel Metadata
    lines.append("## PM Size")
    lines.append("```")
    lines.append(str(extract_pm_size(output_dir)))
    lines.append("```")

    # Write markdown file
    with open(file_path, "w", encoding="utf-8") as f:
        f.write("\n".join(lines))


def build_prebuilt_pdi_for_variant(pdi_variant_name: str,  pdi_config: dict, is_fp16_dtpye: bool) -> None:
    """
    Build and store prebuilt PDI artifacts for a single PDI variant.

    This performs:
      1. Cleanup of the Output directory.
      2. Operator selection for the variant.
      3. Compilation of the 'pdi' operator with combined kernels.
      4. Copying generated ASM files into prebuilt/pdi_with_<variant>.
      5. Git stage + commit of the updated prebuilt directory.

    Parameters
    ----------
    pdi_variant:
        Name of the PDI variant (e.g. 'conv_a16', 'matmul_a16w4').
    """
    # Clean up any previous build outputs
    shutil.rmtree("Output", ignore_errors=True, onerror=None)

    # Compile the synthetic 'pdi' operator
    compile_operator(
        "pdi",
        {"input": [0]},
        args_target="sim",
        combined_kernel_names=pdi_config["combined_kernel_names"],
        combined_kernel_includes=pdi_config["combined_kernel_includes"],
        dump_vcd=False,
    )

    dtype_prefix = "fp16" if is_fp16_dtpye else "bf16"
    dest_dir = f"prebuilt/{dtype_prefix}_pdi_with_{pdi_variant_name}"
    save_dict_as_md(pdi_config["combined_kernel_names"], dest_dir, OUTPUT_DIR)
    copy_asm_files(ASM_SOURCE_DIR, dest_dir)
    stage_and_commit_prebuilt_pdi(os.getcwd(), pdi_variant_name)

    print(
        "\033[92m[SUCCESS] Changes have been staged and committed. "
        "You only need to push now.\033[0m"
    )


def main() -> None:
    """
    Build and commit prebuilt PDI artifacts for all configured PDI variants.

    For each variant in PDI_VARIANTS:
      - Rebuild the PDI with all supported ops.
      - Copy generated ASM files into a variant-specific prebuilt directory.
      - Stage and commit the changes to git.
    """
    errors = []

    for fp16_dtype in (True, False):
        set_datatype(fp16_dtype)

        for pdi_variant_name, pdi_config in PDI_VARIANTS.items():
            print(f"\n=== Building prebuilt PDI for variant: {pdi_variant_name} for data type {fp16_dtype} ===")

            try:
                build_prebuilt_pdi_for_variant(pdi_variant_name, pdi_config, fp16_dtype)
            except Exception as exc:
                errors.append((fp16_dtype, pdi_variant_name, exc))
                print(f"[ERROR] Failed for variant={pdi_variant_name}, fp16_dtype={fp16_dtype}: {exc}")

    if errors:
        print("\n================ COMBINED PDI GENERATION SUMMARY: FAILURES ================")
        for fp16_dtype, pdi_variant_name, exc in errors:
            print(f"- fp16_dtype={fp16_dtype}, variant={pdi_variant_name} -> {type(exc).__name__}: {exc}")
        print("=========================================================\n")
        abort_with_error(f"{len(errors)} PDI builds failed. See summary above.")

    else:
        print("\n================ COMBINED PDI GENERATION SUMMARY: ALL SUCCESS ================\n")


if __name__ == "__main__":
    main()
