"""
Common buffer allocation strategies for AIE4 tilers
"""

from typing import Optional, List, Tuple
from graph.allocation_types import is_non_overlapping
from graph.utilities import logger
from tiler.base_tiler import HW_CONFIG, align_up
from tiler.buffer_types import BufferSpec, BufferAllocation, BinItem, BufferPair
from tiler.swbank_alloc import BinPacker as BinPackerV1
from tiler.swbank_alloc_v2 import BinPacker as BinPackerV2
from utils.utils_common import overlay_3x4_core_stack_addr, floordiv


class BaseL1BufferAllocator:
    """
    Generic L1 buffer allocator with multiple strategies.
    """

    def __init__(
        self,
        core_stack_addr: int | None = None,
        core_bank_size: int = HW_CONFIG.CORE_BANK_SIZE,
        memory_align: int = HW_CONFIG.MEMORY_ALIGNMENT,
        use_software_banks: bool = True,
    ):
        self.core_stack_addr = core_stack_addr or self._get_default_stack_addr()
        self.core_bank_size = core_bank_size
        self.sw_bank_size = (
            HW_CONFIG.CORE_BANK_MEM_SIZE_SOFTWARE if use_software_banks else core_bank_size
        )
        self.memory_align = memory_align
        self.use_software_banks = use_software_banks

    @staticmethod
    def _get_default_stack_addr() -> int:
        """Import dynamically to avoid circular dependency"""
        return overlay_3x4_core_stack_addr()

    def _validate_allocations(self, allocations: List[BufferAllocation]) -> bool:
        """Validate that buffer allocations don't overlap and fit within available memory."""
        if not allocations:
            return True

        # Collect all memory regions (ping and pong addresses)
        memory_regions: List[Tuple[int, int]] = []

        for alloc in allocations:
            memory_regions.append((alloc.addr, alloc.addr + alloc.size))

        # Sort regions by start address
        memory_regions.sort()

        # Check for overlaps using two iterators
        if not is_non_overlapping(memory_regions):
            return False

        # Check that all regions are within valid memory bounds
        if memory_regions[0][0] < 0:
            return False

        if memory_regions[-1][1] > self.core_stack_addr:
            return False

        return True

    def _check_buffers(self, buffers: List[BufferSpec]) -> bool:
        """Check if buffers are valid for allocation"""
        if not buffers:
            return False

        for buf in buffers:
            if buf.size > self.core_stack_addr:
                logger.error(
                    "Buffer %s size %d exceeds core stack size %d", buf.name, buf.size, self.core_stack_addr
                )
                return False

        return len(set(buf.name for buf in buffers)) == len(buffers)

    def _try_sequential_allocation(
        self, buffers: List[BufferSpec], double_buffer: bool
    ) -> Optional[List[BufferAllocation]]:
        """Try sequential allocation with optional double buffering"""
        if not self._check_buffers(buffers):
            return None

        allocations = []
        current_addr = 0

        # Place ping buffers
        for buf in buffers:
            if not buf.is_ping:
                continue
            aligned_addr = align_up(current_addr, buf.alignment)
            allocations.append(BufferAllocation(name=buf.name, size=buf.size, addr=aligned_addr))
            current_addr = aligned_addr + buf.size

        # Place pong buffers if double buffering
        if double_buffer:
            for buf in buffers:
                if not buf.is_pong:
                    continue
                aligned_addr = align_up(current_addr, buf.alignment)
                allocations.append(BufferAllocation(name=buf.name, size=buf.size, addr=aligned_addr))
                current_addr = aligned_addr + buf.size

        return (
            allocations
            if current_addr <= self.core_stack_addr
            and self._validate_allocations(allocations)
            else None
        )

    def _try_cpsat_allocation(
        self, buffers: List[BufferSpec], exclusivity_pairs: Optional[List[BufferPair]] = None
    ) -> Optional[List[BufferAllocation]]:
        """Try constrained satisfaction allocation with optional double buffering"""
        if not self._check_buffers(buffers):
            return None

        # prepare items for bin packing
        items = []
        buffer_set = set()
        for buf in buffers:
            items.append(
                BinItem(
                    name=buf.name,
                    size=buf.size,
                    alignment=buf.alignment,
                    must_place=buf.is_ping,
                    priority=buf.priority,
                )
            )
            buffer_set.add(buf.name)

        # prepare exclusivity constraints
        exclusivity = []
        if exclusivity_pairs:
            for buf1, buf2 in exclusivity_pairs:
                exclusivity.append((buf1.name, buf2.name))
                if not (buf1.name in buffer_set and buf2.name in buffer_set):
                    logger.error(
                        "Exclusivity pair (%s, %s) contains unknown buffer", buf1.name, buf2.name
                    )
                    return None

        # Solve the bin packing problem
        packer = BinPackerV1(
            bin_capacity=self.sw_bank_size,
            num_bins=floordiv(self.core_stack_addr, self.sw_bank_size),
            items=items,
            exclusivity_pairs=exclusivity,
            time_limit_seconds=4.0
        )
        result = packer.solve()
        if result:
            result.print_summary()
        else:
            return None

        # Extract allocations
        allocations = []
        for bin_idx, placements in result.bins:
            for item, offset in placements:
                addr = bin_idx * self.sw_bank_size + offset
                allocations.append(BufferAllocation(name=item.name, size=item.size, addr=addr))

        # Final validation
        return allocations if self._validate_allocations(allocations) else None

    def _try_cpsat_allocation_v2(
        self,
        buffers: List[BufferSpec],
        exclusivity_pairs: Optional[List[BufferPair]] = None
    ) -> Optional[List[BufferAllocation]]:
        """
        Try constrained satisfaction allocation with spanning support and bank conflict minimization.
        - Allows items to span consecutive banks if they don't fit in one
        - Minimizes bank conflicts for better AIE hardware performance
        """
        if not self._check_buffers(buffers):
            return None

        # Prepare items for bin packing
        items = []
        buffer_set = set()
        for buf in buffers:
            items.append(
                BinItem(
                    name=buf.name,
                    size=buf.size,
                    alignment=buf.alignment,
                    must_place=buf.is_ping,
                    priority=buf.priority,
                )
            )
            buffer_set.add(buf.name)

        # Prepare exclusivity constraints
        exclusivity = []
        if exclusivity_pairs:
            for buf1, buf2 in exclusivity_pairs:
                exclusivity.append((buf1.name, buf2.name))
                if not (buf1.name in buffer_set and buf2.name in buffer_set):
                    logger.error(
                        "Exclusivity pair (%s, %s) contains unknown buffer", buf1.name, buf2.name
                    )
                    return None

        # Solve the bin packing problem with v2 (supports spanning and conflict minimization)
        packer = BinPackerV2(
            bin_capacity=self.sw_bank_size,
            num_bins=floordiv(self.core_stack_addr, self.sw_bank_size),
            items=items,
            exclusivity_pairs=exclusivity,
            time_limit_seconds=4.0,
            minimize_bank_conflicts=True,
        )
        result = packer.solve()

        if result:
            result.print_summary()
        else:
            return None

        # Extract allocations from placements
        # In v2, result has 'placements' instead of 'bins'
        allocations = []
        for placement in result.placements:
            # Use absolute_offset to get the actual address
            addr = placement.absolute_offset
            allocations.append(
                BufferAllocation(
                    name=placement.item.name,
                    size=placement.item.size,
                    addr=addr
                )
            )

        # Final validation
        return allocations if self._validate_allocations(allocations) else None
