"""
Memory allocator for tensors with lifetime analysis and spilling support.

This module provides TensorMemoryAllocator which extends BaseMemoryAllocator
with additional capabilities:
- In-place allocation for pointwise operations
- Spilling to secondary storage when memory is constrained
- Lifetime-based allocation and deallocation
"""

import os
from typing import Dict, List, Tuple

from graph.allocation_types import (
    AllocationConfig,
    AllocationStrategy,
    MemoryBlock,
    TensorAllocation,
)
from graph.base_memory_allocator import BaseMemoryAllocator
from graph.L2_fusion_tiling import POINTWISE_OPS
from graph.runtime_ops import GatherRuntimeAttrs, SliceRuntimeAttrs, ConcatRuntimeAttrs, SplitRuntimeAttrs
from graph.tensor_types import Operation, TensorLifetime, TensorLocation
from graph.utilities import logger


class TensorMemoryAllocator(BaseMemoryAllocator):
    """Allocates tensors in a fixed memory pool using lifetime analysis.

    This allocator extends BaseMemoryAllocator with advanced features:

    Features:
        - Lifetime-based allocation: Tracks when tensors are born and die
        - In-place allocation: Reuses memory for pointwise operations
        - Noop aliasing: Aliases memory for operations that don't modify data
        - Spilling support: Evicts tensors to secondary storage when needed
        - Multiple strategies: FIRST_FIT, BEST_FIT, WORST_FIT allocation
        - Fragmentation management: Merges adjacent free blocks

    Attributes:
        alloc_config: Configuration for memory allocation strategy
        allocated_memory: Total amount of memory currently allocated
        memory_blocks: List of memory blocks being managed
        allocations: Dictionary mapping tensor IDs to their allocations
        spilled_tensor_ids: Set of tensor IDs that have been spilled
        memory_size: Total size of all memory blocks

    Examples:
        Basic allocation with lifetime tracking:

        >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
        >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
        >>> from graph.allocation_types import TensorAllocation
        >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
        >>> from graph.tensor_memory_allocator import TensorMemoryAllocator
        >>>
        >>> # Create allocator with 1KB memory
        >>> config = AllocationConfig(
        ...     strategy=AllocationStrategy.FIRST_FIT,
        ...     alignment=AllocationAlignment.DEFAULT,
        ...     bin=XrtId.DEFAULT
        ... )
        >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
        >>> allocator = TensorMemoryAllocator(config, blocks)
        >>>
        >>> # Allocate tensor with lifetime [0-10]
        >>> t1 = Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
        >>> alloc1 = TensorAllocation(
        ...     tensor=t1,
        ...     range=TensorLifetime(start=0, end=10),
        ...     location=TensorLocation.MEMORY
        ... )
        >>> allocator.allocate(alloc1)
        True
        >>>
        >>> # Allocate another tensor with overlapping lifetime [5-15]
        >>> t2 = Tensor(id="t2", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
        >>> alloc2 = TensorAllocation(
        ...     tensor=t2,
        ...     range=TensorLifetime(start=5, end=15),
        ...     location=TensorLocation.MEMORY
        ... )
        >>> allocator.allocate(alloc2)
        True
        >>> len(allocator.allocations)
        2
        >>>
        >>> # Deallocate first tensor when lifetime ends
        >>> allocator.deallocate("t1")
        True
        >>> len(allocator.allocations)
        1
    """

    def __reinit__(self):
        self.allocated_memory = 0
        self.allocations.clear()
        self.spilled_tensor_ids.clear()
        for block in self.memory_blocks:
            block.is_free = True
            block.tensor_id = None
        self._merge_free_blocks()

    @classmethod
    def from_allocations(
        cls,
        alloc_config: AllocationConfig, memory_blocks: List[MemoryBlock],
        allocations: Dict[str, TensorAllocation] = {}
    ):
        """Construct given allocations"""
        obj = cls(alloc_config, memory_blocks)
        obj.allocated_memory = sum(
            allocation.block.size for allocation in allocations.values())
        obj.allocations: Dict[str, TensorAllocation] = allocations
        return obj

    def allocate(self, allocation: TensorAllocation) -> bool:
        """Allocate memory for a tensor using configured strategy.

        Finds a suitable free block based on the allocation strategy
        (FIRST_FIT, BEST_FIT, or WORST_FIT) and assigns it to the tensor.
        Splits blocks when necessary to minimize waste.

        Args:
            allocation: Tensor allocation request with lifetime information

        Returns:
            True if allocation successful, False if not enough memory

        Examples:
            Different allocation strategies produce different results:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.tensor_memory_allocator import TensorMemoryAllocator
            >>>
            >>> # FIRST_FIT: Uses first block that fits
            >>> config_ff = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks_ff = [MemoryBlock(start=0, size=1024, is_free=True)]
            >>> allocator_ff = TensorMemoryAllocator(config_ff, blocks_ff)
            >>> t1 = Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc1 = TensorAllocation(
            ...     tensor=t1,
            ...     range=TensorLifetime(start=0, end=5),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator_ff.allocate(alloc1)
            True
            >>> alloc1.block.start
            0
            >>>
            >>> # Out of memory returns False
            >>> t_large = Tensor(id="t_large", shape=(1000,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_large = TensorAllocation(
            ...     tensor=t_large,
            ...     range=TensorLifetime(start=0, end=5),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator_ff.allocate(alloc_large)
            False
        """
        assert allocation.tensor.bin == self.alloc_config.bin, "Tensor must be in the same bin"

        if self._update_allocation_if_already_allocated(allocation):
            return True

        tensor = allocation.tensor
        block_idx = self._find_free_block(tensor.size)
        if block_idx == -1:
            return False  # Not enough memory

        # Split the block if necessary
        block = self.memory_blocks[block_idx]
        alloc_size = block.required_size(tensor.size, self.alloc_config)

        if tensor.size == 0 and alloc_size != 0:
            raise RuntimeError(f"{tensor.size} {alloc_size}")

        if block.size > alloc_size:
            # Create new block for remaining space
            new_block = MemoryBlock(
                start=block.start + alloc_size,
                size=block.size - alloc_size,
                is_free=True,
            )
            self.memory_blocks.insert(block_idx + 1, new_block)

        # Allocate the block
        block.size = alloc_size
        block.is_free = False
        block.tensor_id = tensor.id

        # update the allocation
        allocation.block = block
        allocation.location = TensorLocation.MEMORY

        if os.getenv("AIE4_FORCE_ALLOCATOR_MODE_L3") == "CONTINUOUS":
            allocation.is_deallocatable = False

        self.allocations[tensor.id] = allocation
        self.allocated_memory += alloc_size
        return True

    def allocate_in_place(self, allocation: TensorAllocation, op: Operation) -> bool:
        """Try to allocate by overwriting other tensors for pointwise operations.

        For pointwise operations (Add, Mul, etc.), the output can reuse the input's
        memory when the input's lifetime ends exactly where the output begins. This
        optimization reduces memory usage without copying data.

        Args:
            allocation: Tensor allocation request for the output
            op: Operation that produces this tensor (must be pointwise)

        Returns:
            True if in-place allocation successful, False otherwise

        Examples:
            In-place allocation for pointwise operations:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.tensor_types import Operation
            >>> from graph.tensor_memory_allocator import TensorMemoryAllocator
            >>>
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
            >>> allocator = TensorMemoryAllocator(config, blocks)
            >>>
            >>> # Allocate input tensor with lifetime [0-10]
            >>> t_input = Tensor(id="input", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_input = TensorAllocation(
            ...     tensor=t_input,
            ...     range=TensorLifetime(start=0, end=10),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc_input)
            True
            >>> input_block_start = alloc_input.block.start
            >>>
            >>> # Output tensor with lifetime [10-20] can reuse input's memory
            >>> t_output = Tensor(id="output", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_output = TensorAllocation(
            ...     tensor=t_output,
            ...     range=TensorLifetime(start=10, end=20),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> op_add = Operation(id="add1", type="Add", inputs=["input"], outputs=["output"])
            >>> allocator.allocate_in_place(alloc_output, op_add)
            True
            >>> alloc_output.block.start == input_block_start
            True
            >>>
            >>> # Non-pointwise operations don't support in-place
            >>> t_conv = Tensor(id="conv_out", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_conv = TensorAllocation(
            ...     tensor=t_conv,
            ...     range=TensorLifetime(start=20, end=30),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> op_conv = Operation(id="conv1", type="Conv", inputs=["output"], outputs=["conv_out"])
            >>> allocator.allocate_in_place(alloc_conv, op_conv)
            False
        """
        assert allocation.tensor.bin == self.alloc_config.bin, (
            f"Tensor must be in the same bin {self.alloc_config.bin} {allocation.tensor.bin}"
        )

        if op.type not in POINTWISE_OPS:
            return False

        if self._update_allocation_if_already_allocated(allocation):
            return True

        tensor = allocation.tensor
        if tensor.size == 0:
            return False

        # Find a tensor whose lifetime ends where the current one starts, and has same space
        overwritable_candidate = None
        for existing_allocation in self.allocations.values():
            if (
                existing_allocation.location == TensorLocation.MEMORY
                and self._allocations_overwrite(existing_allocation, allocation)
                and existing_allocation.tensor.shape == allocation.tensor.shape
                and existing_allocation.is_deallocatable
            ):
                overwritable_candidate = existing_allocation

        if overwritable_candidate is None:
            return False   # No suitable candidate

        # Update the block
        block = overwritable_candidate.block
        block.is_free = False
        block.tensor_id = tensor.id

        # update the lifetimes
        s, e = overwritable_candidate.range.start, allocation.range.end
        overwritable_candidate.range = TensorLifetime(s, e)

        # update the allocation
        allocation.block = block
        allocation.range = overwritable_candidate.range
        allocation.location = TensorLocation.MEMORY

        del self.allocations[overwritable_candidate.tensor.id]
        self.allocations[tensor.id] = allocation
        self.allocated_memory += 0
        return True

    def allocate_noop_in_place(
        self, allocation: TensorAllocation, input_tensor_id: str
    ) -> bool:
        """
        Allocate by aliasing memory from the input tensor for noop operations.

        Noop operations (e.g., Reshape_noop, Transpose_noop) do not modify data,
        so their output can alias the input's memory instead of copying. This enables
        zero-copy transformations where the output tensor simply points to the same
        memory block as the input, tracking both IDs on the block for proper lifetime
        management during deallocation.

        Args:
            allocation: The allocation for the output tensor
            input_tensor_id: The ID of the input tensor whose memory to alias

        Returns:
            True if aliasing succeeded, False otherwise

        Examples:
            Aliasing memory for no-op operations:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.tensor_memory_allocator import TensorMemoryAllocator
            >>>
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
            >>> allocator = TensorMemoryAllocator(config, blocks)
            >>>
            >>> # Allocate input tensor
            >>> t_input = Tensor(id="input", shape=(10, 4), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_input = TensorAllocation(
            ...     tensor=t_input,
            ...     range=TensorLifetime(start=0, end=20),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc_input)
            True
            >>>
            >>> # Output is just a reshape - no data copy needed
            >>> t_output = Tensor(id="output", shape=(40,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_output = TensorAllocation(
            ...     tensor=t_output,
            ...     range=TensorLifetime(start=10, end=20),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate_noop_in_place(alloc_output, "input")
            True
            >>>
            >>> # Output aliases input's memory
            >>> alloc_output.block.start == alloc_input.block.start
            True
            >>> "output" in alloc_output.block.aliased_tensor_ids
            True
            >>>
            >>> # Input not allocated yet - fails
            >>> t_output2 = Tensor(id="output2", shape=(40,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_output2 = TensorAllocation(
            ...     tensor=t_output2,
            ...     range=TensorLifetime(start=20, end=30),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate_noop_in_place(alloc_output2, "nonexistent")
            False
        """
        assert allocation.tensor.bin == self.alloc_config.bin, (
            f"Tensor must be in the same bin {self.alloc_config.bin} {allocation.tensor.bin}"
        )

        if self._update_allocation_if_already_allocated(allocation):
            return True

        # Find the input tensor allocation
        if input_tensor_id not in self.allocations:
            return False  # Input not allocated yet

        input_allocation = self.allocations[input_tensor_id]

        # Input must already be in memory to alias it
        if input_allocation.location != TensorLocation.MEMORY:
            return False

        # Alias the input's memory block
        block = input_allocation.block
        assert block is not None, "Input block must be allocated"

        # For noop ops, output aliases input - track both tensor IDs
        block.is_free = False
        if allocation.tensor.id not in block.aliased_tensor_ids:
            block.aliased_tensor_ids.append(allocation.tensor.id)

        # The output allocation inherits the input's block but keeps its own lifetime
        allocation.block = block
        allocation.location = TensorLocation.MEMORY

        # Add the output allocation (input remains in allocations)
        self.allocations[allocation.tensor.id] = allocation

        # No additional memory consumed - this is an alias
        self.allocated_memory += 0

        logger.debug(
            "Aliased tensor %s to input tensor %s (block at offset %s, size %s)",
            allocation.tensor.id, input_tensor_id, block.start, block.size
        )

        return True

    def allocate_with_spilling(
        self, allocation: TensorAllocation
    ) -> Tuple[bool, List[TensorAllocation]]:
        """Try to allocate by spilling other tensors to secondary storage.

        When memory is full, this method identifies tensors with overlapping lifetimes
        that can be evicted to secondary storage. Uses a greedy heuristic: spill tensors
        that end later first, as they're less likely to be needed soon. Attempts to
        spill the minimum number of tensors needed to free enough space.

        Args:
            allocation: Tensor allocation request that failed due to insufficient memory

        Returns:
            Tuple of (success: bool, spilled: List[TensorAllocation])
            - success: True if allocation succeeded after spilling
            - spilled: List of allocations that were evicted to make room

        Examples:
            Spilling tensors when memory is constrained:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.tensor_memory_allocator import TensorMemoryAllocator
            >>>
            >>> # Create small memory pool (200 bytes)
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks = [MemoryBlock(start=0, size=200, is_free=True)]
            >>> allocator = TensorMemoryAllocator(config, blocks)
            >>>
            >>> # Allocate tensor taking most of memory
            >>> t1 = Tensor(id="t1", shape=(20,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc1 = TensorAllocation(
            ...     tensor=t1,
            ...     range=TensorLifetime(start=0, end=20),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc1)
            True
            >>>
            >>> # Try to allocate overlapping tensor - requires spilling
            >>> t2 = Tensor(id="t2", shape=(30,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc2 = TensorAllocation(
            ...     tensor=t2,
            ...     range=TensorLifetime(start=10, end=30),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> success, spilled = allocator.allocate_with_spilling(alloc2)
            >>> success
            True
            >>> len(spilled) > 0  # At least one tensor was spilled
            True
            >>> spilled[0].tensor.id
            't1'
            >>>
            >>> # Not enough memory even with spilling - fails
            >>> t3 = Tensor(id="t3", shape=(100,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc3 = TensorAllocation(
            ...     tensor=t3,
            ...     range=TensorLifetime(start=20, end=40),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> success, spilled = allocator.allocate_with_spilling(alloc3)
            >>> success
            False
            >>> len(spilled)
            0
        """
        assert allocation.tensor.bin == self.alloc_config.bin, "Tensor must be in the same bin"

        # Find tensors that can be spilled (currently in memory, overlapping lifetime)
        spillable_candidates = []

        for existing_allocation in self.allocations.values():
            if (
                existing_allocation.location == TensorLocation.MEMORY
                and self._allocations_overlap(allocation, existing_allocation)
            ):
                spillable_candidates.append(existing_allocation)

        # Sort by end time (spill tensors that end later first - greedy heuristic)
        spillable_candidates.sort(key=lambda t: t.range.end, reverse=True)

        # Try spilling combinations until we have enough space
        spilled_size = 0
        allocations_to_spill = []

        for candidate in spillable_candidates:
            allocations_to_spill.append(candidate)
            spilled_size += candidate.block.size

            # Check if we now have enough space
            available_space = self.get_memory_usage()["free"] + spilled_size
            if available_space >= allocation.tensor.size:
                # Spill the selected tensors
                for spill_allocation in allocations_to_spill:
                    self.spill_allocation(spill_allocation)

                # Try allocation again
                return (self.allocate(allocation), allocations_to_spill)

        return (False, [])

    def deallocate(self, tensor_id: str) -> bool:
        """Deallocate memory for a tensor, handling aliases and ownership transfers.

        Manages complex deallocation scenarios including:
        - Non-deallocatable tensors (graph I/O) remain in memory for entire graph lifetime
        - Aliased tensors: removes from alias list without freeing block
        - Primary owner with aliases: transfers ownership to first alias
        - Primary owner without aliases: frees block and merges adjacent free blocks

        Args:
            tensor_id: ID of the tensor to deallocate

        Returns:
            True if deallocation handled successfully, False if tensor not found

        Examples:
            Complex deallocation with aliases:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.tensor_memory_allocator import TensorMemoryAllocator
            >>>
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
            >>> allocator = TensorMemoryAllocator(config, blocks)
            >>>
            >>> # Allocate input tensor
            >>> t_input = Tensor(id="input", shape=(10, 4), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_input = TensorAllocation(
            ...     tensor=t_input,
            ...     range=TensorLifetime(start=0, end=20),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate(alloc_input)
            True
            >>>
            >>> # Create aliased output
            >>> t_output = Tensor(id="output", shape=(40,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_output = TensorAllocation(
            ...     tensor=t_output,
            ...     range=TensorLifetime(start=10, end=20),
            ...     location=TensorLocation.MEMORY
            ... )
            >>> allocator.allocate_noop_in_place(alloc_output, "input")
            True
            >>>
            >>> # Deallocate aliased tensor - doesn't free memory
            >>> allocator.deallocate("output")
            True
            >>> alloc_input.block.is_free
            False
            >>>
            >>> # Deallocate primary owner - now frees memory
            >>> allocator.deallocate("input")
            True
            >>> alloc_input.block.is_free
            True
            >>>
            >>> # Non-deallocatable tensor remains in memory
            >>> t_io = Tensor(id="graph_input", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> alloc_io = TensorAllocation(
            ...     tensor=t_io,
            ...     range=TensorLifetime(start=0, end=100),
            ...     location=TensorLocation.MEMORY,
            ...     is_deallocatable=False
            ... )
            >>> allocator.allocate(alloc_io)
            True
            >>> allocator.deallocate("graph_input")  # No-op for non-deallocatable
            True
            >>> "graph_input" in allocator.allocations  # Still present
            True
        """
        if tensor_id not in self.allocations:
            return False

        allocation = self.allocations[tensor_id]
        allocated_block = allocation.block
        assert allocated_block is not None, "Block must be allocated"

        # If tensor is not deallocatable (e.g., graph I/O), skip deallocation
        if not allocation.is_deallocatable:
            logger.debug(
                "skipping deallocation of %s with size %s since it's not deallocatable, current memory usage: %s",
                tensor_id, allocation.tensor.size, self.get_memory_usage()
            )
            # Do NOT remove from allocations - it remains for entire graph lifetime
            return True

        # Handle alias tracking and deallocation for deallocatable tensors
        if allocation.location == TensorLocation.MEMORY:
            for block in self.memory_blocks:
                # Check if this block is owned by the tensor (primary) or aliased
                is_primary_owner = block.start == allocated_block.start and block.tensor_id == tensor_id
                is_aliased = block.start == allocated_block.start and tensor_id in block.aliased_tensor_ids

                if is_primary_owner or is_aliased:
                    # Remove from aliased list if present
                    if tensor_id in block.aliased_tensor_ids:
                        block.aliased_tensor_ids.remove(tensor_id)

                    # Only free the block if no other tensors are aliasing it
                    if block.tensor_id == tensor_id:
                        # This tensor is the primary owner
                        if len(block.aliased_tensor_ids) == 0:
                            # No aliases left, free the block
                            block.is_free = True
                            block.tensor_id = None
                            self.allocated_memory -= block.size
                            self._merge_free_blocks()
                        else:
                            # Transfer ownership to the first alias
                            block.tensor_id = block.aliased_tensor_ids[0]
                            block.aliased_tensor_ids.pop(0)
                    # else: This was an alias, already removed from list above
                    break

        logger.debug(
            "deallocated %s with size %s, current memory usage: %s",
            tensor_id, allocation.tensor.size, self.get_memory_usage()
        )
        del self.allocations[tensor_id]
        return True

    def get_memory_usage(self) -> Dict[str, int]:
        """Get current memory usage."""
        return {
            "total_size": self.memory_size,
            "allocated": self.allocated_memory,
            "free": self.memory_size - self.allocated_memory,
            "fragmentation": len([b for b in self.memory_blocks if b.is_free]),
            "active_allocations": len(self.allocations)
        }

    def _find_free_block(self, size: int) -> int:
        """Find a suitable free block based on allocation strategy.

        Implements three strategies:
        - FIRST_FIT: Returns first block large enough (fast, may cause fragmentation)
        - BEST_FIT: Returns smallest block that fits (minimizes waste, slower)
        - WORST_FIT: Returns largest block (keeps large blocks available, may waste space)

        Args:
            size: Required size in bytes (before alignment)

        Returns:
            Index of suitable block in memory_blocks list, or -1 if none found

        Examples:
            Comparing allocation strategies:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.tensor_memory_allocator import TensorMemoryAllocator
            >>>
            >>> # Create memory with fragmented blocks
            >>> config_first = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks_first = [
            ...     MemoryBlock(start=0, size=100, is_free=True),
            ...     MemoryBlock(start=100, size=200, is_free=True),
            ...     MemoryBlock(start=300, size=150, is_free=True)
            ... ]
            >>> allocator_first = TensorMemoryAllocator(config_first, blocks_first)
            >>>
            >>> # FIRST_FIT: Returns index 0 (first block of 100 bytes)
            >>> allocator_first._find_free_block(80)
            0
            >>>
            >>> # BEST_FIT: Would return index 0 (smallest that fits)
            >>> config_best = AllocationConfig(
            ...     strategy=AllocationStrategy.BEST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks_best = [
            ...     MemoryBlock(start=0, size=100, is_free=True),
            ...     MemoryBlock(start=100, size=200, is_free=True),
            ...     MemoryBlock(start=300, size=150, is_free=True)
            ... ]
            >>> allocator_best = TensorMemoryAllocator(config_best, blocks_best)
            >>> allocator_best._find_free_block(80)
            0
            >>>
            >>> # WORST_FIT: Returns index 1 (largest block of 200 bytes)
            >>> config_worst = AllocationConfig(
            ...     strategy=AllocationStrategy.WORST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>> blocks_worst = [
            ...     MemoryBlock(start=0, size=100, is_free=True),
            ...     MemoryBlock(start=100, size=200, is_free=True),
            ...     MemoryBlock(start=300, size=150, is_free=True)
            ... ]
            >>> allocator_worst = TensorMemoryAllocator(config_worst, blocks_worst)
            >>> allocator_worst._find_free_block(80)
            1
            >>>
            >>> # No block large enough - returns -1
            >>> allocator_first._find_free_block(500)
            -1
        """
        suitable_blocks = []

        for i, block in enumerate(self.memory_blocks):
            if block.is_free and block.size >= block.required_size(size, self.alloc_config):
                suitable_blocks.append((i, block.size))

        if not suitable_blocks:
            return -1

        if self.alloc_config.strategy == AllocationStrategy.FIRST_FIT:
            return suitable_blocks[0][0]
        if self.alloc_config.strategy == AllocationStrategy.BEST_FIT:
            # Find smallest block that fits
            return min(suitable_blocks, key=lambda x: x[1])[0]
        if self.alloc_config.strategy == AllocationStrategy.WORST_FIT:
            # Find largest block that fits
            return max(suitable_blocks, key=lambda x: x[1])[0]
        return -1

    def _merge_free_blocks(self):
        """Merge adjacent free blocks to reduce fragmentation.

        Sorts memory blocks by start address, then iterates to find adjacent free blocks
        that can be combined into larger contiguous regions. This defragmentation process
        is called after deallocations to maintain memory efficiency and enable allocation
        of larger tensors.

        Examples:
            Reducing fragmentation by merging adjacent blocks:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock
            >>> from graph.allocation_types import TensorAllocation
            >>> from graph.tensor_types import Tensor, TensorLifetime, TensorLocation, XrtId
            >>> from graph.tensor_memory_allocator import TensorMemoryAllocator
            >>>
            >>> config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.DEFAULT
            ... )
            >>>
            >>> # Start with two adjacent free blocks
            >>> blocks = [
            ...     MemoryBlock(start=0, size=100, is_free=True),
            ...     MemoryBlock(start=100, size=100, is_free=True),
            ...     MemoryBlock(start=200, size=200, is_free=False, tensor_id="occupied")
            ... ]
            >>> allocator = TensorMemoryAllocator(config, blocks)
            >>> len(allocator.memory_blocks)
            3
            >>>
            >>> # Merge adjacent free blocks
            >>> allocator._merge_free_blocks()
            >>>
            >>> # Now only 2 blocks: one merged [0-200] and one occupied [200-400]
            >>> len(allocator.memory_blocks)
            2
            >>> allocator.memory_blocks[0].start
            0
            >>> allocator.memory_blocks[0].size
            200
            >>> allocator.memory_blocks[0].is_free
            True
        """
        self.memory_blocks.sort(key=lambda block: block.start)
        i = 0
        while i < len(self.memory_blocks) - 1:
            current = self.memory_blocks[i]
            next_block = self.memory_blocks[i + 1]

            if (
                current.is_free
                and next_block.is_free
                and current.start + current.size == next_block.start
            ):
                current.size += next_block.size
                self.memory_blocks.pop(i + 1)
            else:
                i += 1

    def allocate_gather_runtime(
        self,
        allocation: TensorAllocation,
        input_tensor_id: str,
        onnx_attrs: dict
    ) -> bool:
        """Allocate gather_runtime using max-size aliasing with axis normalization.

        Args:
            allocation: Output tensor allocation
            input_tensor_id: ID of input tensor to alias
            onnx_attrs: ONNX attributes containing:
                - axis: int - Gather axis (validated against input shape)

        Returns:
            True if allocation succeeded, False otherwise
        """
        assert allocation.tensor.bin == self.alloc_config.bin, (
            f"Tensor must be in the same bin {self.alloc_config.bin} {allocation.tensor.bin}"
        )

        if self._update_allocation_if_already_allocated(allocation):
            return True

        # Find the input tensor allocation
        if input_tensor_id not in self.allocations:
            return False  # Input not allocated yet

        input_allocation = self.allocations[input_tensor_id]

        # Input must already be in memory to alias it
        if input_allocation.location != TensorLocation.MEMORY:
            return False

        input_shape = input_allocation.tensor.shape

        _ = GatherRuntimeAttrs.build(onnx_attrs, input_shape)

        # Validate size constraint: output must not exceed input size
        if allocation.tensor.size > input_allocation.tensor.size:
            raise ValueError(
                f"gather_runtime: output size ({allocation.tensor.size}) "
                f"exceeds input size ({input_allocation.tensor.size}). "
                f"Input shape: {input_allocation.tensor.shape}, Output shape: {allocation.tensor.shape}"
            )

        block = input_allocation.block
        assert block is not None, "Input block must be allocated"

        block.is_free = False
        if allocation.tensor.id not in block.aliased_tensor_ids:
            block.aliased_tensor_ids.append(allocation.tensor.id)

        allocation.block = block
        allocation.location = TensorLocation.MEMORY

        self.allocations[allocation.tensor.id] = allocation
        return True

    def allocate_slice_runtime(
        self,
        allocation: TensorAllocation,
        input_tensor_id: str,
        onnx_attrs: dict
    ) -> bool:
        """Allocate slice_runtime using offset-based aliasing with axis normalization.

        The output tensor points to an offset within the input tensor's buffer.
        Supports any axis when all dimensions before the slice axis are unity (1).

        Implementation uses two-phase address calculation to handle alignment:
        - Store parent's unaligned address in output_block.start
        - Store byte offset in output_block.offset_from_parent
        - During alignment: final_address = align(block.start) + offset_from_parent

        See MemoryBlock.offset_from_parent documentation in allocation_types.py for
        detailed explanation of the two-phase address calculation approach.

        Args:
            allocation: Output tensor allocation
            input_tensor_id: ID of input tensor to slice from
            onnx_attrs: ONNX attributes containing:
                - starts: list[int] - Start indices (single element)
                - axes: list[int] - Slice axis (validated against input shape)
                - ends: list[int] - End indices (optional)
                - steps: list[int] - Step sizes (optional)

        Returns:
            True if allocation succeeded, False otherwise

        Raises:
            ValueError: If slice validation fails (invalid axes, starts, or size mismatch)

        Examples:
            PSMU_ST1 KV-cache: axes=[2] (W dimension) with N=1, H=1:
            - Input: [1, 1, 64, 64], starts=[1], ends=[64]
            - Output: [1, 1, 63, 64]
            - Equivalent to axis=0 slicing in flattened memory
        """
        assert allocation.tensor.bin == self.alloc_config.bin, (
            f"Tensor must be in the same bin {self.alloc_config.bin} {allocation.tensor.bin}"
        )

        if self._update_allocation_if_already_allocated(allocation):
            return True

        # Find the input tensor allocation
        if input_tensor_id not in self.allocations:
            return False  # Input not allocated yet

        input_allocation = self.allocations[input_tensor_id]

        # Input must already be in memory to alias it
        if input_allocation.location != TensorLocation.MEMORY:
            return False

        # Get input shape directly from ONNX-inferred shape (already 4D NHWC)
        input_shape = input_allocation.tensor.shape

        # Parse and validate slice attributes with shape-aware axis normalization
        attrs = SliceRuntimeAttrs.build(onnx_attrs, input_shape)

        # Get start index and axis
        start_idx = attrs.starts[0]
        axis = attrs.axes[0]

        # Normalize axis to positive index
        normalized_axis = axis if axis >= 0 else len(input_shape) + axis

        # Validate size constraint: output must not exceed input size
        if allocation.tensor.size > input_allocation.tensor.size:
            raise ValueError(
                f"slice_runtime: output size ({allocation.tensor.size}) "
                f"exceeds input size ({input_allocation.tensor.size}). "
                f"Input shape: {input_shape}, Output shape: {allocation.tensor.shape}"
            )

        # Get input block and calculate offset
        input_block = input_allocation.block
        assert input_block is not None, "Input block must be allocated"

        batch_shape = input_shape[normalized_axis + 1:] if normalized_axis + 1 < len(input_shape) else []
        if batch_shape:
            size_per_batch = input_allocation.tensor.update(shape=batch_shape).size
        else:
            # Slicing along last dimension - each element is one unit
            size_per_batch = input_allocation.tensor.update(shape=[1]).size // input_shape[-1] if input_shape[-1] > 0 else 0

        # Byte offset = start_index × size_per_batch
        byte_offset = start_idx * size_per_batch

        # If input itself is a slice (has offset_from_parent), accumulate offsets
        # This handles chained slicing: slice(slice(tensor))
        if input_block.offset_from_parent is not None:
            accumulated_offset = input_block.offset_from_parent + byte_offset
        else:
            accumulated_offset = byte_offset

        # CRITICAL: Store parent's unaligned address in block.start (NOT parent + offset)
        # This allows convert_alloc_to_aligned to calculate: align(block.start) + offset
        # See docstring for detailed explanation of two-phase address calculation
        output_block = MemoryBlock(
            start=input_block.start,  # Parent's address (will be aligned later)
            size=allocation.tensor.size,
            is_free=False,
            tensor_id=allocation.tensor.id,
            aliased_tensor_ids=[],
            offset_from_parent=accumulated_offset,  # Accumulated byte offset from original parent
        )

        # Mark the input block as having an aliased tensor
        if allocation.tensor.id not in input_block.aliased_tensor_ids:
            input_block.aliased_tensor_ids.append(allocation.tensor.id)

        # The output allocation uses the offset block
        allocation.block = output_block
        allocation.location = TensorLocation.MEMORY

        # Store the allocation
        self.allocations[allocation.tensor.id] = allocation

        logger.debug(
            "Allocated slice_runtime for %s aliasing %s: byte_offset=%d, output_size=%d",
            allocation.tensor.id, input_tensor_id, byte_offset, allocation.tensor.size
        )

        return True

    def allocate_split_runtime(
        self,
        allocations: List[TensorAllocation],
        input_tensor_id: str,
        onnx_attrs: dict,
    ) -> bool:
        """Allocate split_runtime using contiguous output views with axis normalization.

        Splits an input tensor into multiple output tensors. Supports any axis when all
        dimensions before the split axis are unity (1).

        Args:
            allocations: List of output tensor allocations (in order)
            input_tensor_id: ID of input tensor to split from
            onnx_attrs: ONNX attributes containing:
                - split: list[int] - Dimension sizes for each output along split axis
                - axis: int - Split axis (validated against input shape)

        Returns:
            True if allocation succeeded, False otherwise

        Raises:
            ValueError: If split validation fails
        """
        # Validate all allocations are in same bin and check if already allocated
        for allocation in allocations:
            assert allocation.tensor.bin == self.alloc_config.bin, (
                f"Tensor must be in the same bin {self.alloc_config.bin} {allocation.tensor.bin}"
            )

            if self._update_allocation_if_already_allocated(allocation):
                # If any output already allocated, can't do split aliasing
                # All outputs must be unallocated to alias into input buffer
                return False

        if input_tensor_id not in self.allocations:
            return False  # Input not allocated yet

        input_allocation = self.allocations[input_tensor_id]

        # Verify input is in memory
        if input_allocation.location != TensorLocation.MEMORY:
            return False

        input_shape = input_allocation.tensor.shape
        attrs = SplitRuntimeAttrs.build(onnx_attrs, input_shape)

        axis = attrs.axis
        normalized_axis = axis if axis >= 0 else len(input_shape) + axis

        if sum(attrs.split) != input_shape[normalized_axis]:
            raise ValueError(
                f"split_runtime: split values sum to {sum(attrs.split)} "
                f"but input dimension at axis={axis} is {input_shape[normalized_axis]}"
            )

        batch_shape = input_shape[normalized_axis + 1:] if normalized_axis + 1 < len(input_shape) else []
        if batch_shape:
            size_per_batch = input_allocation.tensor.update(shape=batch_shape).size
        else:
            size_per_batch = input_allocation.tensor.update(shape=[1]).size // input_shape[-1] if input_shape[-1] > 0 else 0

        output_sizes = [split_val * size_per_batch for split_val in attrs.split]

        if sum(output_sizes) != input_allocation.tensor.size:
            raise ValueError(
                f"split_runtime: output sizes sum to {sum(output_sizes)} "
                f"but input size is {input_allocation.tensor.size}"
            )

        # Create output blocks with offsets
        input_block = input_allocation.block
        assert input_block is not None, "Input block must be allocated"
        cumulative_offset = 0

        for i, allocation in enumerate(allocations):
            # Handle chained operations: if input itself has offset, accumulate
            if input_block.offset_from_parent is not None:
                accumulated_offset = input_block.offset_from_parent + cumulative_offset
            else:
                accumulated_offset = cumulative_offset

            output_block = MemoryBlock(
                start=input_block.start,  # Parent's address (will be aligned later)
                size=output_sizes[i],
                is_free=False,
                tensor_id=allocation.tensor.id,
                aliased_tensor_ids=[],
                offset_from_parent=accumulated_offset,
            )
            cumulative_offset += output_sizes[i]

            # Mark input as having aliased tensor
            if allocation.tensor.id not in input_block.aliased_tensor_ids:
                input_block.aliased_tensor_ids.append(allocation.tensor.id)

            # Store the allocation
            allocation.block = output_block
            allocation.location = TensorLocation.MEMORY
            self.allocations[allocation.tensor.id] = allocation

        logger.debug(
            "Allocated split_runtime for %d outputs from %s: split=%s, axis=%d",
            len(allocations), input_tensor_id, attrs.split, axis
        )

        return True

    def allocate_concat_runtime(
        self,
        allocations: List[TensorAllocation],
        output_tensor_id: str,
        onnx_attrs: dict,
    ) -> bool:
        """Allocate concat_runtime using contiguous input layout with axis normalization.

        Concatenates multiple input tensors by allocating a single output buffer and
        partitioning it into contiguous input views. Supports any axis when all dimensions
        before the concat axis are unity (1), making it equivalent to axis=0 in memory layout.

        This method is called during the REVERSE PASS (after output allocation), ensuring
        concat inputs are laid out contiguously in the output buffer.

        Axis Normalization:
            When all leading dimensions are 1, concat along any axis has the same memory
            layout as axis=0 (flattened concatenation). Examples:
            - axis=-2 with shapes [(1,1,63,64), (1,1,1,64)] → Valid (N=1, H=1)
            - axis=1 with shapes [(1,10,20,64), (1,5,20,64)] → Valid (N=1)
            - axis=-1 with shapes [(1,1,1,63), (1,1,1,1)] → Valid (N=1, H=1, W=1)

        Memory Layout (conceptual for axis=0):
            Output: [N_total, H, W, C] where N_total = sum(N_i)
            Input 0: points to output[0:N_0, :, :, :]
            Input 1: points to output[N_0:N_0+N_1, :, :, :]
            ...

        Args:
            allocations: List of input tensor allocations (in concat order)
            output_tensor_id: ID of the output tensor (already allocated)
            onnx_attrs: ONNX attributes containing:
                - axis: int - Concatenation axis (validated against input shapes)

        Returns:
            True if allocation succeeded, False otherwise

        Raises:
            ValueError: If concat validation fails:
                - Axis cannot be normalized to axis=0 (non-unity leading dimensions)
                - Total input size doesn't match output size
                - Invalid axis range for 4D tensors
        """
        # Validate all input allocations are in same bin
        for allocation in allocations:
            assert allocation.tensor.bin == self.alloc_config.bin, (
                f"Tensor must be in the same bin {self.alloc_config.bin} {allocation.tensor.bin}"
            )

        # Check if any input already allocated (can't do concat aliasing)
        for allocation in allocations:
            if self._update_allocation_if_already_allocated(allocation):
                # If any input already allocated, can't do concat aliasing
                # All inputs must be unallocated to alias into output buffer
                return False

        # Find output tensor allocation (return False if not allocated yet)
        if output_tensor_id not in self.allocations:
            return False  # Output not allocated yet

        output_allocation = self.allocations[output_tensor_id]

        # Verify output is in memory
        if output_allocation.location != TensorLocation.MEMORY:
            return False

        # Get input shapes directly from ONNX-inferred shapes (already 4D NHWC)
        input_shapes = [allocation.tensor.shape for allocation in allocations]

        # Parse and validate concat attributes with shape-aware axis normalization
        # This validates that the axis can be normalized to axis=0
        _ = ConcatRuntimeAttrs.build(onnx_attrs, input_shapes)

        # Calculate physical sizes by flattening each input tensor
        input_sizes = [alloc.tensor.size for alloc in allocations]

        # Verify total input size equals output size
        output_size = output_allocation.tensor.size
        if sum(input_sizes) != output_size:
            raise ValueError(
                f"concat_runtime: input sizes sum to {sum(input_sizes)} bytes "
                f"but output size is {output_size} bytes. "
                f"Input shapes: {input_shapes}, output shape: {output_allocation.tensor.shape}"
            )

        # Create input blocks with offsets into output
        output_block = output_allocation.block
        assert output_block is not None, "Output block must be allocated"
        cumulative_offset = 0

        for i, allocation in enumerate(allocations):
            # Handle chained operations: if output itself has offset, accumulate
            if output_block.offset_from_parent is not None:
                accumulated_offset = output_block.offset_from_parent + cumulative_offset
            else:
                accumulated_offset = cumulative_offset

            input_block = MemoryBlock(
                start=output_block.start,  # Parent's address (will be aligned later)
                size=input_sizes[i],
                is_free=False,
                tensor_id=allocation.tensor.id,
                aliased_tensor_ids=[],
                offset_from_parent=accumulated_offset,
            )
            cumulative_offset += input_sizes[i]

            # Mark output as having aliased tensor
            if allocation.tensor.id not in output_block.aliased_tensor_ids:
                output_block.aliased_tensor_ids.append(allocation.tensor.id)

            # Store the allocation
            allocation.block = input_block
            allocation.location = TensorLocation.MEMORY
            self.allocations[allocation.tensor.id] = allocation

        logger.debug(
            "Allocated concat_runtime for %d inputs → %s",
            len(allocations), output_tensor_id
        )

        return True
