"""
Parse ml_timeline_metadata.json + record_timer_ts.json and emit ml_timeline_profile.csv

CSV columns:
["Layer ID", "E2E Cycle uC_0", "E2E Cycle uC_2", "E2E Cycle uC_4"]

- metadata JSON maps uc*.asm -> layer_id(str) -> {"start": <ts_id>, "end": <ts_id>}
- record_timer JSON has:
    header.num_buffer_segments = number of uCs present (1..3)
    record_timer_ts = flat list of {"id": <ts_id>, "cycle": <cycle>}
  For each layer, each UC contributes a start/end pair identified by metadata ts ids.
"""

from __future__ import annotations

import argparse
import csv
import json
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional


UC_ORDER = ["uc0.asm", "uc2.asm", "uc4.asm"]
CSV_FIELDS = ["Layer ID", "E2E Cycle uC_0", "E2E Cycle uC_2", "E2E Cycle uC_4"]


def _load_json(path: str | Path) -> Any:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"JSON not found: {p}")
    with p.open("r", encoding="utf-8") as f:
        return json.load(f)


def _normalize_uc_list(num_segments: int) -> List[str]:
    if num_segments < 1 or num_segments > 3:
        raise ValueError(f"num_buffer_segments must be 1..3, got {num_segments}")
    return UC_ORDER[:num_segments]


def _cycles_by_id_per_uc(record_timer: dict, uc_list: List[str]) -> Dict[str, Dict[int, List[int]]]:
    """
    Splits record_timer_ts into per-UC streams.
    Assumption: entries are interleaved per-UC in contiguous blocks:
      [all ids for uc0] [all ids for uc2] [all ids for uc4]
    i.e. total entries is divisible by num_uCs and each UC gets equal count.

    Returns: {uc: {ts_id: [cycle0, cycle1, ...]}}
    """
    ts = record_timer.get("record_timer_ts")
    if not isinstance(ts, list) or not ts:
        raise ValueError("record_timer_ts missing or empty")

    n_uc = len(uc_list)
    if len(ts) % n_uc != 0:
        raise ValueError(
            f"record_timer_ts length ({len(ts)}) not divisible by num_buffer_segments ({n_uc}). "
            "Can't split per-UC reliably."
        )

    chunk = len(ts) // n_uc
    out: Dict[str, Dict[int, List[int]]] = {uc: {} for uc in uc_list}

    for i, uc in enumerate(uc_list):
        seg = ts[i * chunk: (i + 1) * chunk]
        for item in seg:
            if not isinstance(item, dict) or "id" not in item or "cycle" not in item:
                raise ValueError(f"Bad record_timer_ts entry: {item}")
            tid = int(item["id"])
            cyc = int(item["cycle"])
            out[uc].setdefault(tid, []).append(cyc)

    return out


def _numeric_key(k: str) -> int:
    try:
        return int(k)
    except ValueError:
        # fallback: stable sort non-numeric after numeric
        return 10**18


def generate_ml_timeline_csv(metadata_path: str | Path, record_timer_path: str | Path, outdir: str | Path) -> Path:
    """
    Reads metadata + record_timer json and writes ml_timeline_profile.csv to outdir.
    Returns the path to the created CSV.
    """
    metadata = _load_json(metadata_path)
    record_timer = _load_json(record_timer_path)

    num_segments = record_timer.get("header", {}).get("num_buffer_segments")
    if num_segments is None:
        raise ValueError("record_timer['header']['num_buffer_segments'] missing")
    uc_list = _normalize_uc_list(int(num_segments))

    # Build per-UC (ts_id -> [cycles...]) map
    cycles_map = _cycles_by_id_per_uc(record_timer, uc_list)

    # Determine layer IDs from metadata (union across UCs present)
    layer_ids: List[str] = sorted(
        {lid for uc in uc_list for lid in (metadata.get(uc, {}) or {}).keys()},
        key=_numeric_key,
    )

    if not layer_ids:
        raise ValueError("No layer IDs found in metadata for the active UCs.")

    # Validate each UC contains all layer_ids (optional strictness)
    for uc in uc_list:
        if uc not in metadata:
            raise ValueError(f"Metadata missing UC: {uc}")
        missing = [lid for lid in layer_ids if lid not in metadata[uc]]
        if missing:
            raise ValueError(f"Metadata for {uc} missing layer IDs: {missing}")

    # Prepare output dir
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)
    csv_path = outdir / "ml_timeline_profile.csv"

    # Compute E2E cycles per layer per UC: cycle(end_id) - cycle(start_id)
    rows: List[Dict[str, Any]] = []
    for lid in layer_ids:
        row = {
            "Layer ID": lid,
            "E2E Cycle uC_0": "",
            "E2E Cycle uC_2": "",
            "E2E Cycle uC_4": "",
        }

        for uc in uc_list:
            se = metadata[uc][lid]
            start_id = int(se["start"])
            end_id = int(se["end"])

            # cycles for these ids in this UC
            start_cycles = cycles_map[uc].get(start_id, [])
            end_cycles = cycles_map[uc].get(end_id, [])

            if not start_cycles or not end_cycles:
                raise ValueError(
                    f"[{uc}][layer {lid}] Missing cycles for start_id={start_id} or end_id={end_id} "
                    f"(start_cycles={len(start_cycles)}, end_cycles={len(end_cycles)})"
                )

            # If multiple occurrences exist (e.g., repeated runs), take the first occurrence by default.
            # You can change this to average/min/max if desired.
            e2e = int(end_cycles[0]) - int(start_cycles[0])

            if uc == "uc0.asm":
                row["E2E Cycle uC_0"] = e2e
            elif uc == "uc2.asm":
                row["E2E Cycle uC_2"] = e2e
            elif uc == "uc4.asm":
                row["E2E Cycle uC_4"] = e2e

        rows.append(row)

    # Write CSV
    with csv_path.open("w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=CSV_FIELDS)
        w.writeheader()
        w.writerows(rows)

    return csv_path


def get_uc_cycles_from_csv(csv_path: str | Path) -> Tuple[Optional[int], Optional[int], Optional[int]]:
    """
    Reads ml_timeline_profile.csv and returns:
      (E2E Cycle uC_0, E2E Cycle uC_2, E2E Cycle uC_4) for Layer ID == 0.

    Returns None for any missing column/value.
    """
    csv_path = Path(csv_path)
    if not csv_path.exists():
        raise FileNotFoundError(f"CSV not found: {csv_path}")

    with csv_path.open("r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if str(row.get("Layer ID", "")).strip() == "0":
                def _to_int(v):
                    v = (v or "").strip()
                    return int(v) if v else None

                return (
                    _to_int(row.get("E2E Cycle uC_0")),
                    _to_int(row.get("E2E Cycle uC_2")),
                    _to_int(row.get("E2E Cycle uC_4")),
                )

    raise ValueError("Layer ID 0 not found in CSV.")


def main(argv: Optional[List[str]] = None) -> int:
    """Main function to parse ML Timeline metadata."""
    p = argparse.ArgumentParser(description="Generate ml_timeline_profile.csv from metadata + record_timer JSONs.")
    p.add_argument("-metadata", "--metadata", required=True, help="Path to ml_timeline_metadata.json")
    p.add_argument("-record_timer", "--record_timer", required=True, help="Path to record_timer JSON (cycle counts)")
    p.add_argument("-out", "--outdir", required=True, help="Directory to write ml_timeline_profile.csv")
    args = p.parse_args(argv)

    csv_path = generate_ml_timeline_csv(args.metadata, args.record_timer, args.outdir)
    print(f"[INFO] Wrote: {csv_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
