"""
This module provides type definitions for tensor and related types.
"""

from __future__ import annotations
import operator
from dataclasses import dataclass, field, replace
from enum import Enum
from functools import reduce
from typing import Any, List, Tuple
import copy
from math import ceil

# pylint: disable-next=import-error,no-name-in-module
from graph.utilities import tensor_dtype_to_size, logger
from graph.dag import scoped_id, scoped_tensor, OnnxAttributes


class XrtId(Enum):
    """XRT memory bin identifiers for different tensor types.

    Memory Bins:
        - DEFAULT: Default memory bin (-1)
        - OFM: Output Feature Map bin (0)
        - IFM: Input Feature Map bin (1)
        - WEIGHT: Weight data bin (2)
        - PARAM: Parameter data bin (3)
        - SCRATCH: Scratch/temporary memory bin (4)

    Examples:
        >>> XrtId.DEFAULT.value
        -1
        >>> XrtId.IFM.value
        1
        >>> XrtId.WEIGHT.value
        2
        >>> list(XrtId)[:3]
        [<XrtId.DEFAULT: -1>, <XrtId.OFM: 0>, <XrtId.IFM: 1>]
    """
    DEFAULT = -1
    OFM = 0
    IFM = 1
    WEIGHT = 2
    PARAM = 3
    SCRATCH = 4


def iceil64(x: int) -> int:
    """Rounds up to the nearest multiple of 64.

    This is commonly used for channel padding in neural networks
    to meet hardware alignment requirements.

    Args:
        x: Value to round up

    Returns:
        Smallest multiple of 64 >= x

    Examples:
        >>> iceil64(1)
        64
        >>> iceil64(64)
        64
        >>> iceil64(65)
        128
        >>> iceil64(100)
        128
        >>> iceil64(0)
        0
    """
    return (x + 63) & ~63


def get_padded_shape(shape: Tuple[int, ...] | List[int]) -> Tuple[int, ...]:
    """Returns the padded shape with channel dimension rounded to multiple of 64.

    For tensors with 3+ dimensions (e.g., NCHW format), pads the 3rd-from-last
    dimension (typically channels). For 2D tensors, prepends a dimension of 64.

    Args:
        shape: Original tensor shape

    Returns:
        Padded shape with channel dimension rounded to multiple of 64

    Raises:
        ValueError: If shape has less than 2 dimensions

    Examples:
        >>> get_padded_shape((1, 32, 224, 224))
        (1, 64, 224, 224)
        >>> get_padded_shape((1, 64, 224, 224))
        (1, 64, 224, 224)
        >>> get_padded_shape((1, 100, 56, 56))
        (1, 128, 56, 56)
        >>> get_padded_shape((224, 224))
        (64, 224, 224)
        >>> get_padded_shape([1, 50, 128, 128])
        (1, 64, 128, 128)
    """
    if isinstance(shape, list):
        shape = tuple(shape)
    ndim = len(shape)
    if ndim >= 3:
        return shape[: ndim - 3] + (iceil64(shape[-3]),) + shape[-2:]
    if ndim == 2:
        return (64, ) + shape
    raise ValueError(f"Invalid shape {shape}")


def get_padded_shape_rev(shape: Tuple[int, ...] | List[int]) -> Tuple[int, ...]:
    """Returns the padded shape for NHWC format with last dimension rounded to 64.

    This is used for tensors in NHWC (channels-last) format where the
    last dimension (channels) needs to be padded to a multiple of 64.

    Args:
        shape: Original tensor shape in NHWC format

    Returns:
        Padded shape with last dimension rounded to multiple of 64

    Examples:
        >>> get_padded_shape_rev((1, 224, 224, 3))
        (1, 224, 224, 64)
        >>> get_padded_shape_rev((1, 56, 56, 64))
        (1, 56, 56, 64)
        >>> get_padded_shape_rev((1, 28, 28, 100))
        (1, 28, 28, 128)
        >>> get_padded_shape_rev([2, 112, 112, 32])
        (2, 112, 112, 64)
    """
    if isinstance(shape, list):
        shape = tuple(shape)
    return shape[:-1] + (iceil64(shape[-1]), )


def flatten_batch_dims(shape: List[int]) -> List[int]:
    """Flatten leading batch dimensions, keeping the last 2 dims intact.

    For MatMul operations, collapse multiple batch dimensions into a single
    batch dimension while preserving the matrix dimensions (M, K).

    Args:
        shape: Original tensor shape

    Returns:
        Flattened shape with batch dims collapsed

    Examples:
        >>> flatten_batch_dims([1, 1, 48, 45, 47])
        [1, 1, 48, 45, 47]
        >>> flatten_batch_dims([1, 12, 48, 45, 47])
        [1, 1, 576, 45, 47]
        >>> flatten_batch_dims([1, 48, 45, 47])
        [1, 48, 45, 47]
        >>> flatten_batch_dims([12, 48, 45, 47])
        [1, 576, 45, 47]
        >>> flatten_batch_dims([48, 45, 47])
        [48, 45, 47]
        >>> flatten_batch_dims([45, 47])
        [45, 47]
    """
    if len(shape) <= 2:
        return shape
    *batch_dims, m, k = shape
    batch_product = reduce(operator.mul, batch_dims, 1)
    return [1] * max(len(shape) - 3, 0) + [batch_product, m, k]


@dataclass(frozen=True, slots=True)
class Tensor:
    """Represents a tensor with metadata for memory allocation and execution.

    Attributes:
        id: Unique identifier for the tensor
        shape: Tensor dimensions (e.g., (batch, channels, height, width))
        dtype: Data type string (e.g., "TensorProto.FLOAT", "TensorProto.INT8")
        bin: Memory bin assignment (DEFAULT, IFM, OFM, etc.)
        size: Computed size in bytes (automatically calculated)
        is_channel_multiple_of_64: Whether to pad channels to multiple of 64
        is_model_io: Whether this tensor is a graph input/output

    Examples:
        >>> tensor = Tensor(
        ...     id="input",
        ...     shape=(1, 3, 224, 224),
        ...     dtype="TensorProto.FLOAT"
        ... )
        >>> tensor.id
        'input'
        >>> tensor.shape
        (1, 3, 224, 224)
        >>> tensor.bin
        <XrtId.DEFAULT: -1>

        >>> # Tensor with specific memory bin
        >>> ifm_tensor = Tensor(
        ...     id="feature_map",
        ...     shape=(1, 64, 56, 56),
        ...     dtype="TensorProto.INT8",
        ...     bin=XrtId.IFM
        ... )
        >>> ifm_tensor.bin
        <XrtId.IFM: 1>

        >>> # Model I/O tensor
        >>> output = Tensor(
        ...     id="output",
        ...     shape=(1, 1000),
        ...     dtype="TensorProto.FLOAT",
        ...     is_model_io=True
        ... )
        >>> output.is_model_io
        True
    """

    id: str
    shape: Tuple[int, ...]
    dtype: str
    bin: XrtId = XrtId.DEFAULT
    size: int = field(init=False, repr=False)
    is_channel_multiple_of_64: bool = False
    is_model_io: bool = False
    is_constant: bool = False  # True for initializers/constants

    def __post_init__(self) -> None:
        shape = self.shape or [0]
        if self.is_channel_multiple_of_64:
            shape = get_padded_shape_rev(shape)

        if self.dtype is None:
            logger.info("Tensor %s has no dtype defined, setting it to INT8", self.id)
            object.__setattr__(self, "dtype", "TensorProto.INT8")

        size = ceil(
            reduce(operator.mul, shape) * tensor_dtype_to_size(self.dtype)
        )
        object.__setattr__(self, "size", size)

    def __hash__(self) -> int:
        return hash(self.id)

    def __eq__(self, other: Any) -> bool:
        return isinstance(other, Tensor) and (other.id == self.id)

    def copy(self) -> Tensor:
        """Returns a deep copy of the tensor.

        Returns:
            Deep copy of the Tensor instance

        Examples:
            >>> tensor = Tensor(id="t1", shape=(10, 10), dtype="TensorProto.FLOAT")
            >>> copy_tensor = tensor.copy()
            >>> copy_tensor.id
            't1'
            >>> copy_tensor is tensor
            False
        """
        return copy.deepcopy(self)

    def update(self, **kwargs: Any) -> Tensor:
        """Update tensor attributes and return a new instance.

        Since Tensor is frozen, this creates a new instance with
        the specified attributes modified.

        Args:
            **kwargs: Attributes to update

        Returns:
            New Tensor instance with updated attributes

        Examples:
            >>> tensor = Tensor(id="t1", shape=(10, 10), dtype="TensorProto.FLOAT")
            >>> updated = tensor.update(bin=XrtId.IFM)
            >>> updated.bin
            <XrtId.IFM: 1>
            >>> tensor.bin
            <XrtId.DEFAULT: -1>
        """
        return replace(self, **kwargs)

    @staticmethod
    def empty() -> Tensor:
        """Create an empty tensor with undefined shape and type.

        Returns:
            An empty Tensor instance

        Examples:
            >>> tensor = Tensor.empty()
            >>> tensor.bin
            <XrtId.DEFAULT: -1>
            >>> tensor.dtype
            'TensorProto.UNDEFINED'
        """
        return Tensor(
            id="empty_tensor",
            shape=(),
            dtype="TensorProto.UNDEFINED"
        )


class TensorLocation(str, Enum):
    """Memory location of a tensor during execution.

    Locations:
        - UNKNOWN: Location not yet determined
        - MEMORY: Tensor is in primary memory (fast access)
        - SPILLED: Tensor is in secondary storage (slower access)

    Examples:
        >>> TensorLocation.MEMORY.value
        'memory'
        >>> TensorLocation.SPILLED
        <TensorLocation.SPILLED: 'spilled'>
        >>> list(TensorLocation)
        [<TensorLocation.UNKNOWN: 'unknown'>, <TensorLocation.MEMORY: 'memory'>, <TensorLocation.SPILLED: 'spilled'>]
    """

    UNKNOWN = "unknown"
    MEMORY = "memory"
    SPILLED = "spilled"


@dataclass(frozen=True, slots=True)
class TensorLifetime:
    """Represents a tensor lifetime interval [start, end] in the execution graph.

    The lifetime defines when a tensor is first produced (start) and
    when it is last consumed (end), measured in operation indices.

    Attributes:
        start: Index of first operation producing/using this tensor
        end: Index of last operation consuming this tensor

    Special Values:
        - start=-1, end=-1: Pre-allocated tensor with fixed lifetime

    Examples:
        >>> lifetime = TensorLifetime(start=0, end=10)
        >>> lifetime.start
        0
        >>> lifetime.end
        10

        >>> # Pre-allocated tensor
        >>> prealloc = TensorLifetime(start=-1, end=-1)
        >>> prealloc.start
        -1
        >>> prealloc.end
        -1
    """

    start: int
    end: int


@dataclass(frozen=True)
class Operation:
    """Represents an operation in the computational graph.

    An operation defines a computation that consumes input tensors
    and produces output tensors.

    Attributes:
        id: Unique identifier for the operation
        type: Operation type (e.g., "Conv", "Add", "MatMul", "Reshape")
        inputs: List of input tensor IDs
        outputs: List of output tensor IDs
        attributes: Optional dictionary of ONNX attributes (e.g., axis, starts, split)

    Examples:
        >>> op = Operation(
        ...     id="conv1",
        ...     type="Conv",
        ...     inputs=["input", "weights"],
        ...     outputs=["conv1_out"]
        ... )
        >>> op.id
        'conv1'
        >>> op.type
        'Conv'
        >>> op.inputs
        ['input', 'weights']
        >>> op.outputs
        ['conv1_out']

        >>> # Pointwise operation
        >>> add_op = Operation(
        ...     id="add1",
        ...     type="Add",
        ...     inputs=["tensor_a", "tensor_b"],
        ...     outputs=["sum"]
        ... )
        >>> add_op.type
        'Add'
        >>> len(add_op.inputs)
        2

        >>> # Runtime operation with attributes
        >>> slice_op = Operation(
        ...     id="slice1",
        ...     type="slice_runtime",
        ...     inputs=["input"],
        ...     outputs=["output"],
        ...     attributes={"axes": [0], "starts": [10], "ends": [30]}
        ... )
        >>> slice_op.attributes["starts"]
        [10]
    """

    id: str
    type: str
    inputs: List[str]  # tensor IDs that are inputs
    outputs: List[str]  # tensor IDs that are outputs
    attributes: OnnxAttributes = field(default_factory=dict)  # ONNX attributes

    def copy(self) -> Operation:
        """Returns a deep copy of the operation.

        Returns:
            Deep copy of the Operation instance

        Examples:
            >>> op = Operation(id="conv1", type="Conv", inputs=["in"], outputs=["out"])
            >>> copy_op = op.copy()
            >>> copy_op.id
            'conv1'
            >>> copy_op is op
            False
        """
        return copy.deepcopy(self)

    @staticmethod
    def from_dict(**kwargs: Any) -> 'Operation':
        """Builds an Operation instance from a dictionary in local scope.

        Args:
            **kwargs: Dictionary with keys 'name', 'op_type', 'inputs', 'outputs'

        Returns:
            Operation instance

        Examples:
            >>> op = Operation.from_dict(
            ...     name="conv1",
            ...     op_type="Conv",
            ...     inputs=["input", "weights"],
            ...     outputs=["output"]
            ... )
            >>> op.id
            'conv1'
            >>> op.type
            'Conv'
        """
        return Operation(kwargs["name"], kwargs["op_type"], kwargs["inputs"], kwargs["outputs"])

    @staticmethod
    def from_dict_scope(**kwargs: Any) -> Operation:
        """Builds an Operation instance with scoped identifiers.

        This applies a scope prefix to the operation name and all
        tensor IDs, useful for subgraph isolation.

        Args:
            **kwargs: Dictionary with 'scope', 'name', 'op_type', 'inputs', 'outputs'

        Returns:
            Operation instance with scoped identifiers
        """
        st = lambda items: [scoped_tensor(kwargs["scope"], item) for item in items]  # pylint: disable=C3001 # noqa: E731
        return Operation(
                scoped_id(kwargs["scope"], kwargs["name"]), kwargs["op_type"],
                st(kwargs["inputs"]), st(kwargs["outputs"])
            )
