'''VCD Analysis for single and batch VCD files'''
import os
import re
import sys
import csv
import gzip
import argparse
from concurrent.futures import ProcessPoolExecutor
from typing import Tuple, Optional, Dict, List
from multiprocessing import cpu_count


# ---------------- Configuration ----------------
def build_signals(from_simnow: bool):
    """Substring matches for the "group" analysis (min first-high / max last-low)"""
    prefix = "simnow_top.simnow_xtlm" if from_simnow else "l.aie_logical.aie_xtlm"
    signals: List[str] = [
        f"{prefix}.math_engine.mem_row.tile_0_1.dma.s2mm_state0.channel_running",
        f"{prefix}.math_engine.mem_row.tile_0_1.dma.s2mm_state1.channel_running",
        f"{prefix}.math_engine.mem_row.tile_0_1.dma.s2mm_state4.channel_running",
        f"{prefix}.math_engine.mem_row.tile_0_1.dma.s2mm_state5.channel_running",
        f"{prefix}.math_engine.mem_row.tile_1_1.dma.s2mm_state0.channel_running",
        f"{prefix}.math_engine.mem_row.tile_1_1.dma.s2mm_state1.channel_running",
        f"{prefix}.math_engine.mem_row.tile_1_1.dma.s2mm_state4.channel_running",
        f"{prefix}.math_engine.mem_row.tile_1_1.dma.s2mm_state5.channel_running",
        f"{prefix}.math_engine.mem_row.tile_2_1.dma.s2mm_state0.channel_running",
        f"{prefix}.math_engine.mem_row.tile_2_1.dma.s2mm_state1.channel_running",
        f"{prefix}.math_engine.mem_row.tile_2_1.dma.s2mm_state4.channel_running",
        f"{prefix}.math_engine.mem_row.tile_2_1.dma.s2mm_state5.channel_running",
    ]
    # The two instruction signals must match BOTH substrings in their full names.
    instr_mat_parts = [f"{prefix}.math_engine.array.tile_0_2.cm.event_trace_0", "instr_matrix"]
    instr_vec_parts = [f"{prefix}.math_engine.array.tile_0_2.cm.event_trace_0", "instr_vector"]

    return signals, instr_mat_parts, instr_vec_parts
# ------------------------------------------------


_TS_UNITS_TO_S = {"s": 1.0, "ms": 1e-3, "us": 1e-6, "ns": 1e-9, "ps": 1e-12, "fs": 1e-15}


def _parse_timescale(header_text: str) -> float:
    '''Parse timescales'''
    m = re.search(r"\$timescale\s+(\d+)\s*([munpf]?s)\s+\$end", header_text, re.IGNORECASE)
    if not m:
        return 1.0  # fallback: treat 1 tick = 1 ns
    return int(m.group(1)) * _TS_UNITS_TO_S[m.group(2).lower()]


def _open_maybe_gz(path: str):
    '''Process compressed VCD'''
    return gzip.open(path, "rt") if path.endswith(".gz") else open(path, "r", encoding="utf-8", errors="ignore")


def _is_vec_high(bits: str) -> bool:
    '''Is Signal High'''
    return '1' in bits.lower()


def _is_vec_low(bits: str) -> bool:
    '''Is Signal Low'''
    vals = [ch for ch in bits if ch in "01"]
    return bool(vals) and all(ch == '0' for ch in vals)


def analyze_vcd(path: str, signals: List[str],
                instr_mat_parts: List[str], instr_vec_parts: List[str]
                ) -> Tuple[int, int, int, int, int]:
    """
    Returns:
      min_first_high_ns, max_last_low_ns, diff_ns, instr_mat_high_ns, instr_vec_high_ns  (all ints)
    """
    with _open_maybe_gz(path) as f:
        header_lines, id2name, scope = [], {}, []
        enddefs = False
        for line in f:
            header_lines.append(line)
            s = line.strip()
            if s.startswith("$scope"):
                parts = s.split()
                if len(parts) >= 3:
                    scope.append(parts[2])
            elif s.startswith("$upscope"):
                if scope:
                    scope.pop()
            elif s.startswith("$var"):
                # $var <type> <width> <id> <name...> $end
                tokens = s.split()
                if len(tokens) >= 5:
                    sym_id = tokens[3]
                    try:
                        name_part = s[s.index(tokens[4]): s.index("$end")].strip()
                    except ValueError:
                        name_part = tokens[4]
                    full_name = ".".join(scope + [name_part])
                    id2name[sym_id] = full_name
            elif "$enddefinitions" in s:
                enddefs = True
                break

        if not enddefs:
            raise RuntimeError("Malformed VCD: $enddefinitions not found")

        sec_per_tick = _parse_timescale("".join(header_lines))
        ns_per_tick = sec_per_tick * 1e9 if sec_per_tick != 1.0 else 1.0  # default 1 tick = 1 ns

        # --- Resolve IDs ---
        # Group signals: substring match; map sym -> signal index
        sym_map: Dict[str, int] = {}
        for i, sig_sub in enumerate(signals):
            sig_sub_s = sig_sub.strip()
            for sym, fullname in id2name.items():
                if sig_sub_s in fullname:
                    sym_map[sym] = i
        if len(sym_map) != len(signals):
            missing = [s for s in signals if all(s.strip() not in fn for fn in id2name.values())]
            if missing:
                raise ValueError(f"Could not find signals (substring) in VCD: {missing}")

        # Instruction signals: require BOTH substrings
        def _match_both(parts: List[str], name: str) -> bool:
            return all(p.strip() in name for p in parts)

        instr_mat_id = next((sym for sym, fn in id2name.items() if _match_both(instr_mat_parts, fn)), None)
        instr_vec_id = next((sym for sym, fn in id2name.items() if _match_both(instr_vec_parts, fn)), None)
        if instr_mat_id is None:
            raise ValueError(f"instr_mat not found matching substrings: {instr_mat_parts}")
        if instr_vec_id is None:
            raise ValueError(f"instr_vector not found matching substrings: {instr_vec_parts}")

        # --- Stream changes ---
        cur_time = 0
        last_time_seen = 0

        first_high_times: List[Optional[int]] = [None] * len(signals)
        last_low_times:  List[Optional[int]] = [None] * len(signals)

        mat_high_start: Optional[int] = None
        vec_high_start: Optional[int] = None
        mat_high_total = 0
        vec_high_total = 0

        in_dumpvars = False

        def handle_group_change(sym: str, value_is_high: Optional[bool], t: int):
            if sym in sym_map:
                idx = sym_map[sym]
                if value_is_high is True and first_high_times[idx] is None:
                    first_high_times[idx] = t
                if value_is_high is False:
                    last_low_times[idx] = t

        def handle_instr_change(sym: str, value_is_high: Optional[bool], t: int):
            nonlocal mat_high_start, vec_high_start, mat_high_total, vec_high_total
            if sym == instr_mat_id:
                if value_is_high is True:
                    if mat_high_start is None:
                        mat_high_start = t
                elif value_is_high is False:
                    if mat_high_start is not None:
                        mat_high_total += (t - mat_high_start)
                        mat_high_start = None
            elif sym == instr_vec_id:
                if value_is_high is True:
                    if vec_high_start is None:
                        vec_high_start = t
                elif value_is_high is False:
                    if vec_high_start is not None:
                        vec_high_total += (t - vec_high_start)
                        vec_high_start = None

        for raw in f:
            line = raw.strip()
            if not line:
                continue

            if line.startswith("$dumpvars"):
                in_dumpvars = True
                continue
            if in_dumpvars:
                if line.startswith("$end"):
                    in_dumpvars = False
                    continue
                if line[0] in "01xXzZ":     # If the line describes a scalar value change, parse it
                    v, sym = line[0], line[1:]
                    handle_group_change(sym, True if v == '1' else (False if v == '0' else None), cur_time)
                    handle_instr_change(sym, True if v == '1' else (False if v == '0' else None), cur_time)
                elif line[0] in "br":       # This checks for vector or real value changes
                    try:
                        bits, sym = line[1:].split()
                    except ValueError:
                        continue
                    hv = _is_vec_high(bits)
                    lv = _is_vec_low(bits)
                    handle_group_change(sym, True if hv else (False if lv else None), cur_time)
                    handle_instr_change(sym, True if hv else (False if lv else None), cur_time)
                continue

            if line[0] == '#':      # This indicates a time marker in the VCD
                cur_time = int(line[1:])
                last_time_seen = cur_time
                continue

            if line[0] in "01xXzZ":
                v, sym = line[0], line[1:]
                handle_group_change(sym, True if v == '1' else (False if v == '0' else None), cur_time)
                handle_instr_change(sym, True if v == '1' else (False if v == '0' else None), cur_time)
                continue

            if line[0] in "br":
                try:
                    bits, sym = line[1:].split()
                except ValueError:
                    continue
                hv = _is_vec_high(bits)
                lv = _is_vec_low(bits)
                handle_group_change(sym, True if hv else (False if lv else None), cur_time)
                handle_instr_change(sym, True if hv else (False if lv else None), cur_time)
                continue

        # Close open high windows at EOF using last_time_seen
        if mat_high_start is not None:
            mat_high_total += (last_time_seen - mat_high_start)
            mat_high_start = None
        if vec_high_start is not None:
            vec_high_total += (last_time_seen - vec_high_start)
            vec_high_start = None

        if any(t is None for t in first_high_times):
            raise ValueError(f"No rising-to-1 event for one or more signals: {signals}")
        if any(t is None for t in last_low_times):
            raise ValueError(f"No falling-to-0 event for one or more signals: {signals}")

        min_high_ns = int(min(first_high_times) * ns_per_tick)
        max_low_ns = int(max(last_low_times) * ns_per_tick)
        diff_ns = int(max_low_ns - min_high_ns)

        mat_ns = int(mat_high_total * ns_per_tick)
        vec_ns = int(vec_high_total * ns_per_tick)
        return min_high_ns, max_low_ns, diff_ns, mat_ns, vec_ns


def run_single(signals: list, instr_mat_parts: list, instr_vec_parts: list, vcd_path: str):
    '''Processs one VCD'''
    tmin, tmax, d, mat_ns, vec_ns = analyze_vcd(vcd_path, signals, instr_mat_parts, instr_vec_parts)
    print(f"[RESULT] min_first_high_ns={tmin}")
    print(f"[RESULT] max_last_low_ns={tmax}")
    print(f"[RESULT] E2E Cycles (ns)={d}")
    print(f"[RESULT] Instr-Matrix MAC Efficiency={mat_ns} ns")
    print(f"[RESULT] Instr-Vector MAC Efficiency={vec_ns} ns")


def _process_vcd_file(args: Tuple[str, str, List[str], List[str], List[str]]) -> Tuple[str, int | str, int | str, int | str]:
    """Process a single VCD file (worker function for parallel processing)"""
    entry, vcd_path, signals, instr_mat_parts, instr_vec_parts = args
    try:
        _, _, d, mat_ns, vec_ns = analyze_vcd(vcd_path, signals, instr_mat_parts, instr_vec_parts)
        print(f"[OK] {entry}: E2E={d} ns | mat={mat_ns} ns | vec={vec_ns} ns")
        return (entry, d, mat_ns, vec_ns)
    except Exception as e:     # pylint: disable=W0718
        print(f"[ERR] {entry}: {e}", file=sys.stderr)
        return (entry, f"ERROR: {e}", "", "")


def run_batch(signals: list, instr_mat_parts: list, instr_vec_parts: list, folder: str, csv_name: str = "vcd_analysis.csv", max_workers: int | None = None):
    '''Process all VCDs in Output folder in parallel'''
    # Collect all valid VCD files
    tasks = []
    for entry in sorted(os.listdir(folder)):
        subdir = os.path.join(folder, entry)
        if not os.path.isdir(subdir):
            continue
        vcd_path = os.path.join(subdir, "trace.vcd")
        if not os.path.exists(vcd_path):
            gz = vcd_path + ".gz"
            vcd_path = gz if os.path.exists(gz) else None
        if vcd_path is None:
            print(f"[WARN] Skipping {entry}: trace.vcd not found", file=sys.stderr)
            continue
        tasks.append((entry, vcd_path, signals, instr_mat_parts, instr_vec_parts))

    # Process in parallel
    rows = [("Operator", "E2E Cycles (ns)", "Instr-Matrix MAC Efficiency", "Instr-Vector MAC Efficiency")]
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(_process_vcd_file, tasks))

    rows.extend(results)

    # Write results
    out_csv = os.path.join(folder, csv_name)
    with open(out_csv, "w", newline="", encoding="utf-8") as fp:
        writer = csv.writer(fp)
        writer.writerows(rows)
    print(f"\n[INFO] Wrote {out_csv}")


def main():
    '''Entry point'''
    ap = argparse.ArgumentParser(description="VCD analyzer: min(first-high) & max(last-low) over N signals + MAC-efficiency (two signals' total high time).")
    ap.add_argument("--vcd", help="Path to a single VCD file (.vcd or .vcd.gz).")
    ap.add_argument("--dir", help='Directory containing subfolders, each with "trace.vcd" (or trace.vcd.gz).')
    ap.add_argument("--from_simnow", action="store_true", help="Use simnow_top.simnow_xtlm prefix instead of l.aie_logical.aie_xtlm")
    ap.add_argument("--workers", type=int, default=cpu_count(), help="Number of parallel workers (default: CPU count)")
    args = ap.parse_args()

    signals, instr_mat_parts, instr_vec_parts = build_signals(args.from_simnow)
    if args.dir:
        run_batch(signals, instr_mat_parts, instr_vec_parts, args.dir, max_workers=args.workers)
    elif args.vcd:
        run_single(signals, instr_mat_parts, instr_vec_parts, args.vcd)
    else:
        print("[ERR] Provide either --vcd <file> or --dir <dir>", file=sys.stderr)
        sys.exit(2)


if __name__ == "__main__":
    main()
