"""
Map conv shapes to the AIE-4 dataflow architecture.
External facing functions are documented below.

    generate_mappings - enumerate all possible ways to map a conv shape
    onto the compute array and sort them in order of descending projected
    latency (fastest mappings first)
"""

from typing import Optional

from utils.utils_common import (
    ceildiv,
    iceil,
)
from scheduler.common import (
    LinearOpType,
)

from scheduler.conv.conv_config_builders import (
    ConvShape,
    ConvMapping,
    conv_input,
)

from scheduler.dwc.dwc_common import (
    dwc_get_aligned_Xis,
)

from tiler.tiler_common import (
    generate_mappings,
    create_conv_mappings_from_base,
)

from tiler.l1_buffer_allocator import (
    L1BufferAllocator,
    BufferSpec,
    BufferPair
)

# Use shared_process_cache_factory for proper cross-process caching
# This is safe for Windows spawn mode - Manager created lazily on first call
from tiler.cache_decorators import shared_process_cache_factory

# Create cache decorator with factory pattern (Manager created on first allocate() call)
_allocate_cache_decorator, _allocate_cache_manager = shared_process_cache_factory()


def sorted_dwc_mappings(shape: ConvShape, mappings: list[ConvMapping]) -> list[ConvMapping]:
    '''Sort conv mappings in order of descending projected latency'''
    _ = shape

    def mapping_key(mapping: ConvMapping) -> tuple[bool, int, int, int, int, int]:
        """
        Mapping cost is defined by the following parameters in order of increasing importance
        """
        Yos, Xos, Cos = mapping.ofm_subv
        Y_loop, X_loop, Co_loop, _ = mapping.iters
        is_input_single_buffered = not mapping.l1_alloc["ifm"].is_double_buffered and not mapping.l1_alloc["wgt"].is_double_buffered
        total_loop_count = Y_loop * X_loop * Co_loop
        ofm_subv_total = Yos * Xos * Cos
        key = (
            is_input_single_buffered,   # Highest priority
            total_loop_count,
            ofm_subv_total,           # Smaller total subvolume first
            Yos,                      # Tie-breaker: smaller Yos first
            Xos,                      # Tie-breaker: smaller Xos first
            Cos,                      # Tie-breaker: smaller Cos first
        )
        return key
    return sorted(mappings, key=mapping_key)


@_allocate_cache_decorator
def allocate(ifm_L1_size: int, wgt_L1_size: int, ofm_L1_size: int, tdm_L1_size: int, vec_L1_size: int, qdq_L1_size: int):
    """Allocate buffers"""
    # Create buffer allocator
    allocator = L1BufferAllocator()

    # Create buffer specifications with priorities
    # Main buffers: IFM (priority 3), WGT (priority 2), OFM (priority 1)
    # Extra buffers: TDM, VEC, QDQ (no pong, so just ping with lower priorities)
    buffers = {
        "IFM": BufferSpec("IFM", ifm_L1_size, is_ping=True, priority=8),
        "WGT": BufferSpec("WGT", wgt_L1_size, is_ping=True, priority=7),
        "OFM": BufferSpec("OFM", ofm_L1_size, is_ping=True, priority=6),
        "TDM": BufferSpec("TDM", tdm_L1_size, is_ping=True, priority=5),
        "VEC": BufferSpec("VEC", vec_L1_size, is_ping=True, priority=4),
        "QDQ": BufferSpec("QDQ", qdq_L1_size, is_ping=True, priority=3),
        "WGT_pong": BufferSpec("WGT_pong", wgt_L1_size, is_pong=True, priority=2),
        "IFM_pong": BufferSpec("IFM_pong", ifm_L1_size, is_pong=True, priority=1),
        "OFM_pong": BufferSpec("OFM_pong", ofm_L1_size, is_pong=True, priority=0),
    }

    # create exclusions
    exclusions = [
        BufferPair(buffers["IFM"], buffers["IFM_pong"]),  # ifm ping and pong cannot share banks
        BufferPair(buffers["WGT"], buffers["WGT_pong"]),  # wgt ping and pong cannot share banks
        BufferPair(buffers["OFM"], buffers["OFM_pong"]),  # ofm ping and pong cannot share banks
        BufferPair(buffers["IFM"], buffers["TDM"]),  # ifm ping and tdm cannot share banks
        BufferPair(buffers["IFM_pong"], buffers["TDM"]),  # ifm-pong and tdm cannot share banks
    ]

    # Use CP-SAT based allocator
    allocations = allocator.allocate_cpsat(list(buffers.values()), exclusions)
    if allocations is None:
        return None

    # Extract addresses using safe API
    return {
        "ifm": (ifm_L1_size, allocations.get_addr("IFM"), allocations.get_addr("IFM_pong")),
        "wgt": (wgt_L1_size, allocations.get_addr("WGT"), allocations.get_addr("WGT_pong")),
        "ofm": (ofm_L1_size, allocations.get_addr("OFM"), None),
        "tdm": (tdm_L1_size, allocations.get_addr("TDM"), None),
        "vec": (vec_L1_size, allocations.get_addr("VEC"), None),
        "qdq": (qdq_L1_size, allocations.get_addr("QDQ"), None),
    }


def generate_dwc_mappings(shape: ConvShape, enable_over_compute: bool) -> list[ConvMapping]:
    '''Generate all possible DWC mapping solutions for give shape'''
    _ = enable_over_compute
    if shape.linear_op_type == LinearOpType.dwc_A16W8_qdq:
        # Generate all possible combinations for DWC A16W8 QDQ
        # Constraints: Yos % 2 == 0, Xos % 4 == 0, Cos is multiple of 64
        ofm_subvs = []

        # Get output dimensions
        Yo, Xo, Cout = shape.ofm
        Xos_gran = 2
        Yos_gran = 4
        Cos_gran = 64

        # Generate all valid Yos values (must be even and <= Yo)
        for Yos in range(Yos_gran, Yo + 1, Yos_gran):  # Start from 2, step by 2 (even numbers)
            # Generate all valid Xos values (must be multiple of 4 and <= Xo)
            for Xos in range(Xos_gran, Xo + 1, Xos_gran):  # Start from 4, step by 4 (multiples of 4)
                # Generate all valid Cos values (must be multiple of 64 and <= Cout)
                for Cos in range(Cos_gran, Cout + 1, Cos_gran):  # Start from 64, step by 64 (multiples of 64)
                    ofm_subvs.append((Yos, Xos, Cos))

        # If no valid combinations found, add at least one minimal valid combination
        if not ofm_subvs:
            # Use minimal valid sizes that satisfy constraints
            min_Yos = Yos_gran
            min_Xos = Xos_gran
            min_Cos = Cos_gran if Cout <= Cos_gran else Cos_gran * (Cout // Cos_gran)
            ofm_subvs.append((min_Yos, min_Xos, min_Cos))

        kernel_gran = (64, 64)  # Co_gran, Ci_gran
        ifm_bits = 16
        wgt_bits = 8
        bias_bits = 32
        ofm_bits = 16
    else:
        raise ValueError("Unsupported DWC type for DWC mapping generation in tiler")

    # NOTE: This is the memory alignment required for vector loads and stores
    memory_align = 128

    Ky, Kx = shape.kernel

    def allocate_L1_buffers(
        ifm_subv: tuple[int, int, int],
        ofm_subv: tuple[int, int, int],
    ) -> Optional[dict]:
        '''Allocate buffers in L1 if possible'''

        # NOTE: Below, we go through each allocation strategy in priority order.
        # If buffers fit, we attempt double buffering with bank splitting.
        # Subsequent cases remove bank splitting and use single buffering
        # to save space.

        # Compute buffer sizes
        Yis, Xis, Cis = ifm_subv
        Yos, Xos, Cos = ofm_subv
        ifm_L1_size = iceil((Yis * Xis * Cis * ifm_bits) // 8, memory_align)
        ofm_L1_size = iceil((Yos * Xos * Cos * ofm_bits) // 8, memory_align)
        bias_buffer_size = 0
        wgt_L1_size = 0
        qdq_L1_size = 0
        tdm_L1_size = 0
        vec_L1_size = 0
        filter_buffer_size = 0
        if shape.linear_op_type == LinearOpType.dwc_A16W8_qdq:
            no_vec_coeff = 2
            Ky_gran = 3
            Kx_gran = 4
            bias_buffer_size = iceil((no_vec_coeff * Cos * bias_bits // 8), memory_align)
            qdq_param_size = 128
            filter_buffer_size = iceil((max(Kx, Kx_gran) * max(Ky, Ky_gran) * Cos * wgt_bits // 8), memory_align)
            wgt_L1_size = filter_buffer_size + bias_buffer_size + qdq_param_size
        else:
            raise ValueError("Unsupported DWC type for dwc mapping generation in tiler")

        return allocate(ifm_L1_size, wgt_L1_size, ofm_L1_size, tdm_L1_size, vec_L1_size, qdq_L1_size)

    def dwc_is_split_valid(
        ofm_shape: tuple[int, int, int],
        ofm_subv: tuple[int, int, int],
        split: tuple[int, int, int, int],
    ) -> bool:
        Yo, Xo, Co = ofm_shape
        Yos, Xos, Cos = ofm_subv
        _, Y_split, X_split, Co_split = split
        Y_loop = ceildiv(Yo, (Yos * Y_split))
        X_loop = ceildiv(Xo, (Xos * X_split))
        Co_loop = ceildiv(Co, (Cos * Co_split))
        Co_splits_supported = [1, 2, 6, 3]
        # NOTE: Cout split mode (4 or 12) in the Co dimensions
        # Would mean that each core within a column will read unqiue Cos
        # Since both IFM and WGT have Co block unqiue for DWC
        # THis means both IFM/WGT has to be uni-cast
        # Time sharing the same channel for 2 tensors is near immpossible and breaks HW pipelining
        # Hence discouraged the use of the above Co splits
        Co_split_valid = Co_split in Co_splits_supported
        loops_valid = (Y_loop >= 1) and (X_loop >= 1) and (Co_loop >= 1)
        return loops_valid and Co_split_valid

    def dwc_get_input_subv(
        shape: ConvShape,
        ofm_subv: tuple[int, int, int],
        kernel_gran: tuple[int, int]
    ) -> list[tuple[int, int, int]]:
        Yos, Xos, Cos = ofm_subv
        Ky, Kx = shape.kernel
        Sy, Sx = shape.stride
        _, _ = kernel_gran
        Yis = conv_input(Yos, Ky, Sy)
        Kx_gran = 0
        if shape.linear_op_type == LinearOpType.dwc_A16W8_qdq:
            Kx_gran = 4
        Xis = dwc_get_aligned_Xis(Xos, Sx, Kx, Kx_gran)
        Cis = Cos
        valid_input_subvs = []
        valid_input_subvs.append((Yis, Xis, Cis))
        return valid_input_subvs

    base_mappings_cache = generate_mappings(
        shape,
        kernel_gran,
        ofm_subvs,
        dwc_is_split_valid,
        dwc_get_input_subv,
        allocate_L1_buffers,
    )

    # Generate base_mappings from base_mappings_cache but set Ci_loop to 1
    base_mappings = []
    for mapping in base_mappings_cache:
        # Extract current iteration values
        Y_loop, X_loop, Co_loop, _ = mapping.iters
        # NOTE: Force Ci_loop to 1 for DWC OP as there is not accumulation loop
        modified_mapping = mapping.model_copy(
            update={'iters': (Y_loop, X_loop, Co_loop, 1)},
            deep=True
        )
        base_mappings.append(modified_mapping)
    conv_mappings = create_conv_mappings_from_base(base_mappings,
                                                   ifm_L2_strategy='stream',
                                                   wgt_L2_strategy='stream',
                                                   ofm_L2_strategy='stream')
    return sorted_dwc_mappings(shape, conv_mappings)
