"""
Continuous Memory Allocator - allocates tensors in a single contiguous memory block.
This allocator doesn't support fragmentation and uses a simple bump allocator strategy.
"""

from typing import Dict, List, Tuple

from graph.allocation_types import (
    MemoryBlock,
    TensorAllocation,
)
from graph.base_memory_allocator import BaseMemoryAllocator
from graph.tensor_types import Operation, TensorLocation
from graph.utilities import logger


class ContinuousMemoryAllocator(BaseMemoryAllocator):
    """Continuous memory allocator using bump allocation strategy.

    This allocator allocates tensors sequentially in a single contiguous
    memory block. It doesn't support fragmentation, in-place allocation,
    or spilling - tensors are simply placed one after another using a
    simple bump allocator strategy.

    Attributes:
        alloc_config: Configuration for memory allocation
        allocated_memory: Total amount of memory currently allocated
        memory_blocks: List of memory blocks being managed
        allocations: Dictionary mapping tensor IDs to their allocations
        spilled_tensor_ids: Set of tensor IDs that have been spilled
        memory_size: Total size of all memory blocks

    Examples:
        Creating a continuous allocator:

        >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
        >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
        >>> from graph.tensor_types import XrtId
        >>> from graph.continuous_memory_allocator import ContinuousMemoryAllocator
        >>> config = AllocationConfig(
        ...     strategy=AllocationStrategy.FIRST_FIT,
        ...     alignment=AllocationAlignment.DEFAULT,
        ...     bin=XrtId.DEFAULT
        ... )
        >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
        >>> allocator = ContinuousMemoryAllocator(config, blocks)
        >>> allocator.memory_size
        1024
        >>> len(allocator.allocations)
        0

        Allocating tensors sequentially:

        >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation
        >>> from graph.allocation_types import TensorAllocation
        >>> tensor1 = Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
        >>> alloc1 = TensorAllocation(
        ...     tensor=tensor1,
        ...     range=TensorLifetime(start=0, end=5),
        ...     location=TensorLocation.MEMORY
        ... )
        >>> allocator.allocate(alloc1)
        True
        >>> alloc1.block.start
        0
        >>> alloc1.location == TensorLocation.MEMORY
        True
    """

    def reset(self) -> None:
        """Reset the allocator to initial state.

        Clears all allocations and marks all memory blocks as free.
        This is useful for reusing the allocator with a fresh state.

        Examples:
            Resetting after allocations:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.continuous_memory_allocator import ContinuousMemoryAllocator
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
            >>> allocator = ContinuousMemoryAllocator(config, blocks)
            >>> tensor = Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc = TensorAllocation(
            ...     tensor=tensor,
            ...     range=TensorLifetime(start=0, end=5),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc)
            True
            >>> len(allocator.allocations)
            1
            >>> allocator.reset()
            >>> len(allocator.allocations)
            0
            >>> allocator.allocated_memory
            0
            >>> all(block.is_free for block in allocator.memory_blocks)
            True
        """
        self.allocated_memory = 0
        self.allocations = {}
        self.spilled_tensor_ids = set()
        for block in self.memory_blocks:
            block.is_free = True
            block.tensor_id = None

    def allocate(self, allocation: TensorAllocation) -> bool:
        """Allocate memory for a tensor using continuous allocation.

        Uses a bump allocator strategy to sequentially allocate memory.
        Finds the first free block large enough to fit the tensor and
        splits it if necessary.

        Args:
            allocation: Tensor allocation request

        Returns:
            True if allocation successful, False if not enough space

        Examples:
            Successful allocation:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.continuous_memory_allocator import ContinuousMemoryAllocator
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
            >>> allocator = ContinuousMemoryAllocator(config, blocks)
            >>> tensor = Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc = TensorAllocation(
            ...     tensor=tensor,
            ...     range=TensorLifetime(start=0, end=5),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc)
            True
            >>> alloc.block.is_free
            False
            >>> alloc.block.tensor_id
            't1'
            >>>
            >>> # Allocating multiple tensors sequentially
            >>> tensor2 = Tensor(id="t2", shape=(20,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc2 = TensorAllocation(
            ...     tensor=tensor2,
            ...     range=TensorLifetime(start=5, end=10),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc2)
            True
            >>> len(allocator.allocations)
            2
            >>>
            >>> # Allocation fails when out of memory
            >>> tensor_large = Tensor(id="t3", shape=(1000,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_large = TensorAllocation(
            ...     tensor=tensor_large,
            ...     range=TensorLifetime(start=10, end=15),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc_large)
            False
        """
        assert allocation.tensor.bin == self.alloc_config.bin, (
            "Tensor must be in the same bin"
        )

        if self._update_allocation_if_already_allocated(allocation):
            return True

        tensor = allocation.tensor
        block_idx = self._find_free_block(tensor.size)
        if block_idx == -1:
            return False  # Not enough memory

        # Split the block if necessary
        block = self.memory_blocks[block_idx]
        alloc_size = block.required_size(tensor.size, self.alloc_config)

        if tensor.size == 0 and alloc_size != 0:
            raise RuntimeError(f"{tensor.size} {alloc_size}")

        if block.size > alloc_size:
            # Create new block for remaining space
            new_block = MemoryBlock(
                start=block.start + alloc_size,
                size=block.size - alloc_size,
                is_free=True,
            )
            self.memory_blocks.insert(block_idx + 1, new_block)

        # Allocate the block
        block.size = alloc_size
        block.is_free = False
        block.tensor_id = tensor.id

        # update the allocation
        allocation.block = block
        allocation.location = TensorLocation.MEMORY

        self.allocations[tensor.id] = allocation
        self.allocated_memory += alloc_size
        return True

    def _find_free_block(self, size: int) -> int:
        """Find the first free block that can fit the requested size.

        Implements first-fit strategy for the bump allocator.

        Args:
            size: Required size in bytes

        Returns:
            Index of suitable block, or -1 if none found

        Examples:
            Finding free blocks:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.tensor_types import XrtId
            >>> from graph.continuous_memory_allocator import ContinuousMemoryAllocator
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks = [
            ...     MemoryBlock(start=0, size=100, is_free=False),
            ...     MemoryBlock(start=100, size=200, is_free=True),
            ...     MemoryBlock(start=300, size=300, is_free=True)
            ... ]
            >>> allocator = ContinuousMemoryAllocator(config, blocks)
            >>> allocator._find_free_block(50)
            1
            >>> allocator._find_free_block(250)
            2
            >>> allocator._find_free_block(500)
            -1
        """
        for i, block in enumerate(self.memory_blocks):
            if block.is_free and block.size >= size:
                return i
        return -1

    def deallocate(self, tensor_id: str) -> bool:
        """Deallocate memory for a tensor.

        In continuous allocator, we remove the allocation from tracking
        but don't merge blocks to maintain the continuous allocation pattern.
        Non-deallocatable tensors are skipped.

        Args:
            tensor_id: ID of tensor to deallocate

        Returns:
            True if deallocation successful or skipped, False if tensor not found

        Examples:
            Deallocating tensors:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.continuous_memory_allocator import ContinuousMemoryAllocator
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
            >>> allocator = ContinuousMemoryAllocator(config, blocks)
            >>> tensor = Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc = TensorAllocation(
            ...     tensor=tensor,
            ...     range=TensorLifetime(start=0, end=5),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc)
            True
            >>> "t1" in allocator.allocations
            True
            >>> allocator.deallocate("t1")
            True
            >>> "t1" in allocator.allocations
            False
            >>>
            >>> # Deallocating non-existent tensor returns False
            >>> allocator.deallocate("nonexistent")
            False
            >>>
            >>> # Non-deallocatable tensors are skipped
            >>> tensor2 = Tensor(id="t2", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc2 = TensorAllocation(
            ...     tensor=tensor2,
            ...     range=TensorLifetime(start=5, end=10),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> alloc2.is_deallocatable = False
            >>> allocator.allocate(alloc2)
            True
            >>> allocator.deallocate("t2")
            True
            >>> "t2" in allocator.allocations
            True
        """
        if tensor_id not in self.allocations:
            return False

        allocation = self.allocations[tensor_id]
        if not allocation.is_deallocatable:
            logger.debug(
                "skipping deallocation of %s with size %s since it's not deallocatable, current memory usage: %s",
                tensor_id, allocation.tensor.size, self.get_memory_usage()
            )
            return True

        allocated_block = allocation.block
        assert allocated_block is not None, "Block must be allocated"

        logger.debug(
            "Deallocated %s with size %d, current memory usage: %s",
            tensor_id,
            allocation.tensor.size,
            self.get_memory_usage(),
        )

        # Remove from active allocations
        del self.allocations[tensor_id]
        return True

    def allocate_in_place(self, allocation: TensorAllocation, op: Operation) -> bool:
        """Try to allocate by overwriting other tensors.

        Note: Continuous allocator doesn't support in-place allocation
        as it uses sequential bump allocation strategy.

        Args:
            allocation: Tensor allocation request
            op: Operation that produces this tensor

        Returns:
            False - not supported in continuous allocator
        """
        logger.debug(
            "In-place allocation not supported in continuous allocator for %s",
            allocation.tensor.id,
        )
        return False

    def allocate_with_spilling(
        self, allocation: TensorAllocation
    ) -> Tuple[bool, List[TensorAllocation]]:
        """Try to allocate by spilling other tensors.

        Note: Continuous allocator doesn't support spilling as it uses
        simple sequential allocation without eviction policies.

        Args:
            allocation: Tensor allocation request

        Returns:
            Tuple of (False, []) - not supported in continuous allocator
        """
        logger.debug(
            "Spilling allocation not supported in continuous allocator for %s",
            allocation.tensor.id,
        )
        return (False, [])

    def get_memory_usage(self) -> Dict[str, int]:
        """Get current memory usage statistics.

        Provides detailed information about memory utilization including
        total size, allocated space, free space, fragmentation, and the
        number of active allocations.

        Returns:
            Dictionary with memory usage information:

            - total_size: Total memory pool size
            - allocated: Currently allocated memory
            - free: Available memory for new allocations
            - fragmentation: Number of free blocks
            - active_allocations: Number of active tensor allocations

        Examples:
            Tracking memory usage:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.continuous_memory_allocator import ContinuousMemoryAllocator
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
            >>> allocator = ContinuousMemoryAllocator(config, blocks)
            >>> usage = allocator.get_memory_usage()
            >>> usage['total_size']
            1024
            >>> usage['allocated']
            0
            >>> usage['active_allocations']
            0
            >>>
            >>> # After allocation
            >>> tensor = Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc = TensorAllocation(
            ...     tensor=tensor,
            ...     range=TensorLifetime(start=0, end=5),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc)
            True
            >>> usage = allocator.get_memory_usage()
            >>> usage['active_allocations']
            1
            >>> usage['allocated'] > 0
            True
            >>> usage['free'] < 1024
            True
        """
        free_space = sum(block.size for block in self.memory_blocks if block.is_free)
        fragmentation = len([b for b in self.memory_blocks if b.is_free])

        return {
            "total_size": self.memory_size,
            "allocated": self.allocated_memory,
            "free": free_space,
            "fragmentation": fragmentation,
            "active_allocations": len(self.allocations),
        }
