import struct
import math

from kernel.common.kernel_params_helper import (
    DimsHelper,
)

from utils.utils_common import (
    log,
)

def derive_hardened_loop(Cis: int, folded_Kx: int, Ky: int) -> int:
    """
    Derive hardened_loop based on the constraint:
    Cis * folded_Kx * Ky >= 64 * min(4, hardened_loop & 7)

    Args:
        Cis: Channel input size
        folded_Kx: Folded kernel width
        Ky: Kernel height

    Returns:
        Appropriate hardened_loop value
    """
    threshold = (Cis * folded_Kx * Ky) // 64
    hardened_loop = 0
    if threshold >= 4:
        hardened_loop = 0
    elif threshold >= 3:
        hardened_loop = 3  # hardened_loop & 7 = 3
    elif threshold >= 2:
        hardened_loop = 2  # hardened_loop & 7 = 2
    elif threshold >= 1:
        hardened_loop = 1  # hardened_loop & 7 = 1
    else:
        hardened_loop = 0  # threshold < 1

    return hardened_loop

def generate_conv_noqdq_a8w8_params(
    Ci: int,
    Yos: int, Xos: int, Cos: int,
    Yis: int, Xis: int, Cis: int,
    Ky: int, Kx: int,
    Sy: int, Sx: int,
    Y_loop: int, X_loop: int, Co_loop: int, Ci_loop: int,
    mode: int,
    full_iters: bool = True,
) -> bytes:
    '''
    NOTE: All the layer params calculations are with respect to the output Xos and Yos dimensions
    TODO: The kernel supports a Cos < 64 but it has to be folded into Xos
    '''
    assert Cos == 64, "Cos must be 64"
    wgt_bits = 8
    bias_bits = 16
    folded_Xos = 64 // Cos
    folded_Kx = Kx + (folded_Xos - 1) * Sx
    Xis_aligned = Xis
    Cis_ilb = min(64, Cis)
    Xos_ilb = min(64, Xos // folded_Xos)
    kernel_dims = DimsHelper(-64)
    assert Xos / folded_Xos in [8, 16, 32, 64]
    assert Yos == (4096 / (Xos * Cos))
    step_Ci = Xis_aligned * Yis * Cis_ilb
    step_Ky = Xis_aligned * (Cis_ilb)
    step_Xi = Cis_ilb * Sx * folded_Xos
    step_Yi = Sy * step_Ky
    incr_Xi = kernel_dims.from_steps(1, Cis_ilb * Sx * folded_Xos)
    step_align = 0 if incr_Xi == 0 else int(math.log2(incr_Xi ^ (incr_Xi - 1))) - 3
    norm_ch_g = 1
    dims_YXi = kernel_dims.from_steps((Xos_ilb, 64 // Xos_ilb), (step_Xi, step_Yi))
    dims_KCi = kernel_dims.from_steps((math.ceil(folded_Kx * Cis_ilb / 64), Ky),
                                      (64, step_Ky, step_Ci))
    outer_time_iters = Co_loop * Y_loop * X_loop if full_iters else Co_loop
    inner_time_iters = Ci_loop
    inner_loop = Ky * int(math.ceil(folded_Kx * Cis / 64))
    raw_wgt_subv_size = Cos * Ky * Kx * Cis * wgt_bits // 8
    if Ci < 64:
        raw_wgt_subv_size = Cos * Ky * 64 * wgt_bits // 8
    raw_bias_size = Cos * bias_bits // 8
    harndened_loop = derive_hardened_loop(Cis, folded_Kx, Ky)
    packed_params = struct.pack(
        '<4I3H2B1I2i2I3i',
        harndened_loop,  # I
        mode,                           # I
        raw_wgt_subv_size,            # I
        raw_bias_size,           # I
        outer_time_iters,       # H
        inner_time_iters,       # H
        inner_loop,      # H
        step_align,      # B
        norm_ch_g,      # B
        dims_YXi['num0'],       # I
        dims_YXi['inc0'],      # i
        dims_YXi['inc1'],      # i
        dims_KCi['num0'],       # I
        dims_KCi['num1'],       # I
        dims_KCi['inc0'],      # i
        dims_KCi['inc1'],      # i
        dims_KCi['inc2'],      # i
    )
    return packed_params
