"""
Multi-bin packing with constraints using OR-Tools CP-SAT

Features:
- Multiple bins (software banks) with same capacity
- Items can span CONSECUTIVE bins if they don't fit in one
- Mandatory items (must be placed)
- Optional items (maximize placement)
- Alignment constraints (power of 2)
- Exclusivity constraints (items that cannot overlap in bins)
- MINIMIZE BANK CONFLICTS by spreading items across banks
"""

from dataclasses import dataclass
from typing import List, Tuple, Optional
import os
from ortools.sat.python import cp_model
from tiler.buffer_types import BinItem
from tiler.base_tiler import get_os_core_count
from graph.utilities import logger


@dataclass
class ItemPlacement:
    """Placement information for an item across banks"""
    item: BinItem
    start_bank: int
    end_bank: int  # Inclusive
    offset_in_first_bank: int
    bank_capacity: int

    @property
    def num_banks_used(self) -> int:
        """Number of banks used by this item"""
        return self.end_bank - self.start_bank + 1

    @property
    def absolute_offset(self) -> int:
        """Offset from the start of the entire bank array"""
        return self.start_bank * self.bank_capacity + self.offset_in_first_bank


@dataclass
class PackingResult:
    """Result of bin packing"""
    placements: List[ItemPlacement]
    bank_capacity: int
    num_banks: int
    num_banks_used: int
    num_items_placed: int
    total_waste: int
    bank_conflicts: int  # Number of banks used by multiple items

    def print_summary(self):
        """Print packing summary with placement addresses"""
        logger.debug("\n%s", "="*70)
        logger.debug("Banks used: %s", self.num_banks_used)
        logger.debug("Items placed: %s", self.num_items_placed)
        logger.debug("Bank conflicts: %s (lower is better)",
                     self.bank_conflicts)
        logger.debug("Total waste: %s", self.total_waste)
        logger.debug("Bank capacity: %s", self.bank_capacity)
        logger.debug("%s\n", "="*70)

        # Log all placements with absolute addresses
        logger.debug("Placements (absolute addresses):")
        for p in sorted(self.placements, key=lambda x: x.absolute_offset):
            req = "REQ" if p.item.must_place else "OPT"
            logger.debug("  addr=%6d %-20s size=%6d bank=%d-%d [%s]",
                         p.absolute_offset, p.item.name, p.item.size,
                         p.start_bank, p.end_bank, req)
        logger.debug("")

        # Show per-bank breakdown
        bank_items = {b: [] for b in range(self.num_banks)}
        for p in self.placements:
            for b in range(p.start_bank, p.end_bank + 1):
                bank_items[b].append(p)

        for b in sorted(bank_items.keys()):
            if not bank_items[b]:
                continue

            conflict_marker = " CONFLICT" if len(bank_items[b]) > 1 else ""
            logger.debug("Bank %s (%s items)%s:", b,
                         len(bank_items[b]), conflict_marker)

            for p in bank_items[b]:
                req = " REQUIRED" if p.item.must_place else " OPTIONAL"
                if p.start_bank == p.end_bank:
                    end_offset = p.offset_in_first_bank + p.item.size
                    logger.debug("  [%5d-%5d] %-20s size=%6d align=%4d [%s]",
                                 p.offset_in_first_bank, end_offset-1, p.item.name,
                                 p.item.size, p.item.alignment, req)
                else:
                    span_info = f"spans banks {p.start_bank}-{p.end_bank}"
                    if b == p.start_bank:
                        logger.debug("  [%5d-...  ] %-20s size=%6d align=%4d [%s] %s",
                                     p.offset_in_first_bank, p.item.name, p.item.size,
                                     p.item.alignment, req, span_info)
                    elif b == p.end_bank:
                        logger.debug("  [    0-...  ] %-20s (continuation) %s",
                                     p.item.name, span_info)
                    else:
                        logger.debug("  [    0-...  ] %-20s (middle) %s",
                                     p.item.name, span_info)
            logger.debug("")


class BinPacker:
    """Multi-bin packing solver - minimizes bank conflicts"""

    def __init__(
        self,
        bin_capacity: int,
        num_bins: int,
        items: List[BinItem],
        exclusivity_pairs: List[Tuple[str, str]] = None,
        time_limit_seconds: float = 10.0,
        minimize_bank_conflicts: bool = True
    ):
        """
        Args:
            bin_capacity: Capacity of each bin (software bank)
            num_bins: Maximum number of bins available
            items: List of items to pack
            exclusivity_pairs: Pairs of item names that cannot overlap in ANY bin
            time_limit_seconds: Solver time limit
            minimize_bank_conflicts: If True, optimize to reduce bank conflicts
        """
        self.bin_capacity = bin_capacity
        self.num_bins = num_bins
        self.items = items
        self.exclusivity_pairs = exclusivity_pairs or []
        self.time_limit = time_limit_seconds
        self.minimize_bank_conflicts = minimize_bank_conflicts

        # Build item name to index mapping
        self.item_index = {item.name: i for i, item in enumerate(items)}

        logger.debug("Bank capacity: %s, Num banks: %s, Total capacity: %s",
                     bin_capacity, num_bins, bin_capacity * num_bins)
        logger.debug("Total items: %s", len(items))
        logger.debug("Exclusivity pairs: %s", self.exclusivity_pairs)
        logger.debug("Minimize bank conflicts: %s", minimize_bank_conflicts)
        for item in items:
            banks_needed = (item.size + bin_capacity - 1) // bin_capacity
            logger.debug("  Item: %s, Size: %s, Align: %s, Priority: %s, Must place: %s, Banks needed: %s",
                         item.name, item.size, item.alignment, item.priority, item.must_place, banks_needed)

    def solve(self) -> Optional[PackingResult]:
        """Solve the bin packing problem"""
        n = len(self.items)

        # Handle empty items case
        if n == 0:
            return PackingResult(
                placements=[],
                bank_capacity=self.bin_capacity,
                num_banks=self.num_bins,
                num_banks_used=0,
                num_items_placed=0,
                total_waste=0,
                bank_conflicts=0
            )

        model = cp_model.CpModel()

        # Total capacity across all bins
        total_capacity = self.bin_capacity * self.num_bins

        # Variables
        # placed[i] = 1 if item i is placed
        placed = [model.new_bool_var(f'placed_{i}') for i in range(n)]

        # Absolute position in the concatenated bank space (0 to total_capacity - 1)
        abs_start = []
        for i in range(n):
            item = self.items[i]
            align = item.alignment
            max_start = total_capacity - item.size

            # Create domain with only aligned positions
            if align > 1:
                possible_starts = list(range(0, max(max_start + 1, 1), align))
                if max_start < 0:
                    if item.must_place:
                        logger.error("ERROR: Item %s too large (size=%s, total_capacity=%s)", {item.name}, item.size, total_capacity)
                        return None
                    possible_starts = [0]
                start_var = model.new_int_var_from_domain(
                    cp_model.Domain.from_values(possible_starts),
                    f'abs_start_{i}'
                )
            else:
                if max_start < 0:
                    if item.must_place:
                        logger.error("ERROR: Item %s too large (size=%s, total_capacity=%s)", item.name, item.size, total_capacity)
                        return None
                    max_start = 0
                start_var = model.new_int_var(0, max_start, f'abs_start_{i}')

            abs_start.append(start_var)

        # Absolute end position
        abs_end = [model.new_int_var(0, total_capacity, f'abs_end_{i}') for i in range(n)]

        # Which bank does the item start in?
        start_bank = [model.new_int_var(0, self.num_bins - 1, f'start_bank_{i}') for i in range(n)]

        # Which bank does the item end in?
        end_bank = [model.new_int_var(0, self.num_bins - 1, f'end_bank_{i}') for i in range(n)]

        # Track which banks each item uses
        item_uses_bank = []
        for i in range(n):
            item_banks = []
            for b in range(self.num_bins):
                uses = model.new_bool_var(f'item_{i}_uses_bank_{b}')
                # Item i uses bank b if: placed AND start_bank <= b AND end_bank >= b
                model.add(start_bank[i] <= b).only_enforce_if([placed[i], uses])
                model.add(end_bank[i] >= b).only_enforce_if([placed[i], uses])

                # If not using, either not placed OR outside range
                not_in_range = model.new_bool_var(f'item_{i}_not_in_bank_{b}')
                model.add(start_bank[i] > b).only_enforce_if([placed[i], uses.Not(), not_in_range])
                model.add(end_bank[i] < b).only_enforce_if([placed[i], uses.Not(), not_in_range.Not()])

                item_banks.append(uses)
            item_uses_bank.append(item_banks)

        # Count conflicts: bank b has conflict if more than 1 item uses it
        bank_conflict = []
        for b in range(self.num_bins):
            items_using_b = [item_uses_bank[i][b] for i in range(n)]
            num_items_in_bank = model.new_int_var(0, n, f'num_items_bank_{b}')
            model.add(num_items_in_bank == sum(items_using_b))

            # Bank has conflict if > 1 item
            has_conflict = model.new_bool_var(f'bank_{b}_conflict')
            model.add(num_items_in_bank > 1).only_enforce_if(has_conflict)
            model.add(num_items_in_bank <= 1).only_enforce_if(has_conflict.Not())
            bank_conflict.append(has_conflict)

        # Constraints

        # 1. MANDATORY ITEMS: must be placed
        for i in range(n):
            if self.items[i].must_place:
                model.add(placed[i] == 1)

        # 2. Size constraint: abs_end = abs_start + size (only for placed items)
        for i in range(n):
            model.add(abs_end[i] == abs_start[i] + self.items[i].size).only_enforce_if(placed[i])

        # 3. Capacity constraint (only for placed items)
        for i in range(n):
            model.add(abs_end[i] <= total_capacity).only_enforce_if(placed[i])

        # 4. Compute start_bank and end_bank from absolute positions
        for i in range(n):
            model.add_division_equality(start_bank[i], abs_start[i], self.bin_capacity)

            temp_end_minus_1 = model.new_int_var(0, total_capacity, f'abs_end_m1_{i}')
            model.add(temp_end_minus_1 == abs_end[i] - 1).only_enforce_if(placed[i])
            model.add_division_equality(end_bank[i], temp_end_minus_1, self.bin_capacity)

        # 5. Items span consecutive banks: end_bank >= start_bank
        for i in range(n):
            model.add(end_bank[i] >= start_bank[i]).only_enforce_if(placed[i])

        # 6. NO OVERLAP: items don't overlap in absolute space if both placed
        for i in range(n):
            for j in range(i + 1, n):
                both_placed = model.new_bool_var(f'both_placed_{i}_{j}')
                model.add(placed[i] + placed[j] == 2).only_enforce_if(both_placed)
                model.add(placed[i] + placed[j] < 2).only_enforce_if(both_placed.Not())

                i_before_j = model.new_bool_var(f'{i}_before_{j}')
                model.add(abs_end[i] <= abs_start[j]).only_enforce_if([both_placed, i_before_j])
                model.add(abs_end[j] <= abs_start[i]).only_enforce_if([both_placed, i_before_j.Not()])

        # 7. EXCLUSIVITY CONSTRAINT: certain pairs cannot overlap in bank space
        for name1, name2 in self.exclusivity_pairs:
            if name1 not in self.item_index or name2 not in self.item_index:
                continue

            i = self.item_index[name1]
            j = self.item_index[name2]

            both_placed = model.new_bool_var(f'excl_both_placed_{i}_{j}')
            model.add(placed[i] + placed[j] == 2).only_enforce_if(both_placed)
            model.add(placed[i] + placed[j] < 2).only_enforce_if(both_placed.Not())

            i_banks_before_j = model.new_bool_var(f'excl_{i}_before_{j}')
            model.add(end_bank[i] < start_bank[j]).only_enforce_if([both_placed, i_banks_before_j])
            model.add(end_bank[j] < start_bank[i]).only_enforce_if([both_placed, i_banks_before_j.Not()])

        # 8. Compute number of banks used (maximum end_bank + 1)
        max_bank_used = model.new_int_var(0, self.num_bins, 'max_bank_used')
        for i in range(n):
            model.add(max_bank_used >= end_bank[i] + 1).only_enforce_if(placed[i])

        # Objective: Maximize items placed, minimize bank conflicts, minimize banks used
        placement_weight = 1000000  # Highest priority: place all items
        conflict_weight = 10000     # Second priority: minimize bank conflicts
        bank_usage_weight = 1       # Third priority: use fewer banks

        total_value = sum(placed[i] * self.items[i].priority for i in range(n))
        total_conflicts = sum(bank_conflict)

        if self.minimize_bank_conflicts:
            model.maximize(
                total_value * placement_weight
                - total_conflicts * conflict_weight
                - max_bank_used * bank_usage_weight
            )
        else:
            # Original behavior: just minimize banks used
            model.maximize(
                total_value * placement_weight
                - max_bank_used * bank_usage_weight
            )

        # Solve
        solver = cp_model.CpSolver()
        solver.parameters.max_time_in_seconds = self.time_limit
        solver.parameters.log_search_progress = False

        # Use deterministic settings for pytest (env var PYTEST_CURRENT_TEST is set by pytest)
        # For production: use parallel search for better performance
        is_pytest = 'PYTEST_CURRENT_TEST' in os.environ
        if is_pytest:
            solver.parameters.num_workers = 1
            solver.parameters.random_seed = 42
        else:
            solver.parameters.num_workers = get_os_core_count()
            # Don't set random_seed for production - let solver explore freely

        logger.debug("CP-SAT solver config: num_workers=%s, random_seed=%s, is_pytest=%s",
                     solver.parameters.num_workers,
                     solver.parameters.random_seed if is_pytest else 'default',
                     is_pytest)

        status = solver.solve(model)

        if status not in [cp_model.OPTIMAL, cp_model.FEASIBLE]:
            logger.debug("Solver status: %s", solver.status_name(status))
            return None

        # Extract solution
        placements = []
        for i in range(n):
            if solver.value(placed[i]):
                abs_pos = solver.value(abs_start[i])
                s_bank = solver.value(start_bank[i])
                e_bank = solver.value(end_bank[i])
                offset_in_first = abs_pos - (s_bank * self.bin_capacity)

                placements.append(ItemPlacement(
                    item=self.items[i],
                    start_bank=s_bank,
                    end_bank=e_bank,
                    offset_in_first_bank=offset_in_first,
                    bank_capacity=self.bin_capacity
                ))

        # Calculate waste and conflicts
        num_banks_actually_used = solver.value(max_bank_used)
        total_used_capacity = num_banks_actually_used * self.bin_capacity
        total_item_size = sum(p.item.size for p in placements)
        total_waste = total_used_capacity - total_item_size

        num_conflicts = solver.value(total_conflicts) if self.minimize_bank_conflicts else 0
        num_placed = sum(solver.value(placed[i]) for i in range(n))

        return PackingResult(
            placements=placements,
            bank_capacity=self.bin_capacity,
            num_banks=self.num_bins,
            num_banks_used=num_banks_actually_used,
            num_items_placed=num_placed,
            total_waste=total_waste,
            bank_conflicts=num_conflicts
        )
