"""
Map conv shapes to the AIE-4 dataflow architecture.
External facing functions are documented below.

    generate_mappings - enumerate all possible ways to map a conv shape
    onto the compute array and sort them in order of descending projected
    latency (fastest mappings first)
"""

from typing import Optional
from functools import lru_cache

from utils.utils_common import (
    iceil,
)
from scheduler.common import (
    LinearOpType,
)

from scheduler.conv.conv_config_builders import (
    ConvShape,
    ConvMapping,
    is_split_valid,
    conv_input,
    align_Xis,
)

from tiler.load_kernel_metadata import KernelMetadataForOp, OpType
from tiler.tiler_common import (
    generate_mappings,
    sorted_mappings,
    create_conv_mappings_from_base,
)

from tiler.l1_buffer_allocator import (
    L1BufferAllocator,
    BufferSpec,
    BufferPair
)

from tiler.kernel_metadata.build_kernel_metadata_vars import (
    conv_int8x8_params,
    ConvSubVolumeParams,
    ConvTimeSplitParams,
    conv_qdq_int16x8_params,
    SyncTypeParams,
    SyncType
)

# Use shared_process_cache_factory for proper cross-process caching
# This is safe for Windows spawn mode - Manager created lazily on first call
from tiler.cache_decorators import shared_process_cache_factory

# Create cache decorator with factory pattern (Manager created on first allocate() call)
_allocate_cache_decorator, _allocate_cache_manager = shared_process_cache_factory()


def sorted_conv_mappings(shape: ConvShape, mappings: list[ConvMapping], enable_over_compute: bool) -> list[ConvMapping]:
    '''Sort conv mappings in order of descending projected latency'''
    return sorted_mappings(shape, mappings, enable_over_compute)


@lru_cache(maxsize=2)
def get_loader(linear_op_type: LinearOpType) -> KernelMetadataForOp:
    "Get kernel metadata loader for conv A8W8 noqdq"
    if linear_op_type == LinearOpType.conv_A8W8_noqdq:
        loader = KernelMetadataForOp(OpType.BIASED_CONV_INT8X8)
    else:
        loader = KernelMetadataForOp(OpType.ACTIVATED_CONV_QDQ_INT16X8)
    return loader


@_allocate_cache_decorator
def allocate(ifm_L1_size: int, wgt_L1_size: int, ofm_L1_size: int, tdm_L1_size: int, wght_transpose_sb_L1_size: int, vec_L1_size: int, qdq_L1_size: int):
    """Allocate buffers"""
    # Create buffer allocator
    allocator = L1BufferAllocator()

    # Create buffer specifications with priorities
    # Main buffers: IFM (priority 3), WGT (priority 2), OFM (priority 1)
    # Extra buffers: TDM, VEC, QDQ (no pong, so just ping with lower priorities)
    buffers = {
        "IFM": BufferSpec("IFM", ifm_L1_size, is_ping=True, priority=9),
        "WGT": BufferSpec("WGT", wgt_L1_size, is_ping=True, priority=8),
        "OFM": BufferSpec("OFM", ofm_L1_size, is_ping=True, priority=7),
        "TDM": BufferSpec("TDM", tdm_L1_size, is_ping=True, priority=6),
        "WGHT_T_SB": BufferSpec("WGHT_T_SB", wght_transpose_sb_L1_size, is_ping=True, priority=5),
        "VEC": BufferSpec("VEC", vec_L1_size, is_ping=True, priority=4),
        "QDQ": BufferSpec("QDQ", qdq_L1_size, is_ping=True, priority=3),
        "WGT_pong": BufferSpec("WGT_pong", wgt_L1_size, is_pong=True, priority=2),
        "IFM_pong": BufferSpec("IFM_pong", ifm_L1_size, is_pong=True, priority=1),
        "OFM_pong": BufferSpec("OFM_pong", ofm_L1_size, is_pong=True, priority=0),
    }

    # create exclusions
    exclusions = [
        BufferPair(buffers["IFM"], buffers["IFM_pong"]),  # ifm ping and pong cannot share banks
        BufferPair(buffers["WGT"], buffers["WGT_pong"]),  # wgt ping and pong cannot share banks
        BufferPair(buffers["OFM"], buffers["OFM_pong"]),  # ofm ping and pong cannot share banks
        BufferPair(buffers["IFM"], buffers["TDM"]),  # ifm ping and tdm cannot share banks
        BufferPair(buffers["IFM_pong"], buffers["TDM"]),  # ifm-pong and tdm cannot share banks
    ]

    # Use CP-SAT based allocator
    allocations = allocator.allocate_cpsat(list(buffers.values()), exclusions)
    if allocations is None:
        return None

    # Extract addresses using safe API
    return {
        "ifm": (ifm_L1_size, allocations.get_addr("IFM"), allocations.get_addr("IFM_pong")),
        "wgt": (wgt_L1_size, allocations.get_addr("WGT"), allocations.get_addr("WGT_pong")),
        "ofm": (ofm_L1_size, allocations.get_addr("OFM"), None),
        "tdm": (tdm_L1_size, allocations.get_addr("TDM"), None),
        "wght_t_sb": (wght_transpose_sb_L1_size, allocations.get_addr("WGHT_T_SB"), None),
        "vec": (vec_L1_size, allocations.get_addr("VEC"), None),
        "qdq": (qdq_L1_size, allocations.get_addr("QDQ"), None),
    }


def generate_conv_mappings(shape: ConvShape, enable_over_compute: bool) -> list[ConvMapping]:
    '''Generate all possible ways to map a conv shape onto the array'''

    # NOTE: On AIE-4, the output subvolume is computed directly in the
    # accumulator registers. This means there is a very small number of
    # possible output subvolume sizes, directly determined by the
    # size of the accumulator register file. We list the possible
    # cases below in (Yos, Xos, Cos) format.
    # NOTE: Here we define (Ci_gran, Co_gran) for each data type
    # NOTE: Here we define bits per element for each data type
    if shape.linear_op_type == LinearOpType.conv_A8W8_noqdq:
        ofm_subvs = [(1, 64, 64), (2, 32, 64), (4, 16, 64), (8, 8, 64)]
        kernel_gran = (64, 64)
        if shape.ifm[2] < 64:
            kernel_gran = (8, 64)
        ifm_bits = 8
        wgt_bits = 8
        bias_bits = 16
        ofm_bits = 8
    elif shape.linear_op_type == LinearOpType.conv_A16W8_qdq:
        # NOTE: Looking at the kernel grans, these should be valid subvolumes
        # ofm_subvs = [(1, 32, 64), (2, 16, 64), (4, 8, 64), (8, 4, 64), (16, 2, 64), (32, 1, 64)]
        # https://gitenterprise.xilinx.com/IPSP/AIE_SOL/blob/main/AIE4/kernel_lib/kernels/activated_conv_qdq_int16x8/activated_conv_qdq_int16x8.json#L107
        # But Looking at the minimums above, Xos can only be 8, 16, or 32
        # Hence we limit the ofm_subvs to below until further clarification from kernel spec
        # Also revisit kernel params generation upon clarification
        ofm_subvs = [(1, 32, 64), (2, 16, 64), (4, 8, 64)]
        kernel_gran = (64, 64)
        if shape.ifm[2] < 64:
            kernel_gran = (8, 64)
        ifm_bits = 16
        wgt_bits = 8
        bias_bits = 32
        ofm_bits = 16
    else:
        raise ValueError("Unsupported CONV type for conv mapping generation in tiler")

    # NOTE: This is the memory alignment required for vector loads and stores
    memory_align = 128

    _, _, Ci = shape.ifm
    Ky, Kx = shape.kernel
    Sy, Sx = shape.stride

    def allocate_L1_buffers(
        ifm_subv: tuple[int, int, int],
        ofm_subv: tuple[int, int, int],
    ) -> Optional[dict]:
        '''Allocate buffers in L1 if possible'''

        # NOTE: Below, we go through each allocation strategy in priority order.
        # If buffers fit, we attempt double buffering with bank splitting.
        # Subsequent cases remove bank splitting and use single buffering
        # to save space.

        # Compute buffer sizes
        Yis, Xis, Cis = ifm_subv
        Yos, Xos, Cos = ofm_subv
        ifm_L1_size = iceil((Yis * Xis * Cis * ifm_bits) // 8, memory_align)
        ofm_L1_size = (Yos * Xos * Cos * ofm_bits) // 8
        filter_buffer_size = iceil((Cos * Cis * Ky * Kx * wgt_bits) // 8, memory_align)
        if Ci < 64:
            filter_buffer_size = iceil((Cos * Ky * 64 * wgt_bits) // 8, memory_align)
        bias_buffer_size = 0
        wgt_L1_size = 0
        qdq_L1_size = 0
        tdm_L1_size = 0
        vec_L1_size = 0
        wght_transpose_sb_L1_size = 0
        if shape.linear_op_type == LinearOpType.conv_A8W8_noqdq:
            # NOTE: refer the kernel spec json for the scratch buffer requirements
            bias_buffer_size = iceil((Cos * bias_bits) // 8, memory_align)
            # NOTE: The 128 bytes of QDQ params are part of the weight subvolume
            qdq_param_size = 128
            wgt_L1_size = filter_buffer_size + bias_buffer_size + qdq_param_size
            params = conv_int8x8_params(
                subvolume=ConvSubVolumeParams(
                    H=Yos, W=Xos, Ci=Cis, Co=Cos, Kh=Ky, Kw=Kx, Sh=Sy, Sw=Sx
                ),
                time_split=ConvTimeSplitParams(H=Yos, W=Xos, Ci=Cis, Co=Cos, Kh=Ky),
                hardened_loop=1,
                has_relu6=1,
                has_lrelu=1,
                has_bias=1,
                H_outer=0,
                do_bias=1,
                activation="ReLU",
                lrelu_alpha=0
            )
        elif shape.linear_op_type == LinearOpType.conv_A16W8_qdq:
            # refer the kernel spec json for the scratch buffer requirements
            no_vec_coeff = 3
            bias_buffer_size = no_vec_coeff * Cos * bias_bits // 8
            qdq_param_size = 128
            wgt_L1_size = filter_buffer_size + bias_buffer_size + qdq_param_size
            max_accus_v = 16 + 7
            spill_buff_size = max_accus_v * 2 * 2 * 64 * 4
            tdm_L1_size = spill_buff_size
            coeff_tmp_buffer_size = 64 * 4 * 4
            qdq_L1_size = coeff_tmp_buffer_size
            params = conv_qdq_int16x8_params(
                subvolume=ConvSubVolumeParams(
                    H=Yos, W=Xos, Ci=Cis, Co=Cos, Kh=Ky, Kw=Kx, Sh=Sy, Sw=Sx
                ),
                time_split=ConvTimeSplitParams(H=Yos, W=Xos, Ci=Cis, Co=Cos, Kh=Ky),
                has_actv_sum=1,
                vector_coeff=2,
                hardened_loop=-1,
                H_outer=0,
                sync_type=SyncTypeParams(
                    I0=SyncType.ASYNC,
                    I1=SyncType.ASYNC,
                    I2=SyncType.ASYNC,
                    O0=SyncType.ASYNC,
                )
            )
        else:
            raise ValueError("Unsupported CONV type for conv mapping generation in tiler")

        # validate
        try:
            get_loader(shape.linear_op_type).validate(params)
        except Exception as e:  # pylint: disable=broad-except
            print(f"Metadata validation failed: {e}. Params: {params} is invalid.")
            return None

        return allocate(ifm_L1_size, wgt_L1_size, ofm_L1_size, tdm_L1_size, wght_transpose_sb_L1_size, vec_L1_size, qdq_L1_size)

    def conv_is_split_valid(
        ofm_shape: tuple[int, int, int],
        ofm_subv: tuple[int, int, int],
        split: tuple[int, int, int, int],
    ) -> bool:
        return is_split_valid(ofm_shape, ofm_subv, split, enable_over_compute)

    def conv_get_input_subv(
        shape: ConvShape,
        ofm_subv: tuple[int, int, int],
        kernel_gran: tuple[int, int]
    ) -> list[tuple[int, int, int]]:
        Yos, Xos, Cos = ofm_subv
        Ky, Kx = shape.kernel
        Sy, Sx = shape.stride
        Ci_gran, _ = kernel_gran
        Ci = shape.ifm[2]
        Yis = conv_input(Yos, Ky, Sy)
        # Generate all valid Cis values
        valid_input_subvs = []
        for Cis in range(Ci_gran, iceil(Ci, Ci_gran) + 1, Ci_gran):
            Xis = align_Xis(Xos, Cis, Kx, Sx)
            if Cis not in (8, 16, 32, 64) or Cos not in (8, 16, 32, 64):
                continue
            valid_input_subvs.append((Yis, Xis, Cis))
        return valid_input_subvs

    base_mappings = generate_mappings(
        shape,
        kernel_gran,
        ofm_subvs,
        conv_is_split_valid,
        conv_get_input_subv,
        allocate_L1_buffers,
    )
    conv_mappings = create_conv_mappings_from_base(base_mappings,
                                                   ifm_L2_strategy='pin',
                                                   wgt_L2_strategy='stream',
                                                   ofm_L2_strategy='stream')
    return sorted_conv_mappings(shape, conv_mappings, enable_over_compute)
