import struct 
import math

from kernel.common.kernel_params_helper import (
    DimsHelper,
)

from utils.utils_common import (
    log,
    iceil
)


def setup_conv_qdq_a16w8_params(
    Yos: int, Xos: int, Cos: int,
    Yis: int, Xis: int, Cis: int,
    Ky: int, Kx: int,
    Sy: int, Sx: int,
    Y_loop: int, X_loop: int, Co_loop: int, Ci_loop: int,
    ifm_ch_num: int, 
    wgt_size: int,
    bias_size: int,
    core_spill_buf: int,
    core_coeff_tmp_buffer: int,
    full_iters: bool,
) -> bytes:
    granWH = 32
    granCi = 64
    granCo = 64
    folded_Xos = granCo // Cos
    folded_Kx = Kx + (folded_Xos - 1) * Sx
    Xis_aligned = iceil(((Xos - 1) * Sx + Kx), max(1, granCi // Cis))
    Yis = (Yos - 1) * Sy + Ky
    Ci_ilb = min(granCi, Cis)
    Ci_ilb_WH = Ci_ilb / granWH
    step_Ci_var = 2 * Xis_aligned * (Yis) * Ci_ilb
    dims0 = DimsHelper(-2 * 64)
    Wo_ilb = min(granWH, Xos // folded_Xos)
    Ci_H_outer = Ci_ilb
    step_Kx = 2 * Ci_ilb
    step_Ky = 2 * Xis_aligned * Ci_H_outer
    step_Wi = Sx * step_Kx * folded_Xos
    step_Hi = Sy * step_Ky
    samplebytes = 4
    # NOTE: N here is not batch dimension.
    N = 8
    step_Kh_actv_sum = samplebytes * Xis_aligned
    dims_conv2d_inner = DimsHelper(-64)
    dims_conv2d_outer = DimsHelper(-Ky * step_Kh_actv_sum)
    incr_Xi = dims0.from_steps(1, step_Wi)
    incr_Xi_sum = dims0.from_steps(1, step_Kx)
    sum_bound_v = (Xis_aligned * Yis) // granWH
    sum_remainder = (Xis_aligned * Yis) % granWH
    n_accus_v = sum_bound_v + (sum_remainder != 0)
    # https://gitenterprise.xilinx.com/IPSP/AIE_SOL/blob/main/AIE4/kernel_lib/kernels/activated_conv_qdq_int16x8/activated_conv_qdq_int16x8.json#L74
    max_accus_v = 16
    dims_sum_actv = DimsHelper(-2 * 64)
    step_Yis_actv_sum = Sy * step_Kh_actv_sum
    
    assert (Xos // folded_Xos) in (8, 16, 32)
    assert (Yos == 2048 // (Xos * Cos))
    assert(Cis * folded_Kx * Ky >= 64*3)
    assert(Ky * Kx >= 4)
    assert(Xos >= 8)
    assert(max(1, Cis // granCi) * Ky * Kx >= 3)
    
    
    mask_Ci_low = ((1 << 2*Cis)-1) if Cis < granCi else 0xFFFFFFFFFFFFFFFF
    mask_Ci_high = 0xFFFFFFFFFFFFFFFF if Cis >= granCi else 0
    Co_blk = folded_Xos
    Co_shift = math.log(Cos // 16, 2)
    outer_time_iters = Co_loop * Y_loop * X_loop if full_iters else Co_loop
    inner_time_iters = Ci_loop
    inner_loop = Ky * math.ceil(folded_Kx * Cis / granCi)
    step_Ci = step_Ci_var
    dims_HWi = dims0.from_steps((Wo_ilb, (granWH // Wo_ilb)), (step_Wi, step_Hi))
    dims_KCi = dims0.from_steps((math.ceil(folded_Kx * Ci_ilb / granCi), Ky), (2*64, step_Ky, step_Ci_var))
    dims_conv2d_sum_inner = dims_conv2d_inner.from_steps((Kx),  (samplebytes , step_Kh_actv_sum))
    dims_conv2d_sum_outer = dims_conv2d_outer.from_steps((Xos / N),  (samplebytes * N * Sx, step_Yis_actv_sum))
    dims_sum_actv = dims_sum_actv.from_steps((Xis_aligned, Yis), (step_Kx, step_Ky, 0))
    align_step = 0 if incr_Xi == 0 else int(math.log(incr_Xi ^ (incr_Xi - 1), 2)) - 3
    align_step_sum = 0 if incr_Xi_sum == 0 else int(math.log(incr_Xi_sum ^ (incr_Xi_sum - 1))) - 3
    Sx_g = Sx
    Sy_g = Sy
    Kx_g = Kx
    Ky_g = Ky
    sum_outer = max(1, Cis // granCi)
    sum_bound = sum_bound_v
    n_accus = n_accus_v
    max_accus = max_accus_v
    cf_AxB = 1
    cf_Asum = 1
    reserved = 0
    
    
    log(f"mask_Ci_low: {mask_Ci_low}")
    log(f"mask_Ci_high: {mask_Ci_high}")
    log(f"Co_blk: {Co_blk}")
    log(f"Co_shift: {Co_shift}")
    log(f"outer_time_iters: {outer_time_iters}")
    log(f"inner_time_iters: {inner_time_iters}")
    log(f"inner_loop: {inner_loop}")
    log(f"step_Ci: {step_Ci}")
    
    log(f"dims_HWi['num0']: {dims_HWi['num0']}")
    log(f"dims_HWi['inc0']: {dims_HWi['inc0']}")
    log(f"dims_HWi['inc1']: {dims_HWi['inc1']}")

    log(f"dims_KCi['num0']: {dims_KCi['num0']}")
    log(f"dims_KCi['num1']: {dims_KCi['num1']}")
    log(f"dims_KCi['inc0']: {dims_KCi['inc0']}")
    log(f"dims_KCi['inc1']: {dims_KCi['inc1']}")
    log(f"dims_KCi['inc2']: {dims_KCi['inc2']}")
    
    log(f"dims_conv2d_sum_inner['num0']: {dims_conv2d_sum_inner['num0']}")
    log(f"dims_conv2d_sum_inner['inc0']: {dims_conv2d_sum_inner['inc0']}")
    log(f"dims_conv2d_sum_inner['inc1']: {dims_conv2d_sum_inner['inc1']}")
    
    log(f"dims_conv2d_sum_outer['num0']: {dims_conv2d_sum_outer['num0']}")
    log(f"dims_conv2d_sum_outer['inc0']: {dims_conv2d_sum_outer['inc0']}")
    log(f"dims_conv2d_sum_outer['inc1']: {dims_conv2d_sum_outer['inc1']}")

    log(f"dims_sum_actv['num0']: {dims_sum_actv['num0']}")
    log(f"dims_sum_actv['num1']: {dims_sum_actv['num1']}")
    log(f"dims_sum_actv['inc0']: {dims_sum_actv['inc0']}")
    log(f"dims_sum_actv['inc1']: {dims_sum_actv['inc1']}")
    log(f"dims_sum_actv['inc2']: {dims_sum_actv['inc2']}")
    
    log(f"align_step: {align_step}")
    log(f"align_step_sum: {align_step_sum}")
    
    log(f"Sx_g: {Sx_g}")
    log(f"Sy_g: {Sy_g}")
    log(f"Kx_g: {Kx_g}")
    log(f"Ky_g: {Ky_g}")
    
    log(f"sum_outer: {sum_outer}")
    log(f"sum_bound: {sum_bound}")
    log(f"n_accus: {n_accus}")
    log(f"max_accus: {max_accus}")
    log(f"cf_AxB: {cf_AxB}")
    log(f"cf_Asum: {cf_Asum}")
    
    packed_params = struct.pack(
        '<5I2Q4b4HI2i2I3iI2iI2i2I3i4B8b2f',
        ifm_ch_num, # I
        wgt_size, # I
        bias_size, # I
        core_spill_buf,# I
        core_coeff_tmp_buffer, # I
        
        mask_Ci_low, # Q 
        mask_Ci_high, # Q
        Co_blk, # b
        int(Co_shift), # b
        reserved, # b
        reserved, # b
        outer_time_iters, # H
        inner_time_iters, # H
        inner_loop, # H
        step_Ci, # H
        
        int(dims_HWi['num0']), # I
        int(dims_HWi['inc0']), # i
        int(dims_HWi['inc1']), # i
        
        int(dims_KCi['num0']), # I
        int(dims_KCi['num1']), # I
        int(dims_KCi['inc0']), # i
        int(dims_KCi['inc1']), # i 
        int(dims_KCi['inc2']), # i
        
        int(dims_conv2d_sum_inner['num0']), # I
        int(dims_conv2d_sum_inner['inc0']), # i
        int(dims_conv2d_sum_inner['inc1']), # i
        
        int(dims_conv2d_sum_outer['num0']), # I
        int(dims_conv2d_sum_outer['inc0']), # i
        int(dims_conv2d_sum_outer['inc1']), # i
        
        int(dims_sum_actv['num0']), # I
        int(dims_sum_actv['num1']), # I
        int(dims_sum_actv['inc0']), # i
        int(dims_sum_actv['inc1']), # i 
        int(dims_sum_actv['inc2']), # i 
        
        align_step, # B
        align_step_sum, # B
        reserved, # B
        reserved, # B
        Sx_g, # b
        Sy_g, # b 
        Kx_g, # b
        Ky_g, # b
        sum_outer, # b
        sum_bound, # b
        n_accus, # b
        max_accus, # b
        cf_AxB, # f
        cf_Asum # f
    )
    
    return packed_params


def generate_conv_qdq_a16w8_params(
    Yos: int, Xos: int, Cos: int,
    Yis: int, Xis: int, Cis: int,
    Ky: int, Kx: int,
    Sy: int, Sx: int,
    Y_loop: int, X_loop: int, Co_loop: int, Ci_loop: int,
    ifm_ch_num:int,
    core_spill_buf: int,
    core_coeff_tmp_buffer: int,
    full_iters: bool,
) -> bytes:
    wgt_bits = 8
    bias_bits = 32
    raw_wgt_subv_size = Cos * Ky * Kx * Cis * wgt_bits // 8
    bias_size = 3 * (Cos * bias_bits // 8)
    kernel_params = setup_conv_qdq_a16w8_params(
        Yos, Xos, Cos,
        Yis, Xis, Cis,
        Ky, Kx,
        Sy, Sx,
        Y_loop, X_loop, Co_loop, Ci_loop,
        ifm_ch_num, 
        raw_wgt_subv_size,
        bias_size,
        core_spill_buf,
        core_coeff_tmp_buffer,
        full_iters,
    )
    return kernel_params
