"""Common functions for tilers"""

from typing import Callable, Optional
from collections.abc import Iterable

from utils.utils_common import (
    BaseMappingWithL1,
    ceildiv,
    SPATIAL_SPLIT_MODES,
    BaseShape,
    BaseMapping,
    L1Alloc_from_dict
)

from scheduler.conv.conv_config_builders import (
    ConvMapping,
)

# define a base class for tilers
# Add generate mappings and sort mappings functions


def sorted_mappings(
    shape: BaseShape, mappings: list[BaseMapping], enable_over_compute: bool
) -> list[BaseMapping]:
    """Sort mappings in order of descending projected latency"""

    def mapping_key(
        mapping: BaseMapping,
    ) -> tuple[bool, float, bool, int, float] | tuple[bool, float, int]:
        """
        Mapping cost is defined by the following parameters in order of increasing importance

                0. Input single buffered
                    This prevents data movement and compute from being overlapped
                    in a significant way.

                1. Data re-fetch ratio
                    In the data movement implementation, the outer loop will traverse the
                    Y / X axes of the image. This means the weight tensor traversal
                    along Cout is the fast moving dimension. Excess iterations along X / Y
                    will cause the tensor to be re-fetched incurring additional cost.

                2. Ci padding
                    Some mappings may require padding along the Ci axis. This is
                    penalized to prefer mappings that do not require padding.

                3. Total loop count
                    Breaking the large problem into smaller sub-problems will have an
                    associated overhead at all levels of the stack. This quantity
                    accounts for the penality of this excess control.

                4. Over-compute ratio
                    Some mappings may compute additional data along a particular axis
                    in the spatial split, which will be discarded. This quantity
                    penalizes for innefficient usage of MACs. Over-compute is partially
                    accounted for by the total loop count. This ratio can act as a tie-breaker
                    for cases where the loop count is the same.
        """

        Yi, Xi, Ci = shape.ifm
        Yo, Xo, Co = shape.ofm
        Ky, Kx = shape.kernel
        _, _, Cis = mapping.ifm_subv
        Yos, Xos, Cos = mapping.ofm_subv
        _, Y_split, X_split, Co_split = mapping.spatial_split

        Y_loop, X_loop, Co_loop, Ci_loop = mapping.iters

        is_input_single_buffered = (
            not mapping.l1_alloc["ifm"].is_double_buffered
            and not mapping.l1_alloc["wgt"].is_double_buffered
        )

        wgt_subv_size = (Ky * Kx * Cis * Cos * mapping.wgt_bits) // 8
        ifm_size = (Yi * Xi * Ci * mapping.ifm_bits) // 8
        wgt_size = ((Co_loop * Co_split) * Ci_loop) * wgt_subv_size
        ofm_size = (Yo * Xo * Co * mapping.ofm_bits) // 8
        total_size = ifm_size + wgt_size + ofm_size
        fetch_size = ifm_size + (wgt_size * Y_loop * X_loop) + ofm_size
        refetch_ratio = fetch_size / total_size

        total_loop_count = Ci_loop * Y_loop * X_loop * Co_loop

        image_pad = (Yos * Y_split * Y_loop) * (Xos * X_split * X_loop)
        channel_pad = Ky * Kx * Cis * Ci_loop
        ch_out_pad = Cos * Co_split * Co_loop
        total_macs = Yo * Xo * Kx * Ky * Ci * Co
        computed_macs = image_pad * channel_pad * ch_out_pad
        overcompute_ratio = computed_macs / total_macs
        Ci_pad = Ci % Cis == 0

        if not enable_over_compute:
            key = (
                is_input_single_buffered,
                refetch_ratio,
                not Ci_pad,  # penalize Ci padding
                total_loop_count,
                overcompute_ratio,
            )
        else:
            key = (
                is_input_single_buffered,
                refetch_ratio,
                total_loop_count,
            )
        return key

    return sorted(mappings, key=mapping_key)


def generate_mappings(
    shape: BaseShape,
    kernel_granularities: tuple,
    ofm_subvs: list,
    is_split_valid: Callable[..., bool],
    get_input_subv: Callable[
        ..., list[tuple[int, int, int]] | list[list[tuple[int, int, int]]]
    ],
    allocate_L1_buffers: Callable[..., dict[str, tuple[int, int, int]]],
    key: Optional[Callable] = None,
) -> list[BaseMapping]:
    """Find all valid mappings for the given shape and constraints."""
    mappings = []
    for split in SPATIAL_SPLIT_MODES:
        for ofm_subv in ofm_subvs:
            if not is_split_valid(shape.ofm, ofm_subv, split):
                continue
            Yos, Xos, Cos = ofm_subv
            ifm_subv_list = get_input_subv(shape, ofm_subv, kernel_granularities)
            for ifm_subv in ifm_subv_list:
                if key:  # sort before running expensive l1 allocation
                    l1_alloc = {}
                else:
                    l1_alloc = allocate_L1_buffers(ifm_subv, ofm_subv)
                    if l1_alloc is None:
                        continue
                if isinstance(ifm_subv[0], Iterable):
                    Ci = max(ifm[2] for ifm in shape.ifm)
                    Cis = max(subv[2] for subv in ifm_subv)
                else:
                    _, _, Cis = ifm_subv
                    _, _, Ci = shape.ifm
                Yo, Xo, Co = shape.ofm
                _, Yo_split, Xo_split, Co_split = split
                Yo_loop = ceildiv(Yo, (Yos * Yo_split))
                Xo_loop = ceildiv(Xo, (Xos * Xo_split))
                Co_loop = ceildiv(Co, (Cos * Co_split))
                Ci_loop = ceildiv(Ci, Cis)
                Yo_overcompute = Yo_loop * Yos * Yo_split
                Xo_overcompute = Xo_loop * Xos * Xo_split
                Co_overcompute = Co_loop * Cos * Co_split
                mappings.append(
                    BaseMappingWithL1(
                        ofm_pad=(Yo_overcompute, Xo_overcompute, Co_overcompute),
                        ifm_pad=shape.ifm,
                        ifm_subv=ifm_subv,
                        ofm_subv=ofm_subv,
                        spatial_split=split,
                        iters=(Yo_loop, Xo_loop, Co_loop, Ci_loop),
                        kernel_gran=kernel_granularities,
                        ifm_bits=shape.ifm_bits,
                        wgt_bits=shape.wgt_bits,
                        ofm_bits=shape.ofm_bits,
                        bias_bits=shape.bias_bits,
                        l1_alloc=l1_alloc,
                    )
                )
    if key:  # allocate l1 after sorting to short-circuit expensive allocations
        mappings = sorted(mappings, key=key)
        best_single_buffered_mapping = None
        best_ifm_buffered_mapping = None
        best_wgt_buffered_mapping = None
        for mapping in mappings:
            if l1_alloc := allocate_L1_buffers(mapping.ifm_subv, mapping.ofm_subv):
                mapping = mapping.model_copy(update={"l1_alloc": L1Alloc_from_dict(l1_alloc)})
                # prioritize best double buffer mapping, idx 2 is pong addr, see class BaseMappingWithL1
                if mapping.l1_alloc["ifm"].is_double_buffered and mapping.l1_alloc["wgt"].is_double_buffered:
                    return [mapping]
                if not best_ifm_buffered_mapping and mapping.l1_alloc["ifm"].is_double_buffered:
                    best_ifm_buffered_mapping = [mapping]
                if not best_wgt_buffered_mapping and mapping.l1_alloc["wgt"].is_double_buffered:
                    best_wgt_buffered_mapping = [mapping]
                if not best_single_buffered_mapping:
                    best_single_buffered_mapping = [mapping]

        best_mapping = best_ifm_buffered_mapping or best_wgt_buffered_mapping or best_single_buffered_mapping

        if not best_mapping:
            raise ValueError(f"No valid L1 allocations found for {len(mappings)} mappings")
        return best_mapping
    return mappings


def create_conv_mappings_from_base(
    base_mappings: list[BaseMapping],
    ifm_L2_strategy: str = "pin",
    wgt_L2_strategy: str = "stream",
    ofm_L2_strategy: str = "stream",
) -> list[ConvMapping]:
    """Convert a list of BaseMapping to ConvMapping objects."""
    conv_mappings = []

    for base_mapping in base_mappings:
        # Create ConvMapping from BaseMapping
        conv_mapping = ConvMapping.from_base_mapping(
            base_mapping,
            ifm_L2_strategy=ifm_L2_strategy,
            wgt_L2_strategy=wgt_L2_strategy,
            ofm_L2_strategy=ofm_L2_strategy,
        )
        conv_mappings.append(conv_mapping)

    return conv_mappings
