"""
Map conv shapes to the AIE-4 dataflow architecture.
External facing functions are documented below.

    generate_mappings - enumerate all possible ways to map a conv shape
    onto the compute array and sort them in order of descending projected
    latency (fastest mappings first)
"""

import itertools
import math
from copy import deepcopy
from typing import Optional
from functools import lru_cache


from utils.utils_common import (
    BaseShape,
    BaseMapping,
    BaseMappingWithL1,
    iceil,
)

from dmacompiler import DevGen, set_dev_gen

from scheduler.conv.conv_config_builders import (
    is_split_valid,
)
from tiler.conv_tiler import allocate
from tiler.tiler_common import (
    generate_mappings,
)
from tiler.load_kernel_metadata import KernelMetadataForOp, OpType
from tiler.kernel_metadata.build_kernel_metadata_vars import (
    Broadcast2DParams,
    SubVolume2DParams,
    BroadcastQuantizationShifts,
)

set_dev_gen(DevGen.Aie4)


class BroadcastShape(BaseShape):
    """Dataclass for Broadcast Shape."""
    ifm_bytes: int
    ofm_bytes: int
    ifm_bits: int
    wgt_bits: int
    ofm_bits: int
    bias_bits: int
    sign_A: int
    sign_W: int
    sign_O: int
    op_name: str
    b_on_wgt: int
    call_kernel: str
    has_scalar_broadcast: int


@lru_cache(maxsize=2)
def get_loader(current_op: str) -> KernelMetadataForOp:
    "Get kernel metadata loader"
    for op in OpType:
        if current_op == op.value:
            return KernelMetadataForOp(op)
    raise ValueError(
        f"Could not find kernel metadata loader for op {current_op}")


def compute_paddings(fm, subv, split, iters) -> list[int]:
    """Computes the padding needed for each dimension given the full feature map (fm),
    on a per-subvolume level. Idle cores do not count, only cores with some valid data
    and some padded data."""
    paddings = []
    for dim_idx, true_dim in enumerate(fm):
        padded_dim = subv[dim_idx] * split[dim_idx] * iters[dim_idx]
        # ignore cores where the entire subvolume dimension is padding
        # this will decrement the padding until we find the partial core
        while padded_dim - subv[dim_idx] > true_dim:
            padded_dim -= subv[dim_idx]
        paddings.append(padded_dim - true_dim)
    return paddings


def generate_ofm_subvs(ofm: tuple[int, ...], total_elements=4096):
    """Generate all possible ways to tile 'total_elements' across 'ndims' dimensions. Default to chunks of 4096."""
    # It looks like we assume an ofm subvolume of 4096 elements...
    # powers = [2**i for i in range(int(math.log2(total_elements)) + 1)]
    dim_sizes = [1, 2, 4, 8, 16] + list(range(32, total_elements + 1, 32))
    subvs = set()
    valid_dim_sizes = []
    for dim in range(len(ofm)):
        valid_dim_sizes.append([ds for ds in dim_sizes if ds <= ofm[dim]])
    # filter out sizes that are not multiple of 64 for the last dimension
    valid_dim_sizes[-1] = [ds for ds in valid_dim_sizes[-1] if ds % 64 == 0]
    # If a broadcast is of the form A=(Y, X, C) and B=(Y, 1, C) then we need to make sure that the subv along dims Y and greater are 1
    # The kernels will automatically handle broadcasting along the lower dimensions, and the dataflow can handle broadcasting along the higher dimensions.
    # We can safely do this since this step happens after folding
    for i in range(len(ofm) - 2):
        valid_dim_sizes[i] = [1]

    for dims in itertools.product(*valid_dim_sizes):
        if math.prod(dims) <= total_elements:
            subvs.add(dims)

    return list(subvs)


def broadcast_mapping_key(mapping: BaseMappingWithL1, shape: BroadcastShape):
    """Computes a sorting key for broadcast mappings."""
    Y_loop, X_loop, _, Ci_loop = mapping.iters

    def not_multiple_of_4(ifms: tuple[int, int, int]) -> bool:
        return not math.prod(ifms) % 4 == 0

    ifms_not_multiple_of_4 = any(map(not_multiple_of_4, mapping.ifm_subv))

    total_loop_count = Ci_loop * Y_loop * X_loop

    padding_less_than_64 = all(p < 64 for p in compute_paddings(shape.ofm, mapping.ofm_subv, mapping.spatial_split[1:], mapping.iters,))

    # We can complicate this if need be to improve tiling
    key = (
        not padding_less_than_64,
        ifms_not_multiple_of_4,
        total_loop_count,
    )
    return key


def broadcast_is_split_valid(
    ofm_shape: tuple[int, int, int],
    ofm_subv: tuple[int, int, int],
    split: tuple[int, int, int, int],
    enable_over_compute: bool,
) -> bool:
    """Determines if broadcast split is valid."""
    return is_split_valid(ofm_shape, ofm_subv, split, enable_over_compute)


def broadcast_get_input_subv(
    shape: BroadcastShape,
    ofm_subv: tuple[int, int, int],
    _: tuple,  # kernel granularities, unused
) -> list[tuple[int, int, int]] | list[list[tuple[int, int, int]]]:
    """Computes the input subv shapes from a given ofm for broadcast"""
    assert len(shape.ifm) == 2, "Broadcast shape must have 2 input tensors"
    ifm1_subv_shape = tuple(
        min(shape.ifm[0][d], ofm_subv[d]) for d in range(len(ofm_subv))
    )
    ifm2_subv_shape = tuple(
        min(shape.ifm[1][d], ofm_subv[d]) for d in range(len(ofm_subv))
    )
    # c=1 must be 32-bit aligned
    if shape.has_scalar_broadcast:
        ifm2_subv_shape = (ifm2_subv_shape[0], ifm2_subv_shape[1], 4)
    return [[ifm1_subv_shape, ifm2_subv_shape]]


def broadcast_allocate_l1_buffers(
    ifm_subvs: list[tuple[int, int, int]],  # ifm subvs
    ofm_subv: tuple[int, int, int],
    shape: BroadcastShape,
) -> Optional[dict[str, tuple[int, int, int | None]]]:
    """Allocates L1 buffers for broadcast operation."""
    # broadcast tensors will have a dimension of 1
    subv1_shape = ifm_subvs[0]
    subv2_shape = ifm_subvs[1]

    osubv_size = math.prod(ofm_subv)
    subv1_size = math.prod(subv1_shape)
    subv2_size = math.prod(subv2_shape)

    # dq has loop range of 12 for 32 elements, we need to put this much
    # space in between ifm A and B to prevent overwriting
    # we upscale to bf16, so bytes is always 2
    dq_minimum_size = 12*32*2

    # we need to use temp buffers for bf16
    if shape.ifm_bytes == 1:
        ifm_L1_subv1_size = iceil(subv1_size, 128) + iceil(subv1_size*2, 128)
        ifm_L1_subv2_size = iceil(subv2_size, 128) + iceil(subv2_size*2, 128)
        ifm_L1_size = ifm_L1_subv1_size + ifm_L1_subv2_size + dq_minimum_size
    else:
        ifm_L1_size = iceil(subv1_size + subv2_size, 128) * shape.ifm_bytes + dq_minimum_size
    # 128 bytes for qdq parameters and 1024 bytes for qdq scratch buffer, we can shrink this later once we get the kernels
    wgt_L1_size = 128 + 1024
    # always need bf16 output buffer
    ofm_L1_size = osubv_size * 2

    params = Broadcast2DParams(
        subvolume=SubVolume2DParams(R=subv1_shape[1], C=subv1_shape[2]),
        subvolume_in1=SubVolume2DParams(
            R=subv2_shape[1], C=subv2_shape[2]),
        R=ofm_subv[1],
        C=ofm_subv[2],
        has_scalar_broadcast=bool(shape.has_scalar_broadcast),
        quantization_shifts=BroadcastQuantizationShifts(
            I0=0,
            I1=0,
            O0=0,
        ),
    )
    try:
        get_loader(shape.op_name).validate(params)
    except Exception as e:  # pylint: disable=broad-except
        print(
            f"Metadata validation failed: {e}. Params: {params} is invalid.")
        return None

    return allocate(ifm_L1_size, wgt_L1_size, ofm_L1_size, 0, 0, 0, 0)


def generate_broadcast_mappings(
    shape: BroadcastShape,
    enable_over_compute: bool,
    kernel_gran: int,
    _: int,  # kernel loop range, unused
) -> list[BaseMapping]:
    """Generates broadcast mappings. Minimizes loop counts and tries to use all available cores."""
    assert len(shape.ifm) == 2, "Broadcast shape must have 2 input tensors"

    ofm_subvs = generate_ofm_subvs(shape.ofm)
    # Make a copy of the shape so we can pad ifm dimensions with 1,
    # to ensure shapes are equal ndims
    shape = deepcopy(shape)
    if diff_dims := len(shape.ifm[0]) - len(shape.ifm[1]) > 0:
        shape.ifm[1] = [1] * diff_dims + shape.ifm[1]
    elif diff_dims := len(shape.ifm[1]) - len(shape.ifm[0]) > 0:
        shape.ifm[0] = [1] * diff_dims + shape.ifm[0]

    # Create wrapper functions that capture the necessary variables
    def _broadcast_is_split_valid(
        ofm_shape: tuple[int, int, int],
        ofm_subv: tuple[int, int, int],
        split: tuple[int, int, int, int],
    ) -> bool:
        return broadcast_is_split_valid(ofm_shape, ofm_subv, split, enable_over_compute)

    def _broadcast_allocate_l1_buffers(
        ifm_subvs: list[tuple[int, int, int]],
        ofm_subv: tuple[int, int, int]
    ) -> Optional[dict[str, tuple[int, int, int]]]:
        return broadcast_allocate_l1_buffers(ifm_subvs, ofm_subv, shape)

    def _broadcast_mapping_key(mapping: BaseMappingWithL1):
        return broadcast_mapping_key(mapping, shape)

    # We don't need to make a BroadcastMapping
    base_mappings = generate_mappings(
        shape,
        (kernel_gran,),
        ofm_subvs,
        _broadcast_is_split_valid,
        broadcast_get_input_subv,
        _broadcast_allocate_l1_buffers,
        key=_broadcast_mapping_key,
    )
    sorted_mappings = sorted(base_mappings, key=_broadcast_mapping_key)
    if len(sorted_mappings) == 0:
        raise ValueError("No valid mappings found")
    return sorted_mappings
