"""
Runtime operation schemas for memory aliasing optimization.

This module defines Pydantic models for runtime operations (gather, slice, concat, split)
that support memory aliasing. These models validate ONNX attributes and provide
type-safe access to operation parameters needed for memory allocation.

Runtime operations operate on the N dimension (batch axis=0) only in NHWC format.
"""

from __future__ import annotations
import re
from enum import Enum
from typing import List, Tuple
from pydantic import BaseModel, Field, field_validator
from graph.dag import OnnxAttributes

# Regex pattern to match runtime operations: <op_name>_runtime[_suffix]
_RUNTIME_OP_PATTERN = r".+_runtime(?:_|$)"


def _validate_axis_normalization(
    axis: int,
    shapes: List[Tuple[int, ...]] | Tuple[int, ...],
    operation_name: str
) -> int:
    """Validate that axis can be normalized to axis=0 for memory aliasing.

    Checks if all dimensions BEFORE the specified axis are unity (1) across all
    input shapes. When this condition is met, the operation along that axis is
    equivalent to operating along axis=0 in flattened memory layout, enabling
    memory aliasing optimization.

    Args:
        axis: The axis to validate (can be negative)
        shapes: Either a single shape tuple or list of shape tuples to validate.
               All shapes must have the same dimensionality.
        operation_name: Name of the operation for error messages (e.g., "concat_runtime")

    Returns:
        Normalized positive axis index

    Raises:
        ValueError: If shapes have different dimensionalities, axis is out of range,
                   or leading dimensions are not unity

    Examples:
        >>> # Single 4D shape validation (for slice)
        >>> _validate_axis_normalization(2, (1, 1, 64, 64), "slice_runtime")
        2

        >>> # Multiple 4D shapes validation (for concat)
        >>> _validate_axis_normalization(-2, [(1, 1, 63, 64), (1, 1, 1, 64)], "concat_runtime")
        2

        >>> # 3D shape validation
        >>> _validate_axis_normalization(1, (1, 64, 64), "slice_runtime")
        1

        >>> # Invalid: non-unity leading dimension
        >>> _validate_axis_normalization(2, (1, 2, 64, 64), "slice_runtime")
        Traceback (most recent call last):
        ...
        ValueError: slice_runtime: axis=2 cannot be normalized to axis=0...
    """
    # Normalize shapes to list of tuples for uniform handling
    if isinstance(shapes, tuple):
        shape_list = [shapes]
    else:
        shape_list = list(shapes)

    if not shape_list:
        raise ValueError(f"{operation_name}: No shapes provided")

    # Validate all shapes have the same dimensionality
    ndim = len(shape_list[0])
    for i, shape in enumerate(shape_list):
        if len(shape) != ndim:
            raise ValueError(
                f"{operation_name}: All shapes must have the same number of dimensions. "
                f"Shape 0 has {ndim} dimensions but shape {i} has {len(shape)} dimensions. "
                f"Shapes: {shapes}"
            )

    # Normalize axis to positive index
    normalized_axis = axis if axis >= 0 else ndim + axis

    # Validate axis range
    if normalized_axis < 0 or normalized_axis >= ndim:
        raise ValueError(
            f"{operation_name}: axis={axis} is out of range for {ndim}D tensors "
            f"(valid range: -{ndim} to {ndim-1})"
        )

    # axis=0 is always valid
    if normalized_axis == 0:
        return normalized_axis

    # Check if all dimensions BEFORE the axis are unity (1) across all shapes
    for i, shape in enumerate(shape_list):
        for dim_idx in range(normalized_axis):
            if shape[dim_idx] != 1:
                shape_context = f"Input {i}" if len(shape_list) > 1 else "Input"
                raise ValueError(
                    f"{operation_name}: axis={axis} cannot be normalized to axis=0. "
                    f"{shape_context} has shape {shape} with non-unity dimension at index {dim_idx} "
                    f"(value={shape[dim_idx]}). All dimensions before the axis must be 1. "
                    f"For axis={axis} (normalized={normalized_axis}), dimensions [0:{normalized_axis}] "
                    f"must all be 1 across all inputs."
                )

    # All leading dimensions are unity - axis can be treated as axis=0
    return normalized_axis


class RuntimeOpType(str, Enum):
    """Classification of runtime operations for memory aliasing optimization.

    Runtime operations that support memory aliasing optimizations must operate
    on the N dimension (batch axis=0) only in NHWC format.
    """
    GATHER = "gather"      # gather_runtime: Max-size aliasing
    SLICE = "slice"        # slice_runtime: Offset-based aliasing
    CONCAT = "concat"      # concat_runtime: Contiguous input aliasing
    SPLIT = "split"        # split_runtime: Contiguous output aliasing
    UNKNOWN = "unknown"    # Runtime op not suitable for aliasing

    @staticmethod
    def is_runtime_op(op_type: str) -> bool:
        """Check if operator is a runtime operation.

        Runtime operations are structural transformation operations that
        perform memory aliasing. Examples include:
        - gather_runtime
        - slice_runtime
        - concat_runtime
        - split_runtime

        Args:
            op_type: The operator type string to check

        Returns:
            True if the operator is a runtime operation, False otherwise

        Examples:
            >>> RuntimeOpType.is_runtime_op("gather_runtime")
            True
            >>> RuntimeOpType.is_runtime_op("slice_runtime")
            True
            >>> RuntimeOpType.is_runtime_op("concat_runtime")
            True
            >>> RuntimeOpType.is_runtime_op("split_runtime")
            True
            >>> RuntimeOpType.is_runtime_op("reshape_noop")
            False
            >>> RuntimeOpType.is_runtime_op("Conv")
            False
            >>> RuntimeOpType.is_runtime_op("gather_runtime_qdq")
            True
        """
        return re.search(_RUNTIME_OP_PATTERN, op_type) is not None

    @staticmethod
    def classify(op_type: str) -> RuntimeOpType:
        """Classify a runtime operation for memory aliasing optimization.

        Determines the specific type of runtime operation to apply appropriate
        memory aliasing strategy. Only runtime operations on the N dimension
        (batch axis=0) are supported for optimization.

        Args:
            op_type: The operator type string to classify

        Returns:
            RuntimeOpType enum indicating the operation class:
            - GATHER: gather_runtime operations
            - SLICE: slice_runtime operations
            - CONCAT: concat_runtime operations
            - SPLIT: split_runtime operations
            - UNKNOWN: Runtime op not suitable for aliasing or not a runtime op

        Examples:
            >>> RuntimeOpType.classify("gather_runtime")
            <RuntimeOpType.GATHER: 'gather'>
            >>> RuntimeOpType.classify("slice_runtime_qdq")
            <RuntimeOpType.SLICE: 'slice'>
            >>> RuntimeOpType.classify("concat_runtime")
            <RuntimeOpType.CONCAT: 'concat'>
            >>> RuntimeOpType.classify("split_runtime")
            <RuntimeOpType.SPLIT: 'split'>
            >>> RuntimeOpType.classify("Conv")
            <RuntimeOpType.UNKNOWN: 'unknown'>
            >>> RuntimeOpType.classify("reshape_noop")
            <RuntimeOpType.UNKNOWN: 'unknown'>
        """
        if not RuntimeOpType.is_runtime_op(op_type):
            return RuntimeOpType.UNKNOWN

        # Extract base operation name (before _runtime)
        # Pattern: <op_name>_runtime[_suffix]
        match = re.match(r"(.+?)_runtime(?:_|$)", op_type)
        if not match:
            return RuntimeOpType.UNKNOWN

        op_base = match.group(1).lower()

        # Map base operation name to RuntimeOpType
        return next((rt for rt in RuntimeOpType if rt.value == op_base), RuntimeOpType.UNKNOWN)

    @staticmethod
    def is_aliasable(op_type: str) -> bool:
        """Check if a runtime operation supports memory aliasing optimization.

        Runtime operations must meet specific criteria to be aliasable:
        1. Must be a recognized runtime operation (gather, slice, concat, split)
        2. Must operate on N dimension only (axis=0 in NHWC format)
        3. Must have known memory layout requirements

        Args:
            op_type: The operator type string to check

        Returns:
            True if the operation supports memory aliasing, False otherwise

        Examples:
            >>> RuntimeOpType.is_aliasable("gather_runtime")
            True
            >>> RuntimeOpType.is_aliasable("slice_runtime_qdq")
            True
            >>> RuntimeOpType.is_aliasable("concat_runtime")
            True
            >>> RuntimeOpType.is_aliasable("split_runtime")
            True
            >>> RuntimeOpType.is_aliasable("unknown_runtime")
            False
            >>> RuntimeOpType.is_aliasable("Conv")
            False
        """
        return RuntimeOpType.classify(op_type) != RuntimeOpType.UNKNOWN


class GatherRuntimeAttrs(BaseModel):
    """Attributes for gather_runtime operation with axis normalization support.

    Gathers elements from input tensor along any axis when leading dimensions are unity.
    Memory strategy: Max-size aliasing (allocate max(input_size, output_size)).

    Attributes:
        axis: Gather axis (validated with input shape in build() method)
    """
    axis: int = Field(
        default=0,
        description="Axis to gather along (validated with input shape in build())"
    )

    @staticmethod
    def build(attributes: OnnxAttributes, input_shape: Tuple[int, ...]) -> "GatherRuntimeAttrs":
        """Construct GatherRuntimeAttrs with shape-aware axis validation.

        Args:
            attributes: ONNX attributes dict containing 'axis'
            input_shape: Input tensor shape (any dimensionality)

        Returns:
            Validated GatherRuntimeAttrs instance

        Raises:
            ValueError: If axis validation fails

        Examples:
            >>> attrs = GatherRuntimeAttrs.build({"axis": 0}, (10, 20, 30, 64))
            >>> attrs.axis
            0
        """
        attrs = GatherRuntimeAttrs(**attributes)
        _validate_axis_normalization(attrs.axis, input_shape, "gather_runtime")
        return attrs


class SliceRuntimeAttrs(BaseModel):
    """Attributes for slice_runtime operation with axis normalization support.

    Slices a contiguous subset along any axis when leading dimensions are unity.
    Memory strategy: Offset-based aliasing (output points to input + byte_offset).

    Attributes:
        axes: Slice axis (single-element list, validated in build() method)
        starts: Start index (single-element list)
        ends: End index (single-element list, optional)
        steps: Step size (single-element list, optional)
    """
    axes: list[int] = Field(
        description="Axes to slice along (single element, validated in build())"
    )
    starts: list[int] = Field(
        description="Start indices (must have exactly one element)"
    )
    ends: list[int] | None = Field(
        default=None,
        description="End indices (optional, used for output shape validation)"
    )
    steps: list[int] | None = Field(
        default=None,
        description="Step sizes (optional, defaults to [1])"
    )

    @field_validator('starts')
    @classmethod
    def validate_starts(cls, v: list[int]) -> list[int]:
        """Validate that starts has exactly one element."""
        if len(v) != 1:
            raise ValueError(
                f"slice_runtime: starts must have exactly one element, got {len(v)} elements"
            )
        return v

    @field_validator('ends')
    @classmethod
    def validate_ends(cls, v: list[int] | None) -> list[int] | None:
        """Validate that ends has exactly one element if provided."""
        if v is not None and len(v) != 1:
            raise ValueError(
                f"slice_runtime: ends must have exactly one element, got {len(v)} elements"
            )
        return v

    @field_validator('steps')
    @classmethod
    def validate_steps(cls, v: list[int] | None) -> list[int] | None:
        """Validate that steps has exactly one element if provided."""
        if v is not None and len(v) != 1:
            raise ValueError(
                f"slice_runtime: steps must have exactly one element, got {len(v)} elements"
            )
        return v

    @field_validator('axes')
    @classmethod
    def validate_axes(cls, v: list[int]) -> list[int]:
        """Validate that axes has exactly one element."""
        if len(v) != 1:
            raise ValueError(
                f"slice_runtime: axes must have exactly one element, got {len(v)} elements"
            )
        return v

    @staticmethod
    def build(
        attributes: OnnxAttributes,
        input_shape: Tuple[int, ...]
    ) -> "SliceRuntimeAttrs":
        """Create and validate SliceRuntimeAttrs with shape-aware axis normalization.

        Args:
            attributes: ONNX attributes dict containing 'axes', 'starts', etc.
            input_shape: Input tensor shape (any dimensionality)

        Returns:
            Validated SliceRuntimeAttrs instance

        Raises:
            ValueError: If axis cannot be normalized to axis=0

        Examples:
            >>> attrs = SliceRuntimeAttrs.build(
            ...     {"axes": [2], "starts": [1], "ends": [64]},
            ...     (1, 1, 64, 64)
            ... )
            >>> attrs.axes
            [2]
        """
        attrs = SliceRuntimeAttrs(**attributes)
        axis = attrs.axes[0]
        _validate_axis_normalization(axis, input_shape, "slice_runtime")
        return attrs


class ConcatRuntimeAttrs(BaseModel):
    """Attributes for concat_runtime operation with axis normalization support.

    Concatenates multiple tensors along any axis when leading dimensions are unity.
    Memory strategy: Contiguous input aliasing (inputs laid out sequentially).

    Attributes:
        axis: Concatenation axis (validated with input shapes in build() method)
    """
    axis: int = Field(
        default=0,
        description="Axis to concatenate along (validated with input shapes in build())"
    )

    @staticmethod
    def build(
        attributes: OnnxAttributes,
        shapes: List[Tuple[int, ...]]
    ) -> "ConcatRuntimeAttrs":
        """Create and validate ConcatRuntimeAttrs with shape-aware axis normalization.

        Validates that the concat axis can be normalized to axis=0 by checking if all
        dimensions before the concat axis are unity (1) across all inputs. When this
        condition is met, the memory layout is equivalent to concatenating along axis=0.

        Args:
            attributes: ONNX attributes dict containing 'axis' key
            shapes: List of input shapes (any dimensionality, all must match)

        Returns:
            Validated ConcatRuntimeAttrs instance

        Raises:
            ValueError: If axis cannot be normalized to axis=0 (non-unity leading dimensions)

        Examples:
            PSMU_ST1 KV-cache pattern (axis=-2 with N=1, H=1):
            >>> attrs = ConcatRuntimeAttrs.build(
            ...     {"axis": -2},
            ...     [(1, 1, 63, 64), (1, 1, 1, 64)]
            ... )
            >>> attrs.axis
            -2

            Attention head concatenation (axis=1 with N=1):
            >>> attrs = ConcatRuntimeAttrs.build(
            ...     {"axis": 1},
            ...     [(1, 1, 10, 64), (1, 63, 10, 64)]
            ... )
            >>> attrs.axis
            1

            Invalid case (axis=-2 but H=2, not unity):
            >>> try:
            ...     ConcatRuntimeAttrs.build(
            ...         {"axis": -2},
            ...         [(1, 2, 63, 64), (1, 2, 1, 64)]
            ...     )
            ... except ValueError as e:
            ...     print("Validation failed")
            Validation failed
        """
        # Create instance from ONNX attributes
        attrs = ConcatRuntimeAttrs(**attributes)

        # Validate shapes and axis can be normalized
        _validate_axis_normalization(attrs.axis, shapes, "concat_runtime")

        return attrs


class SplitRuntimeAttrs(BaseModel):
    """Attributes for split_runtime operation with axis normalization support.

    Splits one tensor into multiple outputs along any axis when leading dimensions are unity.
    Memory strategy: Contiguous output aliasing (outputs laid out sequentially).

    Attributes:
        axis: Split axis (validated with input shape in build() method)
        split: List of dimension sizes for each output along split axis
    """
    axis: int = Field(
        default=0,
        description="Axis to split along (validated with input shape in build())"
    )
    split: list[int] = Field(
        description="Size of each output along split axis"
    )

    @staticmethod
    def build(attributes: OnnxAttributes, input_shape: Tuple[int, ...]) -> "SplitRuntimeAttrs":
        """Construct SplitRuntimeAttrs with shape-aware axis validation.

        Args:
            attributes: ONNX attributes dict containing 'axis' and 'split'
            input_shape: Input tensor shape (any dimensionality)

        Returns:
            Validated SplitRuntimeAttrs instance

        Raises:
            ValueError: If axis or split validation fails

        Examples:
            >>> attrs = SplitRuntimeAttrs.build(
            ...     {"axis": 0, "split": [10, 20]},
            ...     (30, 10, 10, 64)
            ... )
            >>> attrs.split
            [10, 20]
        """
        attrs = SplitRuntimeAttrs(**attributes)
        _validate_axis_normalization(attrs.axis, input_shape, "split_runtime")

        if len(attrs.split) == 0:
            raise ValueError("split_runtime: split must have at least one element")
        if any(s <= 0 for s in attrs.split):
            raise ValueError(
                f"split_runtime: all split values must be positive, got {attrs.split}"
            )

        return attrs
