"""
Utility to extract a subgraph from an ONNX model between two node names (inclusive)
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path
from typing import Iterable, List, Sequence, Tuple, Set, Optional

import onnx
from onnx import ModelProto, NodeProto
from onnx.utils import extract_model


def load_model(path: Path) -> ModelProto:
    """Load an ONNX model from disk and perform a lightweight structural check."""
    model = onnx.load(str(path))
    onnx.checker.check_model(model)
    return model


def find_unique_node(graph_nodes: Sequence[NodeProto], node_name: str) -> Tuple[NodeProto, int]:
    """Return the unique node and its index for a given name; raise if missing or duplicated."""
    matches = [(i, n) for i, n in enumerate(graph_nodes) if n.name == node_name]
    if not matches:
        raise ValueError(f"Node '{node_name}' not found in graph.")
    if len(matches) > 1:
        raise ValueError(f"Multiple nodes share the name '{node_name}'. Node names must be unique.")
    idx, node = matches[0]
    return node, idx


def validate_order(start_idx: int, end_idx: int, start_name: str, end_name: str) -> None:
    """Ensure the start node precedes or equals the end node in graph order."""
    if start_idx > end_idx:
        raise ValueError(
            f"Start node '{start_name}' appears after end node '{end_name}' in graph order."
        )


def tensors_of(nodes: Iterable[NodeProto]) -> Tuple[List[str], List[str]]:
    """Return concatenated lists of input tensor names and output tensor names for nodes."""
    ins: List[str] = []
    outs: List[str] = []
    for n in nodes:
        ins.extend([t for t in n.input if t])
        outs.extend([t for t in n.output if t])
    return ins, outs


def plan_extract_io(start_node: NodeProto, end_node: NodeProto) -> Tuple[List[str], List[str]]:
    """
    Choose input/output tensor names for extraction.

    Inputs: all non-empty input tensors of the start node.
    Outputs: all non-empty output tensors of the end node.

    This yields a minimal subgraph that produces end outputs from start inputs,
    including both boundary nodes and their dependencies.
    """
    start_inputs = [t for t in start_node.input if t]
    end_outputs = [t for t in end_node.output if t]
    if not start_inputs:
        raise ValueError(f"Start node '{start_node.name}' has no input tensors.")
    if not end_outputs:
        raise ValueError(f"End node '{end_node.name}' has no output tensors.")
    return start_inputs, end_outputs


def all_tensor_names(model: ModelProto) -> Set[str]:
    """Return the set of all tensor names present anywhere in the model."""
    g = model.graph
    names: Set[str] = set()

    # Graph inputs/outputs/value_info
    names.update(vi.name for vi in g.input)
    names.update(vi.name for vi in g.output)
    names.update(vi.name for vi in g.value_info)

    # Initializers
    names.update(init.name for init in g.initializer)

    # Node outputs (and inputs, for completeness)
    for n in g.node:
        names.update(t for t in n.output if t)
        names.update(t for t in n.input if t)

    return names


def resolve_boundary_tensors(
    model: ModelProto,
    start_id: str,
    end_id: str,
) -> Tuple[List[str], List[str], Optional[int], Optional[int]]:
    """
    Resolve extraction boundaries from identifiers that may be node names or tensor names.

    Returns:
      (input_tensors, output_tensors, start_idx, end_idx)
      start_idx/end_idx are the node indices if identifiers matched nodes; otherwise None.
    """
    graph_nodes = list(model.graph.node)
    tensor_set = all_tensor_names(model)

    # Try resolving start as node
    start_idx: Optional[int] = None
    try:
        _, start_idx = find_unique_node(graph_nodes, start_id)
        start_node = graph_nodes[start_idx]
        input_tensors = [t for t in start_node.input if t]
        if not input_tensors:
            raise ValueError(f"Start node '{start_id}' has no input tensors.")
    except ValueError as exc:
        # Not a node name; treat as tensor
        if start_id not in tensor_set:
            raise ValueError(f"'{start_id}' is neither a node name nor a known tensor name.") from exc
        input_tensors = [start_id]

    # Try resolving end as node
    end_idx: Optional[int] = None
    try:
        _, end_idx = find_unique_node(graph_nodes, end_id)
        end_node = graph_nodes[end_idx]
        output_tensors = [t for t in end_node.output if t]
        if not output_tensors:
            raise ValueError(f"End node '{end_id}' has no output tensors.")
    except ValueError as exc:
        # Not a node name; treat as tensor
        if end_id not in tensor_set:
            raise ValueError(f"'{end_id}' is neither a node name nor a known tensor name.") from exc
        output_tensors = [end_id]

    return input_tensors, output_tensors, start_idx, end_idx


def extract_between_nodes(
    input_model_path: Path,
    output_model_path: Path,
    start_node_name: str,
    end_node_name: str,
) -> None:
    """
    Extract the ONNX subgraph between the given identifiers (node or tensor names).

    If an identifier matches a node name, its boundary tensors are derived from that node
    (start: inputs, end: outputs). Otherwise, the identifier is treated directly as a tensor name.
    """
    model = load_model(input_model_path)

    input_tensors, output_tensors, start_idx, end_idx = resolve_boundary_tensors(
        model, start_node_name, end_node_name
    )

    # Validate order only when both identifiers resolved to nodes
    if start_idx is not None and end_idx is not None:
        validate_order(start_idx, end_idx, start_node_name, end_node_name)

    output_model_path.parent.mkdir(parents=True, exist_ok=True)
    extract_model(
        str(input_model_path),
        str(output_model_path),
        input_names=input_tensors,
        output_names=output_tensors,
    )


def parse_args(argv: Sequence[str]) -> argparse.Namespace:
    """Parse command-line arguments for subgraph extraction."""
    p = argparse.ArgumentParser(
        prog="cut_onnx_subgraph",
        description="Extract a subgraph between two node names from an ONNX model.",
    )
    p.add_argument(
        "--input",
        required=True,
        type=Path,
        help="Path to the input ONNX model.",
    )
    p.add_argument(
        "--start",
        required=True,
        help="Start node name (inclusive).",
    )
    p.add_argument(
        "--end",
        required=True,
        help="End node name (inclusive).",
    )
    p.add_argument(
        "--output",
        required=True,
        type=Path,
        help="Output ONNX path for the extracted subgraph.",
    )
    return p.parse_args(argv)


def main(argv: Sequence[str] | None = None) -> int:
    """CLI entrypoint."""
    args = parse_args(argv if argv is not None else sys.argv[1:])
    try:
        extract_between_nodes(
            input_model_path=args.input,
            output_model_path=args.output,
            start_node_name=args.start,
            end_node_name=args.end,
        )
    except Exception as e:   # pylint: disable=W0718
        print(f"[ERROR] {e}", file=sys.stderr)
        return 1
    print(f"[OK] Saved subgraph to: {args.output}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
