"""
Base Tiler utilities and constants for AIE4 operator tiling
Provides common functionality shared across tiler implementations
"""

import os
from functools import reduce
import math
import operator
from typing import Tuple, Any
from dataclasses import dataclass
from utils.utils_common import iceil, ceildiv
from kernel.common.kernel_params_helper import DimsHelper

# ============================================================================
# Hardware Configuration Constants
# ============================================================================


@dataclass(frozen=True)
class AIE4HardwareConfig:  # pylint: disable=invalid-name,too-many-instance-attributes
    """AIE-4 hardware configuration constants"""

    # Overlay configuration
    AIE_COLS: int = 3
    AIE_ROWS: int = 4
    NUM_CORES: int = 12  # AIE_COLS * AIE_ROWS

    # Memory configuration
    CORE_BANK_SIZE: int = 32768  # 32 KB per bank
    CORE_BANK_MEM_SIZE_SOFTWARE: int = 16384  # 16 KB software accessible
    MEMORY_ALIGNMENT: int = 128  # Memory alignment for vector loads/stores


# Global hardware config instance
HW_CONFIG = AIE4HardwareConfig()


# ============================================================================
# Common Utility Functions
# ============================================================================


def compute_total_elements(*dimensions: int) -> int:
    """
    Compute total number of elements from dimensions

    Args:
        dimensions: Variable number of dimension sizes

    Returns:
        Product of all dimensions
    """
    return reduce(operator.mul, dimensions, 1)


def lcm(a: int, b: int) -> int:
    """
    Least common multiple of two integers

    Args:
        a: First integer
        b: Second integer

    Returns:
        LCM of a and b
    """
    return abs(a * b) // math.gcd(a, b)


def compute_subvolume_requirements(
    kernel_subv_requirement: int, memtile_subv_requirement: int
) -> int:
    """
    Compute minimum subvolume based on kernel and memory tile requirements

    Args:
        kernel_subv_requirement: Minimum subvolume for kernel
        memtile_subv_requirement: Minimum subvolume for memory tile alignment

    Returns:
        Minimum subvolume size (LCM of requirements)
    """
    return lcm(kernel_subv_requirement, memtile_subv_requirement)


def compute_max_subvolume(
    available_memory: int, element_bytes: int, min_subvolume: int
) -> int:
    """
    Compute maximum subvolume that fits in available memory

    Args:
        available_memory: Available memory in bytes
        element_bytes: Size of each element in bytes
        min_subvolume: Minimum subvolume requirement

    Returns:
        Maximum subvolume (multiple of min_subvolume)
    """
    if min_subvolume == 0:
        raise ValueError("min_subvolume cannot be zero")
    max_elements = available_memory // element_bytes
    return (max_elements // min_subvolume) * min_subvolume


def distribute_work_uniform(
    total_work: int, subvolume: int, num_cores: int
) -> Tuple[int, int, bool]:
    """
    Distribute work across cores with equal subvolume per core

    Args:
        total_work: Total amount of work (elements)
        subvolume: Subvolume size per core per iteration
        num_cores: Number of available cores

    Returns:
        Tuple of (total_iterations, active_cores_last_iter, partial_last_iter)
    """
    if subvolume == 0:
        return (0, 0, False)

    work_per_iter = num_cores * subvolume
    full_iterations = total_work // work_per_iter
    remainder = total_work % work_per_iter

    if remainder == 0:
        return (full_iterations, num_cores, False)

    active_cores_last = ceildiv(remainder, subvolume)
    partial_last = (remainder % subvolume) != 0
    return (full_iterations + 1, active_cores_last, partial_last)


def validate_subvolume(subvolume: int, min_subvolume: int, max_subvolume: int) -> None:
    """
    Validate that subvolume meets requirements

    Args:
        subvolume: Proposed subvolume
        min_subvolume: Minimum allowed subvolume
        max_subvolume: Maximum allowed subvolume

    Raises:
        ValueError: If subvolume is invalid
    """
    if subvolume < min_subvolume:
        raise ValueError(f"Subvolume {subvolume} < minimum {min_subvolume}")
    if subvolume > max_subvolume:
        raise ValueError(f"Subvolume {subvolume} > maximum {max_subvolume}")
    if min_subvolume > 0 and subvolume % min_subvolume != 0:
        raise ValueError(f"Subvolume {subvolume} not multiple of {min_subvolume}")


def compute_buffer_size(
    dimensions: Tuple[int, ...],
    element_bits: int,
    alignment: int = HW_CONFIG.MEMORY_ALIGNMENT,
) -> int:
    """
    Compute aligned buffer size for given dimensions

    Args:
        dimensions: Tensor dimensions (e.g., (Y, X, C))
        element_bits: Bits per element
        alignment: Memory alignment requirement in bytes

    Returns:
        Aligned buffer size in bytes
    """
    total_elements = compute_total_elements(*dimensions)
    size_bytes = (total_elements * element_bits) // 8
    return iceil(size_bytes, alignment)


def compute_overcompute_ratio(
    actual_dims: Tuple[int, ...], padded_dims: Tuple[int, ...]
) -> float:
    """
    Compute ratio of padded computation to actual computation

    Args:
        actual_dims: Actual tensor dimensions
        padded_dims: Padded tensor dimensions

    Returns:
        Ratio of padded_volume / actual_volume
    """
    actual_volume = compute_total_elements(*actual_dims)
    padded_volume = compute_total_elements(*padded_dims)

    if actual_volume == 0:
        return float("inf")

    return padded_volume / actual_volume


def compute_refetch_ratio(
    ifm_size: int, wgt_size: int, ofm_size: int, loop_counts: Tuple[int, ...]
) -> float:
    """
    Compute data refetch ratio for nested loops

    Args:
        ifm_size: Input feature map size in bytes
        wgt_size: Weight tensor size in bytes
        ofm_size: Output feature map size in bytes
        loop_counts: Tuple of loop iteration counts

    Returns:
        Ratio of total fetched data to unique data
    """
    total_size = ifm_size + wgt_size + ofm_size

    # Weight is refetched for each Y, X iteration
    if len(loop_counts) >= 3:
        y_loop, x_loop = loop_counts[1], loop_counts[2]
        wgt_refetch_size = wgt_size * y_loop * x_loop
    else:
        wgt_refetch_size = wgt_size

    fetch_size = ifm_size + wgt_refetch_size + ofm_size

    if total_size == 0:
        return 1.0

    return fetch_size / total_size


def align_up(addr: int, alignment: int) -> int:
    """Align the given address up to the nearest multiple of alignment.

    Args:
        addr: The address to align.
        alignment: The alignment boundary (must be a power of 2).
    """
    if alignment <= 1:
        return addr
    return ((addr + alignment - 1) // alignment) * alignment


def get_os_core_count() -> int:
    """Get number of available OS cores for tiling computations."""
    try:
        return len(os.sched_getaffinity(0))
    except AttributeError:
        return os.cpu_count() or 1


def _prod_func(x):
    """
    Product function that works with both standard iterables and pydantic models.
    When iterating over pydantic models, they yield (field_name, value) tuples,
    so we need to extract just the values.
    """
    try:
        # First check if it's a plain iterable of numbers
        return math.prod(x)
    except TypeError:
        # If that fails, try extracting values from (name, value) tuples
        # This handles pydantic BaseModel iteration which yields tuples
        try:
            values = [v for k, v in x]
            return math.prod(values)
        except (TypeError, ValueError) as e:
            raise TypeError(
                f"prod() argument must be an iterable of numbers, got {type(x).__name__}"
            ) from e


def _ceil_func(x: int | float, divisor=None):
    """
    Ceil func that works with multiple arguments
    """
    if divisor is None:
        return math.ceil(x)

    assert isinstance(x, int)
    return iceil(int(x), int(divisor))


_eval_syms = {
    "__builtins__": None,
    "log": math.log,
    "log10": math.log10,
    "log2": math.log2,
    "exp": math.exp,
    "pow": math.pow,
    "sqrt": math.sqrt,
    "sin": math.sin,
    "asin": math.asin,
    "cos": math.cos,
    "acos": math.acos,
    "tan": math.tan,
    "atan": math.atan,
    "atan2": math.atan2,
    "tanh": math.tanh,
    "pi": math.pi,
    "abs": abs,
    "min": min,
    "max": max,
    "round": round,
    "int": int,
    "ceil": _ceil_func,
    "floor": math.floor,
    "prod": _prod_func,
    "DimsHelper": DimsHelper,
}


def npeval(expr: str, local_vars_dict: dict) -> Any:
    """Evaluate an expression safely."""
    return eval(expr, _eval_syms, local_vars_dict)  # pylint: disable=eval-used
