"""
This module performs multi-bin memory scheduling for a graph.
Tensors are allocated to different bins.
"""

from __future__ import annotations
from collections import defaultdict, Counter
from typing import Dict, List, Tuple

# pylint: disable=import-error,redefined-outer-name,no-name-in-module
from graph.allocation_types import (
    AllocationConfig,
    AllocationResult,
    MemoryBlock,
    TensorAllocation,
    TensorLocation,
    XrtId
)
from graph.base_memory_allocator import BaseMemoryAllocator
from graph.graph_ops import GraphOps
from graph.L3_fusion_tiling import Builder
from graph.runtime_ops import RuntimeOpType
from graph.tensor_memory_allocator import TensorMemoryAllocator
from graph.tensor_types import Operation, TensorLifetime
from graph.utilities import logger

AllocLocn = Tuple[AllocationConfig, List[MemoryBlock]]


class MultiBinGraphMemoryScheduler(GraphOps):
    """Allocates memory across multiple memory bins for heterogeneous memory architectures.

    Extends GraphOps to support multi-bin memory allocation where tensors can reside
    in different memory regions (e.g., L2, L3, DDR) with different characteristics.
    Each bin has its own allocator with independent memory management strategies.

    Features:
        - Multi-bin allocation: Tensors allocated to different memory bins based on their bin assignment
        - Per-bin allocators: Each bin has dedicated TensorMemoryAllocator with custom configuration
        - Cross-bin deallocation: Frees expired tensors across all bins at each time step
        - Noop optimization: Supports in-place aliasing for noop operations across bins
        - Backward propagation: Pre-allocated outputs propagate backward through noop chains
        - Allocation tracking: Records all allocations/deallocations per time step for analysis

    Attributes:
        allocators: Dictionary mapping XrtId (bin) to BaseMemoryAllocator instances

    Examples:
        Multi-bin memory scheduling with IFM and OFM bins:

        >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
        >>> from graph.allocation_types import AllocationStrategy, MemoryBlock, XrtId
        >>> from graph.tensor_types import Tensor, Operation
        >>> from graph.multibin_allocator import MultiBinGraphMemoryScheduler
        >>>
        >>> # Create allocation configurations for two bins
        >>> ifm_config = AllocationConfig(
        ...     strategy=AllocationStrategy.FIRST_FIT,
        ...     alignment=AllocationAlignment.DEFAULT,
        ...     bin=XrtId.IFM
        ... )
        >>> ofm_config = AllocationConfig(
        ...     strategy=AllocationStrategy.FIRST_FIT,
        ...     alignment=AllocationAlignment.DEFAULT,
        ...     bin=XrtId.OFM
        ... )
        >>>
        >>> # Create memory blocks for each bin
        >>> ifm_blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
        >>> ofm_blocks = [MemoryBlock(start=0, size=4096, is_free=True)]
        >>>
        >>> # Build scheduler with multiple bins
        >>> allocators = {
        ...    ifm_config.bin: TensorMemoryAllocator(ifm_config, ifm_blocks),
        ...    ofm_config.bin: TensorMemoryAllocator(ofm_config, ofm_blocks)
        ... }
        >>> scheduler = MultiBinGraphMemoryScheduler(allocators)
        >>>
        >>> # Verify bins are configured
        >>> XrtId.IFM in scheduler.allocators
        True
        >>> XrtId.OFM in scheduler.allocators
        True
        >>>
        >>> # Add operations and tensors
        >>> t1 = Tensor(id="input", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.IFM)
        >>> t2 = Tensor(id="output", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.OFM)
        >>> op = Operation(id="op1", type="Add", inputs=["input"], outputs=["output"])
        >>> scheduler.add_tensor(t1)
        >>> scheduler.add_tensor(t2)
        >>> scheduler.add_operation(op)
        0
        >>>
        >>> # Schedule memory across bins
        >>> results = scheduler.schedule_memory()
        >>> len(results) > 0
        True
    """

    @staticmethod
    def build(alloc_locns: List[AllocLocn], execution_order: List[Operation] | None = None) -> MultiBinGraphMemoryScheduler:
        """Build an instance of multi-bin graph scheduler given allocation locations.

        Factory method that creates a MultiBinGraphMemoryScheduler with allocators
        for each specified memory bin. Validates that bins are unique and initializes
        TensorMemoryAllocator for each bin with its configuration and memory blocks.

        Args:
            alloc_locns: List of (AllocationConfig, MemoryBlock list) tuples, one per bin
            execution_order: Optional pre-computed execution order for operations

        Returns:
            MultiBinGraphMemoryScheduler instance with configured allocators

        Raises:
            AssertionError: If duplicate bins found in alloc_locns

        Examples:
            Building scheduler with multiple bins:

            >>> from graph.allocation_types import AllocationConfig, AllocationAlignment
            >>> from graph.allocation_types import AllocationStrategy, MemoryBlock, XrtId
            >>> from graph.tensor_types import Tensor, Operation
            >>> from graph.multibin_allocator import MultiBinGraphMemoryScheduler
            >>>
            >>> # Create allocation configurations for two bins
            >>> ifm_config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.IFM
            ... )
            >>> ofm_config = AllocationConfig(
            ...     strategy=AllocationStrategy.FIRST_FIT,
            ...     alignment=AllocationAlignment.DEFAULT,
            ...     bin=XrtId.OFM
            ... )
            >>>
            >>> # Create memory blocks for each bin
            >>> ifm_blocks = [MemoryBlock(start=0, size=1024, is_free=True)]
            >>> ofm_blocks = [MemoryBlock(start=0, size=4096, is_free=True)]
            >>>
            >>> # Build scheduler with multiple bins
            >>> alloc_locations = [(ifm_config, ifm_blocks), (ofm_config, ofm_blocks)]
            >>> scheduler = MultiBinGraphMemoryScheduler.build(alloc_locations)
            >>>
            >>> # Verify bins are configured
            >>> XrtId.IFM in scheduler.allocators
            True
            >>> XrtId.OFM in scheduler.allocators
            True
            >>>
            >>> # Add operations and tensors
            >>> t1 = Tensor(id="input", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.IFM)
            >>> t2 = Tensor(id="output", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.OFM)
            >>> op = Operation(id="op1", type="Add", inputs=["input"], outputs=["output"])
            >>> scheduler.add_tensor(t1)
            >>> scheduler.add_tensor(t2)
            >>> scheduler.add_operation(op)
            0
            >>>
            >>> # Schedule memory across bins
            >>> results = scheduler.schedule_memory()
            >>> len(results) > 0
            True
        """
        bins = [alloc_config.bin for alloc_config, _ in alloc_locns]
        assert len(bins) == len(set(bins)), "Duplicate bins in alloc_loc"

        allocators = {
            alloc_config.bin: TensorMemoryAllocator(alloc_config, memory_blocks)
            for alloc_config, memory_blocks in alloc_locns
        }

        return MultiBinGraphMemoryScheduler(allocators, execution_order)

    def __init__(self, allocators: Dict[XrtId, BaseMemoryAllocator], execution_order: List[Operation] | None = None):
        super().__init__(execution_order=execution_order)
        self.allocators = allocators

    def schedule_memory(
        self,
        enable_spilling: bool = False,
        enable_noop_optim: bool = False,
        enable_runtime_optim: bool = False
    ) -> Dict[int, List[Tuple[AllocationResult, TensorAllocation]]]:
        """Schedule memory allocation for the entire graph across all bins.

        Performs lifetime-based memory allocation with the following phases:
        1. Compute tensor lifetimes and usage counts
        2. Set size=0 for single-use tensors (optimized away)
        3. Backward propagation: Pre-allocated outputs propagate through noop chains
        4. Forward pass: Allocate tensors in execution order with:
           - Cross-bin deallocation of expired tensors
           - Noop aliasing for zero-copy transformations
           - Runtime operation aliasing for gather/slice/concat/split
           - In-place allocation for pointwise operations
           - Regular allocation for remaining tensors

        Args:
            enable_spilling: If True, spill tensors to secondary storage (not supported)
            enable_noop_optim: If True, enable noop operation memory aliasing
            enable_runtime_optim: If True, enable runtime operation memory aliasing

        Returns:
            Dictionary mapping time step to list of (AllocationResult, TensorAllocation) tuples

        Raises:
            AssertionError: If enable_spilling=True (spilling not supported in multi-bin)
            RuntimeError: If allocation fails (must not spill)
        """
        assert (not enable_spilling), "Spilling not supported in multi-bin scheduler"

        lifetimes = self._compute_tensor_lifetimes()
        tensor_usage = self._compute_tensor_usage()
        allocation_results: Dict[
            int, List[Tuple[AllocationResult, TensorAllocation]]
        ] = defaultdict(list)

        # set tensor sizes to zero for tensors that are streamed or reside in control block
        for tensor_id, tensor in self.tensors.items():
            if (
                (tensor_usage[tensor_id] == 1 or tensor.is_constant)
                and not tensor.is_model_io
            ):
                object.__setattr__(tensor, "size", 0)

        for tensor in self.tensors.values():
            logger.debug(
                "%s Tensor %s: shape=%s, size=%s, bin=%s, usage=%s",
                ("Skipping" if tensor.size == 0 else "Allocating"),
                tensor.id, tensor.shape, tensor.size, tensor.bin, tensor_usage[tensor.id]
            )

        # BACKWARD PASS: Propagate pre-allocated outputs backward through noop chains
        # This ensures that when a noop operation produces a pre-allocated output,
        # the noop's input gets allocated to the same address as the pre-allocated output
        if enable_noop_optim:
            self._propagate_preallocated_backward()

        # tensors with lifetimes
        to_allocate = [
            TensorAllocation(
                self.tensors[tensor_id],
                TensorLifetime(start, end),
                TensorLocation.UNKNOWN,
            )
            for tensor_id, (start, end) in lifetimes.items()
        ]
        execution_order = self.get_execution_order()
        to_allocate.sort(key=lambda t: (t.range.start, t.range.end, t.tensor.id in execution_order[t.range.start].outputs))

        # Track tensors pre-allocated by concat/split passes
        pre_allocated_tensors = set()

        if enable_runtime_optim:
            # Process concat operations: allocate outputs, partition into inputs
            self._run_concat_pass(execution_order, to_allocate, pre_allocated_tensors, allocation_results)
            # Process split operations: allocate inputs, partition into outputs
            self._run_split_pass(execution_order, to_allocate, pre_allocated_tensors, allocation_results)

        # Process remaining tensors in execution order
        for allocation in to_allocate:
            start, tensor_id, bin_id = (
                allocation.range.start,
                allocation.tensor.id,
                allocation.tensor.bin,
            )
            logger.debug(
                "\n ======= Layer %s =======\nBin %s Current allocations %s",
                start, bin_id, self.allocators[bin_id].allocations.keys()
            )

            for allocator in self.allocators.values():
                for dealloc in allocator.free_expired_allocations(start):
                    allocation_results[start].append(
                        (AllocationResult.DEALLOCATED, dealloc.copy())
                    )

            allocator = self.allocators[bin_id]
            current_op = execution_order[start]

            # Skip pre_allocated_tensors (already allocated in reverse topological pass)
            if tensor_id in pre_allocated_tensors:
                logger.debug(
                    "Skipping pre-allocated tensors: %s", tensor_id
                )
                continue

            # Try noop aliasing first for noop operations: Only apply to the output of the noop operation (not inputs)
            if (enable_noop_optim and
                Builder.is_op_noop(current_op.type) and
                tensor_id in current_op.outputs and
                    self.tensors[tensor_id].size != 0):
                # find the input with non-zero size
                non_zero_tensors = [inp for inp in current_op.inputs if self.tensors[inp].size != 0]

                # Assumption: noop operations have single input
                assert len(non_zero_tensors) == 1, "Noop operators must have only one input, by assumption"
                input_tensor_id = non_zero_tensors[0]

                if allocator.allocate_noop_in_place(allocation, input_tensor_id):
                    logger.debug(
                        "allocated %s (noop alias to %s) with size %s, stats %s",
                        tensor_id, input_tensor_id, allocation.tensor.size, allocator.get_memory_usage()
                    )
                    allocation_results[start].append(
                        (AllocationResult.ALLOCATED_IN_PLACE, allocation.copy())
                    )
                else:
                    raise RuntimeError("Noop must be allocated in place")

            # Try runtime operation aliasing
            elif (enable_runtime_optim and
                  RuntimeOpType.is_aliasable(current_op.type) and
                  tensor_id in current_op.outputs and
                  self.tensors[tensor_id].size != 0):

                # Find the input with non-zero size
                non_zero_tensors = [inp for inp in current_op.inputs if self.tensors[inp].size != 0]

                # Route to appropriate runtime operation handler
                runtime_op_type = RuntimeOpType.classify(current_op.type)
                allocated = False

                if runtime_op_type == RuntimeOpType.GATHER:
                    assert len(non_zero_tensors) == 1, f"gather_runtime operator {current_op.id} must have only one input, by assumption"
                    input_tensor_id = non_zero_tensors[0]
                    allocated = allocator.allocate_gather_runtime(allocation, input_tensor_id, current_op.attributes)

                elif runtime_op_type == RuntimeOpType.SLICE:
                    assert len(non_zero_tensors) == 1, f"slice_runtime operator {current_op.id} must have only one input, by assumption"
                    input_tensor_id = non_zero_tensors[0]
                    allocated = allocator.allocate_slice_runtime(allocation, input_tensor_id, current_op.attributes)

                elif runtime_op_type == RuntimeOpType.SPLIT:
                    # Split handled in forward split pass (should not reach here)
                    assert allocation.tensor.id in allocator.allocations, f"Split output {tensor_id} should be pre-allocated"
                    continue

                elif runtime_op_type == RuntimeOpType.CONCAT:
                    # Concat handled in reverse topological pass (should not reach here)
                    assert allocation.tensor.id in allocator.allocations, f"Concat input {tensor_id} should be pre-allocated"
                    continue

                # Handle allocation result (common for GATHER and SLICE)
                if allocated:
                    logger.debug(
                        "allocated %s (runtime op %s, alias to %s) with size %s, stats %s",
                        tensor_id, current_op.type, non_zero_tensors[0], allocation.tensor.size, allocator.get_memory_usage()
                    )
                    allocation_results[start].append(
                        (AllocationResult.ALLOCATED_IN_PLACE, allocation.copy())
                    )
                else:
                    raise RuntimeError(f"Runtime op {current_op.type} must be allocated in place: {tensor_id}")

            # Try in-place allocation for pointwise operations
            elif allocator.allocate_in_place(allocation, current_op):
                logger.debug(
                    "allocated %s (in place) with size %s, stats %s",
                    tensor_id, allocation.tensor.size, allocator.get_memory_usage()
                )
                allocation_results[start].append(
                    (AllocationResult.ALLOCATED_IN_PLACE, allocation.copy())
                )
            elif allocator.allocate(allocation):
                logger.debug(
                    "allocated %s with size %s, stats %s",
                    tensor_id, allocation.tensor.size, allocator.get_memory_usage()
                )
                allocation_results[start].append(
                    (AllocationResult.ALLOCATED, allocation.copy())
                )
            else:
                raise RuntimeError("Must not spill")

        for alloc_results in allocation_results.values():
            for results in alloc_results:
                self.allocators[results[1].tensor.bin].convert_alloc_to_aligned(results)
        return allocation_results

    def get_allocation_summary(
        self,
        allocation_results: Dict[int, List[Tuple[AllocationResult, TensorAllocation]]],
    ) -> Dict[str, int | Dict[str, int]]:
        """Produces a summary of allocations and spillings across all bins.

        Aggregates allocation results to provide statistics on:
        - Total layers (time steps) processed
        - Tensors allocated per bin
        - Tensors allocated in-place per bin
        - Tensors spilled per bin

        Args:
            allocation_results: Dictionary from schedule_memory() mapping time step
                               to list of (AllocationResult, TensorAllocation) tuples

        Returns:
            Dictionary with keys:
            - 'total_layers_allocated': Number of time steps with allocations
            - 'tensors_allocated': Dict mapping bin name to count of regular allocations
            - 'tensors_allocated_in_place': Dict mapping bin name to in-place allocation count
            - 'tensors_spilled': Dict mapping bin name to spilled tensor count
        """
        results = allocation_results.values()
        counts = Counter([(al.tensor.bin, ar) for res in results for ar, al in res])

        def get_counts(allocation_result: AllocationResult) -> Dict[str, int]:
            count_res = {}
            for bin_id in self.allocators:
                count_res[bin_id.name] = counts[bin_id, allocation_result]
            return count_res

        return {
            "total_layers_allocated": len(allocation_results.keys()),
            "tensors_allocated": get_counts(AllocationResult.ALLOCATED),
            "tensors_allocated_in_place": get_counts(
                AllocationResult.ALLOCATED_IN_PLACE
            ),
            "tensors_spilled": get_counts(AllocationResult.SPILLED),
        }

    def _run_concat_pass(
        self,
        execution_order: List[Operation],
        to_allocate: List[TensorAllocation],
        pre_allocated_tensors: set,
        allocation_results: Dict[int, List[Tuple[AllocationResult, TensorAllocation]]]
    ) -> None:
        """Process concat operations in reverse topological order.

        Allocates concat outputs first, then partitions them into contiguous input views.
        This ensures concat inputs are laid out contiguously in memory.

        Args:
            execution_order: Topologically sorted operations (will be reversed)
            to_allocate: List of all tensor allocations
            pre_allocated_tensors: Set to track pre-allocated tensors (modified in-place)
            allocation_results: Dictionary to record allocation results (modified in-place)
        """
        for concat_op in reversed(execution_order):
            if not (RuntimeOpType.is_aliasable(concat_op.type) and
                    RuntimeOpType.classify(concat_op.type) == RuntimeOpType.CONCAT):
                continue

            # Skip if inputs already allocated (e.g., by previous operations)
            if any(inp in pre_allocated_tensors for inp in concat_op.inputs):
                raise RuntimeError(
                    f"Concat {concat_op.id} - inputs already pre-allocated"
                )

            # Get input tensor (should have exactly one non-zero input)
            non_zero_outputs = [out for out in concat_op.outputs if self.tensors[out].size > 0]
            assert len(non_zero_outputs) == 1, "Concat operation must have exactly one non-zero output"

            output_tensor_id = non_zero_outputs[0]
            output_bin = self.tensors[output_tensor_id].bin
            allocator = self.allocators[output_bin]

            # Get non-zero inputs
            non_zero_inputs = [inp for inp in concat_op.inputs if self.tensors[inp].size > 0]
            assert non_zero_inputs, "Concat operation must have at least one non-zero input"

            # Find the output allocation
            output_alloc = next((alloc for alloc in to_allocate if alloc.tensor.id == output_tensor_id), None)
            assert output_alloc is not None, "Concat output allocation not found"

            # Find input allocations from to_allocate list
            input_allocs = [alloc for alloc in to_allocate if alloc.tensor.id in non_zero_inputs]
            assert len(input_allocs) == len(non_zero_inputs), "Not all concat input allocations found"

            # Allocate output first (reverse of split which allocates input first)
            if output_tensor_id not in allocator.allocations:
                allocated = allocator.allocate(output_alloc)
                assert allocated, f"Failed to allocate concat output {output_tensor_id}"

            # Allocate concat runtime (partitions output into input views)
            allocated = allocator.allocate_concat_runtime(input_allocs, output_tensor_id, concat_op.attributes)
            assert allocated, f"Failed to partition concat output {output_tensor_id} into inputs"

            # Mark output and inputs as pre-allocated (skip in forward pass)
            pre_allocated_tensors.add(output_tensor_id)
            for inp in non_zero_inputs:
                pre_allocated_tensors.add(inp)

            # Add output to allocation results
            allocation_results[output_alloc.range.start].append(
                (AllocationResult.ALLOCATED, output_alloc.copy())
            )

            # Add inputs to allocation results
            for inp in non_zero_inputs:
                inp_alloc = allocator.allocations[inp]
                allocation_results[inp_alloc.range.start].append(
                    (AllocationResult.ALLOCATED_IN_PLACE, inp_alloc.copy())
                )

            logger.debug(
                "Allocated concat_runtime: %d inputs → %s (concat pass)",
                len(non_zero_inputs), output_tensor_id
            )

    def _handle_split_concat_chaining(
        self,
        split_op: Operation,
        to_allocate: List[TensorAllocation],
        pre_allocated_tensors: set,
        allocation_results: Dict[int, List[Tuple[AllocationResult, TensorAllocation]]]
    ) -> None:
        """Handle split→concat chaining where split outputs are pre-allocated by concat.

        When concat runs first (reverse pass) and allocates split outputs, we need to
        alias the split input to point to the same contiguous region.

        This method should only be called when ALL split outputs are pre-allocated.

        Args:
            split_op: The split operation to process
            to_allocate: List of all tensor allocations
            pre_allocated_tensors: Set of pre-allocated tensor IDs (modified in-place)
            allocation_results: Dictionary to record allocation results (modified in-place)

        Raises:
            RuntimeError: If chaining validation fails (size mismatch, different parents, etc.)
        """

        # Get split input
        non_zero_inputs = [inp for inp in split_op.inputs if self.tensors[inp].size > 0]
        assert len(non_zero_inputs) == 1, "Split operation must have exactly one non-zero input"

        input_tensor_id = non_zero_inputs[0]
        input_tensor = self.tensors[input_tensor_id]
        input_bin = input_tensor.bin
        allocator = self.allocators[input_bin]

        # Get first output's allocation to determine the parent block
        first_output_id = split_op.outputs[0]
        first_output_alloc = allocator.allocations[first_output_id]
        first_output_block = first_output_alloc.block

        # Validate: all outputs should share the same parent (concat output)
        for output_id in split_op.outputs:
            output_alloc = allocator.allocations[output_id]
            output_block = output_alloc.block

            if output_block.start != first_output_block.start:
                raise RuntimeError(
                    f"Split {split_op.id}: Pre-allocated outputs don't share same parent. "
                    f"Split→Concat chaining requires outputs from same concat operation."
                )

        # Validate: input size should match total size of all outputs
        total_output_size = sum(
            allocator.allocations[out].tensor.size
            for out in split_op.outputs
        )

        if input_tensor.size != total_output_size:
            raise RuntimeError(
                f"Split {split_op.id}: Input size ({input_tensor.size}) "
                f"doesn't match total output size ({total_output_size}). "
                f"Split→Concat chaining requires size consistency."
            )

        # Find input allocation
        input_alloc = next((alloc for alloc in to_allocate if alloc.tensor.id == input_tensor_id), None)
        assert input_alloc is not None, f"Split input allocation not found for {input_tensor_id}"

        # Alias split input to the same parent block as outputs
        # The input should point to the beginning of the contiguous output region
        input_block = MemoryBlock(
            start=first_output_block.start,  # Same parent as concat-allocated outputs
            size=input_tensor.size,
            is_free=False,
            tensor_id=input_tensor_id,
            aliased_tensor_ids=[],
            offset_from_parent=first_output_block.offset_from_parent,  # Inherit offset from concat
        )

        # Mark all split outputs as aliased to the input
        for output_id in split_op.outputs:
            if output_id not in input_block.aliased_tensor_ids:
                input_block.aliased_tensor_ids.append(output_id)

        # Store the input allocation
        input_alloc.block = input_block
        input_alloc.location = TensorLocation.MEMORY
        allocator.allocations[input_tensor_id] = input_alloc

        # Mark input as pre-allocated
        pre_allocated_tensors.add(input_tensor_id)

        # Record allocation result
        allocation_results[input_alloc.range.start].append(
            (AllocationResult.ALLOCATED_IN_PLACE, input_alloc.copy())
        )

        logger.debug(
            "Handled split→concat chaining: %s -> %d outputs (aliased to concat output)",
            input_tensor_id, len(split_op.outputs)
        )

    def _run_split_pass(
        self,
        execution_order: List[Operation],
        to_allocate: List[TensorAllocation],
        pre_allocated_tensors: set,
        allocation_results: Dict[int, List[Tuple[AllocationResult, TensorAllocation]]]
    ) -> None:
        """Process split operations in forward topological order.

        Allocates split inputs first, then partitions them into contiguous output views.
        This ensures split outputs are laid out contiguously in memory.

        Args:
            execution_order: Topologically sorted operations
            to_allocate: List of all tensor allocations
            pre_allocated_tensors: Set to track pre-allocated tensors (modified in-place)
            allocation_results: Dictionary to record allocation results (modified in-place)
        """
        for split_op in execution_order:
            if not (RuntimeOpType.is_aliasable(split_op.type) and
                    RuntimeOpType.classify(split_op.type) == RuntimeOpType.SPLIT):
                continue

            # Check if any outputs are pre-allocated (split→concat chaining)
            num_preallocated = sum(1 for out in split_op.outputs if out in pre_allocated_tensors)

            # All outputs pre-allocated - handle split→concat chaining
            if num_preallocated == len(split_op.outputs):
                self._handle_split_concat_chaining(split_op, to_allocate, pre_allocated_tensors, allocation_results)
                continue  # Chaining handled, skip normal split allocation

            # Partial pre-allocation not supported
            if num_preallocated > 0:
                raise RuntimeError(
                    f"Split {split_op.id}: Partial pre-allocation detected. "
                    f"{num_preallocated}/{len(split_op.outputs)} outputs pre-allocated. "
                    f"Split→Concat chaining requires all outputs to be pre-allocated."
                )

            # Get input tensor (should have exactly one non-zero input)
            non_zero_inputs = [inp for inp in split_op.inputs if self.tensors[inp].size > 0]
            assert len(non_zero_inputs) == 1, "Split operation must have exactly one non-zero input"

            input_tensor_id = non_zero_inputs[0]
            input_bin = self.tensors[input_tensor_id].bin
            allocator = self.allocators[input_bin]

            # Ensure split input is allocated
            if input_tensor_id not in allocator.allocations:
                # Find the allocation for this tensor
                input_alloc = next((alloc for alloc in to_allocate if alloc.tensor.id == input_tensor_id), None)
                assert input_alloc is not None, "Split input allocation not found"

                # Allocate input using regular allocation
                if not allocator.allocate(input_alloc):
                    raise RuntimeError(f"Failed to allocate split input {input_tensor_id}")

                allocation_results[input_alloc.range.start].append((AllocationResult.ALLOCATED, input_alloc.copy()))

            # Collect all output allocations for this split operation
            output_allocations = [alloc for alloc in to_allocate if alloc.tensor.id in split_op.outputs]
            assert output_allocations, "No output allocations found for split operation"

            # Allocate split runtime (partitions input into outputs)
            allocated = allocator.allocate_split_runtime(output_allocations, input_tensor_id, split_op.attributes)
            assert allocated, f"Failed to partition split input {input_tensor_id}"

            # Mark input and outputs as pre-allocated (skip in regular forward pass)
            pre_allocated_tensors.add(input_tensor_id)
            for out_alloc in output_allocations:
                pre_allocated_tensors.add(out_alloc.tensor.id)

            # Add to allocation results
            for out_alloc in output_allocations:
                allocation_results[out_alloc.range.start].append(
                    (AllocationResult.ALLOCATED_IN_PLACE, out_alloc.copy())
                )

            logger.debug(
                "Allocated split_runtime: %s -> %d outputs (split pass)",
                input_tensor_id, len(output_allocations)
            )

    def _propagate_preallocated_backward(self) -> None:
        """
        Propagate pre-allocated tensor addresses backward through noop operation chains.

        This is the BACKWARD PASS for noop aliasing. When a noop operation produces a
        pre-allocated output, we need to ensure the noop's input gets allocated at the
        same address as the pre-allocated output.

        Example:
            intermediate -> Identity_noop -> output(pre-allocated at addr 0)

        After backward propagation:
            intermediate is allocated at addr 0 (aliased to output)

        This works recursively through noop chains:
            tensor1 -> noop1 -> tensor2 -> noop2 -> output(pre-allocated)

        After backward propagation:
            tensor1, tensor2, and output all share the same pre-allocated address.
        """
        execution_order = self.get_execution_order()

        # Traverse operations in REVERSE order (from outputs to inputs)
        for op_idx in range(len(execution_order) - 1, -1, -1):
            op = execution_order[op_idx]

            # Only process noop operations
            if not Builder.is_op_noop(op.type):
                continue

            # Noop operations should have exactly one non-zero input and one output
            assert len(op.outputs) == 1, f"Noop {op.id} must have exactly one output"

            output_tensor_id = op.outputs[0]
            output_tensor = self.tensors[output_tensor_id]

            # Skip if output has size=0 (will be optimized away)
            if output_tensor.size == 0:
                continue

            # Find the non-zero input
            non_zero_inputs = [inp for inp in op.inputs if self.tensors[inp].size != 0]
            if len(non_zero_inputs) != 1:
                continue  # Skip if no valid input

            input_tensor_id = non_zero_inputs[0]
            input_tensor = self.tensors[input_tensor_id]

            # Check if output is already pre-allocated
            output_bin = output_tensor.bin
            allocator = self.allocators[output_bin]

            if output_tensor_id in allocator.allocations:
                # Output is pre-allocated! Alias input to output's block
                output_alloc = allocator.allocations[output_tensor_id]

                # Input must be in the same bin for aliasing
                if input_tensor.bin != output_bin:
                    logger.debug(
                        "Cannot backward-alias %s to %s: different bins (%s vs %s)",
                        input_tensor_id, output_tensor_id, input_tensor.bin, output_bin
                    )
                    continue

                # Check if input is already allocated
                if input_tensor_id in allocator.allocations:
                    # Input already allocated - verify it matches output
                    input_alloc = allocator.allocations[input_tensor_id]
                    if input_alloc.block.start != output_alloc.block.start:
                        logger.warning(
                            "Noop input %s already allocated at %s, but output %s is at %s",
                            input_tensor_id, input_alloc.block.start,
                            output_tensor_id, output_alloc.block.start
                        )
                    continue

                # Allocate input to alias output's block
                input_allocation = TensorAllocation(
                    tensor=input_tensor,
                    range=TensorLifetime(start=-1, end=-1),  # Will be updated during forward pass
                    location=TensorLocation.MEMORY,
                    block=output_alloc.block,
                    is_deallocatable=output_alloc.is_deallocatable  # Inherit deallocatability
                )

                # Add input to output's block aliased_tensor_ids
                if input_tensor_id not in output_alloc.block.aliased_tensor_ids:
                    output_alloc.block.aliased_tensor_ids.append(input_tensor_id)

                # Store the allocation
                allocator.allocations[input_tensor_id] = input_allocation

                logger.debug(
                    "Backward-aliased %s to pre-allocated %s at address %s (noop: %s)",
                    input_tensor_id, output_tensor_id, output_alloc.block.start, op.id
                )
