"""VCD Analysis for single and batch VCD files - PARALLEL VERSION (FIXED)"""

import os
import re
import sys
import csv
import gzip
import argparse
import mmap
from multiprocessing import Pool, cpu_count
from typing import Tuple, Optional, Dict, List, Union


# ---------------- Configuration ----------------
def build_signals(from_simnow: bool):
    """Substring matches for the "group" analysis (min first-high / max last-low)"""
    prefix = "simnow_top.simnow_xtlm" if from_simnow else "l.aie_logical.aie_xtlm"
    signals: List[str] = [
        f"{prefix}.math_engine.mem_row.tile_0_1.dma.s2mm_state0.channel_running",
        f"{prefix}.math_engine.mem_row.tile_0_1.dma.s2mm_state1.channel_running",
        f"{prefix}.math_engine.mem_row.tile_0_1.dma.s2mm_state4.channel_running",
        f"{prefix}.math_engine.mem_row.tile_0_1.dma.s2mm_state5.channel_running",
        f"{prefix}.math_engine.mem_row.tile_1_1.dma.s2mm_state0.channel_running",
        f"{prefix}.math_engine.mem_row.tile_1_1.dma.s2mm_state1.channel_running",
        f"{prefix}.math_engine.mem_row.tile_1_1.dma.s2mm_state4.channel_running",
        f"{prefix}.math_engine.mem_row.tile_1_1.dma.s2mm_state5.channel_running",
        f"{prefix}.math_engine.mem_row.tile_2_1.dma.s2mm_state0.channel_running",
        f"{prefix}.math_engine.mem_row.tile_2_1.dma.s2mm_state1.channel_running",
        f"{prefix}.math_engine.mem_row.tile_2_1.dma.s2mm_state4.channel_running",
        f"{prefix}.math_engine.mem_row.tile_2_1.dma.s2mm_state5.channel_running",
    ]
    instr_mat_parts = [
        f"{prefix}.math_engine.array.tile_0_2.cm.event_trace_0",
        "instr_matrix",
    ]
    instr_vec_parts = [
        f"{prefix}.math_engine.array.tile_0_2.cm.event_trace_0",
        "instr_vector",
    ]
    return signals, instr_mat_parts, instr_vec_parts


# ------------------------------------------------


_TS_UNITS_TO_S = {
    "s": 1.0,
    "ms": 1e-3,
    "us": 1e-6,
    "ns": 1e-9,
    "ps": 1e-12,
    "fs": 1e-15,
}


def _parse_timescale(header_text: str) -> float:
    """Parse timescales"""
    m = re.search(
        r"\$timescale\s+(\d+)\s*([munpf]?s)\s+\$end", header_text, re.IGNORECASE
    )
    if not m:
        return 1.0
    return int(m.group(1)) * _TS_UNITS_TO_S[m.group(2).lower()]


def _open_maybe_gz(path: str):
    """Process compressed VCD"""
    return (
        gzip.open(path, "rt")
        if path.endswith(".gz")
        else open(path, "r", encoding="utf-8", errors="ignore")
    )


def find_chunk_boundaries(file_path: str, num_chunks: int) -> List[Tuple[int, int]]:
    """Find safe chunk boundaries (at line breaks) for parallel processing"""
    file_size = os.path.getsize(file_path)
    chunk_size = file_size // num_chunks

    boundaries = [0]

    with open(file_path, "rb") as f:
        with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mmapped:
            for i in range(1, num_chunks):
                pos = min(i * chunk_size, file_size)

                # Find next newline to align on line boundary
                while pos < file_size and mmapped[pos:pos + 1] != b"\n":
                    pos += 1
                pos += 1  # Move past the newline

                boundaries.append(pos)

            boundaries.append(file_size)

    # Convert to (start, end) pairs
    chunks = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
    return chunks


def process_chunk(args: Tuple[str, int, int, Dict[str, int], int]) -> Dict:
    """Process a chunk of the VCD file - GROUP SIGNALS ONLY"""
    file_path, start_byte, end_byte, sym_map, chunk_id = args

    # Local state for this chunk
    first_high_times: Dict[int, int] = {}
    last_low_times: Dict[int, int] = {}

    cur_time = 0
    lines_processed = 0
    in_dumpvars = False

    with open(file_path, "rb") as f:
        f.seek(start_byte)

        while f.tell() < end_byte:  # pylint: disable=R1702
            try:
                raw_line = f.readline()
                if not raw_line:
                    break

                line = raw_line.decode("utf-8", errors="ignore").strip()
                lines_processed += 1

                if not line:
                    continue

                # Handle dumpvars section
                if line.startswith("$dumpvars"):
                    in_dumpvars = True
                    continue
                if in_dumpvars:
                    if line.startswith("$end"):
                        in_dumpvars = False
                    continue

                # Time marker (most common pattern after signal changes)
                if line[0] == "#":
                    cur_time = int(line[1:])
                    continue

                # Scalar change (optimize for this - most common case)
                first_char = line[0]
                if first_char in "01":
                    sym = line[1:]

                    # Group signals only
                    idx = sym_map.get(sym)
                    if idx is not None:
                        if first_char == "1":
                            if idx not in first_high_times:
                                first_high_times[idx] = cur_time
                        else:  # '0'
                            last_low_times[idx] = cur_time
                    continue

                # Vector changes (less common)
                if first_char in "br":
                    parts = line[1:].split(None, 1)
                    if len(parts) != 2:
                        continue
                    bits, sym = parts

                    # Simplified high/low check
                    is_high = "1" in bits
                    is_low = "1" not in bits and "0" in bits

                    # Group signals only
                    idx = sym_map.get(sym)
                    if idx is not None:
                        if is_high and idx not in first_high_times:
                            first_high_times[idx] = cur_time
                        elif is_low:
                            last_low_times[idx] = cur_time
                    continue

            except Exception:  # pylint: disable=W0703
                continue  # Skip malformed lines

    return {
        "chunk_id": chunk_id,
        "first_high_times": first_high_times,
        "last_low_times": last_low_times,
        "lines_processed": lines_processed,
    }


def process_instruction_signals(
    path: str, instr_mat_id: str, instr_vec_id: str, ns_per_tick: float
) -> Tuple[int, int]:
    """Process instruction signals serially (cannot be parallelized correctly)"""
    print("[INFO] Processing instruction signals (serial pass)...")

    cur_time = 0
    last_time_seen = 0
    mat_high_start: Optional[int] = None
    vec_high_start: Optional[int] = None
    mat_high_total = 0
    vec_high_total = 0
    in_dumpvars = False
    in_header = True

    with _open_maybe_gz(path) as f:
        for line in f:
            s = line.strip()

            # Skip header
            if in_header:
                if "$enddefinitions" in s:
                    in_header = False
                continue

            if not s:
                continue

            # Handle dumpvars
            if s.startswith("$dumpvars"):
                in_dumpvars = True
                continue
            if in_dumpvars:
                if s.startswith("$end"):
                    in_dumpvars = False
                continue

            # Time marker
            if s[0] == "#":
                cur_time = int(s[1:])
                last_time_seen = cur_time
                continue

            # Scalar change
            if s[0] in "01":
                val, sym = s[0], s[1:]
                is_high = val == "1"

                if sym == instr_mat_id:
                    if is_high:
                        if mat_high_start is None:
                            mat_high_start = cur_time
                    else:  # '0'
                        if mat_high_start is not None:
                            mat_high_total += cur_time - mat_high_start
                            mat_high_start = None

                elif sym == instr_vec_id:
                    if is_high:
                        if vec_high_start is None:
                            vec_high_start = cur_time
                    else:  # '0'
                        if vec_high_start is not None:
                            vec_high_total += cur_time - vec_high_start
                            vec_high_start = None
                continue

            # Vector change
            if s[0] in "br":
                try:
                    bits, sym = s[1:].split(None, 1)
                except ValueError:
                    continue

                is_high = "1" in bits
                is_low = "1" not in bits and "0" in bits

                if sym == instr_mat_id:
                    if is_high:
                        if mat_high_start is None:
                            mat_high_start = cur_time
                    elif is_low:
                        if mat_high_start is not None:
                            mat_high_total += cur_time - mat_high_start
                            mat_high_start = None

                elif sym == instr_vec_id:
                    if is_high:
                        if vec_high_start is None:
                            vec_high_start = cur_time
                    elif is_low:
                        if vec_high_start is not None:
                            vec_high_total += cur_time - vec_high_start
                            vec_high_start = None

    # Close open high windows at EOF
    if mat_high_start is not None:
        mat_high_total += last_time_seen - mat_high_start
    if vec_high_start is not None:
        vec_high_total += last_time_seen - vec_high_start

    mat_ns = int(mat_high_total * ns_per_tick)
    vec_ns = int(vec_high_total * ns_per_tick)

    return mat_ns, vec_ns


def analyze_vcd_parallel(
    path: str,
    signals: List[str],
    instr_mat_parts: List[str],
    instr_vec_parts: List[str],
    num_workers: Optional[int] = None,
    with_efficiency: bool = False,
) -> Tuple[int, int, int, int, int]:
    """
    Parallel VCD analysis - much faster for large files
    Returns: min_first_high_ns, max_last_low_ns, diff_ns, instr_mat_high_ns, instr_vec_high_ns
    """
    if num_workers is None:
        num_workers = cpu_count()

    print(f"[INFO] Using {num_workers} workers for parallel processing")

    # Step 1: Parse header (serial - must be done first)
    print("[INFO] Parsing VCD header...")
    with _open_maybe_gz(path) as f:
        id2name: Dict[str, str] = {}
        scope: List[str] = []
        header_lines = []

        for line in f:
            header_lines.append(line)
            s = line.strip()

            if s.startswith("$scope"):
                parts = s.split()
                if len(parts) >= 3:
                    scope.append(parts[2])
            elif s.startswith("$upscope"):
                if scope:
                    scope.pop()
            elif s.startswith("$var"):
                tokens = s.split()
                if len(tokens) >= 5:
                    sym_id = tokens[3]
                    try:
                        name_part = s[s.index(tokens[4]):s.index("$end")].strip()
                    except ValueError:
                        name_part = tokens[4]
                    full_name = ".".join(scope + [name_part])
                    id2name[sym_id] = full_name
            elif "$enddefinitions" in s:
                break

        sec_per_tick = _parse_timescale("".join(header_lines))
        ns_per_tick = sec_per_tick * 1e9 if sec_per_tick != 1.0 else 1.0

    print(f"[INFO] Parsed {len(id2name)} signals from VCD header.")

    # Step 2: Resolve signal IDs
    sym_map: Dict[str, int] = {}
    for i, sig_sub in enumerate(signals):
        sig_sub_s = sig_sub.strip()
        for sym, fullname in id2name.items():
            if sig_sub_s in fullname:
                sym_map[sym] = i
                break

    if len(sym_map) != len(signals):
        missing = [
            s for s in signals if all(s.strip() not in fn for fn in id2name.values())
        ]
        if missing:
            raise ValueError(f"Could not find signals (substring) in VCD: {missing}")

    def _match_both(parts: List[str], name: str) -> bool:
        return all(p.strip() in name for p in parts)

    instr_mat_id = next(
        (sym for sym, fn in id2name.items() if _match_both(instr_mat_parts, fn)), None
    )
    instr_vec_id = next(
        (sym for sym, fn in id2name.items() if _match_both(instr_vec_parts, fn)), None
    )

    if instr_mat_id is None:
        raise ValueError(f"instr_mat not found matching substrings: {instr_mat_parts}")
    if instr_vec_id is None:
        raise ValueError(
            f"instr_vector not found matching substrings: {instr_vec_parts}"
        )

    print(f"[INFO] Resolved {len(sym_map)} group signals and 2 instruction signals.")

    # Step 3: Split file and process GROUP SIGNALS in parallel
    print("[INFO] Splitting file into chunks...")
    num_chunks = num_workers * 2  # 2x workers for better load balancing
    chunks = find_chunk_boundaries(path, num_chunks)

    print(
        f"[INFO] Split file into {len(chunks)} chunks. Starting parallel processing for group signals..."
    )

    # Prepare arguments for parallel processing (NO instruction signal IDs)
    tasks = [(path, start, end, sym_map, i) for i, (start, end) in enumerate(chunks)]

    # Process in parallel
    with Pool(processes=num_workers) as pool:
        results = pool.map(process_chunk, tasks)

    total_lines = sum(r["lines_processed"] for r in results)
    print(f"[INFO] Processed {total_lines:,} lines across {len(chunks)} chunks")

    # Step 4: Merge GROUP SIGNAL results
    first_high_times: List[int] = []
    last_low_times: List[int] = []

    for i in range(len(signals)):
        min_time: Optional[int] = None
        max_time: Optional[int] = None

        for result in results:
            if i in result["first_high_times"]:
                t = result["first_high_times"][i]
                if min_time is None or t < min_time:
                    min_time = t
            if i in result["last_low_times"]:
                t = result["last_low_times"][i]
                if max_time is None or t > max_time:
                    max_time = t

        if min_time is None:
            raise ValueError(f"No rising-to-1 event for signal index {i}")
        if max_time is None:
            raise ValueError(f"No falling-to-0 event for signal index {i}")

        first_high_times.append(min_time)
        last_low_times.append(max_time)

    # Convert to nanoseconds
    min_high_ns = int(min(first_high_times) * ns_per_tick)
    max_low_ns = int(max(last_low_times) * ns_per_tick)
    diff_ns = int(max_low_ns - min_high_ns)

    # Step 5: Process INSTRUCTION SIGNALS serially (cannot be safely parallelized)
    mat_ns, vec_ns = 0, 0
    if with_efficiency:
        mat_ns, vec_ns = process_instruction_signals(
            path, instr_mat_id, instr_vec_id, ns_per_tick
        )

    return min_high_ns, max_low_ns, diff_ns, mat_ns, vec_ns


def run_single(
    signals: List[str],
    instr_mat_parts: List[str],
    instr_vec_parts: List[str],
    vcd_path: str,
    num_workers: Optional[int] = None,
    with_efficiency: bool = False,
) -> None:
    """Process one VCD using parallel processing"""
    tmin, tmax, d, mat_ns, vec_ns = analyze_vcd_parallel(
        vcd_path,
        signals,
        instr_mat_parts,
        instr_vec_parts,
        num_workers,
        with_efficiency,
    )
    print(f"\n[RESULT] min_first_high_ns={tmin}")
    print(f"[RESULT] max_last_low_ns={tmax}")
    print(f"[RESULT] E2E Cycles (ns)={d}")
    print(f"[RESULT] Instr-Matrix MAC Efficiency={mat_ns} ns")
    print(f"[RESULT] Instr-Vector MAC Efficiency={vec_ns} ns")


def _process_vcd_file(
    args: Tuple[str, str, List[str], List[str], List[str], Optional[int], bool],
) -> Tuple[str, Union[int, str], Union[int, str], Union[int, str]]:
    """Process a single VCD file (worker function for batch processing)"""
    (
        entry,
        vcd_path,
        signals,
        instr_mat_parts,
        instr_vec_parts,
        num_workers,
        with_efficiency,
    ) = args
    try:
        _, _, d, mat_ns, vec_ns = analyze_vcd_parallel(
            vcd_path,
            signals,
            instr_mat_parts,
            instr_vec_parts,
            num_workers,
            with_efficiency,
        )
        print(f"[OK] {entry}: E2E={d} ns | mat={mat_ns} ns | vec={vec_ns} ns")
        return (entry, d, mat_ns, vec_ns)
    except Exception as e:  # pylint: disable=W0718
        print(f"[ERR] {entry}: {e}", file=sys.stderr)
        return (entry, f"ERROR: {e}", "", "")


def run_batch(
    signals: List[str],
    instr_mat_parts: List[str],
    instr_vec_parts: List[str],
    folder: str,
    csv_name: str = "vcd_analysis.csv",
    workers_per_file: Optional[int] = None,
    with_efficiency: bool = False,
) -> None:
    """Process all VCDs in Output folder"""
    # Collect all valid VCD files
    tasks: List[Tuple[str, str, List[str], List[str], List[str], Optional[int]]] = []
    for entry in sorted(os.listdir(folder)):
        subdir = os.path.join(folder, entry)
        if not os.path.isdir(subdir):
            continue
        vcd_path = os.path.join(subdir, "trace.vcd")
        if not os.path.exists(vcd_path):
            gz = vcd_path + ".gz"
            if os.path.exists(gz):
                vcd_path = gz
            else:
                print(f"[WARN] Skipping {entry}: trace.vcd not found", file=sys.stderr)
                continue
        tasks.append(
            (
                entry,
                vcd_path,
                signals,
                instr_mat_parts,
                instr_vec_parts,
                workers_per_file,
                with_efficiency,
            )
        )

    # Process sequentially (each file uses internal parallelism)
    rows: List[Tuple[str, Union[int, str], Union[int, str], Union[int, str]]] = [
        (
            "Operator",
            "E2E Cycles (ns)",
            "Instr-Matrix MAC Efficiency",
            "Instr-Vector MAC Efficiency",
        )
    ]

    for task in tasks:
        result = _process_vcd_file(task)
        rows.append(result)

    # Write results
    out_csv = os.path.join(folder, csv_name)
    with open(out_csv, "w", newline="", encoding="utf-8") as fp:
        writer = csv.writer(fp)
        writer.writerows(rows)
    print(f"\n[INFO] Wrote {out_csv}")


def main() -> None:
    """Entry point"""
    ap = argparse.ArgumentParser(
        description="VCD analyzer (PARALLEL VERSION): min(first-high) & max(last-low) over N signals + MAC-efficiency."
    )
    ap.add_argument("--vcd", help="Path to a single VCD file (.vcd or .vcd.gz).")
    ap.add_argument(
        "--dir",
        help='Directory containing subfolders, each with "trace.vcd" (or trace.vcd.gz).',
    )
    ap.add_argument(
        "--from_simnow",
        action="store_true",
        help="Use simnow_top.simnow_xtlm prefix instead of l.aie_logical.aie_xtlm",
    )
    ap.add_argument(
        "--workers",
        type=int,
        default=None,
        help="Number of parallel workers (default: CPU count)",
    )
    ap.add_argument(
        "--with_efficiency",
        action="store_true",
        help="Compute instruction MAC efficiency",
    )
    args = ap.parse_args()

    signals, instr_mat_parts, instr_vec_parts = build_signals(args.from_simnow)
    if args.dir:
        run_batch(
            signals,
            instr_mat_parts,
            instr_vec_parts,
            args.dir,
            workers_per_file=args.workers,
            with_efficiency=args.with_efficiency,
        )
    elif args.vcd:
        run_single(
            signals,
            instr_mat_parts,
            instr_vec_parts,
            args.vcd,
            num_workers=args.workers,
            with_efficiency=args.with_efficiency,
        )
    else:
        print("[ERR] Provide either --vcd <file> or --dir <dir>", file=sys.stderr)
        sys.exit(2)


if __name__ == "__main__":
    main()
