"""
This module provides type definitions for memory allocation and related operations.
"""

from __future__ import annotations
import copy
from dataclasses import dataclass, field
from enum import Enum
from itertools import chain, tee
from typing import Any, List, Optional, Tuple, Dict
from pydantic import BaseModel, Field
from typing_extensions import Self

# pylint: disable-next=import-error,no-name-in-module
from graph.common import SRAM_TOTAL  # type: ignore

# pylint: disable-next=import-error
from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId  # type: ignore


def round_up(x: int, r: int) -> int:
    """Round up x to the nearest multiple of r.

    Args:
        x: Value to round up
        r: Multiple to round to (must be positive)

    Returns:
        Smallest multiple of r >= x

    Raises:
        ValueError: If r is not positive

    Examples:
        >>> round_up(10, 8)
        16
        >>> round_up(16, 8)
        16
        >>> round_up(0, 8)
        0
        >>> round_up(1, 64)
        64
        >>> round_up(100, 32)
        128
    """
    if r <= 0:
        raise ValueError("r must be positive")
    return ((x + r - 1) // r) * r


def is_non_overlapping(seq: List[Tuple[int, int]]) -> bool:
    """
    Determines if intervals in the input seq do not overlap.

    Args:
        seq: Sorted list of (start, end) tuples representing memory intervals

    Returns:
        True if no intervals overlap, False otherwise

    Examples:
        >>> is_non_overlapping([(0, 10), (10, 20), (20, 30)])
        True
        >>> is_non_overlapping([(0, 10), (5, 15)])
        False
        >>> is_non_overlapping([(0, 10)])
        True
        >>> is_non_overlapping([])
        True
    """
    it1, it2 = tee(seq)
    next(it2, None)
    return all(b[0] >= a[1] for a, b in zip(it1, it2))


@dataclass
class MemoryBlock:
    """Represents a contiguous block of memory.

    Attributes:
        start: Starting address of the memory block
        size: Size of the block in bytes
        is_free: Whether the block is currently free
        tensor_id: ID of tensor allocated in this block (None if free)

    Examples:
        >>> block = MemoryBlock(start=0, size=1024, is_free=True)
        >>> block.start
        0
        >>> block.size
        1024
        >>> block.is_free
        True
        >>> block.tensor_id is None
        True

        >>> # Allocated block
        >>> allocated = MemoryBlock(start=1024, size=512, is_free=False, tensor_id="tensor_1")
        >>> allocated.is_free
        False
        >>> allocated.tensor_id
        'tensor_1'
    """

    start: int
    size: int
    is_free: bool
    tensor_id: Optional[str] = None
    aliased_tensor_ids: List[str] = field(default_factory=list)  # Track all tensors aliasing this block
    offset_from_parent: Optional[int] = None  # For slice/split operations: byte offset from parent tensor

    # IMPLEMENTATION NOTE: Two-Phase Address Calculation for Offset-Based Aliasing
    # ============================================================================
    # For slice/split operations with offset_from_parent:
    #
    # Phase 1 - Allocation (before alignment):
    #   - block.start stores the PARENT's unaligned address (NOT parent + offset)
    #   - block.offset_from_parent stores the byte offset from parent
    #   - This allows us to recalculate after parent alignment
    #
    # Phase 2 - Alignment (convert_alloc_to_aligned):
    #   - Parent blocks are aligned: new_start = align(block.start)
    #   - Offset blocks are adjusted: new_start = align(block.start) + offset_from_parent
    #   - No parent references or lookups needed - pure arithmetic!
    #
    # Example:
    #   Parent allocated at 100, slice starts at element 10 (byte offset 40)
    #   Phase 1: output_block.start = 100, output_block.offset_from_parent = 40
    #   Phase 2: output_block.start = align(100) + 40 = 4096 + 40 = 4136
    #
    # This approach handles nested offsets automatically:
    #   - Each offset block stores its parent's address in block.start
    #   - During alignment, align(block.start) gives aligned parent address
    #   - Add offset_from_parent to get final address
    #   - Works for chains: input→slice→split, etc.

    def __lt__(self, other: Self) -> bool:
        return self.size < other.size

    def aligned_start(self, alloc_config: AllocationConfig) -> int:
        """Compute aligned start address based on allocation configuration.

        Args:
            alloc_config: Allocation configuration with alignment requirements

        Returns:
            Start address aligned to the configured alignment

        Examples:
            >>> block = MemoryBlock(start=100, size=1024, is_free=True)
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT
            ... )
            >>> block.aligned_start(config)
            100

            >>> # With 4KB alignment
            >>> config_4k = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.KB_4
            ... )
            >>> block.aligned_start(config_4k)
            4096
        """
        return round_up(self.start, alloc_config.alignment.value)

    def required_size(self, size: int, alloc_config: AllocationConfig) -> int:
        """Compute total memory requirement considering alignment constraints.

        The required size includes both the requested size and any padding
        needed to meet alignment requirements.

        Args:
            size: Requested size in bytes
            alloc_config: Allocation configuration with alignment requirements

        Returns:
            Total size needed including alignment padding

        Examples:
            >>> block = MemoryBlock(start=100, size=2048, is_free=True)
            >>> config = AllocationConfig(strategy=AllocationStrategy.FIRST_FIT)
            >>> block.required_size(1024, config)
            1024

            >>> # With 4KB alignment, padding is added
            >>> config_4k = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.KB_4
            ... )
            >>> block.required_size(1024, config_4k)
            5020
        """
        if size == 0:
            return 0
        return size + (self.aligned_start(alloc_config) - self.start)

    def copy(self) -> Self:
        """Performs a deep copy of self.

        Returns:
            Deep copy of the MemoryBlock
        """
        return copy.deepcopy(self)


@dataclass(frozen=True, slots=True)
class ParamBlock:
    """Represents a parameter memory block for L2 fusion.

    Parameter blocks store constant parameters that are loaded once
    and remain unchanged during execution.

    Attributes:
        start: Starting address of the parameter block
        size: Size of the block in bytes

    Examples:
        >>> params = ParamBlock(start=0, size=256)
        >>> params.start
        0
        >>> params.size
        256
    """

    start: int
    size: int


@dataclass(frozen=True, slots=True)
class WeightBlock:
    """Represents a weight memory block for neural network layers.

    Weight blocks store trained model weights that are loaded
    during layer execution.

    Attributes:
        start: Starting address of the weight block
        size: Size of the block in bytes

    Examples:
        >>> weights = WeightBlock(start=1024, size=4096)
        >>> weights.start
        1024
        >>> weights.size
        4096
    """

    start: int
    size: int


@dataclass(frozen=True, slots=True)
class WeightPingPongBlock:
    """Represents ping-pong weight buffers for double-buffering.

    Ping-pong buffering allows loading the next layer's weights
    while the current layer is executing, improving throughput.

    Attributes:
        ping: First weight buffer
        pong: Second weight buffer (alternate)

    Examples:
        >>> ping = WeightBlock(start=0, size=1024)
        >>> pong = WeightBlock(start=1024, size=1024)
        >>> pp_block = WeightPingPongBlock(ping=ping, pong=pong)
        >>> pp_block.ping.start
        0
        >>> pp_block.pong.start
        1024
    """

    ping: WeightBlock
    pong: WeightBlock


class MemoryConfig(BaseModel):
    """Memory layout configuration for L2 fusion.

    This class defines the complete memory layout including parameter blocks,
    weight ping-pong buffers, and general memory blocks. It validates that
    all regions are non-overlapping and within SRAM bounds.

    Attributes:
        - params: Tuple of 3 parameter blocks for constant data
        - weights: Tuple of 3 ping-pong weight buffer pairs for double-buffering
        - memory: List of general-purpose memory blocks

    Validation:
        - All memory regions must be non-overlapping
        - All regions must start at address >= 0
        - All regions must end at address <= SRAM_TOTAL

    Examples:
        >>> # Create parameter blocks
        >>> param1 = ParamBlock(start=0, size=256)
        >>> param2 = ParamBlock(start=256, size=256)
        >>> param3 = ParamBlock(start=512, size=256)

        >>> # Create ping-pong weight buffers
        >>> weights1 = WeightPingPongBlock(
        ...     ping=WeightBlock(start=1024, size=512),
        ...     pong=WeightBlock(start=1536, size=512)
        ... )
        >>> weights2 = WeightPingPongBlock(
        ...     ping=WeightBlock(start=2048, size=512),
        ...     pong=WeightBlock(start=2560, size=512)
        ... )
        >>> weights3 = WeightPingPongBlock(
        ...     ping=WeightBlock(start=3072, size=512),
        ...     pong=WeightBlock(start=3584, size=512)
        ... )

        >>> # Create general memory blocks
        >>> mem_blocks = [
        ...     MemoryBlock(start=4096, size=8192, is_free=True),
        ...     MemoryBlock(start=12288, size=4096, is_free=True)
        ... ]

        >>> # Create configuration
        >>> config = MemoryConfig(
        ...     params=(param1, param2, param3),
        ...     weights=(weights1, weights2, weights3),
        ...     memory=mem_blocks
        ... )
        >>> len(config.params)
        3
        >>> len(config.weights)
        3
        >>> len(config.memory)
        2
    """

    params: Tuple[ParamBlock, ParamBlock, ParamBlock] = Field(frozen=True)
    weights: Tuple[WeightPingPongBlock, WeightPingPongBlock, WeightPingPongBlock] = (
        Field(frozen=True)
    )
    memory: List[MemoryBlock] = Field(frozen=True)

    # pylint: disable-next=arguments-differ
    def model_post_init(self, __context: Any) -> None:  # noqa
        memory_regions = sorted(
            map(
                lambda v: (v.start, v.start + v.size),  # type: ignore[attr-defined]
                chain(
                    self.params,
                    chain(*[(weight.ping, weight.pong) for weight in self.weights]),
                    self.memory,
                ),
            )
        )
        if not (
            is_non_overlapping(memory_regions)
            and memory_regions[0][0] >= 0
            and memory_regions[-1][1] <= SRAM_TOTAL
        ):
            raise ValueError("Memory region is inconsistent")


@dataclass(slots=True)
class TensorAllocation:
    """Represents a tensor allocation request with lifetime and location.

    This class combines a tensor with its lifetime range and desired
    memory location, forming a complete allocation request.

    Attributes:
        tensor: The tensor to allocate
        range: Lifetime range (start/end operation indices)
        location: Desired memory location (MEMORY, SPILL, etc.)
        block: Allocated memory block (None if not yet allocated)
        is_deallocatable: Whether this allocation can be deallocated

    Examples:
        >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
        >>> tensor = Tensor(id="t1", shape=(10, 10), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
        >>> lifetime = TensorLifetime(start=0, end=5)
        >>> alloc = TensorAllocation(
        ...     tensor=tensor,
        ...     range=lifetime,
        ...     location=TensorLocation.MEMORY
        ... )
        >>> alloc.tensor.id
        't1'
        >>> alloc.range.start
        0
        >>> alloc.block is None
        True
        >>> alloc.is_deallocatable
        True
    """

    tensor: Tensor
    range: TensorLifetime
    location: TensorLocation
    block: Optional[MemoryBlock] = None
    is_deallocatable: bool = True

    def copy(self) -> Self:
        """Performs a deep copy of self.

        Returns:
            Deep copy of the TensorAllocation
        """
        return copy.deepcopy(self)

    @staticmethod
    def empty() -> Self:
        """An empty tensor.

        Returns:
            An empty TensorAllocation
        """
        return TensorAllocation(
            tensor=Tensor.empty(),
            range=TensorLifetime(start=-1, end=-1),
            location=TensorLocation.UNKNOWN,
            block=None,
            is_deallocatable=False
        )


class AllocationStrategy(str, Enum):
    """Memory allocation strategy for finding suitable memory blocks.

    Strategies:
        - FIRST_FIT: Allocate in the first available block that fits (fastest)
        - BEST_FIT: Allocate in the smallest sufficient block (most memory efficient)
        - WORST_FIT: Allocate in the largest available block (reduces fragmentation)

    Examples:
        >>> strategy = AllocationStrategy.FIRST_FIT
        >>> strategy.value
        'first_fit'
        >>> AllocationStrategy.BEST_FIT
        <AllocationStrategy.BEST_FIT: 'best_fit'>
        >>> list(AllocationStrategy)
        [<AllocationStrategy.FIRST_FIT: 'first_fit'>, <AllocationStrategy.BEST_FIT: 'best_fit'>, <AllocationStrategy.WORST_FIT: 'worst_fit'>]
    """

    FIRST_FIT = "first_fit"
    BEST_FIT = "best_fit"
    WORST_FIT = "worst_fit"


class AllocationAlignment(int, Enum):
    """Memory allocation alignment requirements.

    Alignment:
        - DEFAULT: No special alignment (1 byte)
        - KB_4: 4KB alignment (4096 bytes)

    Examples:
        >>> AllocationAlignment.DEFAULT.value
        1
        >>> AllocationAlignment.KB_4.value
        4096
    """

    DEFAULT = 1
    KB_4 = 4096


@dataclass(frozen=True, slots=True)
class AllocationConfig:
    """Configuration for memory allocation strategy and constraints.

    Attributes:
        strategy: Allocation strategy to use (FIRST_FIT, BEST_FIT, or WORST_FIT)
        alignment: Memory alignment requirement
        bin: XRT memory bin ID

    Examples:
        >>> from graph.tensor_types import XrtId
        >>> config = AllocationConfig(strategy=AllocationStrategy.FIRST_FIT)
        >>> config.strategy
        <AllocationStrategy.FIRST_FIT: 'first_fit'>
        >>> config.alignment
        <AllocationAlignment.DEFAULT: 1>
        >>> config.bin
        <XrtId.DEFAULT: -1>

        >>> # With 4KB alignment
        >>> config_aligned = AllocationConfig(
        ...     strategy=AllocationStrategy.BEST_FIT,
        ...     alignment=AllocationAlignment.KB_4,
        ...     bin=XrtId.IFM
        ... )
        >>> config_aligned.alignment.value
        4096
    """

    strategy: AllocationStrategy
    alignment: AllocationAlignment = AllocationAlignment.DEFAULT
    bin: XrtId = XrtId.DEFAULT


class AllocationResult(str, Enum):
    """Result of a memory allocation operation.

    Results:
        - ALLOCATED: Successfully allocated in primary memory
        - ALLOCATED_WITH_SPILLING: Allocated after spilling other tensors
        - ALLOCATED_IN_PLACE: Reused input buffer (for pointwise ops)
        - DEALLOCATED: Memory was deallocated
        - SPILLED: Tensor was spilled to secondary storage

    Examples:
        >>> AllocationResult.ALLOCATED.value
        'allocated'
        >>> AllocationResult.ALLOCATED_WITH_SPILLING
        <AllocationResult.ALLOCATED_WITH_SPILLING: 'allocated_with_spilling'>
    """

    ALLOCATED = "allocated"
    ALLOCATED_WITH_SPILLING = "allocated_with_spilling"
    ALLOCATED_IN_PLACE = "allocated_in_place"
    DEALLOCATED = "deallocated"
    SPILLED = "spilled"


Alloc = Tuple[AllocationResult, TensorAllocation]
AllocList = List[Alloc]
AllocDict = Dict[int, AllocList]


class SubgraphAllocationMode(str, Enum):
    """Allocation mode for subgraphs in graph partitioning.

    Modes:
        - LOCAL: Each subgraph has its own local memory allocation
        - GLOBAL: All subgraphs share a global memory pool
        - CONTINUOUS: Subgraphs use contiguous memory without fragmentation

    Examples:
        >>> SubgraphAllocationMode.LOCAL.value
        'local'
        >>> SubgraphAllocationMode.GLOBAL
        <SubgraphAllocationMode.GLOBAL: 'global'>
        >>> list(SubgraphAllocationMode)
        [<SubgraphAllocationMode.LOCAL: 'local'>, <SubgraphAllocationMode.GLOBAL: 'global'>, <SubgraphAllocationMode.CONTINUOUS: 'continuous'>]
    """

    LOCAL = "local"
    GLOBAL = "global"
    CONTINUOUS = "continuous"
