"""Abstract memory allocator"""

from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple, Iterable
from graph.allocation_types import (
    AllocationConfig,
    AllocationResult,
    MemoryBlock,
    TensorAllocation,
    round_up
)
from graph.tensor_types import Operation, TensorLocation
from graph.utilities import logger


class BaseMemoryAllocator(ABC):
    """Abstract base class for tensor memory allocation strategies.

    This class defines the interface that all memory allocators must implement
    to manage tensor memory allocation and deallocation. Allocators track
    memory blocks, handle allocation strategies, and support features like
    in-place allocation and memory spilling.

    Attributes:
        alloc_config: Configuration for memory allocation behavior
        allocated_memory: Total amount of memory currently allocated
        memory_blocks: List of memory blocks being managed
        allocations: Dictionary mapping tensor IDs to their allocations
        spilled_tensor_ids: Set of tensor IDs that have been spilled
        memory_size: Total size of all memory blocks

    Examples:
        Creating a base allocator requires concrete implementation:

        >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
        >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
        >>> config = AllocationConfig(
        ...     strategy=AllocationStrategy.FIRST_FIT,
        ...     alignment=AllocationAlignment.DEFAULT
        ... )
        >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]

        Concrete allocators inherit from BaseMemoryAllocator and implement
        abstract methods. The allocator tracks total memory size:

        >>> # This would be implemented by a concrete class
        >>> # allocator.memory_size would equal sum of block sizes
        >>> sum(b.size for b in blocks)
        1024

        Multiple memory blocks are supported:

        >>> blocks = [
        ...     MemoryBlock(start=0, size=512, is_free=True),
        ...     MemoryBlock(start=512, size=512, is_free=True)
        ... ]
        >>> sum(b.size for b in blocks)
        1024
    """

    def __init__(
        self, alloc_config: AllocationConfig, memory_blocks: List[MemoryBlock]
    ):
        self.alloc_config: AllocationConfig = alloc_config
        self.allocated_memory: int = 0
        self.memory_blocks: List[MemoryBlock] = memory_blocks
        self.allocations: Dict[str, TensorAllocation] = {}
        self.spilled_tensor_ids: Set[str] = set()
        self.memory_size: int = sum(block.size for block in self.memory_blocks)

    @abstractmethod
    def allocate(self, allocation: TensorAllocation) -> bool:
        """
        Allocate memory for a tensor.

        Args:
            allocation: Tensor allocation request

        Returns:
            True if allocation successful, False otherwise
        """

    @abstractmethod
    def deallocate(self, tensor_id: str) -> bool:
        """
        Deallocate memory for a tensor.

        Args:
            tensor_id: ID of tensor to deallocate

        Returns:
            True if deallocation successful, False otherwise
        """

    @abstractmethod
    def allocate_in_place(self, allocation: TensorAllocation, op: Operation) -> bool:
        """
        Try to allocate by overwriting other tensors.

        Args:
            allocation: Tensor allocation request
            op: Operation that produces this tensor

        Returns:
            True if in-place allocation successful, False otherwise
        """

    @abstractmethod
    def allocate_with_spilling(
        self, allocation: TensorAllocation
    ) -> Tuple[bool, List[TensorAllocation]]:
        """
        Try to allocate by spilling other tensors.

        Args:
            allocation: Tensor allocation request

        Returns:
            Tuple of (success, list of spilled allocations)
        """

    @abstractmethod
    def get_memory_usage(self) -> Dict[str, int]:
        """
        Get current memory usage statistics.

        Returns:
            Dictionary with memory usage information
        """

    def get_allocated_region_bounds(self) -> Tuple[int, int]:
        """Get the bounds of the allocated memory region.

        Returns the minimum and maximum addresses of all currently allocated
        (non-free) memory blocks. This represents the span of active allocations,
        not the total available memory. The maximum address is aligned according
        to the allocator's alignment configuration.

        Returns:
            Tuple of (min_address, max_address) of allocated blocks.
            Returns (0, 0) if no blocks are allocated.

        Examples:
            Empty memory blocks return (0, 0):

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.base_memory_allocator import BaseMemoryAllocator
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT
            ... )
            >>> class TestAllocator(BaseMemoryAllocator):
            ...     def allocate(self, allocation): return True
            ...     def deallocate(self, tensor_id): return True
            ...     def allocate_in_place(self, allocation, op): return True
            ...     def allocate_with_spilling(self, allocation): return (True, [])
            ...     def get_memory_usage(self): return {}
            >>> allocator = TestAllocator(config, [])
            >>> allocator.get_allocated_region_bounds()
            (0, 0)

            With allocated blocks at [100-200] and [500-600]:

            >>> blocks = [
            ...     MemoryBlock(start=100, size=100, is_free=False),
            ...     MemoryBlock(start=500, size=100, is_free=False)
            ... ]
            >>> allocator = TestAllocator(config, blocks)
            >>> allocator.get_allocated_region_bounds()
            (100, 600)

            Free blocks are ignored:

            >>> blocks = [
            ...     MemoryBlock(start=100, size=100, is_free=False),
            ...     MemoryBlock(start=300, size=100, is_free=True),
            ...     MemoryBlock(start=500, size=100, is_free=False)
            ... ]
            >>> allocator = TestAllocator(config, blocks)
            >>> allocator.get_allocated_region_bounds()
            (100, 600)

            All free blocks return (0, 0):

            >>> blocks = [MemoryBlock(start=100, size=100, is_free=True)]
            >>> allocator = TestAllocator(config, blocks)
            >>> allocator.get_allocated_region_bounds()
            (0, 0)
        """
        if not self.memory_blocks:
            return 0, 0

        allocated_blocks = [block for block in self.memory_blocks if not block.is_free]
        if not allocated_blocks:
            return 0, 0

        min_alloc = min(block.start for block in allocated_blocks)
        max_alloc = max(block.start + block.size for block in allocated_blocks)
        return min_alloc, round_up(max_alloc, self.alloc_config.alignment.value)

    def spill_allocation(self, allocation: TensorAllocation):
        """Spill a tensor from memory to secondary storage."""
        if (
            allocation.location == TensorLocation.MEMORY
            and allocation.tensor.id in self.allocations
        ):
            logger.debug(
                "spilling %s, current allocations %s, current memory usage %s",
                allocation.tensor.id, self.allocations.values(), self.get_memory_usage()
            )

            # Free the memory
            self.deallocate(allocation.tensor.id)

            # Mark as spilled
            self.mark_allocation_as_spilled(allocation)

    def mark_allocation_as_spilled(self, allocation: TensorAllocation):
        """Mark an allocation as spilled to secondary storage.

        Updates the allocation's location to SPILLED and adds its tensor ID
        to the set of spilled tensors. This is typically called after the
        memory has been freed.

        Args:
            allocation: The tensor allocation to mark as spilled

        Examples:
            Marking an allocation as spilled:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLocation, TensorLifetime
            >>> from graph.base_memory_allocator import BaseMemoryAllocator
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT
            ... )
            >>> class TestAllocator(BaseMemoryAllocator):
            ...     def allocate(self, allocation): return True
            ...     def deallocate(self, tensor_id): return True
            ...     def allocate_in_place(self, allocation, op): return True
            ...     def allocate_with_spilling(self, allocation): return (True, [])
            ...     def get_memory_usage(self): return {}
            >>> allocator = TestAllocator(config, [])
            >>> tensor = Tensor(id="t1", shape=(2, 3), dtype="TensorProto.FLOAT")
            >>> lifetime = TensorLifetime(start=0, end=10)
            >>> alloc = TensorAllocation(tensor=tensor, range=lifetime, location=TensorLocation.MEMORY)
            >>> alloc.location = TensorLocation.MEMORY
            >>> allocator.mark_allocation_as_spilled(alloc)
            >>> alloc.location == TensorLocation.SPILLED
            True
            >>> "t1" in allocator.spilled_tensor_ids
            True
            >>>
            >>> # Multiple allocations can be spilled
            >>> tensor2 = Tensor(id="t2", shape=(3, 4), dtype="TensorProto.FLOAT")
            >>> alloc2 = TensorAllocation(tensor=tensor2, range=lifetime, location=TensorLocation.MEMORY)
            >>> alloc2.location = TensorLocation.MEMORY
            >>> allocator.mark_allocation_as_spilled(alloc2)
            >>> len(allocator.spilled_tensor_ids)
            2
            >>> "t2" in allocator.spilled_tensor_ids
            True
        """
        allocation.location = TensorLocation.SPILLED
        self.spilled_tensor_ids.add(allocation.tensor.id)

    def free_expired_allocations(self, current_location: int) -> List[TensorAllocation]:
        """Free memory from tensors whose lifetime has ended."""
        expired_allocations: List[TensorAllocation] = []

        for _, allocation in self.allocations.items():
            if (
                allocation.range.end < current_location
                and allocation.location == TensorLocation.MEMORY
            ):
                expired_allocations.append(allocation)

        for allocation in expired_allocations:
            self.deallocate(allocation.tensor.id)

        return expired_allocations

    def convert_to_aligned(
        self, allocation_results: Iterable[List[Tuple[AllocationResult, TensorAllocation]]]
    ) -> None:
        """Converts the allocation results to aligned memory addresses"""
        for results in allocation_results:
            for allocation in results:
                self.convert_alloc_to_aligned(allocation)

    def convert_alloc_to_aligned(
        self, allocation: Tuple[AllocationResult, TensorAllocation]
    ) -> None:
        """Converts an allocation result to aligned memory addresses.

        This method applies alignment to memory blocks in two different ways:

        1. Regular blocks (offset_from_parent is None):
           - Simply align the block.start address to the configured alignment
           - Example: start=100 → align(100) = 4096 (with 4KB alignment)

        2. Offset blocks (offset_from_parent is not None):
           - These are slice/split outputs that point to offsets within parent tensors
           - block.start holds the PARENT's unaligned address (set during allocation)
           - Final address = align(block.start) + offset_from_parent
           - Example: block.start=100, offset=40 → align(100) + 40 = 4096 + 40 = 4136
           - This works because align(block.start) gives the aligned parent address

        The two-phase approach (store parent address during allocation, calculate offset
        after alignment) avoids needing parent references, tensor_id lookups, or graph
        traversal. It's pure arithmetic that handles nested offsets automatically.

        Args:
            allocation: Tuple of (AllocationResult, TensorAllocation)

        Examples:
            Regular block alignment:
            >>> # Block at address 100 aligns to 4096 (4KB boundary)
            >>> # block.start: 100 → 4096

            Offset block alignment:
            >>> # Slice output: parent at 100, offset 40 bytes
            >>> # During allocation: block.start = 100, offset_from_parent = 40
            >>> # After alignment: block.start = align(100) + 40 = 4096 + 40 = 4136
        """
        result, allocation = allocation
        if result in [
            AllocationResult.DEALLOCATED,
            AllocationResult.SPILLED,
        ]:
            return

        assert allocation.block is not None, "Block must be allocated"
        assert allocation.location == TensorLocation.MEMORY, (
            "Tensor must be allocated in memory"
        )

        if allocation.tensor.size == 0:
            assert allocation.block.size == 0, f"Size must be zero for {allocation.tensor.id}"

        # Handle offset blocks (slice/split outputs)
        if allocation.block.offset_from_parent is not None:
            # This block points to an offset within a parent tensor
            # block.start currently holds the parent's unaligned address
            # Calculate: aligned_parent_address + offset
            aligned_parent_start = allocation.block.aligned_start(self.alloc_config)
            final_address = aligned_parent_start + allocation.block.offset_from_parent
            setattr(allocation.block, "start", final_address)
            return

        # Regular block - just align normally
        new_start = allocation.block.aligned_start(self.alloc_config)
        setattr(allocation.block, "start", new_start)

    def _allocations_overlap(self, t1: TensorAllocation, t2: TensorAllocation) -> bool:
        """Check if two tensor allocations have overlapping lifetimes.

        Two allocations overlap if their lifetime ranges intersect. This is used
        to determine if two tensors can share the same memory location.

        Args:
            t1: First tensor allocation
            t2: Second tensor allocation

        Returns:
            True if the lifetimes overlap, False otherwise

        Examples:
            Overlapping lifetimes [0-10] and [5-15]:

            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime
            >>> from graph.base_memory_allocator import BaseMemoryAllocator
            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT
            ... )
            >>> class TestAllocator(BaseMemoryAllocator):
            ...     def allocate(self, allocation): return True
            ...     def deallocate(self, tensor_id): return True
            ...     def allocate_in_place(self, allocation, op): return True
            ...     def allocate_with_spilling(self, allocation): return (True, [])
            ...     def get_memory_usage(self): return {}
            >>> allocator = TestAllocator(config, [])
            >>> t1 = TensorAllocation(
            ...     tensor=Tensor(id="t1", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=0, end=10), location=TensorLocation.MEMORY
            ... )
            >>> t2 = TensorAllocation(
            ...     tensor=Tensor(id="t2", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=5, end=15), location=TensorLocation.MEMORY
            ... )
            >>> allocator._allocations_overlap(t1, t2)
            True

            Non-overlapping lifetimes [0-5] and [10-15]:

            >>> t3 = TensorAllocation(
            ...     tensor=Tensor(id="t3", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=0, end=5), location=TensorLocation.MEMORY
            ... )
            >>> t4 = TensorAllocation(
            ...     tensor=Tensor(id="t4", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=10, end=15), location=TensorLocation.MEMORY
            ... )
            >>> allocator._allocations_overlap(t3, t4)
            False

            Adjacent lifetimes [0-5] and [5-10] do overlap:

            >>> t5 = TensorAllocation(
            ...     tensor=Tensor(id="t5", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=0, end=5), location=TensorLocation.MEMORY
            ... )
            >>> t6 = TensorAllocation(
            ...     tensor=Tensor(id="t6", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=5, end=10), location=TensorLocation.MEMORY
            ... )
            >>> allocator._allocations_overlap(t5, t6)
            True

            Contained lifetime [5-10] within [0-15]:

            >>> t7 = TensorAllocation(
            ...     tensor=Tensor(id="t7", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=0, end=15), location=TensorLocation.MEMORY
            ... )
            >>> t8 = TensorAllocation(
            ...     tensor=Tensor(id="t8", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=5, end=10), location=TensorLocation.MEMORY
            ... )
            >>> allocator._allocations_overlap(t7, t8)
            True
        """
        s1, e1 = t1.range.start, t1.range.end
        s2, e2 = t2.range.start, t2.range.end
        return s1 <= e2 and s2 <= e1

    def _allocations_overwrite(self, t1: TensorAllocation, t2: TensorAllocation) -> bool:
        """Check if one allocation can overwrite another (in-place allocation).

        Returns True if t1's lifetime ends exactly where t2's lifetime starts,
        allowing t2 to reuse t1's memory location. This is a key optimization
        for in-place operations.

        Args:
            t1: First tensor allocation (whose memory can be reused)
            t2: Second tensor allocation (that can reuse memory)

        Returns:
            True if t2 can overwrite t1's memory, False otherwise

        Examples:
            Perfect overwrite: t1 ends at 10, t2 starts at 10:

            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime
            >>> from graph.base_memory_allocator import BaseMemoryAllocator
            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT
            ... )
            >>> class TestAllocator(BaseMemoryAllocator):
            ...     def allocate(self, allocation): return True
            ...     def deallocate(self, tensor_id): return True
            ...     def allocate_in_place(self, allocation, op): return True
            ...     def allocate_with_spilling(self, allocation): return (True, [])
            ...     def get_memory_usage(self): return {}
            >>> allocator = TestAllocator(config, [])
            >>> t1 = TensorAllocation(
            ...     tensor=Tensor(id="t1", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=0, end=10), location=TensorLocation.MEMORY
            ... )
            >>> t2 = TensorAllocation(
            ...     tensor=Tensor(id="t2", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=10, end=20), location=TensorLocation.MEMORY
            ... )
            >>> allocator._allocations_overwrite(t1, t2)
            True

            Gap between lifetimes: t1 ends at 5, t2 starts at 10:

            >>> t3 = TensorAllocation(
            ...     tensor=Tensor(id="t3", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=0, end=5), location=TensorLocation.MEMORY
            ... )
            >>> t4 = TensorAllocation(
            ...     tensor=Tensor(id="t4", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=10, end=20), location=TensorLocation.MEMORY
            ... )
            >>> allocator._allocations_overwrite(t3, t4)
            False

            Overlap: t1 ends at 15, t2 starts at 10:

            >>> t5 = TensorAllocation(
            ...     tensor=Tensor(id="t5", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=0, end=15), location=TensorLocation.MEMORY
            ... )
            >>> t6 = TensorAllocation(
            ...     tensor=Tensor(id="t6", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=10, end=20), location=TensorLocation.MEMORY
            ... )
            >>> allocator._allocations_overwrite(t5, t6)
            False

            Reverse order doesn't work - need to recreate allocations:

            >>> t1_rev = TensorAllocation(
            ...     tensor=Tensor(id="t1r", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=0, end=10), location=TensorLocation.MEMORY
            ... )
            >>> t2_rev = TensorAllocation(
            ...     tensor=Tensor(id="t2r", shape=(2,), dtype="TensorProto.FLOAT"),
            ...     range=TensorLifetime(start=10, end=20), location=TensorLocation.MEMORY
            ... )
            >>> allocator._allocations_overwrite(t2_rev, t1_rev)
            False
        """
        s1, e1 = t1.range.start, t1.range.end
        s2, _ = t2.range.start, t2.range.end
        return s2 >= s1 and e1 == s2

    def _update_allocation_if_already_allocated(
        self, allocation: TensorAllocation
    ) -> bool:
        """
        Update allocation with existing data if tensor is already allocated.

        This is a helper function for derived classes to use in their allocate()
        and allocate_in_place() methods to handle the case where a tensor has
        already been allocated.

        Args:
            allocation: Tensor allocation request to update

        Returns:
            True if tensor was already allocated (allocation updated),
            False if tensor needs to be allocated

        Example:
            def allocate(self, allocation: TensorAllocation) -> bool:
                if self._update_allocation_if_already_allocated(allocation):
                    return True
                # ... proceed with actual allocation ...
        """
        tensor_id = allocation.tensor.id

        if tensor_id not in self.allocations:
            return False

        # Tensor already allocated - update the allocation parameter
        existing = self.allocations[tensor_id]
        allocation.block = existing.block
        allocation.location = existing.location
        allocation.is_deallocatable = existing.is_deallocatable

        logger.debug(
            "Tensor %s is already allocated, block: %s",
            tensor_id,
            allocation.block
        )

        return True
