"""
Map gemm shapes to the AIE-4 dataflow architecture.
External facing functions are documented below.

    generate_mappings - enumerate all possible ways to map a gemm shape
    onto the compute array and sort them in order of descending projected
    latency (fastest mappings first)
"""

from typing import Optional
from functools import lru_cache

from utils.utils_common import (
    iceil,
    ceildiv,
)
from scheduler.common import (
    LinearOpType,
)

from scheduler.conv.conv_config_builders import (
    ConvShape,
    ConvMapping
)

# Import buffer allocator
from tiler.tiler_common import (
    generate_mappings,
    sorted_mappings,
    create_conv_mappings_from_base,
)
from tiler.conv_tiler import allocate


from tiler.load_kernel_metadata import KernelMetadataForOp, OpType
from tiler.kernel_metadata.build_kernel_metadata_vars import (
    mmult_qdq_int16x8_params,
    mmult_qdq_int16x16_params,
    SubVolumeParams,
    TimeSplitParams,
)


def sorted_gemm_mappings(
    shape: ConvShape, mappings: list[ConvMapping], enable_over_compute: bool
) -> list[ConvMapping]:
    """Sort gemm mappings in order of descending projected latency"""
    return sorted_mappings(shape, mappings, enable_over_compute)


def kernel_gran_params(shape: ConvShape, mode: str):
    """Return a dict of mode-specific constants."""
    if mode == 'wgt':
        if shape.linear_op_type == LinearOpType.conv_A8W8_noqdq:
            ofm_subvs = [(1, 64, 64), (2, 32, 64), (4, 16, 64), (8, 8, 64)]
            kernel_gran = (64, 64) if shape.ifm[2] >= 64 else (8, 64)
            bits = {"ifm": 8, "wgt": 8, "ofm": 8, "bias": 16}
        elif shape.linear_op_type == LinearOpType.gemm_A16W8_qdq:
            ofm_subvs = [(1, 32, 64)]
            kernel_gran = (64, 64)
            bits = {"ifm": 16, "wgt": 8, "ofm": 16, "bias": 32}
        elif shape.linear_op_type == LinearOpType.gemm_A16W4_qdq:
            ofm_subvs = [(1, 32, 64)]
            kernel_gran = (64, 64)
            bits = {"ifm": 16, "wgt": 4, "ofm": 16, "bias": 32}
        else:
            raise ValueError("Unsupported dtype for wgt mode")
        return {
            "ofm_subvs": ofm_subvs,
            "kernel_gran": kernel_gran,
            "Ci_gran": kernel_gran[0],
            "bits": bits,
            "extra_act_buffers": False,
        }
    if mode == 'act':
        if shape.linear_op_type in [LinearOpType.gemm_A16A16_v2, LinearOpType.gemm_A16A16_v1]:
            ofm_subvs = [(1, 16, 64)]
            kernel_gran = (64, 64)
            bits = {"ifm": 16, "wgt": 16, "ofm": 16, "bias": 16}
        else:
            raise ValueError("Unsupported dtype for act mode")
        return {
            "ofm_subvs": ofm_subvs,
            "kernel_gran": kernel_gran,
            "Ci_gran": 64,
            "bits": bits,
            "extra_act_buffers": True,
        }
    raise ValueError(f"Invalid mode {mode}")


@lru_cache(maxsize=2)
def get_loader(is_act_mode: bool) -> KernelMetadataForOp:
    """Import dynamically to avoid circular dependency"""
    if is_act_mode:
        loader = KernelMetadataForOp(OpType.ACTIVATED_MMULT_QDQ_INT16X16)
    else:
        loader = KernelMetadataForOp(OpType.ACTIVATED_MMULT_QDQ_INT16X8)
    return loader


def allocate_L1_buffer(
    shape: ConvShape,
    ifm_subv: tuple[int, int, int],
    ofm_subv: tuple[int, int, int],
    bits: dict,
    is_act_mode: bool,
) -> Optional[dict]:
    """Original act allocation logic (unchanged)."""
    memory_align = 128
    Yis, Xis, Cis = ifm_subv
    Yos, Xos, Cos = ofm_subv
    print(f"function allocate_L1_buffer recieved ifm_subv={ifm_subv}, ofm_subv={ofm_subv}")

    # FIX: Temporary constraint to limit Cis (Remove once tiler is more optimized)
    if Cis > 256:
        print(f"Skipping allocation: Cis {Cis} exceeds 256")
        return None

    Ky, Kx = shape.kernel
    Ci = shape.ifm[2]
    ifm_bits, wgt_bits, ofm_bits = bits['ifm'], bits['wgt'], bits['ofm']
    bias_bits = bits['bias'] if 'bias' in bits else 0
    ifm_L1_size = iceil((Yis * Xis * Cis * ifm_bits) // 8, memory_align)
    ofm_L1_size = (Yos * Xos * Cos * ofm_bits) // 8
    filter_buffer_size = iceil((Cos * Cis * Ky * Kx * wgt_bits) // 8, memory_align)
    if Ci < 64:
        filter_buffer_size = iceil((Cos * Ky * 64 * wgt_bits) // 8, memory_align)
    qdq_param_size = 0
    wgt_L1_size = 0
    tdm_L1_size = 0
    vec_L1_size = 0
    wght_transpose_sb_L1_size = 0
    if is_act_mode:
        # NOTE: THIS is act x act GEMM case
        qdq_param_size = 128
        wgt_L1_size = filter_buffer_size
        tdm_L1_size = 512
        vec_L1_size = 2048
        if shape.linear_op_type == LinearOpType.gemm_A16A16_v2:
            wght_transpose_sb_L1_size = 256
        # Create metadata loader for act x act
        params = mmult_qdq_int16x16_params(
            subvolume=SubVolumeParams(H=Yos, W=Xos, Ci=Cis, Co=Cos),
            time_split=TimeSplitParams(H=Yos, W=Xos, Ci=Cis, Co=Cos),
            has_actv_sum=1,
            has_vector_coeffs=2,
            vector_coeffs=2
        )
    else:
        # NOTE: This is act x wgt GEMM case
        bias_buffer_size = iceil((Cos * bias_bits) // 8, memory_align)
        wgt_L1_size = filter_buffer_size + bias_buffer_size*3 + 128
        core_spill_buff_size = 0
        core_ifm_tmp_buffer_size = 0
        core_coeff_tmp_buffer_size = 0
        if shape.linear_op_type in [LinearOpType.gemm_A16W8_qdq, LinearOpType.gemm_A16W4_qdq]:
            core_spill_buff_size = 1536
            core_ifm_tmp_buffer_size = ifm_L1_size
            core_coeff_tmp_buffer_size = Cos * 4 * 4 * 2
        tdm_L1_size = core_spill_buff_size
        vec_L1_size = core_ifm_tmp_buffer_size
        qdq_param_size = core_coeff_tmp_buffer_size

        # Create metadata loader for act x wgt GEMM
        params = mmult_qdq_int16x8_params(
            subvolume=SubVolumeParams(H=Yos, W=Xos, Ci=Cis, Co=Cos),
            time_split=TimeSplitParams(H=Yos, W=Xos, Ci=Cis, Co=Cos),
            has_actv_sum=1,
            vector_coeff=2,
        )

    try:
        get_loader(is_act_mode).validate(params)
    except Exception as e:  # pylint: disable=broad-except
        print(f"Metadata validation failed: {e}. Params: {params} is invalid.")
        return None

    return allocate(ifm_L1_size, wgt_L1_size, ofm_L1_size, tdm_L1_size, wght_transpose_sb_L1_size, vec_L1_size, qdq_param_size)


def generate_gemm_mappings(shape: ConvShape, mode: str, enable_over_compute: bool = True) -> list[ConvMapping]:
    """Shared mapping loop for wgt/act modes."""
    p = kernel_gran_params(shape, mode)
    ofm_subvs = p['ofm_subvs']
    kernel_gran = p['kernel_gran']
    # Ci_gran = p['Ci_gran']
    bits = p['bits']
    extra_act = p['extra_act_buffers']

    # _, _, Ci = shape.ifm
    # Ky, Kx = shape.kernel
    # Sy, Sx = shape.stride

    def gemm_is_split_valid(
        ofm_shape: tuple[int, int, int],
        ofm_subv: tuple[int, int, int],
        split: tuple[int, int, int, int],
    ) -> bool:
        Yo, _, Co = ofm_shape
        _, _, Cos = ofm_subv
        _, _, _, Co_split = split
        Co_loop = ceildiv(Co, (Cos * Co_split))
        # Filter out any split that results in prime Co_loop count
        # Because the Co_loop places a role in the ifm reuse in memtile
        # if the Co_loop is prime, that means a reuse chain cannot be established
        is_yo_split_valid = (Yo == 1) or (Yo > 1 and split[1] in [1, 3])
        # Check if max_consumers * Co_loop < 64, this way no reuse_chain is needed
        num_consumers = 2 if Co_split in [12, 6, 4] else 4
        max_lock_value = 63
        max_chain_length = 8
        reuse_ratio = Co_loop
        is_valid = False
        for i in range(1, max_chain_length + 1):
            is_valid = is_valid or (
                ((reuse_ratio % i) == 0) and
                (((reuse_ratio // i) * num_consumers) <= max_lock_value)
            )
        return is_valid and is_yo_split_valid

    def gemm_get_input_subv(
        shape: ConvShape,
        ofm_subv: tuple[int, int, int],
        kernel_granularities: tuple[int, int]
    ) -> list[tuple[int, int, int]]:
        _, _, Ci = shape.ifm
        Ci_gran, _ = kernel_granularities
        Yos, Xos, _ = ofm_subv
        Yis = Yos
        Xis = Xos
        valid_input_subvs = []
        for Cis in range(Ci_gran, iceil(Ci, Ci_gran) + 1, Ci_gran):
            if extra_act and Cis > 128:
                continue
            if extra_act:
                Cis = 128
            else:
                # NOTE: A16W8 and A16W4 gemm kernels can only support Cis > 256
                # For this reason, we overide the Cis to atleast 256
                # The padding from Ci -> Cis should be handled on the scheduler
                Cis = max(Cis, 256)
            valid_input_subvs.append((Yis, Xis, Cis))
        return valid_input_subvs

    def gemm_allocate_L1_buffers(ifm_subv, ofm_subv):
        print(f"Trying ifm_subv={ifm_subv}, ofm_subv={ofm_subv}")
        return allocate_L1_buffer(shape, ifm_subv, ofm_subv, bits, extra_act)

    # Generate base mappings using common interface
    base_mappings = generate_mappings(
        shape,
        kernel_gran,
        ofm_subvs,
        gemm_is_split_valid,
        gemm_get_input_subv,
        gemm_allocate_L1_buffers,
    )

    # Convert BaseMapping to ConvMapping
    gemm_mappings = create_conv_mappings_from_base(
        base_mappings,
        ifm_L2_strategy='pin',
        wgt_L2_strategy='stream',
        ofm_L2_strategy='stream'
    )

    return sorted_gemm_mappings(shape, gemm_mappings, enable_over_compute)
